1use crate::{
5    Error, FailureInjector, ProxyHandler, RecordReplayHandler, RequestFingerprint,
6    ResponsePriority, ResponseSource, Result,
7};
8use axum::http::{HeaderMap, Method, StatusCode, Uri};
9use std::collections::HashMap;
10
11pub struct PriorityHttpHandler {
13    record_replay: RecordReplayHandler,
15    failure_injector: Option<FailureInjector>,
17    proxy_handler: Option<ProxyHandler>,
19    mock_generator: Option<Box<dyn MockGenerator + Send + Sync>>,
21    openapi_spec: Option<crate::openapi::spec::OpenApiSpec>,
23}
24
25pub trait MockGenerator {
27    fn generate_mock_response(
29        &self,
30        fingerprint: &RequestFingerprint,
31        headers: &HeaderMap,
32        body: Option<&[u8]>,
33    ) -> Result<Option<MockResponse>>;
34}
35
36#[derive(Debug, Clone)]
38pub struct MockResponse {
39    pub status_code: u16,
41    pub headers: HashMap<String, String>,
43    pub body: String,
45    pub content_type: String,
47}
48
49impl PriorityHttpHandler {
50    pub fn new(
52        record_replay: RecordReplayHandler,
53        failure_injector: Option<FailureInjector>,
54        proxy_handler: Option<ProxyHandler>,
55        mock_generator: Option<Box<dyn MockGenerator + Send + Sync>>,
56    ) -> Self {
57        Self {
58            record_replay,
59            failure_injector,
60            proxy_handler,
61            mock_generator,
62            openapi_spec: None,
63        }
64    }
65
66    pub fn new_with_openapi(
68        record_replay: RecordReplayHandler,
69        failure_injector: Option<FailureInjector>,
70        proxy_handler: Option<ProxyHandler>,
71        mock_generator: Option<Box<dyn MockGenerator + Send + Sync>>,
72        openapi_spec: Option<crate::openapi::spec::OpenApiSpec>,
73    ) -> Self {
74        Self {
75            record_replay,
76            failure_injector,
77            proxy_handler,
78            mock_generator,
79            openapi_spec,
80        }
81    }
82
83    pub async fn process_request(
85        &self,
86        method: &Method,
87        uri: &Uri,
88        headers: &HeaderMap,
89        body: Option<&[u8]>,
90    ) -> Result<PriorityResponse> {
91        let fingerprint = RequestFingerprint::new(method.clone(), uri, headers, body);
92
93        if let Some(recorded_request) =
95            self.record_replay.replay_handler().load_fixture(&fingerprint).await?
96        {
97            let content_type = recorded_request
98                .response_headers
99                .get("content-type")
100                .unwrap_or(&"application/json".to_string())
101                .clone();
102
103            return Ok(PriorityResponse {
104                source: ResponseSource::new(ResponsePriority::Replay, "replay".to_string())
105                    .with_metadata("fixture_path".to_string(), "recorded".to_string()),
106                status_code: recorded_request.status_code,
107                headers: recorded_request.response_headers,
108                body: recorded_request.response_body.into_bytes(),
109                content_type,
110            });
111        }
112
113        if let Some(ref failure_injector) = self.failure_injector {
115            let tags = if let Some(ref spec) = self.openapi_spec {
116                fingerprint.openapi_tags(spec).unwrap_or_else(|| fingerprint.tags())
117            } else {
118                fingerprint.tags()
119            };
120            if let Some((status_code, error_message)) = failure_injector.process_request(&tags) {
121                let error_response = serde_json::json!({
122                    "error": error_message,
123                    "injected_failure": true,
124                    "timestamp": chrono::Utc::now().to_rfc3339()
125                });
126
127                return Ok(PriorityResponse {
128                    source: ResponseSource::new(
129                        ResponsePriority::Fail,
130                        "failure_injection".to_string(),
131                    )
132                    .with_metadata("error_message".to_string(), error_message),
133                    status_code,
134                    headers: HashMap::new(),
135                    body: serde_json::to_string(&error_response)?.into_bytes(),
136                    content_type: "application/json".to_string(),
137                });
138            }
139        }
140
141        if let Some(ref proxy_handler) = self.proxy_handler {
143            if proxy_handler.config.should_proxy(method, uri.path()) {
144                match proxy_handler.proxy_request(method, uri, headers, body).await {
145                    Ok(proxy_response) => {
146                        let mut response_headers = HashMap::new();
147                        for (key, value) in proxy_response.headers.iter() {
148                            let key_str = key.as_str();
149                            if let Ok(value_str) = value.to_str() {
150                                response_headers.insert(key_str.to_string(), value_str.to_string());
151                            }
152                        }
153
154                        let content_type = response_headers
155                            .get("content-type")
156                            .unwrap_or(&"application/json".to_string())
157                            .clone();
158
159                        return Ok(PriorityResponse {
160                            source: ResponseSource::new(
161                                ResponsePriority::Proxy,
162                                "proxy".to_string(),
163                            )
164                            .with_metadata(
165                                "upstream_url".to_string(),
166                                proxy_handler.config.get_upstream_url(uri.path()),
167                            ),
168                            status_code: proxy_response.status_code,
169                            headers: response_headers,
170                            body: proxy_response.body.unwrap_or_default(),
171                            content_type,
172                        });
173                    }
174                    Err(e) => {
175                        tracing::warn!("Proxy request failed: {}", e);
176                        }
178                }
179            }
180        }
181
182        if let Some(ref mock_generator) = self.mock_generator {
184            if let Some(mock_response) =
185                mock_generator.generate_mock_response(&fingerprint, headers, body)?
186            {
187                return Ok(PriorityResponse {
188                    source: ResponseSource::new(ResponsePriority::Mock, "mock".to_string())
189                        .with_metadata("generated_from".to_string(), "openapi_spec".to_string()),
190                    status_code: mock_response.status_code,
191                    headers: mock_response.headers,
192                    body: mock_response.body.into_bytes(),
193                    content_type: mock_response.content_type,
194                });
195            }
196        }
197
198        if self.record_replay.record_handler().should_record(method) {
200            let default_response = serde_json::json!({
202                "message": "Request recorded for future replay",
203                "timestamp": chrono::Utc::now().to_rfc3339(),
204                "fingerprint": fingerprint.to_hash()
205            });
206
207            let response_body = serde_json::to_string(&default_response)?;
208            let status_code = 200;
209
210            self.record_replay
212                .record_handler()
213                .record_request(&fingerprint, status_code, headers, &response_body, None)
214                .await?;
215
216            return Ok(PriorityResponse {
217                source: ResponseSource::new(ResponsePriority::Record, "record".to_string())
218                    .with_metadata("recorded".to_string(), "true".to_string()),
219                status_code,
220                headers: HashMap::new(),
221                body: response_body.into_bytes(),
222                content_type: "application/json".to_string(),
223            });
224        }
225
226        Err(Error::generic("No handler could process the request".to_string()))
228    }
229}
230
231#[derive(Debug, Clone)]
233pub struct PriorityResponse {
234    pub source: ResponseSource,
236    pub status_code: u16,
238    pub headers: HashMap<String, String>,
240    pub body: Vec<u8>,
242    pub content_type: String,
244}
245
246impl PriorityResponse {
247    pub fn to_axum_response(self) -> axum::response::Response {
249        let mut response = axum::response::Response::new(axum::body::Body::from(self.body));
250        *response.status_mut() = StatusCode::from_u16(self.status_code).unwrap_or(StatusCode::OK);
251
252        for (key, value) in self.headers {
254            if let (Ok(header_name), Ok(header_value)) =
255                (key.parse::<axum::http::HeaderName>(), value.parse::<axum::http::HeaderValue>())
256            {
257                response.headers_mut().insert(header_name, header_value);
258            }
259        }
260
261        if !response.headers().contains_key("content-type") {
263            if let Ok(header_value) = self.content_type.parse::<axum::http::HeaderValue>() {
264                response.headers_mut().insert("content-type", header_value);
265            }
266        }
267
268        response
269    }
270}
271
272pub struct SimpleMockGenerator {
274    pub default_status: u16,
276    pub default_body: String,
278}
279
280impl SimpleMockGenerator {
281    pub fn new(default_status: u16, default_body: String) -> Self {
283        Self {
284            default_status,
285            default_body,
286        }
287    }
288}
289
290impl MockGenerator for SimpleMockGenerator {
291    fn generate_mock_response(
292        &self,
293        _fingerprint: &RequestFingerprint,
294        _headers: &HeaderMap,
295        _body: Option<&[u8]>,
296    ) -> Result<Option<MockResponse>> {
297        Ok(Some(MockResponse {
298            status_code: self.default_status,
299            headers: HashMap::new(),
300            body: self.default_body.clone(),
301            content_type: "application/json".to_string(),
302        }))
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use tempfile::TempDir;
310
311    #[tokio::test]
312    async fn test_priority_chain() {
313        let temp_dir = TempDir::new().unwrap();
314        let fixtures_dir = temp_dir.path().to_path_buf();
315
316        let record_replay = RecordReplayHandler::new(fixtures_dir, true, true, false);
317        let mock_generator =
318            Box::new(SimpleMockGenerator::new(200, r#"{"message": "mock response"}"#.to_string()));
319
320        let handler = PriorityHttpHandler::new_with_openapi(
321            record_replay,
322            None, None, Some(mock_generator),
325            None, );
327
328        let method = Method::GET;
329        let uri = Uri::from_static("/api/test");
330        let headers = HeaderMap::new();
331
332        let response = handler.process_request(&method, &uri, &headers, None).await.unwrap();
333
334        assert_eq!(response.status_code, 200);
335        assert_eq!(response.source.source_type, "mock");
336    }
337}