mockforge_http/middleware/
behavioral_cloning.rs1use axum::{
7 body::Body,
8 extract::{Path, State},
9 http::{Request, Response, StatusCode},
10 middleware::Next,
11};
12use mockforge_core::behavioral_cloning::{ProbabilisticModel, SequenceLearner};
13use mockforge_recorder::database::RecorderDatabase;
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::time::sleep;
19use tracing::{debug, trace};
20
21#[derive(Clone)]
23pub struct BehavioralCloningMiddlewareState {
24 pub database_path: Option<PathBuf>,
26 pub enabled: bool,
28 pub model_cache: Arc<
30 tokio::sync::RwLock<
31 HashMap<String, mockforge_core::behavioral_cloning::EndpointProbabilityModel>,
32 >,
33 >,
34}
35
36impl BehavioralCloningMiddlewareState {
37 pub fn new() -> Self {
39 Self {
40 database_path: None,
41 enabled: true,
42 model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
43 }
44 }
45
46 pub fn with_database_path(path: PathBuf) -> Self {
48 Self {
49 database_path: Some(path),
50 enabled: true,
51 model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
52 }
53 }
54
55 async fn open_database(&self) -> Option<RecorderDatabase> {
57 let db_path = self.database_path.as_ref().cloned().unwrap_or_else(|| {
58 std::env::current_dir()
59 .unwrap_or_else(|_| PathBuf::from("."))
60 .join("recordings.db")
61 });
62
63 RecorderDatabase::new(&db_path).await.ok()
64 }
65
66 async fn get_probability_model(
68 &self,
69 endpoint: &str,
70 method: &str,
71 ) -> Option<mockforge_core::behavioral_cloning::EndpointProbabilityModel> {
72 let cache_key = format!("{}:{}", method, endpoint);
73
74 {
76 let cache = self.model_cache.read().await;
77 if let Some(model) = cache.get(&cache_key) {
78 return Some(model.clone());
79 }
80 }
81
82 if let Some(db) = self.open_database().await {
84 if let Ok(Some(model)) = db.get_endpoint_probability_model(endpoint, method).await {
85 let mut cache = self.model_cache.write().await;
87 cache.insert(cache_key, model.clone());
88 return Some(model);
89 }
90 }
91
92 None
93 }
94}
95
96impl Default for BehavioralCloningMiddlewareState {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102pub async fn behavioral_cloning_middleware(req: Request<Body>, next: Next) -> Response<Body> {
109 let state = req.extensions().get::<BehavioralCloningMiddlewareState>().cloned();
111
112 let state = match state {
114 Some(s) if s.enabled => s,
115 _ => return next.run(req).await,
116 };
117
118 let method = req.method().to_string();
120 let path = req.uri().path().to_string();
121
122 let model = state.get_probability_model(&path, &method).await;
124
125 if let Some(model) = model {
126 debug!("Applying behavioral cloning to {} {}", method, path);
127
128 let sampled_status = ProbabilisticModel::sample_status_code(&model);
130
131 let sampled_latency = ProbabilisticModel::sample_latency(&model);
133
134 if sampled_latency > 0 {
136 trace!("Applying latency delay: {}ms", sampled_latency);
137 sleep(Duration::from_millis(sampled_latency)).await;
138 }
139
140 let error_pattern = ProbabilisticModel::sample_error_pattern(&model, None);
142
143 if let Some(pattern) = error_pattern {
149 debug!(
150 "Sampled error pattern: {} (probability: {})",
151 pattern.error_type, pattern.probability
152 );
153 }
156
157 let mut response = next.run(req).await;
159
160 if sampled_status != response.status().as_u16() {
162 *response.status_mut() =
163 StatusCode::from_u16(sampled_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
164 }
165
166 response
167 } else {
168 next.run(req).await
170 }
171}