1use axum::extract::{Path, Query, State};
10use axum::response::Json;
11use mockforge_core::behavioral_cloning::types::BehavioralSequence;
12use mockforge_core::behavioral_cloning::{
13 EdgeAmplificationConfig, EdgeAmplifier, EndpointProbabilityModel, ProbabilisticModel,
14 SequenceLearner,
15};
16use mockforge_recorder::database::RecorderDatabase;
17use serde::{Deserialize, Serialize};
18use serde_json::{json, Value};
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23#[derive(Clone)]
25pub struct BehavioralCloningState {
26 pub edge_amplifier: Arc<EdgeAmplifier>,
28 pub database_path: Option<PathBuf>,
31}
32
33impl BehavioralCloningState {
34 pub fn new() -> Self {
36 Self {
37 edge_amplifier: Arc::new(EdgeAmplifier::new()),
38 database_path: None,
39 }
40 }
41
42 pub fn with_database_path(path: PathBuf) -> Self {
44 Self {
45 edge_amplifier: Arc::new(EdgeAmplifier::new()),
46 database_path: Some(path),
47 }
48 }
49
50 async fn open_database(&self) -> Result<RecorderDatabase, String> {
52 let db_path = self.database_path.as_ref().cloned().unwrap_or_else(|| {
53 std::env::current_dir()
54 .unwrap_or_else(|_| PathBuf::from("."))
55 .join("recordings.db")
56 });
57
58 RecorderDatabase::new(&db_path)
59 .await
60 .map_err(|e| format!("Failed to open recorder database: {}", e))
61 }
62}
63
64impl Default for BehavioralCloningState {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[derive(Debug, Deserialize)]
72pub struct BuildProbabilityModelRequest {
73 pub endpoint: String,
75 pub method: String,
77 #[serde(default)]
79 pub sample_limit: Option<u32>,
80}
81
82#[derive(Debug, Deserialize)]
84pub struct DiscoverSequencesRequest {
85 #[serde(default)]
87 pub min_requests_per_trace: Option<i32>,
88 #[serde(default = "default_min_frequency")]
90 pub min_frequency: f64,
91}
92
93fn default_min_frequency() -> f64 {
94 0.1 }
96
97#[derive(Debug, Deserialize)]
99pub struct ApplyAmplificationRequest {
100 pub config: EdgeAmplificationConfig,
102 #[serde(default)]
104 pub endpoint: Option<String>,
105 #[serde(default)]
107 pub method: Option<String>,
108}
109
110#[derive(Debug, Serialize)]
112pub struct ProbabilityModelResponse {
113 pub success: bool,
115 pub model: EndpointProbabilityModel,
117}
118
119#[derive(Debug, Serialize)]
121pub struct SequenceDiscoveryResponse {
122 pub success: bool,
124 pub count: usize,
126 pub sequences: Vec<BehavioralSequence>,
128}
129
130pub async fn build_probability_model(
134 State(state): State<BehavioralCloningState>,
135 Json(request): Json<BuildProbabilityModelRequest>,
136) -> Result<Json<Value>, String> {
137 let db = state.open_database().await?;
139
140 let limit = request.sample_limit.map(|l| l as i32);
142 let exchanges = db
143 .get_exchanges_for_endpoint(&request.endpoint, &request.method, limit)
144 .await
145 .map_err(|e| format!("Failed to query exchanges: {}", e))?;
146
147 if exchanges.is_empty() {
148 return Err(format!(
149 "No recorded traffic found for {} {}",
150 request.method, request.endpoint
151 ));
152 }
153
154 let mut status_codes = Vec::new();
156 let mut latencies_ms = Vec::new();
157 let mut error_responses = Vec::new();
158
159 for (req, resp_opt) in &exchanges {
160 let status_code = if let Some(resp) = resp_opt {
162 resp.status_code as u16
163 } else if let Some(code) = req.status_code {
164 code as u16
165 } else {
166 continue; };
168
169 status_codes.push(status_code);
170
171 if let Some(duration) = req.duration_ms {
173 latencies_ms.push(duration as u64);
174 }
175
176 if status_code >= 400 {
178 if let Some(resp) = resp_opt {
179 if let Some(ref body) = resp.body {
180 if let Ok(json_body) = serde_json::from_str::<Value>(body) {
182 error_responses.push((status_code, json_body));
183 } else {
184 error_responses.push((
186 status_code,
187 json!({
188 "error": body.clone()
189 }),
190 ));
191 }
192 }
193 }
194 }
195 }
196
197 let mut request_payloads = Vec::new();
199 let mut response_payloads = Vec::new();
200
201 for (req, resp_opt) in &exchanges {
202 if let Some(ref body) = req.body {
204 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
205 request_payloads.push(json);
206 }
207 }
208
209 if let Some(ref resp) = resp_opt {
211 if let Some(ref body) = resp.body {
212 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
213 response_payloads.push(json);
214 }
215 }
216 }
217 }
218
219 let model = ProbabilisticModel::build_probability_model_from_data(
221 &request.endpoint,
222 &request.method,
223 &status_codes,
224 &latencies_ms,
225 &error_responses,
226 &request_payloads,
227 &response_payloads,
228 );
229
230 db.insert_endpoint_probability_model(&model)
232 .await
233 .map_err(|e| format!("Failed to store probability model: {}", e))?;
234
235 Ok(Json(json!({
236 "success": true,
237 "model": model
238 })))
239}
240
241pub async fn get_probability_model(
245 Path((endpoint, method)): Path<(String, String)>,
246 State(state): State<BehavioralCloningState>,
247) -> Result<Json<Value>, String> {
248 let db = state.open_database().await?;
249
250 let model = db
251 .get_endpoint_probability_model(&endpoint, &method)
252 .await
253 .map_err(|e| format!("Failed to query probability model: {}", e))?
254 .ok_or_else(|| format!("No probability model found for {} {}", method, endpoint))?;
255
256 Ok(Json(json!({
257 "success": true,
258 "model": model
259 })))
260}
261
262pub async fn list_probability_models(
266 State(state): State<BehavioralCloningState>,
267) -> Result<Json<Value>, String> {
268 let db = state.open_database().await?;
269
270 let models = db
271 .get_all_endpoint_probability_models()
272 .await
273 .map_err(|e| format!("Failed to query probability models: {}", e))?;
274
275 Ok(Json(json!({
276 "success": true,
277 "models": models,
278 "count": models.len()
279 })))
280}
281
282pub async fn discover_sequences(
286 State(state): State<BehavioralCloningState>,
287 Json(request): Json<DiscoverSequencesRequest>,
288) -> Result<Json<Value>, String> {
289 let db = state.open_database().await?;
290
291 let trace_groups = db
293 .get_requests_by_trace(request.min_requests_per_trace)
294 .await
295 .map_err(|e| format!("Failed to query traces: {}", e))?;
296
297 if trace_groups.is_empty() {
298 return Ok(Json(json!({
299 "success": true,
300 "count": 0,
301 "sequences": [],
302 "message": "No traces found with sufficient requests"
303 })));
304 }
305
306 let mut sequences: Vec<Vec<(String, String, Option<u64>)>> = Vec::new();
308
309 for (_trace_id, requests) in trace_groups {
310 let mut seq = Vec::new();
311 let mut prev_timestamp = None;
312
313 for req in requests {
314 let delay = if let Some(prev_ts) = prev_timestamp {
316 let duration = req.timestamp.signed_duration_since(prev_ts);
317 Some(duration.num_milliseconds().max(0) as u64)
318 } else {
319 None
320 };
321
322 seq.push((req.path.clone(), req.method.clone(), delay));
323 prev_timestamp = Some(req.timestamp);
324 }
325
326 if !seq.is_empty() {
327 sequences.push(seq);
328 }
329 }
330
331 let learned_sequences =
333 SequenceLearner::learn_sequence_pattern(&sequences, request.min_frequency)
334 .map_err(|e| format!("Failed to learn sequences: {}", e))?;
335
336 for sequence in &learned_sequences {
338 db.insert_behavioral_sequence(sequence)
339 .await
340 .map_err(|e| format!("Failed to store sequence: {}", e))?;
341 }
342
343 Ok(Json(json!({
344 "success": true,
345 "count": learned_sequences.len(),
346 "sequences": learned_sequences
347 })))
348}
349
350pub async fn list_sequences(
354 State(state): State<BehavioralCloningState>,
355) -> Result<Json<Value>, String> {
356 let db = state.open_database().await?;
357
358 let sequences = db
359 .get_behavioral_sequences()
360 .await
361 .map_err(|e| format!("Failed to query sequences: {}", e))?;
362
363 Ok(Json(json!({
364 "success": true,
365 "sequences": sequences,
366 "count": sequences.len()
367 })))
368}
369
370pub async fn get_sequence(
374 Path(sequence_id): Path<String>,
375 State(state): State<BehavioralCloningState>,
376) -> Result<Json<Value>, String> {
377 let db = state.open_database().await?;
378
379 let sequences = db
380 .get_behavioral_sequences()
381 .await
382 .map_err(|e| format!("Failed to query sequences: {}", e))?;
383
384 let sequence = sequences
385 .into_iter()
386 .find(|s| s.id == sequence_id)
387 .ok_or_else(|| format!("Sequence {} not found", sequence_id))?;
388
389 Ok(Json(json!({
390 "success": true,
391 "sequence": sequence
392 })))
393}
394
395pub async fn apply_amplification(
399 State(state): State<BehavioralCloningState>,
400 Json(request): Json<ApplyAmplificationRequest>,
401) -> Result<Json<Value>, String> {
402 if !request.config.enabled {
403 return Ok(Json(json!({
404 "success": true,
405 "message": "Amplification disabled"
406 })));
407 }
408
409 let db = state.open_database().await?;
410
411 let models_to_update = match &request.config.scope {
413 mockforge_core::behavioral_cloning::AmplificationScope::Global => db
414 .get_all_endpoint_probability_models()
415 .await
416 .map_err(|e| format!("Failed to query models: {}", e))?,
417 mockforge_core::behavioral_cloning::AmplificationScope::Endpoint { endpoint, method } => {
418 if let Some(model) = db
419 .get_endpoint_probability_model(endpoint, method)
420 .await
421 .map_err(|e| format!("Failed to query model: {}", e))?
422 {
423 vec![model]
424 } else {
425 return Err(format!("No probability model found for {} {}", method, endpoint));
426 }
427 }
428 mockforge_core::behavioral_cloning::AmplificationScope::Sequence { .. } => {
429 return Err("Sequence-scoped amplification not yet implemented".to_string());
432 }
433 };
434
435 let mut updated_count = 0;
437 for mut model in models_to_update {
438 EdgeAmplifier::apply_amplification(&mut model, &request.config)
439 .map_err(|e| format!("Failed to apply amplification: {}", e))?;
440
441 db.insert_endpoint_probability_model(&model)
443 .await
444 .map_err(|e| format!("Failed to store updated model: {}", e))?;
445
446 updated_count += 1;
447 }
448
449 Ok(Json(json!({
450 "success": true,
451 "updated_models": updated_count,
452 "config": request.config
453 })))
454}
455
456pub async fn get_rare_edges(
460 Path((endpoint, method)): Path<(String, String)>,
461 Query(params): Query<HashMap<String, String>>,
462 State(state): State<BehavioralCloningState>,
463) -> Result<Json<Value>, String> {
464 let db = state.open_database().await?;
465
466 let model = db
467 .get_endpoint_probability_model(&endpoint, &method)
468 .await
469 .map_err(|e| format!("Failed to query model: {}", e))?
470 .ok_or_else(|| format!("No probability model found for {} {}", method, endpoint))?;
471
472 let threshold: f64 = params.get("threshold").and_then(|s| s.parse().ok()).unwrap_or(0.01); let rare_patterns = EdgeAmplifier::identify_rare_edges(&model, threshold);
475
476 Ok(Json(json!({
477 "success": true,
478 "endpoint": endpoint,
479 "method": method,
480 "threshold": threshold,
481 "rare_patterns": rare_patterns
482 })))
483}
484
485pub async fn sample_status_code(
489 Path((endpoint, method)): Path<(String, String)>,
490 State(state): State<BehavioralCloningState>,
491) -> Result<Json<Value>, String> {
492 let db = state.open_database().await?;
493
494 let model = db
495 .get_endpoint_probability_model(&endpoint, &method)
496 .await
497 .map_err(|e| format!("Failed to query model: {}", e))?
498 .ok_or_else(|| format!("No probability model found for {} {}", method, endpoint))?;
499
500 let sampled_code = ProbabilisticModel::sample_status_code(&model);
501
502 Ok(Json(json!({
503 "success": true,
504 "endpoint": endpoint,
505 "method": method,
506 "status_code": sampled_code
507 })))
508}
509
510pub async fn sample_latency(
514 Path((endpoint, method)): Path<(String, String)>,
515 State(state): State<BehavioralCloningState>,
516) -> Result<Json<Value>, String> {
517 let db = state.open_database().await?;
518
519 let model = db
520 .get_endpoint_probability_model(&endpoint, &method)
521 .await
522 .map_err(|e| format!("Failed to query model: {}", e))?
523 .ok_or_else(|| format!("No probability model found for {} {}", method, endpoint))?;
524
525 let sampled_latency = ProbabilisticModel::sample_latency(&model);
526
527 Ok(Json(json!({
528 "success": true,
529 "endpoint": endpoint,
530 "method": method,
531 "latency_ms": sampled_latency
532 })))
533}
534
535pub async fn generate_sequence_scenario(
539 Path(sequence_id): Path<String>,
540 State(state): State<BehavioralCloningState>,
541) -> Result<Json<Value>, String> {
542 let db = state.open_database().await?;
543
544 let sequences = db
545 .get_behavioral_sequences()
546 .await
547 .map_err(|e| format!("Failed to query sequences: {}", e))?;
548
549 let sequence = sequences
550 .into_iter()
551 .find(|s| s.id == sequence_id)
552 .ok_or_else(|| format!("Sequence {} not found", sequence_id))?;
553
554 let scenario = SequenceLearner::generate_sequence_scenario(&sequence);
555
556 Ok(Json(json!({
557 "success": true,
558 "sequence_id": sequence_id,
559 "scenario": scenario
560 })))
561}
562
563pub fn behavioral_cloning_router(state: BehavioralCloningState) -> axum::Router {
565 use axum::routing::{get, post};
566 use axum::Router;
567
568 Router::new()
569 .route(
571 "/api/v1/behavioral-cloning/probability-models",
572 post(build_probability_model).get(list_probability_models),
573 )
574 .route(
575 "/api/v1/behavioral-cloning/probability-models/{endpoint}/{method}",
576 get(get_probability_model),
577 )
578 .route(
579 "/api/v1/behavioral-cloning/probability-models/{endpoint}/{method}/sample/status-code",
580 post(sample_status_code),
581 )
582 .route(
583 "/api/v1/behavioral-cloning/probability-models/{endpoint}/{method}/sample/latency",
584 post(sample_latency),
585 )
586 .route(
588 "/api/v1/behavioral-cloning/sequences",
589 get(list_sequences),
590 )
591 .route(
592 "/api/v1/behavioral-cloning/sequences/discover",
593 post(discover_sequences),
594 )
595 .route(
596 "/api/v1/behavioral-cloning/sequences/{sequence_id}",
597 get(get_sequence),
598 )
599 .route(
600 "/api/v1/behavioral-cloning/sequences/{sequence_id}/scenario",
601 post(generate_sequence_scenario),
602 )
603 .route(
605 "/api/v1/behavioral-cloning/amplification/apply",
606 post(apply_amplification),
607 )
608 .route(
609 "/api/v1/behavioral-cloning/amplification/rare-edges/{endpoint}/{method}",
610 get(get_rare_edges),
611 )
612 .with_state(state)
613}