Skip to main content

mockforge_intelligence/behavioral_cloning/
probabilistic_model.rs

1//! Probabilistic outcome modeling
2//!
3//! This module provides functionality to build and use probability models
4//! for endpoint behavior, including status codes, latency, and error patterns.
5
6use crate::behavioral_cloning::types::{
7    EndpointProbabilityModel, ErrorPattern, LatencyDistribution,
8};
9use std::collections::HashMap;
10
11/// Probabilistic model builder and sampler
12pub struct ProbabilisticModel;
13
14impl ProbabilisticModel {
15    /// Build a probability model from a list of status codes and latencies
16    ///
17    /// This is a pure function that takes observed data and builds a probability model.
18    /// The caller is responsible for querying the database and providing the data.
19    pub fn build_probability_model_from_data(
20        endpoint: &str,
21        method: &str,
22        status_codes: &[u16],
23        latencies_ms: &[u64],
24        error_responses: &[(u16, serde_json::Value)],
25        request_payloads: &[serde_json::Value],
26        response_payloads: &[serde_json::Value],
27    ) -> EndpointProbabilityModel {
28        let sample_count = status_codes.len().max(latencies_ms.len()) as u64;
29
30        // Calculate status code distribution
31        let mut status_code_counts: HashMap<u16, usize> = HashMap::new();
32        for &code in status_codes {
33            *status_code_counts.entry(code).or_insert(0) += 1;
34        }
35
36        let total_status_codes = status_codes.len() as f64;
37        let status_code_distribution: HashMap<u16, f64> = status_code_counts
38            .into_iter()
39            .map(|(code, count)| (code, count as f64 / total_status_codes))
40            .collect();
41
42        // Calculate latency distribution
43        let latency_distribution = if latencies_ms.is_empty() {
44            LatencyDistribution::new(0, 0, 0, 0.0, 0.0, 0, 0)
45        } else {
46            let mut sorted_latencies = latencies_ms.to_vec();
47            sorted_latencies.sort_unstable();
48
49            let len = sorted_latencies.len();
50            let p50_idx = (len as f64 * 0.5) as usize;
51            let p95_idx = (len as f64 * 0.95) as usize;
52            let p99_idx = (len as f64 * 0.99).min((len - 1) as f64) as usize;
53
54            let p50 = sorted_latencies[p50_idx.min(len - 1)];
55            let p95 = sorted_latencies[p95_idx.min(len - 1)];
56            let p99 = sorted_latencies[p99_idx.min(len - 1)];
57
58            let mean = sorted_latencies.iter().sum::<u64>() as f64 / len as f64;
59            let variance = sorted_latencies
60                .iter()
61                .map(|&x| {
62                    let diff = x as f64 - mean;
63                    diff * diff
64                })
65                .sum::<f64>()
66                / len as f64;
67            let std_dev = variance.sqrt();
68
69            let min = *sorted_latencies.first().unwrap_or(&0);
70            let max = *sorted_latencies.last().unwrap_or(&0);
71
72            LatencyDistribution::new(p50, p95, p99, mean, std_dev, min, max)
73        };
74
75        // Identify error patterns
76        let mut error_patterns: Vec<ErrorPattern> = Vec::new();
77        let mut error_counts: HashMap<u16, (usize, Vec<serde_json::Value>)> = HashMap::new();
78
79        for (status_code, response_body) in error_responses {
80            if *status_code >= 400 {
81                let entry = error_counts.entry(*status_code).or_insert_with(|| (0, Vec::new()));
82                entry.0 += 1;
83                entry.1.push(response_body.clone());
84            }
85        }
86
87        let total_errors = error_responses.len() as f64;
88        if total_errors > 0.0 {
89            for (status_code, (count, samples)) in error_counts {
90                let probability = count as f64 / total_errors;
91                let mut pattern = ErrorPattern::new(format!("http_{}", status_code), probability);
92                pattern.status_code = Some(status_code);
93                if let Some(sample) = samples.first() {
94                    pattern.sample_responses.push(sample.clone());
95                }
96                error_patterns.push(pattern);
97            }
98        }
99
100        // Detect payload variations
101        let payload_variations =
102            Self::detect_payload_variations(request_payloads, response_payloads, status_codes);
103
104        EndpointProbabilityModel {
105            endpoint: endpoint.to_string(),
106            method: method.to_string(),
107            status_code_distribution,
108            latency_distribution,
109            error_patterns,
110            payload_variations,
111            sample_count,
112            updated_at: chrono::Utc::now(),
113            original_error_probabilities: None,
114        }
115    }
116
117    /// Detect payload variations from observed request/response bodies
118    ///
119    /// Groups similar payloads and calculates their probabilities.
120    /// Uses structural similarity (JSON structure) rather than exact matching.
121    fn detect_payload_variations(
122        request_payloads: &[serde_json::Value],
123        response_payloads: &[serde_json::Value],
124        status_codes: &[u16],
125    ) -> Vec<crate::behavioral_cloning::types::PayloadVariation> {
126        use crate::behavioral_cloning::types::PayloadVariation;
127        use std::collections::HashMap;
128
129        if response_payloads.is_empty() && request_payloads.is_empty() {
130            return Vec::new();
131        }
132
133        // Group response payloads by status code and structure
134        let mut variation_groups: HashMap<String, (usize, serde_json::Value, Option<u16>)> =
135            HashMap::new();
136
137        // Process response payloads (grouped by status code)
138        for (idx, payload) in response_payloads.iter().enumerate() {
139            let status_code = if idx < status_codes.len() {
140                Some(status_codes[idx])
141            } else {
142                None
143            };
144
145            // Create a structural signature (normalized JSON structure)
146            let signature = Self::payload_signature(payload);
147            let key = if let Some(code) = status_code {
148                format!("{}:{}", code, signature)
149            } else {
150                signature.clone()
151            };
152
153            let entry =
154                variation_groups.entry(key).or_insert_with(|| (0, payload.clone(), status_code));
155            entry.0 += 1;
156        }
157
158        // Process request payloads (if provided)
159        for payload in request_payloads {
160            let signature = Self::payload_signature(payload);
161            let key = format!("request:{}", signature);
162
163            let entry = variation_groups.entry(key).or_insert_with(|| (0, payload.clone(), None));
164            entry.0 += 1;
165        }
166
167        // Convert groups to PayloadVariation structs
168        let total_samples =
169            variation_groups.values().map(|(count, _, _)| *count).sum::<usize>() as f64;
170        if total_samples == 0.0 {
171            return Vec::new();
172        }
173
174        let mut variations = Vec::new();
175        for (idx, (_key, (count, sample, status_code))) in variation_groups.into_iter().enumerate()
176        {
177            let probability = count as f64 / total_samples;
178            let variation_id = format!("var_{}", idx);
179
180            let mut variation = PayloadVariation {
181                id: variation_id,
182                probability,
183                sample_payload: sample,
184                conditions: None,
185            };
186
187            // Add status code as a condition if present
188            if let Some(code) = status_code {
189                let mut conditions = HashMap::new();
190                conditions.insert("status_code".to_string(), code.to_string());
191                variation.conditions = Some(conditions);
192            }
193
194            variations.push(variation);
195        }
196
197        // Sort by probability (descending)
198        variations.sort_by(|a, b| {
199            b.probability.partial_cmp(&a.probability).unwrap_or(std::cmp::Ordering::Equal)
200        });
201
202        variations
203    }
204
205    /// Create a structural signature for a JSON payload
206    ///
207    /// Normalizes the payload to show only structure (keys, types) without values.
208    /// This allows grouping similar payloads together.
209    fn payload_signature(payload: &serde_json::Value) -> String {
210        match payload {
211            serde_json::Value::Object(map) => {
212                let mut keys: Vec<String> = map.keys().cloned().collect();
213                keys.sort();
214                let mut sig_parts = Vec::new();
215                for key in keys {
216                    if let Some(value) = map.get(&key) {
217                        let value_type = match value {
218                            serde_json::Value::Null => "null",
219                            serde_json::Value::Bool(_) => "bool",
220                            serde_json::Value::Number(_) => "number",
221                            serde_json::Value::String(_) => "string",
222                            serde_json::Value::Array(_) => "array",
223                            serde_json::Value::Object(_) => "object",
224                        };
225                        sig_parts.push(format!("{}:{}", key, value_type));
226                    }
227                }
228                format!("{{{}}}", sig_parts.join(","))
229            }
230            serde_json::Value::Array(arr) => {
231                if arr.is_empty() {
232                    "[]".to_string()
233                } else {
234                    // Use first element's structure as representative
235                    format!("[{}]", Self::payload_signature(&arr[0]))
236                }
237            }
238            _ => {
239                // Primitive value - use type
240                match payload {
241                    serde_json::Value::Null => "null",
242                    serde_json::Value::Bool(_) => "bool",
243                    serde_json::Value::Number(_) => "number",
244                    serde_json::Value::String(_) => "string",
245                    _ => "unknown",
246                }
247                .to_string()
248            }
249        }
250    }
251
252    /// Sample a status code based on learned distribution
253    pub fn sample_status_code(model: &EndpointProbabilityModel) -> u16 {
254        use rand::Rng;
255        let mut rng = rand::thread_rng();
256        let random: f64 = rng.gen_range(0.0..1.0);
257
258        let mut cumulative = 0.0;
259        for (status_code, probability) in &model.status_code_distribution {
260            cumulative += probability;
261            if random <= cumulative {
262                return *status_code;
263            }
264        }
265
266        // Fallback to most common status code
267        model
268            .status_code_distribution
269            .iter()
270            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
271            .map(|(code, _)| *code)
272            .unwrap_or(200)
273    }
274
275    /// Sample latency based on learned distribution
276    pub fn sample_latency(model: &EndpointProbabilityModel) -> u64 {
277        use rand::Rng;
278        let mut rng = rand::thread_rng();
279
280        // Use normal distribution approximation based on mean and std_dev
281        let mean = model.latency_distribution.mean;
282        let std_dev = model.latency_distribution.std_dev;
283
284        // Generate normal distribution sample using Box-Muller transform
285        let u1: f64 = rng.gen_range(0.0..1.0);
286        let u2: f64 = rng.gen_range(0.0..1.0);
287        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
288        let sample = mean + std_dev * z0;
289
290        // Clamp to min/max bounds
291        sample
292            .max(model.latency_distribution.min as f64)
293            .min(model.latency_distribution.max as f64) as u64
294    }
295
296    /// Sample an error pattern based on conditions
297    pub fn sample_error_pattern<'a>(
298        model: &'a EndpointProbabilityModel,
299        _conditions: Option<&HashMap<String, String>>,
300    ) -> Option<&'a ErrorPattern> {
301        use rand::Rng;
302        let mut rng = rand::thread_rng();
303        let random: f64 = rng.gen_range(0.0..1.0);
304
305        let mut cumulative = 0.0;
306        for pattern in &model.error_patterns {
307            cumulative += pattern.probability;
308            if random <= cumulative {
309                return Some(pattern);
310            }
311        }
312
313        None
314    }
315
316    /// Update model incrementally with new observations
317    pub fn update_model(
318        model: &mut EndpointProbabilityModel,
319        status_code: u16,
320        latency_ms: u64,
321        _error_pattern: Option<&ErrorPattern>,
322    ) {
323        // Update status code distribution
324        let total = model.sample_count as f64;
325        let new_total = total + 1.0;
326
327        // Update frequency for observed status code
328        for (_code, prob) in model.status_code_distribution.iter_mut() {
329            *prob = (*prob * total) / new_total;
330        }
331
332        let status_prob = model.status_code_distribution.entry(status_code).or_insert(0.0);
333        *status_prob = (*status_prob * total + 1.0) / new_total;
334
335        // Update latency distribution using Welford's online algorithm for variance
336        let latency = latency_ms as f64;
337        let old_mean = model.latency_distribution.mean;
338        let new_mean = (old_mean * total + latency) / new_total;
339        model.latency_distribution.mean = new_mean;
340
341        // Welford's online variance: update std_dev incrementally
342        // M2(n) = M2(n-1) + (x - old_mean) * (x - new_mean)
343        // variance = M2(n) / n
344        if total > 0.0 {
345            let old_variance = model.latency_distribution.std_dev.powi(2);
346            let old_m2 = old_variance * total;
347            let new_m2 = old_m2 + (latency - old_mean) * (latency - new_mean);
348            model.latency_distribution.std_dev = (new_m2 / new_total).sqrt();
349        } else {
350            model.latency_distribution.std_dev = 0.0;
351        }
352
353        // Update min/max
354        if latency_ms < model.latency_distribution.min {
355            model.latency_distribution.min = latency_ms;
356        }
357        if latency_ms > model.latency_distribution.max {
358            model.latency_distribution.max = latency_ms;
359        }
360
361        // Update percentile estimates using the P-square algorithm approximation.
362        // Move each percentile estimate toward the observed value when the observation
363        // is on the "correct" side, using a step proportional to 1/n for stability.
364        let step = 1.0 / new_total;
365        if latency_ms <= model.latency_distribution.p50 {
366            let delta = (model.latency_distribution.p50 as f64
367                - model.latency_distribution.min as f64)
368                * step;
369            model.latency_distribution.p50 =
370                (model.latency_distribution.p50 as f64 - delta).round() as u64;
371        } else {
372            let delta = (model.latency_distribution.max as f64
373                - model.latency_distribution.p50 as f64)
374                * step;
375            model.latency_distribution.p50 =
376                (model.latency_distribution.p50 as f64 + delta).round() as u64;
377        }
378
379        if latency_ms <= model.latency_distribution.p95 {
380            let delta = (model.latency_distribution.p95 as f64
381                - model.latency_distribution.min as f64)
382                * step
383                * 0.05; // Slower movement for high percentiles
384            model.latency_distribution.p95 =
385                (model.latency_distribution.p95 as f64 - delta).round() as u64;
386        } else {
387            let delta = (model.latency_distribution.max as f64
388                - model.latency_distribution.p95 as f64)
389                * step
390                * 0.95;
391            model.latency_distribution.p95 =
392                (model.latency_distribution.p95 as f64 + delta).round() as u64;
393        }
394
395        if latency_ms <= model.latency_distribution.p99 {
396            let delta = (model.latency_distribution.p99 as f64
397                - model.latency_distribution.min as f64)
398                * step
399                * 0.01;
400            model.latency_distribution.p99 =
401                (model.latency_distribution.p99 as f64 - delta).round() as u64;
402        } else {
403            let delta = (model.latency_distribution.max as f64
404                - model.latency_distribution.p99 as f64)
405                * step
406                * 0.99;
407            model.latency_distribution.p99 =
408                (model.latency_distribution.p99 as f64 + delta).round() as u64;
409        }
410
411        model.sample_count += 1;
412        model.updated_at = chrono::Utc::now();
413    }
414}