mockforge_sdk/
server.rs

1//! Mock server implementation
2
3use crate::builder::MockServerBuilder;
4use crate::stub::ResponseStub;
5use crate::{Error, Result};
6use axum::Router;
7use mockforge_core::config::{RouteConfig, RouteResponseConfig};
8use mockforge_core::{Config, ServerConfig};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use tokio::task::JoinHandle;
13
14/// A mock server that can be embedded in tests
15#[derive(Debug)]
16pub struct MockServer {
17    port: u16,
18    address: SocketAddr,
19    config: ServerConfig,
20    server_handle: Option<JoinHandle<()>>,
21    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
22    routes: Vec<RouteConfig>,
23}
24
25impl MockServer {
26    /// Create a new mock server builder
27    pub fn new() -> MockServerBuilder {
28        MockServerBuilder::new()
29    }
30
31    /// Create a mock server from configuration
32    pub(crate) async fn from_config(
33        server_config: ServerConfig,
34        _core_config: Config,
35    ) -> Result<Self> {
36        let port = server_config.http.port;
37        let host = server_config.http.host.clone();
38
39        let address: SocketAddr = format!("{}:{}", host, port)
40            .parse()
41            .map_err(|e| Error::InvalidConfig(format!("Invalid address: {}", e)))?;
42
43        Ok(Self {
44            port,
45            address,
46            config: server_config,
47            server_handle: None,
48            shutdown_tx: None,
49            routes: Vec::new(),
50        })
51    }
52
53    /// Start the mock server
54    pub async fn start(&mut self) -> Result<()> {
55        if self.server_handle.is_some() {
56            return Err(Error::ServerAlreadyStarted(self.port));
57        }
58
59        // Build the router from routes
60        let router = self.build_simple_router();
61
62        // Create shutdown channel
63        let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
64        self.shutdown_tx = Some(shutdown_tx);
65
66        let address = self.address;
67
68        // Spawn the server
69        let server_handle = tokio::spawn(async move {
70            let listener = match tokio::net::TcpListener::bind(address).await {
71                Ok(l) => l,
72                Err(e) => {
73                    tracing::error!("Failed to bind to {}: {}", address, e);
74                    return;
75                }
76            };
77
78            tracing::info!("MockForge SDK server listening on {}", address);
79
80            axum::serve(listener, router)
81                .with_graceful_shutdown(async move {
82                    let _ = shutdown_rx.await;
83                })
84                .await
85                .expect("Server error");
86        });
87
88        self.server_handle = Some(server_handle);
89
90        // Wait for the server to be ready by polling health
91        self.wait_for_ready().await?;
92
93        Ok(())
94    }
95
96    /// Wait for the server to be ready
97    async fn wait_for_ready(&self) -> Result<()> {
98        let max_attempts = 50;
99        let delay = tokio::time::Duration::from_millis(100);
100
101        for attempt in 0..max_attempts {
102            // Try to connect to the server
103            let client = reqwest::Client::builder()
104                .timeout(tokio::time::Duration::from_millis(100))
105                .build()
106                .map_err(|e| Error::General(format!("Failed to create HTTP client: {}", e)))?;
107
108            match client.get(format!("{}/health", self.url())).send().await {
109                Ok(response) if response.status().is_success() => return Ok(()),
110                _ => {
111                    if attempt < max_attempts - 1 {
112                        tokio::time::sleep(delay).await;
113                    }
114                }
115            }
116        }
117
118        Err(Error::General(format!(
119            "Server failed to become ready within {}ms",
120            max_attempts * delay.as_millis() as u32
121        )))
122    }
123
124    /// Build a simple router from stored routes
125    fn build_simple_router(&self) -> Router {
126        use axum::http::StatusCode;
127        use axum::routing::{delete, get, post, put};
128        use axum::{response::IntoResponse, Json};
129
130        let mut router = Router::new();
131
132        for route_config in &self.routes {
133            let status = route_config.response.status;
134            let body = route_config.response.body.clone();
135            let headers = route_config.response.headers.clone();
136
137            let handler = move || {
138                let body = body.clone();
139                let headers = headers.clone();
140                async move {
141                    let mut response = Json(body).into_response();
142                    *response.status_mut() = StatusCode::from_u16(status).unwrap();
143
144                    for (key, value) in headers {
145                        if let Ok(header_name) = axum::http::HeaderName::from_bytes(key.as_bytes())
146                        {
147                            if let Ok(header_value) = axum::http::HeaderValue::from_str(&value) {
148                                response.headers_mut().insert(header_name, header_value);
149                            }
150                        }
151                    }
152
153                    response
154                }
155            };
156
157            let path = &route_config.path;
158
159            router = match route_config.method.to_uppercase().as_str() {
160                "GET" => router.route(path, get(handler)),
161                "POST" => router.route(path, post(handler)),
162                "PUT" => router.route(path, put(handler)),
163                "DELETE" => router.route(path, delete(handler)),
164                _ => router,
165            };
166        }
167
168        router
169    }
170
171    /// Stop the mock server
172    pub async fn stop(mut self) -> Result<()> {
173        if let Some(shutdown_tx) = self.shutdown_tx.take() {
174            let _ = shutdown_tx.send(());
175        }
176
177        if let Some(handle) = self.server_handle.take() {
178            let _ = handle.await;
179        }
180
181        Ok(())
182    }
183
184    /// Stub a response for a given method and path
185    pub async fn stub_response(
186        &mut self,
187        method: impl Into<String>,
188        path: impl Into<String>,
189        body: Value,
190    ) -> Result<()> {
191        let stub = ResponseStub::new(method, path, body);
192        self.add_stub(stub).await
193    }
194
195    /// Add a response stub
196    pub async fn add_stub(&mut self, stub: ResponseStub) -> Result<()> {
197        let route_config = RouteConfig {
198            path: stub.path.clone(),
199            method: stub.method,
200            request: None,
201            response: RouteResponseConfig {
202                status: stub.status,
203                headers: stub.headers,
204                body: Some(stub.body),
205            },
206        };
207
208        self.routes.push(route_config);
209
210        Ok(())
211    }
212
213    /// Remove all stubs
214    pub async fn clear_stubs(&mut self) -> Result<()> {
215        self.routes.clear();
216        Ok(())
217    }
218
219    /// Get the server port
220    pub fn port(&self) -> u16 {
221        self.port
222    }
223
224    /// Get the server base URL
225    pub fn url(&self) -> String {
226        format!("http://{}", self.address)
227    }
228
229    /// Check if the server is running
230    pub fn is_running(&self) -> bool {
231        self.server_handle.is_some()
232    }
233}
234
235impl Default for MockServer {
236    fn default() -> Self {
237        Self {
238            port: 0,
239            address: "127.0.0.1:0".parse().unwrap(),
240            config: ServerConfig::default(),
241            server_handle: None,
242            shutdown_tx: None,
243            routes: Vec::new(),
244        }
245    }
246}
247
248// Implement Drop to ensure server is stopped
249impl Drop for MockServer {
250    fn drop(&mut self) {
251        if let Some(shutdown_tx) = self.shutdown_tx.take() {
252            let _ = shutdown_tx.send(());
253        }
254    }
255}