mockforge_http/middleware/
behavioral_cloning.rs

1//! Behavioral cloning middleware
2//!
3//! This middleware applies learned behavioral patterns to requests,
4//! including probabilistic status codes, latency, and error patterns.
5
6use axum::{
7    body::Body,
8    extract::{Path, State},
9    http::{Request, Response, StatusCode},
10    middleware::Next,
11};
12use mockforge_core::behavioral_cloning::{ProbabilisticModel, SequenceLearner};
13use mockforge_recorder::database::RecorderDatabase;
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::time::sleep;
19use tracing::{debug, trace};
20
21/// Behavioral cloning middleware state
22#[derive(Clone)]
23pub struct BehavioralCloningMiddlewareState {
24    /// Optional recorder database path
25    pub database_path: Option<PathBuf>,
26    /// Whether behavioral cloning is enabled
27    pub enabled: bool,
28    /// Cache for loaded probability models (to avoid repeated DB queries)
29    pub model_cache: Arc<
30        tokio::sync::RwLock<
31            HashMap<String, mockforge_core::behavioral_cloning::EndpointProbabilityModel>,
32        >,
33    >,
34}
35
36impl BehavioralCloningMiddlewareState {
37    /// Create new middleware state
38    pub fn new() -> Self {
39        Self {
40            database_path: None,
41            enabled: true,
42            model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
43        }
44    }
45
46    /// Create state with database path
47    pub fn with_database_path(path: PathBuf) -> Self {
48        Self {
49            database_path: Some(path),
50            enabled: true,
51            model_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
52        }
53    }
54
55    /// Open database connection
56    async fn open_database(&self) -> Option<RecorderDatabase> {
57        let db_path = self.database_path.as_ref().cloned().unwrap_or_else(|| {
58            std::env::current_dir()
59                .unwrap_or_else(|_| PathBuf::from("."))
60                .join("recordings.db")
61        });
62
63        RecorderDatabase::new(&db_path).await.ok()
64    }
65
66    /// Get probability model for endpoint (with caching)
67    async fn get_probability_model(
68        &self,
69        endpoint: &str,
70        method: &str,
71    ) -> Option<mockforge_core::behavioral_cloning::EndpointProbabilityModel> {
72        let cache_key = format!("{}:{}", method, endpoint);
73
74        // Check cache first
75        {
76            let cache = self.model_cache.read().await;
77            if let Some(model) = cache.get(&cache_key) {
78                return Some(model.clone());
79            }
80        }
81
82        // Load from database
83        if let Some(db) = self.open_database().await {
84            if let Ok(Some(model)) = db.get_endpoint_probability_model(endpoint, method).await {
85                // Store in cache
86                let mut cache = self.model_cache.write().await;
87                cache.insert(cache_key, model.clone());
88                return Some(model);
89            }
90        }
91
92        None
93    }
94}
95
96impl Default for BehavioralCloningMiddlewareState {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102/// Behavioral cloning middleware
103///
104/// Applies learned behavioral patterns to requests:
105/// - Samples status codes from probability models
106/// - Applies latency based on learned distributions
107/// - Injects error patterns based on learned probabilities
108pub async fn behavioral_cloning_middleware(req: Request<Body>, next: Next) -> Response<Body> {
109    // Extract state from extensions (set by router)
110    let state = req.extensions().get::<BehavioralCloningMiddlewareState>().cloned();
111
112    // If no state or disabled, pass through
113    let state = match state {
114        Some(s) if s.enabled => s,
115        _ => return next.run(req).await,
116    };
117
118    // Extract endpoint and method
119    let method = req.method().to_string();
120    let path = req.uri().path().to_string();
121
122    // Get probability model for this endpoint
123    let model = state.get_probability_model(&path, &method).await;
124
125    if let Some(model) = model {
126        debug!("Applying behavioral cloning to {} {}", method, path);
127
128        // Sample status code
129        let sampled_status = ProbabilisticModel::sample_status_code(&model);
130
131        // Sample latency
132        let sampled_latency = ProbabilisticModel::sample_latency(&model);
133
134        // Apply latency delay
135        if sampled_latency > 0 {
136            trace!("Applying latency delay: {}ms", sampled_latency);
137            sleep(Duration::from_millis(sampled_latency)).await;
138        }
139
140        // Sample error pattern
141        let error_pattern = ProbabilisticModel::sample_error_pattern(&model, None);
142
143        // If we sampled an error status code, we need to modify the response
144        // However, we can't easily modify the response status in middleware
145        // without intercepting it. For now, we'll just apply latency.
146        // Full error injection would require response interception middleware.
147
148        if let Some(pattern) = error_pattern {
149            debug!(
150                "Sampled error pattern: {} (probability: {})",
151                pattern.error_type, pattern.probability
152            );
153            // TODO: Apply error pattern to response
154            // This would require response interception middleware
155        }
156
157        // Continue with request (status code modification would need response interception)
158        let mut response = next.run(req).await;
159
160        // Modify response status if we sampled a different status code
161        if sampled_status != response.status().as_u16() {
162            *response.status_mut() =
163                StatusCode::from_u16(sampled_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
164        }
165
166        response
167    } else {
168        // No model found, pass through
169        next.run(req).await
170    }
171}