1use crate::types::{ObservationOutcome, ProbeObservation, ProbeVariant, PropertyRole};
2use std::collections::{BTreeMap, HashMap};
3
4#[derive(Clone, Debug)]
5struct ProbeRecord {
6 properties: BTreeMap<String, String>,
7 signature: ObservationSignature,
8}
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash)]
11struct ObservationSignature {
12 outcome: ObservationOutcome,
13 error_key: String,
14 categories: Vec<String>,
15 timing_bucket: u8,
16 return_bucket: u8,
17}
18
19#[derive(Clone, Debug)]
20pub struct DifferentialLearner {
21 history: Vec<ProbeRecord>,
22 property_roles: HashMap<String, PropertyRole>,
23 best_shapes: Vec<BTreeMap<String, String>>,
24 paths: HashMap<ObservationSignature, Vec<BTreeMap<String, String>>>,
25 max_history: usize,
26 analyze_every: usize,
27}
28
29impl Default for DifferentialLearner {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl DifferentialLearner {
36 pub fn new() -> Self {
37 Self {
38 history: Vec::new(),
39 property_roles: HashMap::new(),
40 best_shapes: Vec::new(),
41 paths: HashMap::new(),
42 max_history: 10_000,
43 analyze_every: 100,
44 }
45 }
46
47 pub fn with_analyze_every(mut self, analyze_every: usize) -> Self {
48 self.analyze_every = analyze_every.max(1);
49 self
50 }
51
52 pub fn with_max_history(mut self, max_history: usize) -> Self {
53 self.max_history = max_history.max(2);
54 self
55 }
56
57 pub fn record(
58 &mut self,
59 properties: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
60 observation: ProbeObservation,
61 ) {
62 let properties: BTreeMap<String, String> = properties
63 .into_iter()
64 .map(|(key, value)| (key.into(), value.into()))
65 .collect();
66 if properties.is_empty() {
67 return;
68 }
69
70 let signature = build_signature(&observation);
71 self.paths
72 .entry(signature.clone())
73 .or_default()
74 .push(properties.clone());
75
76 if signature.outcome == ObservationOutcome::Match && self.best_shapes.len() < 50 {
77 self.best_shapes.push(properties.clone());
78 }
79
80 self.history.push(ProbeRecord {
81 properties,
82 signature,
83 });
84
85 if self.history.len() % self.analyze_every == 0 {
86 self.analyze();
87 }
88
89 if self.history.len() > self.max_history {
90 self.compact();
91 }
92 }
93
94 pub fn analyze(&mut self) {
95 let mut property_outcomes: HashMap<
96 String,
97 HashMap<String, HashMap<ObservationOutcome, u32>>,
98 > = HashMap::new();
99
100 for record in &self.history {
101 for (key, value) in &record.properties {
102 let outcomes = property_outcomes
103 .entry(key.clone())
104 .or_default()
105 .entry(value.clone())
106 .or_default();
107 *outcomes
108 .entry(record.signature.outcome.clone())
109 .or_insert(0) += 1;
110 }
111 }
112
113 let total_match_rate = self
114 .history
115 .iter()
116 .filter(|record| record.signature.outcome == ObservationOutcome::Match)
117 .count() as f32
118 / self.history.len().max(1) as f32;
119
120 for (key, values) in property_outcomes {
121 if values.len() < 2 {
122 continue;
123 }
124
125 let mut gate_values = Vec::new();
126 let mut injectable = true;
127
128 for (value, outcomes) in values {
129 let total: u32 = outcomes.values().sum();
130 if total < 3 {
131 continue;
132 }
133 let matches = outcomes
134 .get(&ObservationOutcome::Match)
135 .copied()
136 .unwrap_or(0);
137 let rate = matches as f32 / total as f32;
138 let threshold = (total_match_rate * 2.0).max(0.1);
139 if rate >= threshold {
140 gate_values.push(value.clone());
141 }
142 if (rate - total_match_rate).abs() > total_match_rate.max(0.05) {
143 injectable = false;
144 }
145 }
146
147 if !gate_values.is_empty() {
148 gate_values.sort();
149 gate_values.dedup();
150 self.property_roles
151 .insert(key, PropertyRole::Gate(gate_values));
152 } else if injectable {
153 self.property_roles.insert(key, PropertyRole::Injectable);
154 }
155 }
156 }
157
158 pub fn property_roles(&self) -> &HashMap<String, PropertyRole> {
159 &self.property_roles
160 }
161
162 pub fn gates_found(&self) -> usize {
163 self.property_roles
164 .values()
165 .filter(|role| matches!(role, PropertyRole::Gate(_)))
166 .count()
167 }
168
169 pub fn injectables_found(&self) -> usize {
170 self.property_roles
171 .values()
172 .filter(|role| matches!(role, PropertyRole::Injectable))
173 .count()
174 }
175
176 pub fn paths_found(&self) -> usize {
177 self.paths.len()
178 }
179
180 pub fn dangerous_path_count(&self) -> usize {
181 self.paths
182 .keys()
183 .filter(|signature| signature.outcome == ObservationOutcome::Match)
184 .count()
185 }
186
187 pub fn generate_variants(&self, payloads: &[impl AsRef<str>]) -> Vec<ProbeVariant> {
188 let mut variants = Vec::new();
189
190 for shape in self.best_shapes.iter().take(5) {
191 let gates: Vec<(&String, &String)> = shape
192 .iter()
193 .filter(|(key, _)| {
194 matches!(self.property_roles.get(*key), Some(PropertyRole::Gate(_)))
195 })
196 .collect();
197 let injectables: Vec<(&String, &String)> = shape
198 .iter()
199 .filter(|(key, _)| {
200 matches!(
201 self.property_roles.get(*key),
202 Some(PropertyRole::Injectable) | None
203 )
204 })
205 .collect();
206
207 for payload in payloads.iter().take(5) {
208 let payload = payload.as_ref();
209 let mut properties = BTreeMap::new();
210 for (key, value) in &gates {
211 properties.insert((*key).clone(), (*value).clone());
212 }
213 for (key, _) in &injectables {
214 properties.insert((*key).clone(), payload.to_string());
215 }
216 variants.push(ProbeVariant {
217 properties,
218 reason: format!("replay successful shape with payload `{payload}`"),
219 });
220 }
221 }
222
223 for (key, role) in &self.property_roles {
224 if let PropertyRole::Gate(known_values) = role {
225 for value in known_values.iter().take(3) {
226 for variant in [
227 format!("{value}2"),
228 value.to_uppercase(),
229 value.to_lowercase(),
230 format!("_{value}"),
231 ] {
232 let mut properties = BTreeMap::new();
233 properties.insert(key.clone(), variant);
234 variants.push(ProbeVariant {
235 properties,
236 reason: format!("explore nearby gate value for `{key}`"),
237 });
238 }
239 }
240 }
241 }
242
243 variants
244 }
245
246 fn compact(&mut self) {
247 let midpoint = self.history.len() / 2;
248 let mut kept = self.history[midpoint..].to_vec();
249 kept.extend(
250 self.history[..midpoint]
251 .iter()
252 .filter(|record| record.signature.outcome != ObservationOutcome::Silent)
253 .cloned(),
254 );
255 self.history = kept;
256 }
257}
258
259fn build_signature(observation: &ProbeObservation) -> ObservationSignature {
260 let error_key = observation
261 .error
262 .as_deref()
263 .map(|error| error.to_lowercase().chars().take(80).collect::<String>())
264 .unwrap_or_default();
265
266 let timing_bucket = match observation.elapsed.as_micros() {
267 0..=10_000 => 0,
268 10_001..=100_000 => 1,
269 100_001..=1_000_000 => 2,
270 _ => 3,
271 };
272
273 let return_bucket = match observation.return_value.as_deref() {
274 None | Some("") | Some("undefined") | Some("null") => 0,
275 Some(value) if value.starts_with('{') => 1,
276 Some(value) if value.starts_with('[') => 2,
277 Some(value) if value.starts_with('"') => 3,
278 Some(value) if value.parse::<f64>().is_ok() => 4,
279 Some("true") | Some("false") => 5,
280 Some(value) if value.len() > 100 => 6,
281 Some(_) => 7,
282 };
283
284 ObservationSignature {
285 outcome: observation.outcome.clone(),
286 error_key,
287 categories: observation.categories.clone(),
288 timing_bucket,
289 return_bucket,
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::time::Duration;
297
298 fn matched() -> ProbeObservation {
299 ProbeObservation::matched(Duration::from_millis(10), ["ok"])
300 }
301
302 fn silent() -> ProbeObservation {
303 ProbeObservation::silent(Duration::from_millis(10))
304 }
305
306 #[test]
307 fn new_learner_starts_empty() {
308 let learner = DifferentialLearner::new();
309 assert!(learner.history.is_empty());
310 assert!(learner.property_roles.is_empty());
311 assert!(learner.best_shapes.is_empty());
312 assert_eq!(learner.max_history, 10_000);
313 assert_eq!(learner.analyze_every, 100);
314 }
315
316 #[test]
317 fn with_analyze_every_clamps_to_one() {
318 let learner = DifferentialLearner::new().with_analyze_every(0);
319 assert_eq!(learner.analyze_every, 1);
320 }
321
322 #[test]
323 fn with_max_history_clamps_to_two() {
324 let learner = DifferentialLearner::new().with_max_history(1);
325 assert_eq!(learner.max_history, 2);
326 }
327
328 #[test]
329 fn record_ignores_empty_properties() {
330 let mut learner = DifferentialLearner::new();
331 learner.record(Vec::<(String, String)>::new(), matched());
332 assert!(learner.history.is_empty());
333 assert!(learner.paths.is_empty());
334 }
335
336 #[test]
337 fn analyze_finds_gate_property() {
338 let mut learner = DifferentialLearner::new();
339 for _ in 0..3 {
340 learner.record([("role", "admin")], matched());
341 learner.record([("role", "guest")], silent());
342 }
343 learner.analyze();
344 assert_eq!(
345 learner.property_roles.get("role"),
346 Some(&PropertyRole::Gate(vec!["admin".to_string()]))
347 );
348 assert_eq!(learner.gates_found(), 1);
349 }
350
351 #[test]
352 fn analyze_finds_injectable_property() {
353 let mut learner = DifferentialLearner::new();
354 for value in ["a", "b", "c"] {
355 for _ in 0..3 {
356 learner.record([("input", value)], matched());
357 }
358 }
359 learner.analyze();
360 assert_eq!(
361 learner.property_roles.get("input"),
362 Some(&PropertyRole::Injectable)
363 );
364 assert_eq!(learner.injectables_found(), 1);
365 }
366
367 #[test]
368 fn paths_and_dangerous_counts_track_signatures() {
369 let mut learner = DifferentialLearner::new();
370 learner.record([("role", "admin")], matched());
371 learner.record([("role", "guest")], silent());
372 assert_eq!(learner.paths_found(), 2);
373 assert_eq!(learner.dangerous_path_count(), 1);
374 }
375
376 #[test]
377 fn generate_variants_reuses_best_shapes_and_gate_values() {
378 let mut learner = DifferentialLearner::new();
379 learner.best_shapes.push(
380 [
381 ("role".to_string(), "admin".to_string()),
382 ("input".to_string(), "safe".to_string()),
383 ]
384 .into_iter()
385 .collect(),
386 );
387 learner.property_roles.insert(
388 "role".to_string(),
389 PropertyRole::Gate(vec!["admin".to_string()]),
390 );
391 learner
392 .property_roles
393 .insert("input".to_string(), PropertyRole::Injectable);
394
395 let variants = learner.generate_variants(&["PAYLOAD"]);
396 assert!(variants.iter().any(|variant| {
397 variant.properties.get("role") == Some(&"admin".to_string())
398 && variant.properties.get("input") == Some(&"PAYLOAD".to_string())
399 }));
400 assert!(variants
401 .iter()
402 .any(|variant| variant.reason.contains("gate value")));
403 }
404
405 #[test]
406 fn compact_discards_old_silent_history_first() {
407 let mut learner = DifferentialLearner::new();
408 learner.max_history = 2;
409 learner.record([("role", "admin")], silent());
410 learner.record(
411 [("role", "guest")],
412 ProbeObservation::error(Duration::from_millis(5), "x"),
413 );
414 learner.record([("role", "user")], silent());
415 assert!(learner.history.len() <= 2);
416 assert!(learner
417 .history
418 .iter()
419 .any(|record| record.signature.outcome == ObservationOutcome::Error));
420 }
421
422 #[test]
423 fn build_signature_buckets_elapsed_and_return_values() {
424 let mut observation = matched();
425 observation.elapsed = Duration::from_secs(2);
426 observation.return_value = Some("{\"ok\":true}".to_string());
427 let signature = build_signature(&observation);
428 assert_eq!(signature.timing_bucket, 3);
429 assert_eq!(signature.return_bucket, 1);
430 assert_eq!(signature.categories, vec!["ok".to_string()]);
431 }
432}