recall_echo/graph/
confidence.rs1use serde::{Deserialize, Serialize};
7
8const PSEUDOCOUNT: f64 = 10.0;
11
12#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ExtractionContext {
16 Explicit, Inferred, Speculative, Authoritative, }
21
22impl ExtractionContext {
23 pub fn prior(self) -> f64 {
25 match self {
26 Self::Authoritative => 1.0,
27 Self::Explicit => 0.9,
28 Self::Inferred => 0.6,
29 Self::Speculative => 0.3,
30 }
31 }
32}
33
34impl std::str::FromStr for ExtractionContext {
35 type Err = String;
36
37 fn from_str(s: &str) -> Result<Self, Self::Err> {
38 match s.to_lowercase().as_str() {
39 "explicit" => Ok(Self::Explicit),
40 "inferred" => Ok(Self::Inferred),
41 "speculative" => Ok(Self::Speculative),
42 "authoritative" => Ok(Self::Authoritative),
43 other => Err(format!("unknown extraction context: {}", other)),
44 }
45 }
46}
47
48pub fn bayesian_update(current_confidence: f64, corroborate: bool) -> f64 {
56 let alpha = current_confidence * PSEUDOCOUNT;
57 let beta = PSEUDOCOUNT - alpha;
58
59 if corroborate {
60 (alpha + 1.0) / (alpha + beta + 1.0)
61 } else {
62 alpha / (alpha + beta + 1.0)
63 }
64}
65
66pub fn path_confidence(edge_confidences: &[f64]) -> f64 {
70 edge_confidences.iter().product()
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 fn approx_eq(a: f64, b: f64) -> bool {
78 (a - b).abs() < 0.001
79 }
80
81 #[test]
82 fn bayesian_update_corroborate_0_6() {
83 let result = bayesian_update(0.6, true);
84 assert!(approx_eq(result, 0.636), "got {}", result);
86 }
87
88 #[test]
89 fn bayesian_update_contradict_0_6() {
90 let result = bayesian_update(0.6, false);
91 assert!(approx_eq(result, 0.545), "got {}", result);
93 }
94
95 #[test]
96 fn bayesian_update_corroborate_0_9() {
97 let result = bayesian_update(0.9, true);
98 assert!(approx_eq(result, 0.909), "got {}", result);
100 }
101
102 #[test]
103 fn bayesian_update_contradict_0_9() {
104 let result = bayesian_update(0.9, false);
105 assert!(approx_eq(result, 0.818), "got {}", result);
107 }
108
109 #[test]
110 fn bayesian_update_corroborate_0_3() {
111 let result = bayesian_update(0.3, true);
112 assert!(approx_eq(result, 0.364), "got {}", result);
114 }
115
116 #[test]
117 fn path_confidence_two_edges() {
118 let result = path_confidence(&[0.8, 0.7]);
119 assert!(approx_eq(result, 0.56), "got {}", result);
120 }
121
122 #[test]
123 fn path_confidence_empty() {
124 assert_eq!(path_confidence(&[]), 1.0);
125 }
126
127 #[test]
128 fn extraction_context_priors() {
129 assert_eq!(ExtractionContext::Authoritative.prior(), 1.0);
130 assert_eq!(ExtractionContext::Explicit.prior(), 0.9);
131 assert_eq!(ExtractionContext::Inferred.prior(), 0.6);
132 assert_eq!(ExtractionContext::Speculative.prior(), 0.3);
133 }
134
135 #[test]
136 fn extraction_context_from_str() {
137 assert_eq!(
138 "explicit".parse::<ExtractionContext>().unwrap(),
139 ExtractionContext::Explicit
140 );
141 assert_eq!(
142 "inferred".parse::<ExtractionContext>().unwrap(),
143 ExtractionContext::Inferred
144 );
145 assert_eq!(
146 "speculative".parse::<ExtractionContext>().unwrap(),
147 ExtractionContext::Speculative
148 );
149 assert_eq!(
150 "authoritative".parse::<ExtractionContext>().unwrap(),
151 ExtractionContext::Authoritative
152 );
153 assert!("unknown".parse::<ExtractionContext>().is_err());
154 }
155}