Skip to main content

engram/search/
feedback.rs

1//! Relevance Feedback Loop — RML-1243
2//!
3//! Persists explicit user feedback (useful / irrelevant) for search results and
4//! uses that history to compute per-memory boost factors that can be applied to
5//! subsequent search score vectors.
6
7use rusqlite::{Connection, OptionalExtension};
8use serde::{Deserialize, Serialize};
9
10use crate::error::{EngramError, Result};
11
12// ---------------------------------------------------------------------------
13// DDL
14// ---------------------------------------------------------------------------
15
16/// SQL for creating the `search_feedback` table and its indexes.
17/// Safe to call on an existing database — uses `CREATE TABLE IF NOT EXISTS`.
18pub const CREATE_SEARCH_FEEDBACK_TABLE: &str = r#"
19CREATE TABLE IF NOT EXISTS search_feedback (
20    id INTEGER PRIMARY KEY AUTOINCREMENT,
21    query TEXT NOT NULL,
22    query_embedding_hash TEXT,
23    memory_id INTEGER NOT NULL,
24    signal TEXT NOT NULL CHECK(signal IN ('useful', 'irrelevant')),
25    rank_position INTEGER,
26    original_score REAL,
27    created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
28    workspace TEXT DEFAULT 'default'
29);
30CREATE INDEX IF NOT EXISTS idx_feedback_memory ON search_feedback(memory_id);
31CREATE INDEX IF NOT EXISTS idx_feedback_query ON search_feedback(query);
32CREATE INDEX IF NOT EXISTS idx_feedback_workspace ON search_feedback(workspace);
33"#;
34
35// ---------------------------------------------------------------------------
36// Types
37// ---------------------------------------------------------------------------
38
39/// Feedback signal: whether a search result was helpful.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum FeedbackSignal {
43    Useful,
44    Irrelevant,
45}
46
47impl FeedbackSignal {
48    fn as_str(self) -> &'static str {
49        match self {
50            FeedbackSignal::Useful => "useful",
51            FeedbackSignal::Irrelevant => "irrelevant",
52        }
53    }
54
55    fn from_str(s: &str) -> Result<Self> {
56        match s {
57            "useful" => Ok(FeedbackSignal::Useful),
58            "irrelevant" => Ok(FeedbackSignal::Irrelevant),
59            other => Err(EngramError::InvalidInput(format!(
60                "unknown feedback signal: {other}"
61            ))),
62        }
63    }
64}
65
66/// A single recorded feedback entry.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SearchFeedback {
69    pub id: i64,
70    pub query: String,
71    pub query_embedding_hash: Option<String>,
72    pub memory_id: i64,
73    pub signal: FeedbackSignal,
74    pub rank_position: Option<i32>,
75    pub original_score: Option<f32>,
76    pub created_at: String,
77    pub workspace: String,
78}
79
80/// Aggregated feedback statistics for a workspace (or all workspaces).
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct FeedbackStats {
83    pub total_feedback: i64,
84    pub useful_count: i64,
85    pub irrelevant_count: i64,
86    pub useful_ratio: f64,
87    /// Top memories marked useful — `(memory_id, count)`, up to 10 entries.
88    pub top_useful_memories: Vec<(i64, i64)>,
89    /// Top memories marked irrelevant — `(memory_id, count)`, up to 10 entries.
90    pub top_irrelevant_memories: Vec<(i64, i64)>,
91    /// Average rank position of results marked useful.
92    pub avg_useful_rank: Option<f64>,
93    /// Average rank position of results marked irrelevant.
94    pub avg_irrelevant_rank: Option<f64>,
95}
96
97/// Boost factor derived from feedback history for a single memory.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct FeedbackBoost {
100    pub memory_id: i64,
101    /// Multiplier applied to the raw search score.
102    /// `> 1.0` promotes the result, `< 1.0` demotes it, `1.0` is neutral.
103    pub boost_factor: f64,
104    /// Total number of feedback signals this boost is based on.
105    pub signal_count: i64,
106    /// Confidence in the boost estimate (0.0 – 1.0).
107    /// Increases with more signals, capped at 1.0.
108    pub confidence: f64,
109}
110
111// ---------------------------------------------------------------------------
112// Storage functions
113// ---------------------------------------------------------------------------
114
115/// Record a feedback signal for a (query, memory) pair.
116///
117/// Returns the newly created [`SearchFeedback`] row.
118pub fn record_feedback(
119    conn: &Connection,
120    query: &str,
121    memory_id: i64,
122    signal: FeedbackSignal,
123    rank_position: Option<i32>,
124    original_score: Option<f32>,
125    workspace: &str,
126) -> Result<SearchFeedback> {
127    conn.execute(
128        "INSERT INTO search_feedback (query, memory_id, signal, rank_position, original_score, workspace)
129         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
130        rusqlite::params![
131            query,
132            memory_id,
133            signal.as_str(),
134            rank_position,
135            original_score,
136            workspace,
137        ],
138    )?;
139
140    let id = conn.last_insert_rowid();
141
142    let row = conn.query_row(
143        "SELECT id, query, query_embedding_hash, memory_id, signal,
144                rank_position, original_score, created_at, workspace
145         FROM search_feedback WHERE id = ?1",
146        rusqlite::params![id],
147        row_to_feedback,
148    )?;
149
150    Ok(row)
151}
152
153/// Retrieve all feedback rows for a specific memory.
154pub fn get_feedback_for_memory(conn: &Connection, memory_id: i64) -> Result<Vec<SearchFeedback>> {
155    let mut stmt = conn.prepare(
156        "SELECT id, query, query_embedding_hash, memory_id, signal,
157                rank_position, original_score, created_at, workspace
158         FROM search_feedback
159         WHERE memory_id = ?1
160         ORDER BY created_at DESC",
161    )?;
162
163    let rows = stmt
164        .query_map(rusqlite::params![memory_id], row_to_feedback)?
165        .collect::<std::result::Result<Vec<_>, _>>()?;
166
167    Ok(rows)
168}
169
170/// Retrieve all feedback rows for a specific query string.
171pub fn get_feedback_for_query(conn: &Connection, query: &str) -> Result<Vec<SearchFeedback>> {
172    let mut stmt = conn.prepare(
173        "SELECT id, query, query_embedding_hash, memory_id, signal,
174                rank_position, original_score, created_at, workspace
175         FROM search_feedback
176         WHERE query = ?1
177         ORDER BY created_at DESC",
178    )?;
179
180    let rows = stmt
181        .query_map(rusqlite::params![query], row_to_feedback)?
182        .collect::<std::result::Result<Vec<_>, _>>()?;
183
184    Ok(rows)
185}
186
187/// Delete a single feedback entry by its ID.
188pub fn delete_feedback(conn: &Connection, feedback_id: i64) -> Result<()> {
189    let affected = conn.execute(
190        "DELETE FROM search_feedback WHERE id = ?1",
191        rusqlite::params![feedback_id],
192    )?;
193
194    if affected == 0 {
195        return Err(EngramError::NotFound(feedback_id));
196    }
197
198    Ok(())
199}
200
201/// Compute aggregated feedback statistics.
202///
203/// When `workspace` is `Some`, only feedback rows from that workspace are
204/// included.  Pass `None` to aggregate across all workspaces.
205pub fn feedback_stats(conn: &Connection, workspace: Option<&str>) -> Result<FeedbackStats> {
206    // Helper: execute a query that uses an optional workspace parameter.
207    // When `ws` is Some the query must use `?1` as its sole parameter.
208    let exec_scalar = |sql: &str| -> Result<(i64, i64, i64)> {
209        if let Some(ws) = workspace {
210            Ok(conn.query_row(sql, rusqlite::params![ws], |r| {
211                Ok((r.get(0)?, r.get(1)?, r.get(2)?))
212            })?)
213        } else {
214            Ok(conn.query_row(sql, [], |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))?)
215        }
216    };
217
218    let exec_pairs = |sql: &str| -> Result<Vec<(i64, i64)>> {
219        if let Some(ws) = workspace {
220            let mut stmt = conn.prepare(sql)?;
221            let v = stmt
222                .query_map(rusqlite::params![ws], |r| Ok((r.get(0)?, r.get(1)?)))?
223                .collect::<std::result::Result<Vec<_>, _>>()?;
224            Ok(v)
225        } else {
226            let mut stmt = conn.prepare(sql)?;
227            let v = stmt
228                .query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?
229                .collect::<std::result::Result<Vec<_>, _>>()?;
230            Ok(v)
231        }
232    };
233
234    let exec_avg = |sql: &str| -> Result<Option<f64>> {
235        let v: Option<f64> = if let Some(ws) = workspace {
236            conn.query_row(sql, rusqlite::params![ws], |r| r.get(0))
237                .optional()?
238                .flatten()
239        } else {
240            conn.query_row(sql, [], |r| r.get(0)).optional()?.flatten()
241        };
242        Ok(v)
243    };
244
245    // Build SQL that filters by workspace when provided.
246    let ws_clause = if workspace.is_some() {
247        "WHERE workspace = ?1"
248    } else {
249        "WHERE 1=1"
250    };
251
252    // --- totals ---
253    let totals_sql = format!(
254        "SELECT
255            COUNT(*),
256            SUM(CASE WHEN signal = 'useful' THEN 1 ELSE 0 END),
257            SUM(CASE WHEN signal = 'irrelevant' THEN 1 ELSE 0 END)
258         FROM search_feedback {ws_clause}"
259    );
260    let (total_feedback, useful_count, irrelevant_count) = exec_scalar(&totals_sql)?;
261
262    let useful_ratio = if total_feedback == 0 {
263        0.0
264    } else {
265        useful_count as f64 / total_feedback as f64
266    };
267
268    // --- top useful memories ---
269    let top_useful_sql = format!(
270        "SELECT memory_id, COUNT(*) AS cnt
271         FROM search_feedback
272         {ws_clause} AND signal = 'useful'
273         GROUP BY memory_id
274         ORDER BY cnt DESC
275         LIMIT 10"
276    );
277    let top_useful_memories = exec_pairs(&top_useful_sql)?;
278
279    // --- top irrelevant memories ---
280    let top_irrelevant_sql = format!(
281        "SELECT memory_id, COUNT(*) AS cnt
282         FROM search_feedback
283         {ws_clause} AND signal = 'irrelevant'
284         GROUP BY memory_id
285         ORDER BY cnt DESC
286         LIMIT 10"
287    );
288    let top_irrelevant_memories = exec_pairs(&top_irrelevant_sql)?;
289
290    // --- average ranks ---
291    let avg_useful_sql = format!(
292        "SELECT AVG(rank_position)
293         FROM search_feedback
294         {ws_clause} AND signal = 'useful' AND rank_position IS NOT NULL"
295    );
296    let avg_useful_rank = exec_avg(&avg_useful_sql)?;
297
298    let avg_irrelevant_sql = format!(
299        "SELECT AVG(rank_position)
300         FROM search_feedback
301         {ws_clause} AND signal = 'irrelevant' AND rank_position IS NOT NULL"
302    );
303    let avg_irrelevant_rank = exec_avg(&avg_irrelevant_sql)?;
304
305    Ok(FeedbackStats {
306        total_feedback,
307        useful_count,
308        irrelevant_count,
309        useful_ratio,
310        top_useful_memories,
311        top_irrelevant_memories,
312        avg_useful_rank,
313        avg_irrelevant_rank,
314    })
315}
316
317/// Compute boost factors for a set of memory IDs.
318///
319/// For each memory, the boost formula is:
320///
321/// ```text
322/// boost = 1.0 + (useful_count - irrelevant_count * 1.5) / (total_count + 5)
323/// ```
324///
325/// The `+5` smoothing term prevents extreme boosts from very few signals.
326/// Confidence is `min(1.0, total_count / 10.0)`.
327///
328/// If `query` is provided, feedback rows whose query text overlaps heavily
329/// with the current query receive a 2× weight in the aggregation (query
330/// similarity is measured as Jaccard overlap on word sets).
331pub fn compute_feedback_boosts(
332    conn: &Connection,
333    memory_ids: &[i64],
334    query: Option<&str>,
335) -> Result<Vec<FeedbackBoost>> {
336    if memory_ids.is_empty() {
337        return Ok(Vec::new());
338    }
339
340    let mut boosts = Vec::with_capacity(memory_ids.len());
341
342    for &memory_id in memory_ids {
343        // Fetch all feedback for this memory.
344        let rows = get_feedback_for_memory(conn, memory_id)?;
345
346        if rows.is_empty() {
347            boosts.push(FeedbackBoost {
348                memory_id,
349                boost_factor: 1.0,
350                signal_count: 0,
351                confidence: 0.0,
352            });
353            continue;
354        }
355
356        // Accumulate weighted counts.
357        let mut weighted_useful = 0.0_f64;
358        let mut weighted_irrelevant = 0.0_f64;
359        let mut weighted_total = 0.0_f64;
360
361        for row in &rows {
362            let weight = if let Some(q) = query {
363                query_similarity_weight(q, &row.query)
364            } else {
365                1.0
366            };
367
368            match row.signal {
369                FeedbackSignal::Useful => weighted_useful += weight,
370                FeedbackSignal::Irrelevant => weighted_irrelevant += weight,
371            }
372            weighted_total += weight;
373        }
374
375        let signal_count = rows.len() as i64;
376        let boost_factor =
377            1.0 + (weighted_useful - weighted_irrelevant * 1.5) / (weighted_total + 5.0);
378        let confidence = (signal_count as f64 / 10.0).min(1.0);
379
380        boosts.push(FeedbackBoost {
381            memory_id,
382            boost_factor,
383            signal_count,
384            confidence,
385        });
386    }
387
388    Ok(boosts)
389}
390
391/// Apply boost factors to a slice of `(memory_id, score)` pairs in-place.
392///
393/// Each score is multiplied by the matching boost factor, clamped to `[0.5, 2.0]`.
394/// Memory IDs with no matching boost entry are left unchanged.
395pub fn apply_feedback_boosts(scores: &mut [(i64, f32)], boosts: &[FeedbackBoost]) {
396    for (memory_id, score) in scores.iter_mut() {
397        if let Some(boost) = boosts.iter().find(|b| b.memory_id == *memory_id) {
398            *score = (*score * boost.boost_factor as f32).clamp(0.5, 2.0);
399        }
400    }
401}
402
403// ---------------------------------------------------------------------------
404// Helpers
405// ---------------------------------------------------------------------------
406
407/// Map a rusqlite row to [`SearchFeedback`].
408fn row_to_feedback(r: &rusqlite::Row<'_>) -> rusqlite::Result<SearchFeedback> {
409    let signal_str: String = r.get(4)?;
410    let signal = FeedbackSignal::from_str(&signal_str).map_err(|_| {
411        rusqlite::Error::FromSqlConversionFailure(
412            4,
413            rusqlite::types::Type::Text,
414            Box::new(std::fmt::Error),
415        )
416    })?;
417
418    Ok(SearchFeedback {
419        id: r.get(0)?,
420        query: r.get(1)?,
421        query_embedding_hash: r.get(2)?,
422        memory_id: r.get(3)?,
423        signal,
424        rank_position: r.get(5)?,
425        original_score: r.get(6)?,
426        created_at: r.get(7)?,
427        workspace: r.get(8)?,
428    })
429}
430
431/// Compute a simple query similarity weight in `[1.0, 2.0]`.
432///
433/// Uses Jaccard overlap on word sets:
434/// - identical queries → 2.0
435/// - no overlap → 1.0
436/// - partial overlap → interpolated
437fn query_similarity_weight(current: &str, historical: &str) -> f64 {
438    let current_words: std::collections::HashSet<&str> = current.split_whitespace().collect();
439    let historical_words: std::collections::HashSet<&str> = historical.split_whitespace().collect();
440
441    if current_words.is_empty() || historical_words.is_empty() {
442        return 1.0;
443    }
444
445    let intersection = current_words.intersection(&historical_words).count();
446    let union = current_words.union(&historical_words).count();
447
448    let jaccard = intersection as f64 / union as f64;
449    // Map [0, 1] → [1.0, 2.0]
450    1.0 + jaccard
451}
452
453// ---------------------------------------------------------------------------
454// Tests
455// ---------------------------------------------------------------------------
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn setup() -> Connection {
462        let conn = Connection::open_in_memory().expect("open in-memory db");
463        conn.execute_batch(CREATE_SEARCH_FEEDBACK_TABLE)
464            .expect("create table");
465        conn
466    }
467
468    // 1. Record and retrieve feedback
469    #[test]
470    fn test_record_and_retrieve_feedback() {
471        let conn = setup();
472
473        let fb = record_feedback(
474            &conn,
475            "rust async",
476            42,
477            FeedbackSignal::Useful,
478            Some(1),
479            Some(0.9),
480            "default",
481        )
482        .expect("record");
483
484        assert_eq!(fb.query, "rust async");
485        assert_eq!(fb.memory_id, 42);
486        assert_eq!(fb.signal, FeedbackSignal::Useful);
487        assert_eq!(fb.rank_position, Some(1));
488        assert!((fb.original_score.unwrap() - 0.9).abs() < 1e-5);
489        assert_eq!(fb.workspace, "default");
490        assert!(fb.id > 0);
491    }
492
493    // 2. Record useful signal
494    #[test]
495    fn test_record_useful_signal() {
496        let conn = setup();
497
498        let fb = record_feedback(
499            &conn,
500            "search query",
501            10,
502            FeedbackSignal::Useful,
503            None,
504            None,
505            "ws1",
506        )
507        .expect("record useful");
508
509        assert_eq!(fb.signal, FeedbackSignal::Useful);
510    }
511
512    // 3. Record irrelevant signal
513    #[test]
514    fn test_record_irrelevant_signal() {
515        let conn = setup();
516
517        let fb = record_feedback(
518            &conn,
519            "another query",
520            20,
521            FeedbackSignal::Irrelevant,
522            Some(5),
523            Some(0.3),
524            "ws1",
525        )
526        .expect("record irrelevant");
527
528        assert_eq!(fb.signal, FeedbackSignal::Irrelevant);
529        assert_eq!(fb.rank_position, Some(5));
530    }
531
532    // 4. Stats computation — counts and ratios
533    #[test]
534    fn test_stats_counts_and_ratios() {
535        let conn = setup();
536
537        record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws").unwrap();
538        record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws").unwrap();
539        record_feedback(&conn, "q", 3, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
540
541        let stats = feedback_stats(&conn, None).expect("stats");
542
543        assert_eq!(stats.total_feedback, 3);
544        assert_eq!(stats.useful_count, 2);
545        assert_eq!(stats.irrelevant_count, 1);
546        assert!((stats.useful_ratio - 2.0 / 3.0).abs() < 1e-9);
547    }
548
549    // 5. Stats with workspace filter
550    #[test]
551    fn test_stats_workspace_filter() {
552        let conn = setup();
553
554        record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws_a").unwrap();
555        record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws_a").unwrap();
556        record_feedback(
557            &conn,
558            "q",
559            3,
560            FeedbackSignal::Irrelevant,
561            None,
562            None,
563            "ws_b",
564        )
565        .unwrap();
566
567        let stats_a = feedback_stats(&conn, Some("ws_a")).expect("stats_a");
568        assert_eq!(stats_a.total_feedback, 2);
569        assert_eq!(stats_a.useful_count, 2);
570        assert_eq!(stats_a.irrelevant_count, 0);
571
572        let stats_b = feedback_stats(&conn, Some("ws_b")).expect("stats_b");
573        assert_eq!(stats_b.total_feedback, 1);
574        assert_eq!(stats_b.useful_count, 0);
575        assert_eq!(stats_b.irrelevant_count, 1);
576    }
577
578    // 6. Boost — mostly useful signals → boost > 1.0
579    #[test]
580    fn test_boost_mostly_useful() {
581        let conn = setup();
582
583        for _ in 0..8 {
584            record_feedback(&conn, "q", 99, FeedbackSignal::Useful, None, None, "ws").unwrap();
585        }
586        record_feedback(&conn, "q", 99, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
587
588        let boosts = compute_feedback_boosts(&conn, &[99], None).expect("boosts");
589        assert_eq!(boosts.len(), 1);
590        assert!(
591            boosts[0].boost_factor > 1.0,
592            "expected boost > 1.0, got {}",
593            boosts[0].boost_factor
594        );
595    }
596
597    // 7. Boost — mostly irrelevant → boost < 1.0
598    #[test]
599    fn test_boost_mostly_irrelevant() {
600        let conn = setup();
601
602        for _ in 0..8 {
603            record_feedback(&conn, "q", 77, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
604        }
605        record_feedback(&conn, "q", 77, FeedbackSignal::Useful, None, None, "ws").unwrap();
606
607        let boosts = compute_feedback_boosts(&conn, &[77], None).expect("boosts");
608        assert_eq!(boosts.len(), 1);
609        assert!(
610            boosts[0].boost_factor < 1.0,
611            "expected boost < 1.0, got {}",
612            boosts[0].boost_factor
613        );
614    }
615
616    // 8. Boost — no feedback → boost = 1.0
617    #[test]
618    fn test_boost_no_feedback() {
619        let conn = setup();
620
621        let boosts = compute_feedback_boosts(&conn, &[999], None).expect("boosts");
622        assert_eq!(boosts.len(), 1);
623        assert_eq!(boosts[0].boost_factor, 1.0);
624        assert_eq!(boosts[0].signal_count, 0);
625        assert_eq!(boosts[0].confidence, 0.0);
626    }
627
628    // 9. Boost smoothing prevents extreme values with few signals
629    #[test]
630    fn test_boost_smoothing_prevents_extremes() {
631        let conn = setup();
632
633        // Only 1 useful signal — smoothing (+5) should keep boost moderate.
634        record_feedback(&conn, "q", 55, FeedbackSignal::Useful, None, None, "ws").unwrap();
635
636        let boosts = compute_feedback_boosts(&conn, &[55], None).expect("boosts");
637        // With 1 useful: boost = 1 + (1 - 0) / (1 + 5) = 1 + 1/6 ≈ 1.167
638        let expected = 1.0 + 1.0 / 6.0;
639        assert!((boosts[0].boost_factor - expected).abs() < 1e-9);
640        // Not extreme (e.g., not 2.0)
641        assert!(boosts[0].boost_factor < 1.3);
642    }
643
644    // 10. Apply boosts modifies scores correctly
645    #[test]
646    fn test_apply_boosts_modifies_scores() {
647        let boosts = vec![
648            FeedbackBoost {
649                memory_id: 1,
650                boost_factor: 1.5,
651                signal_count: 5,
652                confidence: 0.5,
653            },
654            // 0.7 * 0.8 = 0.56 — stays above the 0.5 clamp floor
655            FeedbackBoost {
656                memory_id: 2,
657                boost_factor: 0.8,
658                signal_count: 3,
659                confidence: 0.3,
660            },
661        ];
662
663        let mut scores = vec![(1_i64, 0.6_f32), (2_i64, 0.7_f32), (3_i64, 0.4_f32)];
664        apply_feedback_boosts(&mut scores, &boosts);
665
666        // memory 1: 0.6 * 1.5 = 0.9
667        assert!(
668            (scores[0].1 - 0.9_f32).abs() < 1e-5,
669            "score[0] = {}",
670            scores[0].1
671        );
672        // memory 2: 0.7 * 0.8 = 0.56
673        assert!(
674            (scores[1].1 - 0.56_f32).abs() < 1e-4,
675            "score[1] = {}",
676            scores[1].1
677        );
678        // memory 3: no boost entry, unchanged
679        assert!(
680            (scores[2].1 - 0.4_f32).abs() < 1e-5,
681            "score[2] = {}",
682            scores[2].1
683        );
684    }
685
686    // 11. Boost clamping to [0.5, 2.0]
687    #[test]
688    fn test_boost_clamping() {
689        // Very high boost factor → clamped to 2.0
690        let boosts_high = vec![FeedbackBoost {
691            memory_id: 10,
692            boost_factor: 5.0,
693            signal_count: 100,
694            confidence: 1.0,
695        }];
696        let mut scores_high = vec![(10_i64, 0.9_f32)];
697        apply_feedback_boosts(&mut scores_high, &boosts_high);
698        assert!(
699            (scores_high[0].1 - 2.0_f32).abs() < 1e-5,
700            "expected clamp to 2.0, got {}",
701            scores_high[0].1
702        );
703
704        // Very low boost factor → clamped to 0.5
705        let boosts_low = vec![FeedbackBoost {
706            memory_id: 20,
707            boost_factor: 0.1,
708            signal_count: 100,
709            confidence: 1.0,
710        }];
711        let mut scores_low = vec![(20_i64, 0.9_f32)];
712        apply_feedback_boosts(&mut scores_low, &boosts_low);
713        assert!(
714            (scores_low[0].1 - 0.5_f32).abs() < 1e-5,
715            "expected clamp to 0.5, got {}",
716            scores_low[0].1
717        );
718    }
719
720    // 12. Delete feedback
721    #[test]
722    fn test_delete_feedback() {
723        let conn = setup();
724
725        let fb = record_feedback(
726            &conn,
727            "to delete",
728            1,
729            FeedbackSignal::Useful,
730            None,
731            None,
732            "ws",
733        )
734        .expect("record");
735
736        delete_feedback(&conn, fb.id).expect("delete");
737
738        let remaining = get_feedback_for_memory(&conn, 1).expect("get");
739        assert!(remaining.is_empty());
740    }
741
742    // 12b. Delete non-existent feedback returns NotFound
743    #[test]
744    fn test_delete_nonexistent_feedback() {
745        let conn = setup();
746        let result = delete_feedback(&conn, 9999);
747        assert!(matches!(result, Err(EngramError::NotFound(_))));
748    }
749
750    // 13. Query similarity weighting
751    #[test]
752    fn test_query_similarity_weighting() {
753        let conn = setup();
754
755        // Record feedback for two different queries on the same memory.
756        // "rust async runtime" overlaps with "rust async" but not "python web".
757        record_feedback(
758            &conn,
759            "rust async runtime",
760            42,
761            FeedbackSignal::Useful,
762            None,
763            None,
764            "ws",
765        )
766        .unwrap();
767        record_feedback(
768            &conn,
769            "python web framework",
770            42,
771            FeedbackSignal::Irrelevant,
772            None,
773            None,
774            "ws",
775        )
776        .unwrap();
777
778        // Query "rust async" — should weight the useful signal higher → boost > 1.0
779        let boosts_rust =
780            compute_feedback_boosts(&conn, &[42], Some("rust async")).expect("boosts");
781        assert!(
782            boosts_rust[0].boost_factor > 1.0,
783            "expected boost > 1.0 with matching query, got {}",
784            boosts_rust[0].boost_factor
785        );
786
787        // Query "python web" — should weight the irrelevant signal higher → boost < 1.0
788        let boosts_python =
789            compute_feedback_boosts(&conn, &[42], Some("python web")).expect("boosts");
790        assert!(
791            boosts_python[0].boost_factor < 1.0,
792            "expected boost < 1.0 with mismatched query, got {}",
793            boosts_python[0].boost_factor
794        );
795    }
796
797    // Extra: get_feedback_for_query
798    #[test]
799    fn test_get_feedback_for_query() {
800        let conn = setup();
801
802        record_feedback(
803            &conn,
804            "specific query",
805            1,
806            FeedbackSignal::Useful,
807            None,
808            None,
809            "ws",
810        )
811        .unwrap();
812        record_feedback(
813            &conn,
814            "specific query",
815            2,
816            FeedbackSignal::Irrelevant,
817            None,
818            None,
819            "ws",
820        )
821        .unwrap();
822        record_feedback(
823            &conn,
824            "other query",
825            3,
826            FeedbackSignal::Useful,
827            None,
828            None,
829            "ws",
830        )
831        .unwrap();
832
833        let rows = get_feedback_for_query(&conn, "specific query").expect("get");
834        assert_eq!(rows.len(), 2);
835        for r in &rows {
836            assert_eq!(r.query, "specific query");
837        }
838    }
839
840    // Extra: top_useful and top_irrelevant_memories populated correctly
841    #[test]
842    fn test_stats_top_memories() {
843        let conn = setup();
844
845        // memory 1: 3 useful
846        for _ in 0..3 {
847            record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws").unwrap();
848        }
849        // memory 2: 1 useful
850        record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws").unwrap();
851        // memory 3: 2 irrelevant
852        for _ in 0..2 {
853            record_feedback(&conn, "q", 3, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
854        }
855
856        let stats = feedback_stats(&conn, None).unwrap();
857        assert_eq!(stats.top_useful_memories[0].0, 1);
858        assert_eq!(stats.top_useful_memories[0].1, 3);
859        assert_eq!(stats.top_irrelevant_memories[0].0, 3);
860        assert_eq!(stats.top_irrelevant_memories[0].1, 2);
861    }
862
863    // Extra: average rank computed correctly
864    #[test]
865    fn test_stats_avg_rank() {
866        let conn = setup();
867
868        record_feedback(&conn, "q", 1, FeedbackSignal::Useful, Some(1), None, "ws").unwrap();
869        record_feedback(&conn, "q", 2, FeedbackSignal::Useful, Some(3), None, "ws").unwrap();
870        record_feedback(
871            &conn,
872            "q",
873            3,
874            FeedbackSignal::Irrelevant,
875            Some(10),
876            None,
877            "ws",
878        )
879        .unwrap();
880
881        let stats = feedback_stats(&conn, None).unwrap();
882        // avg useful rank = (1 + 3) / 2 = 2.0
883        assert!((stats.avg_useful_rank.unwrap() - 2.0).abs() < 1e-9);
884        // avg irrelevant rank = 10.0
885        assert!((stats.avg_irrelevant_rank.unwrap() - 10.0).abs() < 1e-9);
886    }
887
888    // Extra: empty memory_ids returns empty vec
889    #[test]
890    fn test_compute_boosts_empty_ids() {
891        let conn = setup();
892        let boosts = compute_feedback_boosts(&conn, &[], None).expect("boosts");
893        assert!(boosts.is_empty());
894    }
895}