Skip to main content

respdiff/
learner.rs

1use crate::types::{ObservationOutcome, ProbeObservation, ProbeVariant, PropertyRole};
2use std::collections::{BTreeMap, HashMap};
3
4#[derive(Clone, Debug)]
5struct ProbeRecord {
6    properties: BTreeMap<String, String>,
7    signature: ObservationSignature,
8}
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash)]
11struct ObservationSignature {
12    outcome: ObservationOutcome,
13    error_key: String,
14    categories: Vec<String>,
15    timing_bucket: u8,
16    return_bucket: u8,
17}
18
19#[derive(Clone, Debug)]
20pub struct DifferentialLearner {
21    history: Vec<ProbeRecord>,
22    property_roles: HashMap<String, PropertyRole>,
23    best_shapes: Vec<BTreeMap<String, String>>,
24    paths: HashMap<ObservationSignature, Vec<BTreeMap<String, String>>>,
25    max_history: usize,
26    analyze_every: usize,
27}
28
29impl Default for DifferentialLearner {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl DifferentialLearner {
36    pub fn new() -> Self {
37        Self {
38            history: Vec::new(),
39            property_roles: HashMap::new(),
40            best_shapes: Vec::new(),
41            paths: HashMap::new(),
42            max_history: 10_000,
43            analyze_every: 100,
44        }
45    }
46
47    pub fn with_analyze_every(mut self, analyze_every: usize) -> Self {
48        self.analyze_every = analyze_every.max(1);
49        self
50    }
51
52    pub fn with_max_history(mut self, max_history: usize) -> Self {
53        self.max_history = max_history.max(2);
54        self
55    }
56
57    pub fn record(
58        &mut self,
59        properties: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
60        observation: ProbeObservation,
61    ) {
62        let properties: BTreeMap<String, String> = properties
63            .into_iter()
64            .map(|(key, value)| (key.into(), value.into()))
65            .collect();
66        if properties.is_empty() {
67            return;
68        }
69
70        let signature = build_signature(&observation);
71        self.paths
72            .entry(signature.clone())
73            .or_default()
74            .push(properties.clone());
75
76        if signature.outcome == ObservationOutcome::Match && self.best_shapes.len() < 50 {
77            self.best_shapes.push(properties.clone());
78        }
79
80        self.history.push(ProbeRecord {
81            properties,
82            signature,
83        });
84
85        if self.history.len() % self.analyze_every == 0 {
86            self.analyze();
87        }
88
89        if self.history.len() > self.max_history {
90            self.compact();
91        }
92    }
93
94    pub fn analyze(&mut self) {
95        let mut property_outcomes: HashMap<
96            String,
97            HashMap<String, HashMap<ObservationOutcome, u32>>,
98        > = HashMap::new();
99
100        for record in &self.history {
101            for (key, value) in &record.properties {
102                let outcomes = property_outcomes
103                    .entry(key.clone())
104                    .or_default()
105                    .entry(value.clone())
106                    .or_default();
107                *outcomes
108                    .entry(record.signature.outcome.clone())
109                    .or_insert(0) += 1;
110            }
111        }
112
113        let total_match_rate = self
114            .history
115            .iter()
116            .filter(|record| record.signature.outcome == ObservationOutcome::Match)
117            .count() as f32
118            / self.history.len().max(1) as f32;
119
120        for (key, values) in property_outcomes {
121            if values.len() < 2 {
122                continue;
123            }
124
125            let mut gate_values = Vec::new();
126            let mut injectable = true;
127
128            for (value, outcomes) in values {
129                let total: u32 = outcomes.values().sum();
130                if total < 3 {
131                    continue;
132                }
133                let matches = outcomes
134                    .get(&ObservationOutcome::Match)
135                    .copied()
136                    .unwrap_or(0);
137                let rate = matches as f32 / total as f32;
138                let threshold = (total_match_rate * 2.0).max(0.1);
139                if rate >= threshold {
140                    gate_values.push(value.clone());
141                }
142                if (rate - total_match_rate).abs() > total_match_rate.max(0.05) {
143                    injectable = false;
144                }
145            }
146
147            if !gate_values.is_empty() {
148                gate_values.sort();
149                gate_values.dedup();
150                self.property_roles
151                    .insert(key, PropertyRole::Gate(gate_values));
152            } else if injectable {
153                self.property_roles.insert(key, PropertyRole::Injectable);
154            }
155        }
156    }
157
158    pub fn property_roles(&self) -> &HashMap<String, PropertyRole> {
159        &self.property_roles
160    }
161
162    pub fn gates_found(&self) -> usize {
163        self.property_roles
164            .values()
165            .filter(|role| matches!(role, PropertyRole::Gate(_)))
166            .count()
167    }
168
169    pub fn injectables_found(&self) -> usize {
170        self.property_roles
171            .values()
172            .filter(|role| matches!(role, PropertyRole::Injectable))
173            .count()
174    }
175
176    pub fn paths_found(&self) -> usize {
177        self.paths.len()
178    }
179
180    pub fn dangerous_path_count(&self) -> usize {
181        self.paths
182            .keys()
183            .filter(|signature| signature.outcome == ObservationOutcome::Match)
184            .count()
185    }
186
187    pub fn generate_variants(&self, payloads: &[impl AsRef<str>]) -> Vec<ProbeVariant> {
188        let mut variants = Vec::new();
189
190        for shape in self.best_shapes.iter().take(5) {
191            let gates: Vec<(&String, &String)> = shape
192                .iter()
193                .filter(|(key, _)| {
194                    matches!(self.property_roles.get(*key), Some(PropertyRole::Gate(_)))
195                })
196                .collect();
197            let injectables: Vec<(&String, &String)> = shape
198                .iter()
199                .filter(|(key, _)| {
200                    matches!(
201                        self.property_roles.get(*key),
202                        Some(PropertyRole::Injectable) | None
203                    )
204                })
205                .collect();
206
207            for payload in payloads.iter().take(5) {
208                let payload = payload.as_ref();
209                let mut properties = BTreeMap::new();
210                for (key, value) in &gates {
211                    properties.insert((*key).clone(), (*value).clone());
212                }
213                for (key, _) in &injectables {
214                    properties.insert((*key).clone(), payload.to_string());
215                }
216                variants.push(ProbeVariant {
217                    properties,
218                    reason: format!("replay successful shape with payload `{payload}`"),
219                });
220            }
221        }
222
223        for (key, role) in &self.property_roles {
224            if let PropertyRole::Gate(known_values) = role {
225                for value in known_values.iter().take(3) {
226                    for variant in [
227                        format!("{value}2"),
228                        value.to_uppercase(),
229                        value.to_lowercase(),
230                        format!("_{value}"),
231                    ] {
232                        let mut properties = BTreeMap::new();
233                        properties.insert(key.clone(), variant);
234                        variants.push(ProbeVariant {
235                            properties,
236                            reason: format!("explore nearby gate value for `{key}`"),
237                        });
238                    }
239                }
240            }
241        }
242
243        variants
244    }
245
246    fn compact(&mut self) {
247        let midpoint = self.history.len() / 2;
248        let mut kept = self.history[midpoint..].to_vec();
249        kept.extend(
250            self.history[..midpoint]
251                .iter()
252                .filter(|record| record.signature.outcome != ObservationOutcome::Silent)
253                .cloned(),
254        );
255        self.history = kept;
256    }
257}
258
259fn build_signature(observation: &ProbeObservation) -> ObservationSignature {
260    let error_key = observation
261        .error
262        .as_deref()
263        .map(|error| error.to_lowercase().chars().take(80).collect::<String>())
264        .unwrap_or_default();
265
266    let timing_bucket = match observation.elapsed.as_micros() {
267        0..=10_000 => 0,
268        10_001..=100_000 => 1,
269        100_001..=1_000_000 => 2,
270        _ => 3,
271    };
272
273    let return_bucket = match observation.return_value.as_deref() {
274        None | Some("") | Some("undefined") | Some("null") => 0,
275        Some(value) if value.starts_with('{') => 1,
276        Some(value) if value.starts_with('[') => 2,
277        Some(value) if value.starts_with('"') => 3,
278        Some(value) if value.parse::<f64>().is_ok() => 4,
279        Some("true") | Some("false") => 5,
280        Some(value) if value.len() > 100 => 6,
281        Some(_) => 7,
282    };
283
284    ObservationSignature {
285        outcome: observation.outcome.clone(),
286        error_key,
287        categories: observation.categories.clone(),
288        timing_bucket,
289        return_bucket,
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::time::Duration;
297
298    fn matched() -> ProbeObservation {
299        ProbeObservation::matched(Duration::from_millis(10), ["ok"])
300    }
301
302    fn silent() -> ProbeObservation {
303        ProbeObservation::silent(Duration::from_millis(10))
304    }
305
306    #[test]
307    fn new_learner_starts_empty() {
308        let learner = DifferentialLearner::new();
309        assert!(learner.history.is_empty());
310        assert!(learner.property_roles.is_empty());
311        assert!(learner.best_shapes.is_empty());
312        assert_eq!(learner.max_history, 10_000);
313        assert_eq!(learner.analyze_every, 100);
314    }
315
316    #[test]
317    fn with_analyze_every_clamps_to_one() {
318        let learner = DifferentialLearner::new().with_analyze_every(0);
319        assert_eq!(learner.analyze_every, 1);
320    }
321
322    #[test]
323    fn with_max_history_clamps_to_two() {
324        let learner = DifferentialLearner::new().with_max_history(1);
325        assert_eq!(learner.max_history, 2);
326    }
327
328    #[test]
329    fn record_ignores_empty_properties() {
330        let mut learner = DifferentialLearner::new();
331        learner.record(Vec::<(String, String)>::new(), matched());
332        assert!(learner.history.is_empty());
333        assert!(learner.paths.is_empty());
334    }
335
336    #[test]
337    fn analyze_finds_gate_property() {
338        let mut learner = DifferentialLearner::new();
339        for _ in 0..3 {
340            learner.record([("role", "admin")], matched());
341            learner.record([("role", "guest")], silent());
342        }
343        learner.analyze();
344        assert_eq!(
345            learner.property_roles.get("role"),
346            Some(&PropertyRole::Gate(vec!["admin".to_string()]))
347        );
348        assert_eq!(learner.gates_found(), 1);
349    }
350
351    #[test]
352    fn analyze_finds_injectable_property() {
353        let mut learner = DifferentialLearner::new();
354        for value in ["a", "b", "c"] {
355            for _ in 0..3 {
356                learner.record([("input", value)], matched());
357            }
358        }
359        learner.analyze();
360        assert_eq!(
361            learner.property_roles.get("input"),
362            Some(&PropertyRole::Injectable)
363        );
364        assert_eq!(learner.injectables_found(), 1);
365    }
366
367    #[test]
368    fn paths_and_dangerous_counts_track_signatures() {
369        let mut learner = DifferentialLearner::new();
370        learner.record([("role", "admin")], matched());
371        learner.record([("role", "guest")], silent());
372        assert_eq!(learner.paths_found(), 2);
373        assert_eq!(learner.dangerous_path_count(), 1);
374    }
375
376    #[test]
377    fn generate_variants_reuses_best_shapes_and_gate_values() {
378        let mut learner = DifferentialLearner::new();
379        learner.best_shapes.push(
380            [
381                ("role".to_string(), "admin".to_string()),
382                ("input".to_string(), "safe".to_string()),
383            ]
384            .into_iter()
385            .collect(),
386        );
387        learner.property_roles.insert(
388            "role".to_string(),
389            PropertyRole::Gate(vec!["admin".to_string()]),
390        );
391        learner
392            .property_roles
393            .insert("input".to_string(), PropertyRole::Injectable);
394
395        let variants = learner.generate_variants(&["PAYLOAD"]);
396        assert!(variants.iter().any(|variant| {
397            variant.properties.get("role") == Some(&"admin".to_string())
398                && variant.properties.get("input") == Some(&"PAYLOAD".to_string())
399        }));
400        assert!(variants
401            .iter()
402            .any(|variant| variant.reason.contains("gate value")));
403    }
404
405    #[test]
406    fn compact_discards_old_silent_history_first() {
407        let mut learner = DifferentialLearner::new();
408        learner.max_history = 2;
409        learner.record([("role", "admin")], silent());
410        learner.record(
411            [("role", "guest")],
412            ProbeObservation::error(Duration::from_millis(5), "x"),
413        );
414        learner.record([("role", "user")], silent());
415        assert!(learner.history.len() <= 2);
416        assert!(learner
417            .history
418            .iter()
419            .any(|record| record.signature.outcome == ObservationOutcome::Error));
420    }
421
422    #[test]
423    fn build_signature_buckets_elapsed_and_return_values() {
424        let mut observation = matched();
425        observation.elapsed = Duration::from_secs(2);
426        observation.return_value = Some("{\"ok\":true}".to_string());
427        let signature = build_signature(&observation);
428        assert_eq!(signature.timing_bucket, 3);
429        assert_eq!(signature.return_bucket, 1);
430        assert_eq!(signature.categories, vec!["ok".to_string()]);
431    }
432}