1use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection};
7
8use crate::error::Result;
9use crate::types::MemoryId;
10
11pub const DEFAULT_HALF_LIFE_DAYS: f32 = 30.0;
13
14pub fn calculate_decayed_confidence(
18 initial_confidence: f32,
19 created_at: DateTime<Utc>,
20 half_life_days: f32,
21) -> f32 {
22 let age_days = (Utc::now() - created_at).num_days() as f32;
23 initial_confidence * 0.5_f32.powf(age_days / half_life_days)
24}
25
26pub fn get_effective_confidence(
28 conn: &Connection,
29 from_id: MemoryId,
30 to_id: MemoryId,
31 half_life_days: f32,
32) -> Result<Option<f32>> {
33 let row = conn.query_row(
34 "SELECT confidence, created_at, pinned FROM crossrefs
35 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
36 params![from_id, to_id],
37 |row| {
38 let confidence: f32 = row.get(0)?;
39 let created_at: String = row.get(1)?;
40 let pinned: i32 = row.get(2)?;
41 Ok((confidence, created_at, pinned != 0))
42 },
43 );
44
45 match row {
46 Ok((confidence, created_at_str, pinned)) => {
47 if pinned {
48 return Ok(Some(confidence));
50 }
51
52 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
53 .map(|dt| dt.with_timezone(&Utc))
54 .unwrap_or_else(|_| Utc::now());
55
56 Ok(Some(calculate_decayed_confidence(
57 confidence,
58 created_at,
59 half_life_days,
60 )))
61 }
62 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
63 Err(e) => Err(e.into()),
64 }
65}
66
67pub fn get_related_with_decay(
69 conn: &Connection,
70 memory_id: MemoryId,
71 half_life_days: f32,
72 min_confidence: f32,
73) -> Result<Vec<DecayedCrossRef>> {
74 let mut stmt = conn.prepare(
75 "SELECT from_id, to_id, edge_type, score, confidence, strength,
76 created_at, pinned
77 FROM crossrefs
78 WHERE (from_id = ? OR to_id = ?) AND valid_to IS NULL",
79 )?;
80
81 let now = Utc::now();
82 let mut results = Vec::new();
83
84 let rows = stmt.query_map(params![memory_id, memory_id], |row| {
85 let from_id: MemoryId = row.get(0)?;
86 let to_id: MemoryId = row.get(1)?;
87 let edge_type: String = row.get(2)?;
88 let score: f32 = row.get(3)?;
89 let confidence: f32 = row.get(4)?;
90 let strength: f32 = row.get(5)?;
91 let created_at: String = row.get(6)?;
92 let pinned: i32 = row.get(7)?;
93
94 Ok((
95 from_id,
96 to_id,
97 edge_type,
98 score,
99 confidence,
100 strength,
101 created_at,
102 pinned != 0,
103 ))
104 })?;
105
106 for row in rows {
107 let (from_id, to_id, edge_type, score, confidence, strength, created_at_str, pinned) = row?;
108
109 let effective_confidence = if pinned {
110 confidence
111 } else {
112 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
113 .map(|dt| dt.with_timezone(&Utc))
114 .unwrap_or(now);
115 calculate_decayed_confidence(confidence, created_at, half_life_days)
116 };
117
118 if effective_confidence >= min_confidence {
120 results.push(DecayedCrossRef {
121 from_id,
122 to_id,
123 edge_type,
124 score,
125 original_confidence: confidence,
126 effective_confidence,
127 strength,
128 pinned,
129 });
130 }
131 }
132
133 results.sort_by(|a, b| {
135 let score_a = a.score * a.effective_confidence * a.strength;
136 let score_b = b.score * b.effective_confidence * b.strength;
137 score_b
138 .partial_cmp(&score_a)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 });
141
142 Ok(results)
143}
144
145#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
147pub struct DecayedCrossRef {
148 pub from_id: MemoryId,
149 pub to_id: MemoryId,
150 pub edge_type: String,
151 pub score: f32,
152 pub original_confidence: f32,
153 pub effective_confidence: f32,
154 pub strength: f32,
155 pub pinned: bool,
156}
157
158impl DecayedCrossRef {
159 pub fn effective_score(&self) -> f32 {
161 self.score * self.effective_confidence * self.strength
162 }
163}
164
165pub fn refresh_confidence_batch(
167 conn: &Connection,
168 half_life_days: f32,
169 min_confidence: f32,
170) -> Result<RefreshResult> {
171 let now = Utc::now();
172 let now_str = now.to_rfc3339();
173
174 let mut stmt = conn.prepare(
176 "SELECT id, confidence, created_at FROM crossrefs
177 WHERE pinned = 0 AND valid_to IS NULL",
178 )?;
179
180 let rows: Vec<(i64, f32, String)> = stmt
181 .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
182 .filter_map(|r| r.ok())
183 .collect();
184
185 let mut updated = 0;
186 let mut expired = 0;
187
188 for (id, original_confidence, created_at_str) in rows {
189 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
190 .map(|dt| dt.with_timezone(&Utc))
191 .unwrap_or(now);
192
193 let effective =
194 calculate_decayed_confidence(original_confidence, created_at, half_life_days);
195
196 if effective < min_confidence {
197 conn.execute(
199 "UPDATE crossrefs SET valid_to = ? WHERE id = ?",
200 params![now_str, id],
201 )?;
202 expired += 1;
203 }
204 updated += 1;
205 }
206
207 Ok(RefreshResult { updated, expired })
208}
209
210#[derive(Debug, Clone)]
212pub struct RefreshResult {
213 pub updated: i64,
214 pub expired: i64,
215}
216
217pub fn pin_crossref(conn: &Connection, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
219 conn.execute(
220 "UPDATE crossrefs SET pinned = 1 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
221 params![from_id, to_id],
222 )?;
223 Ok(())
224}
225
226pub fn unpin_crossref(conn: &Connection, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
228 conn.execute(
229 "UPDATE crossrefs SET pinned = 0 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
230 params![from_id, to_id],
231 )?;
232 Ok(())
233}
234
235pub fn boost_confidence(
237 conn: &Connection,
238 from_id: MemoryId,
239 to_id: MemoryId,
240 boost: f32,
241) -> Result<()> {
242 conn.execute(
244 "UPDATE crossrefs SET confidence = MIN(1.0, confidence + ?)
245 WHERE from_id = ? AND to_id = ? AND valid_to IS NULL",
246 params![boost, from_id, to_id],
247 )?;
248 Ok(())
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_confidence_decay() {
257 let now = Utc::now();
259 let half_life_ago = now - chrono::Duration::days(30);
260
261 let decayed = calculate_decayed_confidence(1.0, half_life_ago, 30.0);
262 assert!(
263 (decayed - 0.5).abs() < 0.01,
264 "Expected ~0.5, got {}",
265 decayed
266 );
267 }
268
269 #[test]
270 fn test_confidence_decay_double_half_life() {
271 let now = Utc::now();
273 let two_half_lives_ago = now - chrono::Duration::days(60);
274
275 let decayed = calculate_decayed_confidence(1.0, two_half_lives_ago, 30.0);
276 assert!(
277 (decayed - 0.25).abs() < 0.01,
278 "Expected ~0.25, got {}",
279 decayed
280 );
281 }
282
283 #[test]
284 fn test_confidence_no_decay_for_new() {
285 let now = Utc::now();
286 let decayed = calculate_decayed_confidence(1.0, now, 30.0);
287 assert!(
288 (decayed - 1.0).abs() < 0.01,
289 "Expected ~1.0, got {}",
290 decayed
291 );
292 }
293}