Skip to main content

engram/storage/
confidence.rs

1//! Confidence decay for cross-references (RML-897)
2//!
3//! Relations automatically decay in confidence over time using exponential decay.
4
5use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection};
7
8use crate::error::Result;
9use crate::types::MemoryId;
10
11/// Default half-life in days (configurable via env)
12pub const DEFAULT_HALF_LIFE_DAYS: f32 = 30.0;
13
14/// Calculate decayed confidence based on age
15///
16/// Uses exponential decay: confidence = initial * 0.5^(age_days / half_life)
17pub fn calculate_decayed_confidence(
18    initial_confidence: f32,
19    created_at: DateTime<Utc>,
20    half_life_days: f32,
21) -> f32 {
22    let age_days = (Utc::now() - created_at).num_days() as f32;
23    initial_confidence * 0.5_f32.powf(age_days / half_life_days)
24}
25
26/// Get effective confidence for a cross-reference (considering decay and pinned status)
27pub fn get_effective_confidence(
28    conn: &Connection,
29    from_id: MemoryId,
30    to_id: MemoryId,
31    half_life_days: f32,
32) -> Result<Option<f32>> {
33    let row = conn.query_row(
34        "SELECT confidence, created_at, pinned FROM crossrefs
35         WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
36        params![from_id, to_id],
37        |row| {
38            let confidence: f32 = row.get(0)?;
39            let created_at: String = row.get(1)?;
40            let pinned: i32 = row.get(2)?;
41            Ok((confidence, created_at, pinned != 0))
42        },
43    );
44
45    match row {
46        Ok((confidence, created_at_str, pinned)) => {
47            if pinned {
48                // Pinned relations don't decay
49                return Ok(Some(confidence));
50            }
51
52            let created_at = DateTime::parse_from_rfc3339(&created_at_str)
53                .map(|dt| dt.with_timezone(&Utc))
54                .unwrap_or_else(|_| Utc::now());
55
56            Ok(Some(calculate_decayed_confidence(
57                confidence,
58                created_at,
59                half_life_days,
60            )))
61        }
62        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
63        Err(e) => Err(e.into()),
64    }
65}
66
67/// Get all cross-references with decayed confidence scores
68pub fn get_related_with_decay(
69    conn: &Connection,
70    memory_id: MemoryId,
71    half_life_days: f32,
72    min_confidence: f32,
73) -> Result<Vec<DecayedCrossRef>> {
74    let mut stmt = conn.prepare(
75        "SELECT from_id, to_id, edge_type, score, confidence, strength,
76                created_at, pinned
77         FROM crossrefs
78         WHERE (from_id = ? OR to_id = ?) AND valid_to IS NULL",
79    )?;
80
81    let now = Utc::now();
82    let mut results = Vec::new();
83
84    let rows = stmt.query_map(params![memory_id, memory_id], |row| {
85        let from_id: MemoryId = row.get(0)?;
86        let to_id: MemoryId = row.get(1)?;
87        let edge_type: String = row.get(2)?;
88        let score: f32 = row.get(3)?;
89        let confidence: f32 = row.get(4)?;
90        let strength: f32 = row.get(5)?;
91        let created_at: String = row.get(6)?;
92        let pinned: i32 = row.get(7)?;
93
94        Ok((
95            from_id,
96            to_id,
97            edge_type,
98            score,
99            confidence,
100            strength,
101            created_at,
102            pinned != 0,
103        ))
104    })?;
105
106    for row in rows {
107        let (from_id, to_id, edge_type, score, confidence, strength, created_at_str, pinned) = row?;
108
109        let effective_confidence = if pinned {
110            confidence
111        } else {
112            let created_at = DateTime::parse_from_rfc3339(&created_at_str)
113                .map(|dt| dt.with_timezone(&Utc))
114                .unwrap_or(now);
115            calculate_decayed_confidence(confidence, created_at, half_life_days)
116        };
117
118        // Filter out low-confidence relations
119        if effective_confidence >= min_confidence {
120            results.push(DecayedCrossRef {
121                from_id,
122                to_id,
123                edge_type,
124                score,
125                original_confidence: confidence,
126                effective_confidence,
127                strength,
128                pinned,
129            });
130        }
131    }
132
133    // Sort by effective score (score * confidence * strength)
134    results.sort_by(|a, b| {
135        let score_a = a.score * a.effective_confidence * a.strength;
136        let score_b = b.score * b.effective_confidence * b.strength;
137        score_b
138            .partial_cmp(&score_a)
139            .unwrap_or(std::cmp::Ordering::Equal)
140    });
141
142    Ok(results)
143}
144
145/// Cross-reference with calculated decayed confidence
146#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
147pub struct DecayedCrossRef {
148    pub from_id: MemoryId,
149    pub to_id: MemoryId,
150    pub edge_type: String,
151    pub score: f32,
152    pub original_confidence: f32,
153    pub effective_confidence: f32,
154    pub strength: f32,
155    pub pinned: bool,
156}
157
158impl DecayedCrossRef {
159    /// Calculate effective score considering all factors
160    pub fn effective_score(&self) -> f32 {
161        self.score * self.effective_confidence * self.strength
162    }
163}
164
165/// Batch update confidence values (for maintenance)
166pub fn refresh_confidence_batch(
167    conn: &Connection,
168    half_life_days: f32,
169    min_confidence: f32,
170) -> Result<RefreshResult> {
171    let now = Utc::now();
172    let now_str = now.to_rfc3339();
173
174    // Get all non-pinned crossrefs
175    let mut stmt = conn.prepare(
176        "SELECT id, confidence, created_at FROM crossrefs
177         WHERE pinned = 0 AND valid_to IS NULL",
178    )?;
179
180    let rows: Vec<(i64, f32, String)> = stmt
181        .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
182        .filter_map(|r| r.ok())
183        .collect();
184
185    let mut updated = 0;
186    let mut expired = 0;
187
188    for (id, original_confidence, created_at_str) in rows {
189        let created_at = DateTime::parse_from_rfc3339(&created_at_str)
190            .map(|dt| dt.with_timezone(&Utc))
191            .unwrap_or(now);
192
193        let effective =
194            calculate_decayed_confidence(original_confidence, created_at, half_life_days);
195
196        if effective < min_confidence {
197            // Mark as expired (soft delete)
198            conn.execute(
199                "UPDATE crossrefs SET valid_to = ? WHERE id = ?",
200                params![now_str, id],
201            )?;
202            expired += 1;
203        }
204        updated += 1;
205    }
206
207    Ok(RefreshResult { updated, expired })
208}
209
210/// Result of batch confidence refresh
211#[derive(Debug, Clone)]
212pub struct RefreshResult {
213    pub updated: i64,
214    pub expired: i64,
215}
216
217/// Pin a cross-reference (exempt from decay)
218pub fn pin_crossref(conn: &Connection, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
219    conn.execute(
220        "UPDATE crossrefs SET pinned = 1 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
221        params![from_id, to_id],
222    )?;
223    Ok(())
224}
225
226/// Unpin a cross-reference (subject to decay)
227pub fn unpin_crossref(conn: &Connection, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
228    conn.execute(
229        "UPDATE crossrefs SET pinned = 0 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
230        params![from_id, to_id],
231    )?;
232    Ok(())
233}
234
235/// Boost confidence of a cross-reference (user interaction)
236pub fn boost_confidence(
237    conn: &Connection,
238    from_id: MemoryId,
239    to_id: MemoryId,
240    boost: f32,
241) -> Result<()> {
242    // Boost is additive, capped at 1.0
243    conn.execute(
244        "UPDATE crossrefs SET confidence = MIN(1.0, confidence + ?)
245         WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
246        params![boost, from_id, to_id],
247    )?;
248    Ok(())
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_confidence_decay() {
257        // At half-life, confidence should be 50%
258        let now = Utc::now();
259        let half_life_ago = now - chrono::Duration::days(30);
260
261        let decayed = calculate_decayed_confidence(1.0, half_life_ago, 30.0);
262        assert!(
263            (decayed - 0.5).abs() < 0.01,
264            "Expected ~0.5, got {}",
265            decayed
266        );
267    }
268
269    #[test]
270    fn test_confidence_decay_double_half_life() {
271        // At 2x half-life, confidence should be 25%
272        let now = Utc::now();
273        let two_half_lives_ago = now - chrono::Duration::days(60);
274
275        let decayed = calculate_decayed_confidence(1.0, two_half_lives_ago, 30.0);
276        assert!(
277            (decayed - 0.25).abs() < 0.01,
278            "Expected ~0.25, got {}",
279            decayed
280        );
281    }
282
283    #[test]
284    fn test_confidence_no_decay_for_new() {
285        let now = Utc::now();
286        let decayed = calculate_decayed_confidence(1.0, now, 30.0);
287        assert!(
288            (decayed - 1.0).abs() < 0.01,
289            "Expected ~1.0, got {}",
290            decayed
291        );
292    }
293}