mockforge_intelligence/behavioral_cloning/
probabilistic_model.rs1use crate::behavioral_cloning::types::{
7 EndpointProbabilityModel, ErrorPattern, LatencyDistribution,
8};
9use std::collections::HashMap;
10
11pub struct ProbabilisticModel;
13
14impl ProbabilisticModel {
15 pub fn build_probability_model_from_data(
20 endpoint: &str,
21 method: &str,
22 status_codes: &[u16],
23 latencies_ms: &[u64],
24 error_responses: &[(u16, serde_json::Value)],
25 request_payloads: &[serde_json::Value],
26 response_payloads: &[serde_json::Value],
27 ) -> EndpointProbabilityModel {
28 let sample_count = status_codes.len().max(latencies_ms.len()) as u64;
29
30 let mut status_code_counts: HashMap<u16, usize> = HashMap::new();
32 for &code in status_codes {
33 *status_code_counts.entry(code).or_insert(0) += 1;
34 }
35
36 let total_status_codes = status_codes.len() as f64;
37 let status_code_distribution: HashMap<u16, f64> = status_code_counts
38 .into_iter()
39 .map(|(code, count)| (code, count as f64 / total_status_codes))
40 .collect();
41
42 let latency_distribution = if latencies_ms.is_empty() {
44 LatencyDistribution::new(0, 0, 0, 0.0, 0.0, 0, 0)
45 } else {
46 let mut sorted_latencies = latencies_ms.to_vec();
47 sorted_latencies.sort_unstable();
48
49 let len = sorted_latencies.len();
50 let p50_idx = (len as f64 * 0.5) as usize;
51 let p95_idx = (len as f64 * 0.95) as usize;
52 let p99_idx = (len as f64 * 0.99).min((len - 1) as f64) as usize;
53
54 let p50 = sorted_latencies[p50_idx.min(len - 1)];
55 let p95 = sorted_latencies[p95_idx.min(len - 1)];
56 let p99 = sorted_latencies[p99_idx.min(len - 1)];
57
58 let mean = sorted_latencies.iter().sum::<u64>() as f64 / len as f64;
59 let variance = sorted_latencies
60 .iter()
61 .map(|&x| {
62 let diff = x as f64 - mean;
63 diff * diff
64 })
65 .sum::<f64>()
66 / len as f64;
67 let std_dev = variance.sqrt();
68
69 let min = *sorted_latencies.first().unwrap_or(&0);
70 let max = *sorted_latencies.last().unwrap_or(&0);
71
72 LatencyDistribution::new(p50, p95, p99, mean, std_dev, min, max)
73 };
74
75 let mut error_patterns: Vec<ErrorPattern> = Vec::new();
77 let mut error_counts: HashMap<u16, (usize, Vec<serde_json::Value>)> = HashMap::new();
78
79 for (status_code, response_body) in error_responses {
80 if *status_code >= 400 {
81 let entry = error_counts.entry(*status_code).or_insert_with(|| (0, Vec::new()));
82 entry.0 += 1;
83 entry.1.push(response_body.clone());
84 }
85 }
86
87 let total_errors = error_responses.len() as f64;
88 if total_errors > 0.0 {
89 for (status_code, (count, samples)) in error_counts {
90 let probability = count as f64 / total_errors;
91 let mut pattern = ErrorPattern::new(format!("http_{}", status_code), probability);
92 pattern.status_code = Some(status_code);
93 if let Some(sample) = samples.first() {
94 pattern.sample_responses.push(sample.clone());
95 }
96 error_patterns.push(pattern);
97 }
98 }
99
100 let payload_variations =
102 Self::detect_payload_variations(request_payloads, response_payloads, status_codes);
103
104 EndpointProbabilityModel {
105 endpoint: endpoint.to_string(),
106 method: method.to_string(),
107 status_code_distribution,
108 latency_distribution,
109 error_patterns,
110 payload_variations,
111 sample_count,
112 updated_at: chrono::Utc::now(),
113 original_error_probabilities: None,
114 }
115 }
116
117 fn detect_payload_variations(
122 request_payloads: &[serde_json::Value],
123 response_payloads: &[serde_json::Value],
124 status_codes: &[u16],
125 ) -> Vec<crate::behavioral_cloning::types::PayloadVariation> {
126 use crate::behavioral_cloning::types::PayloadVariation;
127 use std::collections::HashMap;
128
129 if response_payloads.is_empty() && request_payloads.is_empty() {
130 return Vec::new();
131 }
132
133 let mut variation_groups: HashMap<String, (usize, serde_json::Value, Option<u16>)> =
135 HashMap::new();
136
137 for (idx, payload) in response_payloads.iter().enumerate() {
139 let status_code = if idx < status_codes.len() {
140 Some(status_codes[idx])
141 } else {
142 None
143 };
144
145 let signature = Self::payload_signature(payload);
147 let key = if let Some(code) = status_code {
148 format!("{}:{}", code, signature)
149 } else {
150 signature.clone()
151 };
152
153 let entry =
154 variation_groups.entry(key).or_insert_with(|| (0, payload.clone(), status_code));
155 entry.0 += 1;
156 }
157
158 for payload in request_payloads {
160 let signature = Self::payload_signature(payload);
161 let key = format!("request:{}", signature);
162
163 let entry = variation_groups.entry(key).or_insert_with(|| (0, payload.clone(), None));
164 entry.0 += 1;
165 }
166
167 let total_samples =
169 variation_groups.values().map(|(count, _, _)| *count).sum::<usize>() as f64;
170 if total_samples == 0.0 {
171 return Vec::new();
172 }
173
174 let mut variations = Vec::new();
175 for (idx, (_key, (count, sample, status_code))) in variation_groups.into_iter().enumerate()
176 {
177 let probability = count as f64 / total_samples;
178 let variation_id = format!("var_{}", idx);
179
180 let mut variation = PayloadVariation {
181 id: variation_id,
182 probability,
183 sample_payload: sample,
184 conditions: None,
185 };
186
187 if let Some(code) = status_code {
189 let mut conditions = HashMap::new();
190 conditions.insert("status_code".to_string(), code.to_string());
191 variation.conditions = Some(conditions);
192 }
193
194 variations.push(variation);
195 }
196
197 variations.sort_by(|a, b| {
199 b.probability.partial_cmp(&a.probability).unwrap_or(std::cmp::Ordering::Equal)
200 });
201
202 variations
203 }
204
205 fn payload_signature(payload: &serde_json::Value) -> String {
210 match payload {
211 serde_json::Value::Object(map) => {
212 let mut keys: Vec<String> = map.keys().cloned().collect();
213 keys.sort();
214 let mut sig_parts = Vec::new();
215 for key in keys {
216 if let Some(value) = map.get(&key) {
217 let value_type = match value {
218 serde_json::Value::Null => "null",
219 serde_json::Value::Bool(_) => "bool",
220 serde_json::Value::Number(_) => "number",
221 serde_json::Value::String(_) => "string",
222 serde_json::Value::Array(_) => "array",
223 serde_json::Value::Object(_) => "object",
224 };
225 sig_parts.push(format!("{}:{}", key, value_type));
226 }
227 }
228 format!("{{{}}}", sig_parts.join(","))
229 }
230 serde_json::Value::Array(arr) => {
231 if arr.is_empty() {
232 "[]".to_string()
233 } else {
234 format!("[{}]", Self::payload_signature(&arr[0]))
236 }
237 }
238 _ => {
239 match payload {
241 serde_json::Value::Null => "null",
242 serde_json::Value::Bool(_) => "bool",
243 serde_json::Value::Number(_) => "number",
244 serde_json::Value::String(_) => "string",
245 _ => "unknown",
246 }
247 .to_string()
248 }
249 }
250 }
251
252 pub fn sample_status_code(model: &EndpointProbabilityModel) -> u16 {
254 use rand::Rng;
255 let mut rng = rand::thread_rng();
256 let random: f64 = rng.gen_range(0.0..1.0);
257
258 let mut cumulative = 0.0;
259 for (status_code, probability) in &model.status_code_distribution {
260 cumulative += probability;
261 if random <= cumulative {
262 return *status_code;
263 }
264 }
265
266 model
268 .status_code_distribution
269 .iter()
270 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
271 .map(|(code, _)| *code)
272 .unwrap_or(200)
273 }
274
275 pub fn sample_latency(model: &EndpointProbabilityModel) -> u64 {
277 use rand::Rng;
278 let mut rng = rand::thread_rng();
279
280 let mean = model.latency_distribution.mean;
282 let std_dev = model.latency_distribution.std_dev;
283
284 let u1: f64 = rng.gen_range(0.0..1.0);
286 let u2: f64 = rng.gen_range(0.0..1.0);
287 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
288 let sample = mean + std_dev * z0;
289
290 sample
292 .max(model.latency_distribution.min as f64)
293 .min(model.latency_distribution.max as f64) as u64
294 }
295
296 pub fn sample_error_pattern<'a>(
298 model: &'a EndpointProbabilityModel,
299 _conditions: Option<&HashMap<String, String>>,
300 ) -> Option<&'a ErrorPattern> {
301 use rand::Rng;
302 let mut rng = rand::thread_rng();
303 let random: f64 = rng.gen_range(0.0..1.0);
304
305 let mut cumulative = 0.0;
306 for pattern in &model.error_patterns {
307 cumulative += pattern.probability;
308 if random <= cumulative {
309 return Some(pattern);
310 }
311 }
312
313 None
314 }
315
316 pub fn update_model(
318 model: &mut EndpointProbabilityModel,
319 status_code: u16,
320 latency_ms: u64,
321 _error_pattern: Option<&ErrorPattern>,
322 ) {
323 let total = model.sample_count as f64;
325 let new_total = total + 1.0;
326
327 for (_code, prob) in model.status_code_distribution.iter_mut() {
329 *prob = (*prob * total) / new_total;
330 }
331
332 let status_prob = model.status_code_distribution.entry(status_code).or_insert(0.0);
333 *status_prob = (*status_prob * total + 1.0) / new_total;
334
335 let latency = latency_ms as f64;
337 let old_mean = model.latency_distribution.mean;
338 let new_mean = (old_mean * total + latency) / new_total;
339 model.latency_distribution.mean = new_mean;
340
341 if total > 0.0 {
345 let old_variance = model.latency_distribution.std_dev.powi(2);
346 let old_m2 = old_variance * total;
347 let new_m2 = old_m2 + (latency - old_mean) * (latency - new_mean);
348 model.latency_distribution.std_dev = (new_m2 / new_total).sqrt();
349 } else {
350 model.latency_distribution.std_dev = 0.0;
351 }
352
353 if latency_ms < model.latency_distribution.min {
355 model.latency_distribution.min = latency_ms;
356 }
357 if latency_ms > model.latency_distribution.max {
358 model.latency_distribution.max = latency_ms;
359 }
360
361 let step = 1.0 / new_total;
365 if latency_ms <= model.latency_distribution.p50 {
366 let delta = (model.latency_distribution.p50 as f64
367 - model.latency_distribution.min as f64)
368 * step;
369 model.latency_distribution.p50 =
370 (model.latency_distribution.p50 as f64 - delta).round() as u64;
371 } else {
372 let delta = (model.latency_distribution.max as f64
373 - model.latency_distribution.p50 as f64)
374 * step;
375 model.latency_distribution.p50 =
376 (model.latency_distribution.p50 as f64 + delta).round() as u64;
377 }
378
379 if latency_ms <= model.latency_distribution.p95 {
380 let delta = (model.latency_distribution.p95 as f64
381 - model.latency_distribution.min as f64)
382 * step
383 * 0.05; model.latency_distribution.p95 =
385 (model.latency_distribution.p95 as f64 - delta).round() as u64;
386 } else {
387 let delta = (model.latency_distribution.max as f64
388 - model.latency_distribution.p95 as f64)
389 * step
390 * 0.95;
391 model.latency_distribution.p95 =
392 (model.latency_distribution.p95 as f64 + delta).round() as u64;
393 }
394
395 if latency_ms <= model.latency_distribution.p99 {
396 let delta = (model.latency_distribution.p99 as f64
397 - model.latency_distribution.min as f64)
398 * step
399 * 0.01;
400 model.latency_distribution.p99 =
401 (model.latency_distribution.p99 as f64 - delta).round() as u64;
402 } else {
403 let delta = (model.latency_distribution.max as f64
404 - model.latency_distribution.p99 as f64)
405 * step
406 * 0.99;
407 model.latency_distribution.p99 =
408 (model.latency_distribution.p99 as f64 + delta).round() as u64;
409 }
410
411 model.sample_count += 1;
412 model.updated_at = chrono::Utc::now();
413 }
414}