Skip to main content

mock_igd/
mock.rs

1//! Mock registration and management.
2
3use crate::action::Action;
4use crate::matcher::{Matcher, SoapRequest};
5use crate::responder::{ResponseBody, Responder};
6use std::net::SocketAddr;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::RwLock;
11
12/// A received SOAP request with metadata.
13#[derive(Debug, Clone)]
14pub struct ReceivedRequest {
15    /// The action name (e.g., "GetExternalIPAddress", "AddPortMapping").
16    pub action_name: String,
17    /// The service type from the SOAPAction header.
18    pub service_type: String,
19    /// The parsed request body.
20    pub body: crate::matcher::SoapRequestBody,
21    /// When the request was received (relative to server start).
22    pub timestamp: std::time::Duration,
23}
24
25impl ReceivedRequest {
26    pub(crate) fn from_soap_request(request: &SoapRequest, start_time: Instant) -> Self {
27        ReceivedRequest {
28            action_name: request.action_name.clone(),
29            service_type: request.service_type.clone(),
30            body: request.body.clone(),
31            timestamp: start_time.elapsed(),
32        }
33    }
34}
35
36/// A received SSDP request (M-SEARCH) with metadata.
37#[derive(Debug, Clone)]
38pub struct ReceivedSsdpRequest {
39    /// The source address of the request.
40    pub source: SocketAddr,
41    /// The search target (ST header value).
42    pub search_target: String,
43    /// The MAN header value (e.g., "ssdp:discover").
44    pub man: String,
45    /// The MX header value (maximum wait time in seconds).
46    pub mx: Option<u32>,
47    /// The raw request string.
48    pub raw: String,
49    /// When the request was received (relative to server start).
50    pub timestamp: std::time::Duration,
51}
52
53/// A registered mock that matches requests and generates responses.
54pub(crate) struct Mock {
55    /// The action matcher.
56    action: Action,
57    /// The responder to use when matched.
58    responder: Responder,
59    /// Priority for matching (higher = checked first).
60    priority: u32,
61    /// Maximum number of times this mock can be matched (None = unlimited).
62    max_times: Option<u32>,
63    /// Number of times this mock has been matched.
64    match_count: AtomicU32,
65}
66
67impl Mock {
68    /// Create a new mock with the given action and responder.
69    pub fn new(action: impl Into<Action>, responder: impl Into<Responder>) -> Self {
70        Mock {
71            action: action.into(),
72            responder: responder.into(),
73            priority: 0,
74            max_times: None,
75            match_count: AtomicU32::new(0),
76        }
77    }
78
79    /// Set the priority of this mock (higher = checked first).
80    pub fn with_priority(mut self, priority: u32) -> Self {
81        self.priority = priority;
82        self
83    }
84
85    /// Limit the number of times this mock can be matched.
86    pub fn times(mut self, n: u32) -> Self {
87        self.max_times = Some(n);
88        self
89    }
90
91    /// Check if this mock matches the given request.
92    pub fn matches(&self, request: &SoapRequest) -> bool {
93        // Check if we've exceeded max_times
94        if let Some(max) = self.max_times {
95            if self.match_count.load(Ordering::SeqCst) >= max {
96                return false;
97            }
98        }
99        self.action.matches(request)
100    }
101
102    /// Generate a response for the given request and increment match count.
103    pub fn respond(&self, request: &SoapRequest) -> ResponseBody {
104        self.match_count.fetch_add(1, Ordering::SeqCst);
105        self.responder.respond(request)
106    }
107
108    /// Get the priority of this mock.
109    pub fn priority(&self) -> u32 {
110        self.priority
111    }
112}
113
114impl std::fmt::Debug for Mock {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("Mock")
117            .field("action", &self.action)
118            .field("responder", &self.responder)
119            .field("priority", &self.priority)
120            .field("max_times", &self.max_times)
121            .field("match_count", &self.match_count.load(Ordering::SeqCst))
122            .finish()
123    }
124}
125
126/// Registry of mocks for matching requests.
127pub(crate) struct MockRegistry {
128    mocks: RwLock<Vec<Arc<Mock>>>,
129    received_requests: RwLock<Vec<ReceivedRequest>>,
130    received_ssdp_requests: RwLock<Vec<ReceivedSsdpRequest>>,
131    start_time: Instant,
132}
133
134impl MockRegistry {
135    /// Create a new empty registry.
136    pub fn new() -> Self {
137        MockRegistry {
138            mocks: RwLock::new(Vec::new()),
139            received_requests: RwLock::new(Vec::new()),
140            received_ssdp_requests: RwLock::new(Vec::new()),
141            start_time: Instant::now(),
142        }
143    }
144
145    /// Register a new mock.
146    pub async fn register(&self, mock: Mock) {
147        let mut mocks = self.mocks.write().await;
148        mocks.push(Arc::new(mock));
149        // Sort by priority (highest first)
150        mocks.sort_by(|a, b| b.priority().cmp(&a.priority()));
151    }
152
153    /// Find a mock that matches the given request and generate a response.
154    /// Also records the request.
155    pub async fn find_response(&self, request: &SoapRequest) -> Option<ResponseBody> {
156        // Record the request
157        {
158            let received = ReceivedRequest::from_soap_request(request, self.start_time);
159            let mut requests = self.received_requests.write().await;
160            requests.push(received);
161        }
162
163        let mocks = self.mocks.read().await;
164        for mock in mocks.iter() {
165            if mock.matches(request) {
166                return Some(mock.respond(request));
167            }
168        }
169        None
170    }
171
172    /// Get all received requests.
173    pub async fn received_requests(&self) -> Vec<ReceivedRequest> {
174        let requests = self.received_requests.read().await;
175        requests.clone()
176    }
177
178    /// Clear all registered mocks.
179    pub async fn clear(&self) {
180        let mut mocks = self.mocks.write().await;
181        mocks.clear();
182    }
183
184    /// Clear all received requests.
185    pub async fn clear_received_requests(&self) {
186        let mut requests = self.received_requests.write().await;
187        requests.clear();
188    }
189
190    /// Record a received SSDP request.
191    pub async fn record_ssdp_request(&self, request: ReceivedSsdpRequest) {
192        let mut requests = self.received_ssdp_requests.write().await;
193        requests.push(request);
194    }
195
196    /// Get all received SSDP requests.
197    pub async fn received_ssdp_requests(&self) -> Vec<ReceivedSsdpRequest> {
198        let requests = self.received_ssdp_requests.read().await;
199        requests.clone()
200    }
201
202    /// Clear all received SSDP requests.
203    pub async fn clear_received_ssdp_requests(&self) {
204        let mut requests = self.received_ssdp_requests.write().await;
205        requests.clear();
206    }
207
208    /// Get the start time of the registry.
209    pub fn start_time(&self) -> Instant {
210        self.start_time
211    }
212}