1use crate::builder::MockServerBuilder;
4use crate::stub::ResponseStub;
5use crate::{Error, Result};
6use axum::Router;
7use mockforge_core::config::{RouteConfig, RouteResponseConfig};
8use mockforge_core::{Config, ServerConfig};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use tokio::task::JoinHandle;
13
14#[derive(Debug)]
16pub struct MockServer {
17 port: u16,
18 address: SocketAddr,
19 config: ServerConfig,
20 server_handle: Option<JoinHandle<()>>,
21 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
22 routes: Vec<RouteConfig>,
23}
24
25impl MockServer {
26 pub fn new() -> MockServerBuilder {
28 MockServerBuilder::new()
29 }
30
31 pub(crate) async fn from_config(
33 server_config: ServerConfig,
34 _core_config: Config,
35 ) -> Result<Self> {
36 let port = server_config.http.port;
37 let host = server_config.http.host.clone();
38
39 let address: SocketAddr = format!("{}:{}", host, port)
40 .parse()
41 .map_err(|e| Error::InvalidConfig(format!("Invalid address: {}", e)))?;
42
43 Ok(Self {
44 port,
45 address,
46 config: server_config,
47 server_handle: None,
48 shutdown_tx: None,
49 routes: Vec::new(),
50 })
51 }
52
53 pub async fn start(&mut self) -> Result<()> {
55 if self.server_handle.is_some() {
56 return Err(Error::ServerAlreadyStarted(self.port));
57 }
58
59 let router = self.build_simple_router();
61
62 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
64 self.shutdown_tx = Some(shutdown_tx);
65
66 let address = self.address;
67
68 let server_handle = tokio::spawn(async move {
70 let listener = match tokio::net::TcpListener::bind(address).await {
71 Ok(l) => l,
72 Err(e) => {
73 tracing::error!("Failed to bind to {}: {}", address, e);
74 return;
75 }
76 };
77
78 tracing::info!("MockForge SDK server listening on {}", address);
79
80 axum::serve(listener, router)
81 .with_graceful_shutdown(async move {
82 let _ = shutdown_rx.await;
83 })
84 .await
85 .expect("Server error");
86 });
87
88 self.server_handle = Some(server_handle);
89
90 self.wait_for_ready().await?;
92
93 Ok(())
94 }
95
96 async fn wait_for_ready(&self) -> Result<()> {
98 let max_attempts = 50;
99 let delay = tokio::time::Duration::from_millis(100);
100
101 for attempt in 0..max_attempts {
102 let client = reqwest::Client::builder()
104 .timeout(tokio::time::Duration::from_millis(100))
105 .build()
106 .map_err(|e| Error::General(format!("Failed to create HTTP client: {}", e)))?;
107
108 match client.get(format!("{}/health", self.url())).send().await {
109 Ok(response) if response.status().is_success() => return Ok(()),
110 _ => {
111 if attempt < max_attempts - 1 {
112 tokio::time::sleep(delay).await;
113 }
114 }
115 }
116 }
117
118 Err(Error::General(format!(
119 "Server failed to become ready within {}ms",
120 max_attempts * delay.as_millis() as u32
121 )))
122 }
123
124 fn build_simple_router(&self) -> Router {
126 use axum::http::StatusCode;
127 use axum::routing::{delete, get, post, put};
128 use axum::{response::IntoResponse, Json};
129
130 let mut router = Router::new();
131
132 for route_config in &self.routes {
133 let status = route_config.response.status;
134 let body = route_config.response.body.clone();
135 let headers = route_config.response.headers.clone();
136
137 let handler = move || {
138 let body = body.clone();
139 let headers = headers.clone();
140 async move {
141 let mut response = Json(body).into_response();
142 *response.status_mut() = StatusCode::from_u16(status).unwrap();
143
144 for (key, value) in headers {
145 if let Ok(header_name) = axum::http::HeaderName::from_bytes(key.as_bytes())
146 {
147 if let Ok(header_value) = axum::http::HeaderValue::from_str(&value) {
148 response.headers_mut().insert(header_name, header_value);
149 }
150 }
151 }
152
153 response
154 }
155 };
156
157 let path = &route_config.path;
158
159 router = match route_config.method.to_uppercase().as_str() {
160 "GET" => router.route(path, get(handler)),
161 "POST" => router.route(path, post(handler)),
162 "PUT" => router.route(path, put(handler)),
163 "DELETE" => router.route(path, delete(handler)),
164 _ => router,
165 };
166 }
167
168 router
169 }
170
171 pub async fn stop(mut self) -> Result<()> {
173 if let Some(shutdown_tx) = self.shutdown_tx.take() {
174 let _ = shutdown_tx.send(());
175 }
176
177 if let Some(handle) = self.server_handle.take() {
178 let _ = handle.await;
179 }
180
181 Ok(())
182 }
183
184 pub async fn stub_response(
186 &mut self,
187 method: impl Into<String>,
188 path: impl Into<String>,
189 body: Value,
190 ) -> Result<()> {
191 let stub = ResponseStub::new(method, path, body);
192 self.add_stub(stub).await
193 }
194
195 pub async fn add_stub(&mut self, stub: ResponseStub) -> Result<()> {
197 let route_config = RouteConfig {
198 path: stub.path.clone(),
199 method: stub.method,
200 request: None,
201 response: RouteResponseConfig {
202 status: stub.status,
203 headers: stub.headers,
204 body: Some(stub.body),
205 },
206 };
207
208 self.routes.push(route_config);
209
210 Ok(())
211 }
212
213 pub async fn clear_stubs(&mut self) -> Result<()> {
215 self.routes.clear();
216 Ok(())
217 }
218
219 pub fn port(&self) -> u16 {
221 self.port
222 }
223
224 pub fn url(&self) -> String {
226 format!("http://{}", self.address)
227 }
228
229 pub fn is_running(&self) -> bool {
231 self.server_handle.is_some()
232 }
233}
234
235impl Default for MockServer {
236 fn default() -> Self {
237 Self {
238 port: 0,
239 address: "127.0.0.1:0".parse().unwrap(),
240 config: ServerConfig::default(),
241 server_handle: None,
242 shutdown_tx: None,
243 routes: Vec::new(),
244 }
245 }
246}
247
248impl Drop for MockServer {
250 fn drop(&mut self) {
251 if let Some(shutdown_tx) = self.shutdown_tx.take() {
252 let _ = shutdown_tx.send(());
253 }
254 }
255}