1use crate::hypothesis::{Confidence, HypothesisBoard, HypothesisId};
4use crate::belief::BeliefGraph;
5use crate::errors::Result;
6
7#[derive(Clone, Debug)]
9pub struct PropagationConfig {
10 pub decay_factor: f64,
12 pub min_confidence: f64,
14 pub max_cascade_size: usize,
16}
17
18impl Default for PropagationConfig {
19 fn default() -> Self {
20 Self {
21 decay_factor: 0.95,
22 min_confidence: 0.1,
23 max_cascade_size: 10000,
24 }
25 }
26}
27
28#[derive(Clone, Debug)]
30pub struct PropagationResult {
31 pub changes: Vec<ConfidenceChange>,
32 pub cycles_detected: bool,
33 pub normalized_cycles: usize,
34 pub total_affected: usize,
35 pub max_depth: usize,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40pub struct ConfidenceChange {
41 pub hypothesis_id: HypothesisId,
42 pub hypothesis_name: String,
43 pub old_confidence: Confidence,
44 pub new_confidence: Confidence,
45 pub delta: f64,
46 pub depth: usize,
47 pub propagation_path: Vec<HypothesisId>,
48}
49
50#[derive(Debug, thiserror::Error)]
52pub enum CascadeError {
53 #[error("Cascade too large: {size} hypotheses affected, limit is {limit}")]
54 CascadeTooLarge { size: usize, limit: usize },
55 #[error("Hypothesis not found: {0}")]
56 HypothesisNotFound(HypothesisId),
57 #[error("Confidence validation error: {0}")]
58 ConfidenceError(#[from] crate::hypothesis::ConfidenceError),
59 #[error("Cycle normalization failed")]
60 CycleNormalizationFailed,
61 #[error("Graph error: {0}")]
62 GraphError(String),
63}
64
65pub async fn compute_cascade(
67 start: HypothesisId,
68 new_confidence: Confidence,
69 board: &HypothesisBoard,
70 graph: &BeliefGraph,
71 config: &PropagationConfig,
72) -> std::result::Result<PropagationResult, CascadeError> {
73 use std::collections::{HashSet, VecDeque};
74
75 board.get(start).await
77 .map_err(|e| CascadeError::GraphError(e.to_string()))?
78 .ok_or(CascadeError::HypothesisNotFound(start))?;
79
80 let mut changes = Vec::new();
81 let mut visited = HashSet::new();
82 let mut queue = VecDeque::new();
83
84 queue.push_back((start, vec![start], 0));
86 visited.insert(start);
87
88 while let Some((current_id, path, depth)) = queue.pop_front() {
90 if visited.len() > config.max_cascade_size {
92 return Err(CascadeError::CascadeTooLarge {
93 size: visited.len(),
94 limit: config.max_cascade_size,
95 });
96 }
97
98 let hypothesis = board.get(current_id).await
100 .map_err(|e| CascadeError::GraphError(e.to_string()))?
101 .ok_or(CascadeError::HypothesisNotFound(current_id))?;
102
103 let old_confidence = hypothesis.current_confidence();
104 let name = hypothesis.statement().to_string();
105
106 let decay_factor = config.decay_factor.powi(depth as i32);
108 let decayed_value = new_confidence.get() * decay_factor;
109 let decayed_value = decayed_value.max(config.min_confidence);
110 let new_conf_for_hyp = Confidence::new(decayed_value)?;
111
112 let delta = new_conf_for_hyp.get() - old_confidence.get();
114 changes.push(ConfidenceChange {
115 hypothesis_id: current_id,
116 hypothesis_name: name,
117 old_confidence,
118 new_confidence: new_conf_for_hyp,
119 delta,
120 depth,
121 propagation_path: path.clone(),
122 });
123
124 if let Ok(dependents) = graph.dependents(current_id) {
126 for dependent in dependents {
127 if !visited.contains(&dependent) {
128 visited.insert(dependent);
129 let mut new_path = path.clone();
130 new_path.push(dependent);
131 queue.push_back((dependent, new_path, depth + 1));
132 }
133 }
134 }
135 }
136
137 let cycles_detected = graph.detect_cycles();
139 let has_cycles = !cycles_detected.is_empty();
140
141 let max_depth = changes.iter()
142 .map(|c| c.depth)
143 .max()
144 .unwrap_or(0);
145
146 Ok(PropagationResult {
147 changes,
148 cycles_detected: has_cycles,
149 normalized_cycles: 0,
150 total_affected: visited.len(),
151 max_depth,
152 })
153}
154
155pub fn normalize_cycles(
157 changes: &mut [ConfidenceChange],
158 graph: &BeliefGraph,
159) -> std::result::Result<usize, CascadeError> {
160 use std::collections::HashSet;
161
162 let cycles = graph.detect_cycles();
163 let mut normalized_count = 0;
164
165 for cycle in cycles {
166 if cycle.len() <= 1 {
167 continue; }
169
170 let cycle_ids: HashSet<_> = cycle.iter().cloned().collect();
172
173 let cycle_changes: Vec<_> = changes.iter()
175 .filter(|c| cycle_ids.contains(&c.hypothesis_id))
176 .collect();
177
178 if cycle_changes.is_empty() {
179 continue;
180 }
181
182 let avg_confidence: f64 = cycle_changes.iter()
183 .map(|c| c.new_confidence.get())
184 .sum::<f64>() / cycle_changes.len() as f64;
185
186 for change in changes.iter_mut() {
188 if cycle_ids.contains(&change.hypothesis_id) {
189 let avg_conf = Confidence::new(avg_confidence)?;
190 change.new_confidence = avg_conf;
191 change.delta = avg_conf.get() - change.old_confidence.get();
192 }
193 }
194
195 normalized_count += 1;
196 }
197
198 Ok(normalized_count)
199}
200
201pub async fn propagate_confidence(
203 result: PropagationResult,
204 board: &std::sync::Arc<HypothesisBoard>,
205) -> Result<()> {
206 for change in result.changes {
207 if board.get(change.hypothesis_id).await?.is_none() {
209 continue;
210 }
211 board.update_confidence_direct(change.hypothesis_id, change.new_confidence).await?;
212 }
213 Ok(())
214}
215
216pub async fn impact_radius(
218 start: HypothesisId,
219 graph: &BeliefGraph,
220) -> Result<usize> {
221 use std::collections::{HashSet, VecDeque};
222
223 let mut visited = HashSet::new();
224 let mut queue = VecDeque::new();
225
226 queue.push_back(start);
227 visited.insert(start);
228
229 while let Some(current) = queue.pop_front() {
230 if let Ok(dependents) = graph.dependents(current) {
231 for dependent in dependents {
232 if !visited.contains(&dependent) {
233 visited.insert(dependent);
234 queue.push_back(dependent);
235 }
236 }
237 }
238 }
239
240 Ok(visited.len())
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_propagation_config_default() {
249 let config = PropagationConfig::default();
250 assert_eq!(config.decay_factor, 0.95);
251 assert_eq!(config.min_confidence, 0.1);
252 assert_eq!(config.max_cascade_size, 10000);
253 }
254
255 #[test]
256 fn test_cascade_error_display() {
257 let err = CascadeError::CascadeTooLarge { size: 100, limit: 50 };
258 assert!(err.to_string().contains("100"));
259 assert!(err.to_string().contains("50"));
260 }
261
262 #[tokio::test]
263 async fn test_compute_cascade_linear_chain() {
264 let board = HypothesisBoard::in_memory();
265 let mut graph = BeliefGraph::new();
266
267 let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
269 let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
270 let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
271
272 graph.add_dependency(h_b, h_c).unwrap();
274 graph.add_dependency(h_a, h_b).unwrap();
275
276 let new_conf = Confidence::new(0.9).unwrap();
278 let config = PropagationConfig::default();
279 let result = compute_cascade(h_c, new_conf, &board, &graph, &config).await.unwrap();
280
281 assert_eq!(result.total_affected, 3);
283 assert_eq!(result.changes.len(), 3);
284
285 let change_c = &result.changes[0];
287 let change_b = result.changes.iter().find(|c| c.hypothesis_id == h_b).unwrap();
288 let change_a = result.changes.iter().find(|c| c.hypothesis_id == h_a).unwrap();
289
290 assert_eq!(change_c.depth, 0);
291 assert_eq!(change_b.depth, 1);
292 assert_eq!(change_a.depth, 2);
293
294 assert!((change_b.new_confidence.get() - 0.855).abs() < 0.01); assert!((change_a.new_confidence.get() - 0.81225).abs() < 0.01); }
298
299 #[tokio::test]
300 async fn test_compute_cascade_min_confidence_floor() {
301 let board = HypothesisBoard::in_memory();
302 let mut graph = BeliefGraph::new();
303
304 let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
306 let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
307 let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
308
309 graph.add_dependency(h_b, h_c).unwrap();
310 graph.add_dependency(h_a, h_b).unwrap();
311
312 let new_conf = Confidence::new(0.2).unwrap();
314 let config = PropagationConfig {
315 decay_factor: 0.5,
316 min_confidence: 0.15,
317 max_cascade_size: 1000,
318 };
319
320 let result = compute_cascade(h_c, new_conf, &board, &graph, &config).await.unwrap();
321
322 let change_a = result.changes.iter().find(|c| c.hypothesis_id == h_a).unwrap();
324 assert!(change_a.new_confidence.get() >= 0.15);
325 }
326
327 #[tokio::test]
328 async fn test_cascade_too_large_error() {
329 let board = HypothesisBoard::in_memory();
330 let mut graph = BeliefGraph::new();
331
332 let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
334 let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
335 graph.add_dependency(h_b, h_a).unwrap();
336
337 let new_conf = Confidence::new(0.9).unwrap();
339 let config = PropagationConfig {
340 decay_factor: 0.95,
341 min_confidence: 0.1,
342 max_cascade_size: 1, };
344
345 let result = compute_cascade(h_a, new_conf, &board, &graph, &config).await;
346 assert!(matches!(result, Err(CascadeError::CascadeTooLarge { .. })));
347 }
348
349 #[tokio::test]
350 async fn test_hypothesis_not_found_error() {
351 let board = HypothesisBoard::in_memory();
352 let graph = BeliefGraph::new();
353 let non_existent = HypothesisId::new();
354
355 let new_conf = Confidence::new(0.9).unwrap();
356 let config = PropagationConfig::default();
357
358 let result = compute_cascade(non_existent, new_conf, &board, &graph, &config).await;
359 assert!(matches!(result, Err(CascadeError::HypothesisNotFound(_))));
360 }
361
362 #[tokio::test]
363 async fn test_normalize_cycles() {
364 let board = HypothesisBoard::in_memory();
365 let graph = BeliefGraph::new();
366
367 let h_a = HypothesisId::new();
369 let h_b = HypothesisId::new();
370
371 let mut changes = vec![
372 ConfidenceChange {
373 hypothesis_id: h_a,
374 hypothesis_name: "A".to_string(),
375 old_confidence: Confidence::new(0.5).unwrap(),
376 new_confidence: Confidence::new(0.8).unwrap(),
377 delta: 0.3,
378 depth: 0,
379 propagation_path: vec![h_a],
380 },
381 ConfidenceChange {
382 hypothesis_id: h_b,
383 hypothesis_name: "B".to_string(),
384 old_confidence: Confidence::new(0.5).unwrap(),
385 new_confidence: Confidence::new(0.7).unwrap(),
386 delta: 0.2,
387 depth: 1,
388 propagation_path: vec![h_a, h_b],
389 },
390 ];
391
392 let normalized = normalize_cycles(&mut changes, &graph).unwrap();
394 assert_eq!(normalized, 0);
395
396 assert!((changes[0].new_confidence.get() - 0.8).abs() < 0.01);
398 assert!((changes[1].new_confidence.get() - 0.7).abs() < 0.01);
399 }
400
401 #[tokio::test]
402 async fn test_propagate_confidence() {
403 let board = std::sync::Arc::new(HypothesisBoard::in_memory());
404 let mut graph = BeliefGraph::new();
405
406 let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
408 let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
409
410 graph.add_dependency(h_b, h_a).unwrap();
411
412 let result = PropagationResult {
414 changes: vec![
415 ConfidenceChange {
416 hypothesis_id: h_a,
417 hypothesis_name: "A".to_string(),
418 old_confidence: Confidence::new(0.5).unwrap(),
419 new_confidence: Confidence::new(0.8).unwrap(),
420 delta: 0.3,
421 depth: 0,
422 propagation_path: vec![h_a],
423 },
424 ConfidenceChange {
425 hypothesis_id: h_b,
426 hypothesis_name: "B".to_string(),
427 old_confidence: Confidence::new(0.5).unwrap(),
428 new_confidence: Confidence::new(0.76).unwrap(),
429 delta: 0.26,
430 depth: 1,
431 propagation_path: vec![h_a, h_b],
432 },
433 ],
434 cycles_detected: false,
435 normalized_cycles: 0,
436 total_affected: 2,
437 max_depth: 1,
438 };
439
440 propagate_confidence(result, &board).await.unwrap();
442
443 let h_a_updated = board.get(h_a).await.unwrap().unwrap();
445 let h_b_updated = board.get(h_b).await.unwrap().unwrap();
446
447 assert!((h_a_updated.current_confidence().get() - 0.8).abs() < 0.01);
448 assert!((h_b_updated.current_confidence().get() - 0.76).abs() < 0.01);
449 }
450
451 #[tokio::test]
452 async fn test_impact_radius() {
453 let board = HypothesisBoard::in_memory();
454 let mut graph = BeliefGraph::new();
455
456 let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
458 let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
459 let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
460
461 graph.add_dependency(h_b, h_c).unwrap();
462 graph.add_dependency(h_a, h_b).unwrap();
463
464 let radius = impact_radius(h_c, &graph).await.unwrap();
466 assert_eq!(radius, 3);
467
468 let radius_a = impact_radius(h_a, &graph).await.unwrap();
470 assert_eq!(radius_a, 1);
471 }
472}