Skip to main content

forge_reasoning/impact/
propagation.rs

1//! Confidence propagation through dependency graphs
2
3use crate::hypothesis::{Confidence, HypothesisBoard, HypothesisId};
4use crate::belief::BeliefGraph;
5use crate::errors::Result;
6
7/// Configuration for confidence propagation
8#[derive(Clone, Debug)]
9pub struct PropagationConfig {
10    /// Decay factor per dependency level (default: 0.95)
11    pub decay_factor: f64,
12    /// Minimum confidence floor (default: 0.1)
13    pub min_confidence: f64,
14    /// Maximum cascade size limit (default: 10000)
15    pub max_cascade_size: usize,
16}
17
18impl Default for PropagationConfig {
19    fn default() -> Self {
20        Self {
21            decay_factor: 0.95,
22            min_confidence: 0.1,
23            max_cascade_size: 10000,
24        }
25    }
26}
27
28/// Result of confidence propagation
29#[derive(Clone, Debug)]
30pub struct PropagationResult {
31    pub changes: Vec<ConfidenceChange>,
32    pub cycles_detected: bool,
33    pub normalized_cycles: usize,
34    pub total_affected: usize,
35    pub max_depth: usize,
36}
37
38/// Represents a change in confidence for a single hypothesis
39#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40pub struct ConfidenceChange {
41    pub hypothesis_id: HypothesisId,
42    pub hypothesis_name: String,
43    pub old_confidence: Confidence,
44    pub new_confidence: Confidence,
45    pub delta: f64,
46    pub depth: usize,
47    pub propagation_path: Vec<HypothesisId>,
48}
49
50/// Error type for cascade operations
51#[derive(Debug, thiserror::Error)]
52pub enum CascadeError {
53    #[error("Cascade too large: {size} hypotheses affected, limit is {limit}")]
54    CascadeTooLarge { size: usize, limit: usize },
55    #[error("Hypothesis not found: {0}")]
56    HypothesisNotFound(HypothesisId),
57    #[error("Confidence validation error: {0}")]
58    ConfidenceError(#[from] crate::hypothesis::ConfidenceError),
59    #[error("Cycle normalization failed")]
60    CycleNormalizationFailed,
61    #[error("Graph error: {0}")]
62    GraphError(String),
63}
64
65/// Compute cascade impact with BFS traversal
66pub async fn compute_cascade(
67    start: HypothesisId,
68    new_confidence: Confidence,
69    board: &HypothesisBoard,
70    graph: &BeliefGraph,
71    config: &PropagationConfig,
72) -> std::result::Result<PropagationResult, CascadeError> {
73    use std::collections::{HashSet, VecDeque};
74
75    // Verify start hypothesis exists
76    board.get(start).await
77        .map_err(|e| CascadeError::GraphError(e.to_string()))?
78        .ok_or(CascadeError::HypothesisNotFound(start))?;
79
80    let mut changes = Vec::new();
81    let mut visited = HashSet::new();
82    let mut queue = VecDeque::new();
83
84    // Start node with depth 0
85    queue.push_back((start, vec![start], 0));
86    visited.insert(start);
87
88    // BFS traversal
89    while let Some((current_id, path, depth)) = queue.pop_front() {
90        // Check cascade size limit
91        if visited.len() > config.max_cascade_size {
92            return Err(CascadeError::CascadeTooLarge {
93                size: visited.len(),
94                limit: config.max_cascade_size,
95            });
96        }
97
98        // Get current hypothesis
99        let hypothesis = board.get(current_id).await
100            .map_err(|e| CascadeError::GraphError(e.to_string()))?
101            .ok_or(CascadeError::HypothesisNotFound(current_id))?;
102
103        let old_confidence = hypothesis.current_confidence();
104        let name = hypothesis.statement().to_string();
105
106        // Compute decayed confidence
107        let decay_factor = config.decay_factor.powi(depth as i32);
108        let decayed_value = new_confidence.get() * decay_factor;
109        let decayed_value = decayed_value.max(config.min_confidence);
110        let new_conf_for_hyp = Confidence::new(decayed_value)?;
111
112        // Create change record
113        let delta = new_conf_for_hyp.get() - old_confidence.get();
114        changes.push(ConfidenceChange {
115            hypothesis_id: current_id,
116            hypothesis_name: name,
117            old_confidence,
118            new_confidence: new_conf_for_hyp,
119            delta,
120            depth,
121            propagation_path: path.clone(),
122        });
123
124        // Get dependents (what depends on current hypothesis)
125        if let Ok(dependents) = graph.dependents(current_id) {
126            for dependent in dependents {
127                if !visited.contains(&dependent) {
128                    visited.insert(dependent);
129                    let mut new_path = path.clone();
130                    new_path.push(dependent);
131                    queue.push_back((dependent, new_path, depth + 1));
132                }
133            }
134        }
135    }
136
137    // Detect cycles
138    let cycles_detected = graph.detect_cycles();
139    let has_cycles = !cycles_detected.is_empty();
140
141    let max_depth = changes.iter()
142        .map(|c| c.depth)
143        .max()
144        .unwrap_or(0);
145
146    Ok(PropagationResult {
147        changes,
148        cycles_detected: has_cycles,
149        normalized_cycles: 0,
150        total_affected: visited.len(),
151        max_depth,
152    })
153}
154
155/// Normalize cycles to average confidence
156pub fn normalize_cycles(
157    changes: &mut [ConfidenceChange],
158    graph: &BeliefGraph,
159) -> std::result::Result<usize, CascadeError> {
160    use std::collections::HashSet;
161
162    let cycles = graph.detect_cycles();
163    let mut normalized_count = 0;
164
165    for cycle in cycles {
166        if cycle.len() <= 1 {
167            continue; // Not a real cycle
168        }
169
170        // Find changes belonging to this cycle
171        let cycle_ids: HashSet<_> = cycle.iter().cloned().collect();
172
173        // Compute average confidence for this cycle
174        let cycle_changes: Vec<_> = changes.iter()
175            .filter(|c| cycle_ids.contains(&c.hypothesis_id))
176            .collect();
177
178        if cycle_changes.is_empty() {
179            continue;
180        }
181
182        let avg_confidence: f64 = cycle_changes.iter()
183            .map(|c| c.new_confidence.get())
184            .sum::<f64>() / cycle_changes.len() as f64;
185
186        // Update all cycle members to average confidence
187        for change in changes.iter_mut() {
188            if cycle_ids.contains(&change.hypothesis_id) {
189                let avg_conf = Confidence::new(avg_confidence)?;
190                change.new_confidence = avg_conf;
191                change.delta = avg_conf.get() - change.old_confidence.get();
192            }
193        }
194
195        normalized_count += 1;
196    }
197
198    Ok(normalized_count)
199}
200
201/// Apply confidence propagation changes
202pub async fn propagate_confidence(
203    result: PropagationResult,
204    board: &std::sync::Arc<HypothesisBoard>,
205) -> Result<()> {
206    for change in result.changes {
207        // Skip if hypothesis was deleted
208        if board.get(change.hypothesis_id).await?.is_none() {
209            continue;
210        }
211        board.update_confidence_direct(change.hypothesis_id, change.new_confidence).await?;
212    }
213    Ok(())
214}
215
216/// Query impact radius without applying changes
217pub async fn impact_radius(
218    start: HypothesisId,
219    graph: &BeliefGraph,
220) -> Result<usize> {
221    use std::collections::{HashSet, VecDeque};
222
223    let mut visited = HashSet::new();
224    let mut queue = VecDeque::new();
225
226    queue.push_back(start);
227    visited.insert(start);
228
229    while let Some(current) = queue.pop_front() {
230        if let Ok(dependents) = graph.dependents(current) {
231            for dependent in dependents {
232                if !visited.contains(&dependent) {
233                    visited.insert(dependent);
234                    queue.push_back(dependent);
235                }
236            }
237        }
238    }
239
240    Ok(visited.len())
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_propagation_config_default() {
249        let config = PropagationConfig::default();
250        assert_eq!(config.decay_factor, 0.95);
251        assert_eq!(config.min_confidence, 0.1);
252        assert_eq!(config.max_cascade_size, 10000);
253    }
254
255    #[test]
256    fn test_cascade_error_display() {
257        let err = CascadeError::CascadeTooLarge { size: 100, limit: 50 };
258        assert!(err.to_string().contains("100"));
259        assert!(err.to_string().contains("50"));
260    }
261
262    #[tokio::test]
263    async fn test_compute_cascade_linear_chain() {
264        let board = HypothesisBoard::in_memory();
265        let mut graph = BeliefGraph::new();
266
267        // Create hypotheses: A -> B -> C (A depends on B, B depends on C)
268        let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
269        let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
270        let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
271
272        // Create dependencies
273        graph.add_dependency(h_b, h_c).unwrap();
274        graph.add_dependency(h_a, h_b).unwrap();
275
276        // Compute cascade from C with new confidence 0.9
277        let new_conf = Confidence::new(0.9).unwrap();
278        let config = PropagationConfig::default();
279        let result = compute_cascade(h_c, new_conf, &board, &graph, &config).await.unwrap();
280
281        // Should affect all 3 hypotheses
282        assert_eq!(result.total_affected, 3);
283        assert_eq!(result.changes.len(), 3);
284
285        // Check depth assignments
286        let change_c = &result.changes[0];
287        let change_b = result.changes.iter().find(|c| c.hypothesis_id == h_b).unwrap();
288        let change_a = result.changes.iter().find(|c| c.hypothesis_id == h_a).unwrap();
289
290        assert_eq!(change_c.depth, 0);
291        assert_eq!(change_b.depth, 1);
292        assert_eq!(change_a.depth, 2);
293
294        // Check decay application (0.95^1 = 0.95, 0.95^2 = 0.9025)
295        assert!((change_b.new_confidence.get() - 0.855).abs() < 0.01); // 0.9 * 0.95
296        assert!((change_a.new_confidence.get() - 0.81225).abs() < 0.01); // 0.9 * 0.95^2
297    }
298
299    #[tokio::test]
300    async fn test_compute_cascade_min_confidence_floor() {
301        let board = HypothesisBoard::in_memory();
302        let mut graph = BeliefGraph::new();
303
304        // Create hypotheses with deep chain
305        let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
306        let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
307        let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
308
309        graph.add_dependency(h_b, h_c).unwrap();
310        graph.add_dependency(h_a, h_b).unwrap();
311
312        // Set low confidence and high decay
313        let new_conf = Confidence::new(0.2).unwrap();
314        let config = PropagationConfig {
315            decay_factor: 0.5,
316            min_confidence: 0.15,
317            max_cascade_size: 1000,
318        };
319
320        let result = compute_cascade(h_c, new_conf, &board, &graph, &config).await.unwrap();
321
322        // A at depth 2: 0.2 * 0.5^2 = 0.05, but should be floored to 0.15
323        let change_a = result.changes.iter().find(|c| c.hypothesis_id == h_a).unwrap();
324        assert!(change_a.new_confidence.get() >= 0.15);
325    }
326
327    #[tokio::test]
328    async fn test_cascade_too_large_error() {
329        let board = HypothesisBoard::in_memory();
330        let mut graph = BeliefGraph::new();
331
332        // Create a small chain
333        let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
334        let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
335        graph.add_dependency(h_b, h_a).unwrap();
336
337        // Set very small cascade limit
338        let new_conf = Confidence::new(0.9).unwrap();
339        let config = PropagationConfig {
340            decay_factor: 0.95,
341            min_confidence: 0.1,
342            max_cascade_size: 1, // Only 1 hypothesis allowed
343        };
344
345        let result = compute_cascade(h_a, new_conf, &board, &graph, &config).await;
346        assert!(matches!(result, Err(CascadeError::CascadeTooLarge { .. })));
347    }
348
349    #[tokio::test]
350    async fn test_hypothesis_not_found_error() {
351        let board = HypothesisBoard::in_memory();
352        let graph = BeliefGraph::new();
353        let non_existent = HypothesisId::new();
354
355        let new_conf = Confidence::new(0.9).unwrap();
356        let config = PropagationConfig::default();
357
358        let result = compute_cascade(non_existent, new_conf, &board, &graph, &config).await;
359        assert!(matches!(result, Err(CascadeError::HypothesisNotFound(_))));
360    }
361
362    #[tokio::test]
363    async fn test_normalize_cycles() {
364        let board = HypothesisBoard::in_memory();
365        let graph = BeliefGraph::new();
366
367        // Create changes for testing (no actual cycle in graph)
368        let h_a = HypothesisId::new();
369        let h_b = HypothesisId::new();
370
371        let mut changes = vec![
372            ConfidenceChange {
373                hypothesis_id: h_a,
374                hypothesis_name: "A".to_string(),
375                old_confidence: Confidence::new(0.5).unwrap(),
376                new_confidence: Confidence::new(0.8).unwrap(),
377                delta: 0.3,
378                depth: 0,
379                propagation_path: vec![h_a],
380            },
381            ConfidenceChange {
382                hypothesis_id: h_b,
383                hypothesis_name: "B".to_string(),
384                old_confidence: Confidence::new(0.5).unwrap(),
385                new_confidence: Confidence::new(0.7).unwrap(),
386                delta: 0.2,
387                depth: 1,
388                propagation_path: vec![h_a, h_b],
389            },
390        ];
391
392        // Normalize cycles with empty graph (no cycles to normalize)
393        let normalized = normalize_cycles(&mut changes, &graph).unwrap();
394        assert_eq!(normalized, 0);
395
396        // Changes should remain unchanged
397        assert!((changes[0].new_confidence.get() - 0.8).abs() < 0.01);
398        assert!((changes[1].new_confidence.get() - 0.7).abs() < 0.01);
399    }
400
401    #[tokio::test]
402    async fn test_propagate_confidence() {
403        let board = std::sync::Arc::new(HypothesisBoard::in_memory());
404        let mut graph = BeliefGraph::new();
405
406        // Create hypotheses
407        let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
408        let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
409
410        graph.add_dependency(h_b, h_a).unwrap();
411
412        // Create propagation result
413        let result = PropagationResult {
414            changes: vec![
415                ConfidenceChange {
416                    hypothesis_id: h_a,
417                    hypothesis_name: "A".to_string(),
418                    old_confidence: Confidence::new(0.5).unwrap(),
419                    new_confidence: Confidence::new(0.8).unwrap(),
420                    delta: 0.3,
421                    depth: 0,
422                    propagation_path: vec![h_a],
423                },
424                ConfidenceChange {
425                    hypothesis_id: h_b,
426                    hypothesis_name: "B".to_string(),
427                    old_confidence: Confidence::new(0.5).unwrap(),
428                    new_confidence: Confidence::new(0.76).unwrap(),
429                    delta: 0.26,
430                    depth: 1,
431                    propagation_path: vec![h_a, h_b],
432                },
433            ],
434            cycles_detected: false,
435            normalized_cycles: 0,
436            total_affected: 2,
437            max_depth: 1,
438        };
439
440        // Apply propagation
441        propagate_confidence(result, &board).await.unwrap();
442
443        // Verify confidences were updated
444        let h_a_updated = board.get(h_a).await.unwrap().unwrap();
445        let h_b_updated = board.get(h_b).await.unwrap().unwrap();
446
447        assert!((h_a_updated.current_confidence().get() - 0.8).abs() < 0.01);
448        assert!((h_b_updated.current_confidence().get() - 0.76).abs() < 0.01);
449    }
450
451    #[tokio::test]
452    async fn test_impact_radius() {
453        let board = HypothesisBoard::in_memory();
454        let mut graph = BeliefGraph::new();
455
456        // Create chain: A depends on B, B depends on C
457        let h_c = board.propose("C", Confidence::new(0.5).unwrap()).await.unwrap();
458        let h_b = board.propose("B", Confidence::new(0.5).unwrap()).await.unwrap();
459        let h_a = board.propose("A", Confidence::new(0.5).unwrap()).await.unwrap();
460
461        graph.add_dependency(h_b, h_c).unwrap();
462        graph.add_dependency(h_a, h_b).unwrap();
463
464        // Impact radius from C should be 3 (C, B, A all affected)
465        let radius = impact_radius(h_c, &graph).await.unwrap();
466        assert_eq!(radius, 3);
467
468        // Impact radius from A should be 1 (only A affected)
469        let radius_a = impact_radius(h_a, &graph).await.unwrap();
470        assert_eq!(radius_a, 1);
471    }
472}