use axum::{
body::Body,
http::{Request, Response, StatusCode},
middleware::Next,
};
use mockforge_core::behavioral_cloning::ProbabilisticModel;
use mockforge_recorder::database::RecorderDatabase;
use rand::Rng;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, trace};
#[derive(Clone)]
pub struct BehavioralCloningMiddlewareState {
pub database_path: Option<PathBuf>,
pub enabled: bool,
pub model_cache: Arc<
tokio::sync::RwLock<
HashMap<String, mockforge_core::behavioral_cloning::EndpointProbabilityModel>,
>,
>,
}
impl BehavioralCloningMiddlewareState {
pub fn new() -> Self {
Self {
database_path: None,
enabled: true,
model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub fn with_database_path(path: PathBuf) -> Self {
Self {
database_path: Some(path),
enabled: true,
model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
async fn open_database(&self) -> Option<RecorderDatabase> {
let db_path = self.database_path.as_ref().cloned().unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join("recordings.db")
});
RecorderDatabase::new(&db_path).await.ok()
}
async fn get_probability_model(
&self,
endpoint: &str,
method: &str,
) -> Option<mockforge_core::behavioral_cloning::EndpointProbabilityModel> {
let cache_key = format!("{}:{}", method, endpoint);
{
let cache = self.model_cache.read().await;
if let Some(model) = cache.get(&cache_key) {
return Some(model.clone());
}
}
if let Some(db) = self.open_database().await {
if let Ok(Some(model)) = db.get_endpoint_probability_model(endpoint, method).await {
let mut cache = self.model_cache.write().await;
cache.insert(cache_key, model.clone());
return Some(model);
}
}
None
}
}
impl Default for BehavioralCloningMiddlewareState {
fn default() -> Self {
Self::new()
}
}
pub async fn behavioral_cloning_middleware(req: Request<Body>, next: Next) -> Response<Body> {
let state = req.extensions().get::<BehavioralCloningMiddlewareState>().cloned();
let state = match state {
Some(s) if s.enabled => s,
_ => return next.run(req).await,
};
let method = req.method().to_string();
let path = req.uri().path().to_string();
let model = state.get_probability_model(&path, &method).await;
if let Some(model) = model {
debug!("Applying behavioral cloning to {} {}", method, path);
let sampled_status = ProbabilisticModel::sample_status_code(&model);
let sampled_latency = ProbabilisticModel::sample_latency(&model);
if sampled_latency > 0 {
trace!("Applying latency delay: {}ms", sampled_latency);
sleep(Duration::from_millis(sampled_latency)).await;
}
let error_pattern = ProbabilisticModel::sample_error_pattern(&model, None);
let mut response = next.run(req).await;
if let Some(pattern) = &error_pattern {
debug!(
"Applying error pattern: {} (probability: {})",
pattern.error_type, pattern.probability
);
if let Some(pattern_status) = pattern.status_code {
*response.status_mut() = StatusCode::from_u16(pattern_status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
} else if sampled_status != response.status().as_u16() {
*response.status_mut() = StatusCode::from_u16(sampled_status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
}
if !pattern.sample_responses.is_empty() {
use axum::body::Body;
let sample_idx = if pattern.sample_responses.len() > 1 {
rand::rng().random_range(0..pattern.sample_responses.len())
} else {
0
};
if let Some(sample_body) = pattern.sample_responses.get(sample_idx) {
if let Ok(json_string) = serde_json::to_string(sample_body) {
*response.body_mut() = Body::from(json_string);
response.headers_mut().insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
debug!("Applied error pattern body from sample response");
}
}
}
} else {
if sampled_status != response.status().as_u16() {
*response.status_mut() = StatusCode::from_u16(sampled_status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
}
}
response
} else {
next.run(req).await
}
}