1use rusqlite::Connection;
19use serde::{Deserialize, Serialize};
20
21use crate::error::Result;
22
23pub const CREATE_UTILITY_FEEDBACK_TABLE: &str = r#"
30CREATE TABLE IF NOT EXISTS utility_feedback (
31 id INTEGER PRIMARY KEY AUTOINCREMENT,
32 memory_id INTEGER NOT NULL,
33 was_useful BOOLEAN NOT NULL,
34 query TEXT NOT NULL DEFAULT '',
35 timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
36);
37CREATE INDEX IF NOT EXISTS idx_utility_memory ON utility_feedback(memory_id);
38CREATE INDEX IF NOT EXISTS idx_utility_timestamp ON utility_feedback(timestamp);
39"#;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct UtilityConfig {
48 pub learning_rate: f64,
50 pub decay_factor: f64,
52 pub initial_score: f64,
54}
55
56impl Default for UtilityConfig {
57 fn default() -> Self {
58 Self {
59 learning_rate: 0.1,
60 decay_factor: 0.95,
61 initial_score: 0.5,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct UtilityScore {
69 pub memory_id: i64,
70 pub score: f64,
72 pub retrievals: i64,
74 pub useful_count: i64,
76 pub last_retrieved: String,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct UtilityStats {
83 pub total_feedback: i64,
85 pub avg_score: f64,
87 pub top_useful: Vec<(i64, i64)>,
89 pub bottom_useful: Vec<(i64, i64)>,
91}
92
93pub struct UtilityTracker {
99 pub config: UtilityConfig,
100}
101
102impl UtilityTracker {
103 pub fn new() -> Self {
105 Self {
106 config: UtilityConfig::default(),
107 }
108 }
109
110 pub fn with_config(config: UtilityConfig) -> Self {
112 Self { config }
113 }
114
115 pub fn record_retrieval(
126 &self,
127 conn: &Connection,
128 memory_id: i64,
129 was_useful: bool,
130 query: &str,
131 ) -> Result<()> {
132 conn.execute(
134 "INSERT INTO utility_feedback (memory_id, was_useful, query) VALUES (?1, ?2, ?3)",
135 rusqlite::params![memory_id, was_useful, query],
136 )?;
137
138 Ok(())
142 }
143
144 pub fn get_utility(&self, conn: &Connection, memory_id: i64) -> Result<UtilityScore> {
156 let mut stmt = conn.prepare(
158 "SELECT was_useful, timestamp FROM utility_feedback
159 WHERE memory_id = ?1
160 ORDER BY timestamp ASC, id ASC",
161 )?;
162
163 struct Row {
164 was_useful: bool,
165 timestamp: String,
166 }
167
168 let rows: Vec<Row> = stmt
169 .query_map(rusqlite::params![memory_id], |r| {
170 Ok(Row {
171 was_useful: r.get::<_, bool>(0)?,
172 timestamp: r.get::<_, String>(1)?,
173 })
174 })?
175 .collect::<std::result::Result<Vec<_>, _>>()?;
176
177 if rows.is_empty() {
178 return Ok(UtilityScore {
179 memory_id,
180 score: self.config.initial_score,
181 retrievals: 0,
182 useful_count: 0,
183 last_retrieved: String::new(),
184 });
185 }
186
187 let mut q = self.config.initial_score;
189 let mut useful_count = 0_i64;
190
191 for row in &rows {
192 let reward = if row.was_useful { 1.0 } else { -0.5 };
193 q += self.config.learning_rate * (reward - q);
194 if row.was_useful {
195 useful_count += 1;
196 }
197 }
198
199 let last_retrieved = rows.last().map(|r| r.timestamp.clone()).unwrap_or_default();
201 q = self.apply_decay(q, &last_retrieved);
202 q = q.clamp(0.0, 1.0);
204
205 Ok(UtilityScore {
206 memory_id,
207 score: q,
208 retrievals: rows.len() as i64,
209 useful_count,
210 last_retrieved,
211 })
212 }
213
214 pub fn apply_utility_boost(&self, scores: &mut [(i64, f32)], conn: &Connection) -> Result<()> {
224 for (memory_id, score) in scores.iter_mut() {
225 let utility = self.get_utility(conn, *memory_id)?;
226 let boost = (0.5 + utility.score * 1.5).clamp(0.5, 2.0);
231 *score = (*score * boost as f32).clamp(0.5, 2.0);
232 }
233 Ok(())
234 }
235
236 pub fn batch_decay(&self, conn: &Connection, _config: &UtilityConfig) -> Result<usize> {
250 let mut stmt = conn.prepare("SELECT DISTINCT memory_id FROM utility_feedback")?;
252 let memory_ids: Vec<i64> = stmt
253 .query_map([], |r| r.get::<_, i64>(0))?
254 .collect::<std::result::Result<Vec<_>, _>>()?;
255
256 let mut affected = 0_usize;
257 for memory_id in memory_ids {
258 let scored = self.get_utility(conn, memory_id)?;
259 if (scored.score - self.config.initial_score).abs() >= 0.001 {
262 affected += 1;
263 }
264 }
265 Ok(affected)
266 }
267
268 pub fn utility_stats(
279 &self,
280 conn: &Connection,
281 workspace: Option<&str>,
282 ) -> Result<UtilityStats> {
283 let memory_ids: Vec<i64> = if let Some(ws) = workspace {
285 let mut stmt = conn.prepare(
287 "SELECT DISTINCT uf.memory_id
288 FROM utility_feedback uf
289 INNER JOIN memories m ON m.id = uf.memory_id
290 WHERE m.workspace = ?1",
291 )?;
292 let ids = stmt
293 .query_map(rusqlite::params![ws], |r| r.get::<_, i64>(0))?
294 .collect::<std::result::Result<Vec<_>, _>>()?;
295 ids
296 } else {
297 let mut stmt = conn.prepare("SELECT DISTINCT memory_id FROM utility_feedback")?;
298 let ids = stmt
299 .query_map([], |r| r.get::<_, i64>(0))?
300 .collect::<std::result::Result<Vec<_>, _>>()?;
301 ids
302 };
303
304 let total_feedback: i64 = if let Some(ws) = workspace {
306 conn.query_row(
307 "SELECT COUNT(*) FROM utility_feedback uf
308 INNER JOIN memories m ON m.id = uf.memory_id
309 WHERE m.workspace = ?1",
310 rusqlite::params![ws],
311 |r| r.get(0),
312 )?
313 } else {
314 conn.query_row("SELECT COUNT(*) FROM utility_feedback", [], |r| r.get(0))?
315 };
316
317 if memory_ids.is_empty() {
318 return Ok(UtilityStats {
319 total_feedback,
320 avg_score: self.config.initial_score,
321 top_useful: Vec::new(),
322 bottom_useful: Vec::new(),
323 });
324 }
325
326 let mut scores: Vec<(i64, f64)> = Vec::with_capacity(memory_ids.len());
328 for mid in &memory_ids {
329 let us = self.get_utility(conn, *mid)?;
330 scores.push((*mid, us.score));
331 }
332
333 let avg_score = scores.iter().map(|(_, s)| s).sum::<f64>() / scores.len() as f64;
334
335 let mut sorted_desc = scores.clone();
337 sorted_desc.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
338 let top_useful: Vec<(i64, i64)> = sorted_desc
339 .iter()
340 .take(10)
341 .map(|(mid, _)| {
342 let cnt: i64 = conn
344 .query_row(
345 "SELECT COUNT(*) FROM utility_feedback WHERE memory_id = ?1 AND was_useful = 1",
346 rusqlite::params![mid],
347 |r| r.get(0),
348 )
349 .unwrap_or(0);
350 (*mid, cnt)
351 })
352 .collect();
353
354 let mut sorted_asc = scores.clone();
356 sorted_asc.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
357 let bottom_useful: Vec<(i64, i64)> = sorted_asc
358 .iter()
359 .take(10)
360 .map(|(mid, _)| {
361 let cnt: i64 = conn
362 .query_row(
363 "SELECT COUNT(*) FROM utility_feedback WHERE memory_id = ?1 AND was_useful = 1",
364 rusqlite::params![mid],
365 |r| r.get(0),
366 )
367 .unwrap_or(0);
368 (*mid, cnt)
369 })
370 .collect();
371
372 Ok(UtilityStats {
373 total_feedback,
374 avg_score,
375 top_useful,
376 bottom_useful,
377 })
378 }
379
380 fn apply_decay(&self, score: f64, last_retrieved_ts: &str) -> f64 {
391 if last_retrieved_ts.is_empty() {
392 return score;
393 }
394
395 let parsed = chrono::DateTime::parse_from_rfc3339(last_retrieved_ts)
396 .ok()
397 .map(|dt| dt.with_timezone(&chrono::Utc));
398
399 let Some(last) = parsed else {
400 return score;
401 };
402
403 let now = chrono::Utc::now();
404 let days_elapsed = (now - last).num_seconds() as f64 / 86_400.0;
405
406 if days_elapsed <= 0.0 {
407 return score;
408 }
409
410 score * self.config.decay_factor.powf(days_elapsed)
411 }
412}
413
414impl Default for UtilityTracker {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420#[cfg(test)]
425mod tests {
426 use super::*;
427
428 fn setup() -> Connection {
429 let conn = Connection::open_in_memory().expect("open in-memory db");
430 conn.execute_batch(CREATE_UTILITY_FEEDBACK_TABLE)
431 .expect("create table");
432 conn
433 }
434
435 #[test]
437 fn test_record_and_retrieve_utility() {
438 let conn = setup();
439 let tracker = UtilityTracker::new();
440
441 tracker
442 .record_retrieval(&conn, 1, true, "rust async")
443 .expect("record");
444
445 let us = tracker.get_utility(&conn, 1).expect("get_utility");
446
447 assert_eq!(us.memory_id, 1);
448 assert_eq!(us.retrievals, 1);
449 assert_eq!(us.useful_count, 1);
450 assert!(!us.last_retrieved.is_empty());
451 assert!(
453 us.score > tracker.config.initial_score,
454 "score {} should be > initial {}",
455 us.score,
456 tracker.config.initial_score
457 );
458 }
459
460 #[test]
462 fn test_useful_retrievals_boost_score() {
463 let conn = setup();
464 let tracker = UtilityTracker::new();
465
466 for _ in 0..20 {
467 tracker
468 .record_retrieval(&conn, 42, true, "query")
469 .expect("record");
470 }
471
472 let us = tracker.get_utility(&conn, 42).expect("get_utility");
473
474 assert!(
476 us.score > 0.7,
477 "expected score > 0.7 after 20 useful retrievals, got {}",
478 us.score
479 );
480 }
481
482 #[test]
484 fn test_irrelevant_retrievals_lower_score() {
485 let conn = setup();
486 let tracker = UtilityTracker::new();
487
488 for _ in 0..20 {
489 tracker
490 .record_retrieval(&conn, 7, false, "query")
491 .expect("record");
492 }
493
494 let us = tracker.get_utility(&conn, 7).expect("get_utility");
495
496 assert!(
498 us.score < tracker.config.initial_score,
499 "expected score < initial ({}) after 20 irrelevant retrievals, got {}",
500 tracker.config.initial_score,
501 us.score
502 );
503 }
504
505 #[test]
507 fn test_initial_score_default_when_no_feedback() {
508 let conn = setup();
509 let tracker = UtilityTracker::new();
510
511 let us = tracker.get_utility(&conn, 999).expect("get_utility");
512
513 assert_eq!(us.retrievals, 0);
514 assert_eq!(us.useful_count, 0);
515 assert!(
516 (us.score - tracker.config.initial_score).abs() < 1e-9,
517 "expected initial score {}, got {}",
518 tracker.config.initial_score,
519 us.score
520 );
521 assert!(us.last_retrieved.is_empty());
522 }
523
524 #[test]
526 fn test_temporal_decay_reduces_score() {
527 let conn = setup();
528
529 let config = UtilityConfig {
531 learning_rate: 0.5,
532 decay_factor: 0.5,
533 initial_score: 0.5,
534 };
535 let tracker = UtilityTracker::with_config(config);
536
537 let past = (chrono::Utc::now() - chrono::Duration::days(100))
539 .format("%Y-%m-%dT%H:%M:%S%.3fZ")
540 .to_string();
541 conn.execute(
542 "INSERT INTO utility_feedback (memory_id, was_useful, query, timestamp) VALUES (1, 1, 'q', ?1)",
543 rusqlite::params![past],
544 )
545 .expect("insert");
546
547 let us = tracker.get_utility(&conn, 1).expect("get_utility");
548
549 assert!(
551 us.score < 0.1,
552 "expected heavily decayed score < 0.1, got {}",
553 us.score
554 );
555 }
556
557 #[test]
559 fn test_apply_utility_boost() {
560 let conn = setup();
561 let tracker = UtilityTracker::new();
562
563 for _ in 0..15 {
565 tracker
566 .record_retrieval(&conn, 10, true, "q")
567 .expect("record");
568 }
569 for _ in 0..15 {
571 tracker
572 .record_retrieval(&conn, 20, false, "q")
573 .expect("record");
574 }
575
576 let mut scores = vec![(10_i64, 0.6_f32), (20_i64, 0.6_f32)];
577 tracker
578 .apply_utility_boost(&mut scores, &conn)
579 .expect("boost");
580
581 let boosted = scores[0].1;
582 let demoted = scores[1].1;
583
584 assert!(
585 boosted > demoted,
586 "useful memory ({boosted}) should score higher than useless one ({demoted})"
587 );
588 }
589
590 #[test]
592 fn test_batch_decay_returns_affected_count() {
593 let conn = setup();
594 let tracker = UtilityTracker::new();
595
596 for mid in [1_i64, 2, 3] {
598 tracker
599 .record_retrieval(&conn, mid, true, "q")
600 .expect("record");
601 }
602
603 let config = UtilityConfig::default();
604 let count = tracker.batch_decay(&conn, &config).expect("batch_decay");
605
606 assert_eq!(count, 3, "expected 3 affected memories, got {count}");
608 }
609
610 #[test]
612 fn test_utility_stats() {
613 let conn = setup();
614 let tracker = UtilityTracker::new();
615
616 for _ in 0..5 {
618 tracker
619 .record_retrieval(&conn, 1, true, "q")
620 .expect("record");
621 }
622 for _ in 0..5 {
624 tracker
625 .record_retrieval(&conn, 2, false, "q")
626 .expect("record");
627 }
628
629 let stats = tracker.utility_stats(&conn, None).expect("stats");
630
631 assert_eq!(stats.total_feedback, 10);
632 assert!(
634 stats.avg_score > 0.0 && stats.avg_score < 1.0,
635 "avg_score out of range: {}",
636 stats.avg_score
637 );
638 assert!(!stats.top_useful.is_empty());
640 let top_mid = stats.top_useful[0].0;
641 assert_eq!(top_mid, 1, "expected memory 1 on top, got memory {top_mid}");
642 assert!(!stats.bottom_useful.is_empty());
644 let bottom_mid = stats.bottom_useful[0].0;
645 assert_eq!(
646 bottom_mid, 2,
647 "expected memory 2 at bottom, got memory {bottom_mid}"
648 );
649 }
650
651 #[test]
653 fn test_q_value_formula_single_useful() {
654 let conn = setup();
655 let config = UtilityConfig {
656 learning_rate: 0.1,
657 decay_factor: 1.0, initial_score: 0.5,
659 };
660 let tracker = UtilityTracker::with_config(config);
661
662 tracker
663 .record_retrieval(&conn, 1, true, "q")
664 .expect("record");
665
666 let us = tracker.get_utility(&conn, 1).expect("get_utility");
668 let expected = 0.55;
669 assert!(
670 (us.score - expected).abs() < 1e-9,
671 "expected score {expected}, got {}",
672 us.score
673 );
674 }
675
676 #[test]
678 fn test_q_value_formula_single_not_useful() {
679 let conn = setup();
680 let config = UtilityConfig {
681 learning_rate: 0.1,
682 decay_factor: 1.0, initial_score: 0.5,
684 };
685 let tracker = UtilityTracker::with_config(config);
686
687 tracker
688 .record_retrieval(&conn, 2, false, "q")
689 .expect("record");
690
691 let us = tracker.get_utility(&conn, 2).expect("get_utility");
693 let expected = 0.4;
694 assert!(
695 (us.score - expected).abs() < 1e-9,
696 "expected score {expected}, got {}",
697 us.score
698 );
699 }
700
701 #[test]
703 fn test_boost_clamp_bounds() {
704 let conn = setup();
705 let tracker = UtilityTracker::new();
706
707 tracker
709 .record_retrieval(&conn, 100, true, "q")
710 .expect("record");
711
712 let mut scores = vec![(100_i64, 0.1_f32)];
713 tracker
714 .apply_utility_boost(&mut scores, &conn)
715 .expect("boost");
716
717 assert!(
719 scores[0].1 >= 0.5 && scores[0].1 <= 2.0,
720 "boosted score {} is outside [0.5, 2.0]",
721 scores[0].1
722 );
723 }
724}