use rusqlite::{Connection, OptionalExtension};
use serde::{Deserialize, Serialize};
use crate::error::{EngramError, Result};
pub const CREATE_SEARCH_FEEDBACK_TABLE: &str = r#"
CREATE TABLE IF NOT EXISTS search_feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query TEXT NOT NULL,
query_embedding_hash TEXT,
memory_id INTEGER NOT NULL,
signal TEXT NOT NULL CHECK(signal IN ('useful', 'irrelevant')),
rank_position INTEGER,
original_score REAL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
workspace TEXT DEFAULT 'default'
);
CREATE INDEX IF NOT EXISTS idx_feedback_memory ON search_feedback(memory_id);
CREATE INDEX IF NOT EXISTS idx_feedback_query ON search_feedback(query);
CREATE INDEX IF NOT EXISTS idx_feedback_workspace ON search_feedback(workspace);
"#;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FeedbackSignal {
Useful,
Irrelevant,
}
impl FeedbackSignal {
fn as_str(self) -> &'static str {
match self {
FeedbackSignal::Useful => "useful",
FeedbackSignal::Irrelevant => "irrelevant",
}
}
fn from_str(s: &str) -> Result<Self> {
match s {
"useful" => Ok(FeedbackSignal::Useful),
"irrelevant" => Ok(FeedbackSignal::Irrelevant),
other => Err(EngramError::InvalidInput(format!(
"unknown feedback signal: {other}"
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchFeedback {
pub id: i64,
pub query: String,
pub query_embedding_hash: Option<String>,
pub memory_id: i64,
pub signal: FeedbackSignal,
pub rank_position: Option<i32>,
pub original_score: Option<f32>,
pub created_at: String,
pub workspace: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedbackStats {
pub total_feedback: i64,
pub useful_count: i64,
pub irrelevant_count: i64,
pub useful_ratio: f64,
pub top_useful_memories: Vec<(i64, i64)>,
pub top_irrelevant_memories: Vec<(i64, i64)>,
pub avg_useful_rank: Option<f64>,
pub avg_irrelevant_rank: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedbackBoost {
pub memory_id: i64,
pub boost_factor: f64,
pub signal_count: i64,
pub confidence: f64,
}
pub fn record_feedback(
conn: &Connection,
query: &str,
memory_id: i64,
signal: FeedbackSignal,
rank_position: Option<i32>,
original_score: Option<f32>,
workspace: &str,
) -> Result<SearchFeedback> {
conn.execute(
"INSERT INTO search_feedback (query, memory_id, signal, rank_position, original_score, workspace)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
rusqlite::params![
query,
memory_id,
signal.as_str(),
rank_position,
original_score,
workspace,
],
)?;
let id = conn.last_insert_rowid();
let row = conn.query_row(
"SELECT id, query, query_embedding_hash, memory_id, signal,
rank_position, original_score, created_at, workspace
FROM search_feedback WHERE id = ?1",
rusqlite::params![id],
row_to_feedback,
)?;
Ok(row)
}
pub fn get_feedback_for_memory(conn: &Connection, memory_id: i64) -> Result<Vec<SearchFeedback>> {
let mut stmt = conn.prepare(
"SELECT id, query, query_embedding_hash, memory_id, signal,
rank_position, original_score, created_at, workspace
FROM search_feedback
WHERE memory_id = ?1
ORDER BY created_at DESC",
)?;
let rows = stmt
.query_map(rusqlite::params![memory_id], row_to_feedback)?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn get_feedback_for_query(conn: &Connection, query: &str) -> Result<Vec<SearchFeedback>> {
let mut stmt = conn.prepare(
"SELECT id, query, query_embedding_hash, memory_id, signal,
rank_position, original_score, created_at, workspace
FROM search_feedback
WHERE query = ?1
ORDER BY created_at DESC",
)?;
let rows = stmt
.query_map(rusqlite::params![query], row_to_feedback)?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn delete_feedback(conn: &Connection, feedback_id: i64) -> Result<()> {
let affected = conn.execute(
"DELETE FROM search_feedback WHERE id = ?1",
rusqlite::params![feedback_id],
)?;
if affected == 0 {
return Err(EngramError::NotFound(feedback_id));
}
Ok(())
}
pub fn feedback_stats(conn: &Connection, workspace: Option<&str>) -> Result<FeedbackStats> {
let exec_scalar = |sql: &str| -> Result<(i64, i64, i64)> {
if let Some(ws) = workspace {
Ok(conn.query_row(sql, rusqlite::params![ws], |r| {
Ok((r.get(0)?, r.get(1)?, r.get(2)?))
})?)
} else {
Ok(conn.query_row(sql, [], |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))?)
}
};
let exec_pairs = |sql: &str| -> Result<Vec<(i64, i64)>> {
if let Some(ws) = workspace {
let mut stmt = conn.prepare(sql)?;
let v = stmt
.query_map(rusqlite::params![ws], |r| Ok((r.get(0)?, r.get(1)?)))?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(v)
} else {
let mut stmt = conn.prepare(sql)?;
let v = stmt
.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(v)
}
};
let exec_avg = |sql: &str| -> Result<Option<f64>> {
let v: Option<f64> = if let Some(ws) = workspace {
conn.query_row(sql, rusqlite::params![ws], |r| r.get(0))
.optional()?
.flatten()
} else {
conn.query_row(sql, [], |r| r.get(0)).optional()?.flatten()
};
Ok(v)
};
let ws_clause = if workspace.is_some() {
"WHERE workspace = ?1"
} else {
"WHERE 1=1"
};
let totals_sql = format!(
"SELECT
COUNT(*),
SUM(CASE WHEN signal = 'useful' THEN 1 ELSE 0 END),
SUM(CASE WHEN signal = 'irrelevant' THEN 1 ELSE 0 END)
FROM search_feedback {ws_clause}"
);
let (total_feedback, useful_count, irrelevant_count) = exec_scalar(&totals_sql)?;
let useful_ratio = if total_feedback == 0 {
0.0
} else {
useful_count as f64 / total_feedback as f64
};
let top_useful_sql = format!(
"SELECT memory_id, COUNT(*) AS cnt
FROM search_feedback
{ws_clause} AND signal = 'useful'
GROUP BY memory_id
ORDER BY cnt DESC
LIMIT 10"
);
let top_useful_memories = exec_pairs(&top_useful_sql)?;
let top_irrelevant_sql = format!(
"SELECT memory_id, COUNT(*) AS cnt
FROM search_feedback
{ws_clause} AND signal = 'irrelevant'
GROUP BY memory_id
ORDER BY cnt DESC
LIMIT 10"
);
let top_irrelevant_memories = exec_pairs(&top_irrelevant_sql)?;
let avg_useful_sql = format!(
"SELECT AVG(rank_position)
FROM search_feedback
{ws_clause} AND signal = 'useful' AND rank_position IS NOT NULL"
);
let avg_useful_rank = exec_avg(&avg_useful_sql)?;
let avg_irrelevant_sql = format!(
"SELECT AVG(rank_position)
FROM search_feedback
{ws_clause} AND signal = 'irrelevant' AND rank_position IS NOT NULL"
);
let avg_irrelevant_rank = exec_avg(&avg_irrelevant_sql)?;
Ok(FeedbackStats {
total_feedback,
useful_count,
irrelevant_count,
useful_ratio,
top_useful_memories,
top_irrelevant_memories,
avg_useful_rank,
avg_irrelevant_rank,
})
}
pub fn compute_feedback_boosts(
conn: &Connection,
memory_ids: &[i64],
query: Option<&str>,
) -> Result<Vec<FeedbackBoost>> {
if memory_ids.is_empty() {
return Ok(Vec::new());
}
let mut boosts = Vec::with_capacity(memory_ids.len());
for &memory_id in memory_ids {
let rows = get_feedback_for_memory(conn, memory_id)?;
if rows.is_empty() {
boosts.push(FeedbackBoost {
memory_id,
boost_factor: 1.0,
signal_count: 0,
confidence: 0.0,
});
continue;
}
let mut weighted_useful = 0.0_f64;
let mut weighted_irrelevant = 0.0_f64;
let mut weighted_total = 0.0_f64;
for row in &rows {
let weight = if let Some(q) = query {
query_similarity_weight(q, &row.query)
} else {
1.0
};
match row.signal {
FeedbackSignal::Useful => weighted_useful += weight,
FeedbackSignal::Irrelevant => weighted_irrelevant += weight,
}
weighted_total += weight;
}
let signal_count = rows.len() as i64;
let boost_factor =
1.0 + (weighted_useful - weighted_irrelevant * 1.5) / (weighted_total + 5.0);
let confidence = (signal_count as f64 / 10.0).min(1.0);
boosts.push(FeedbackBoost {
memory_id,
boost_factor,
signal_count,
confidence,
});
}
Ok(boosts)
}
pub fn apply_feedback_boosts(scores: &mut [(i64, f32)], boosts: &[FeedbackBoost]) {
for (memory_id, score) in scores.iter_mut() {
if let Some(boost) = boosts.iter().find(|b| b.memory_id == *memory_id) {
*score = (*score * boost.boost_factor as f32).clamp(0.5, 2.0);
}
}
}
fn row_to_feedback(r: &rusqlite::Row<'_>) -> rusqlite::Result<SearchFeedback> {
let signal_str: String = r.get(4)?;
let signal = FeedbackSignal::from_str(&signal_str).map_err(|_| {
rusqlite::Error::FromSqlConversionFailure(
4,
rusqlite::types::Type::Text,
Box::new(std::fmt::Error),
)
})?;
Ok(SearchFeedback {
id: r.get(0)?,
query: r.get(1)?,
query_embedding_hash: r.get(2)?,
memory_id: r.get(3)?,
signal,
rank_position: r.get(5)?,
original_score: r.get(6)?,
created_at: r.get(7)?,
workspace: r.get(8)?,
})
}
fn query_similarity_weight(current: &str, historical: &str) -> f64 {
let current_words: std::collections::HashSet<&str> = current.split_whitespace().collect();
let historical_words: std::collections::HashSet<&str> = historical.split_whitespace().collect();
if current_words.is_empty() || historical_words.is_empty() {
return 1.0;
}
let intersection = current_words.intersection(&historical_words).count();
let union = current_words.union(&historical_words).count();
let jaccard = intersection as f64 / union as f64;
1.0 + jaccard
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() -> Connection {
let conn = Connection::open_in_memory().expect("open in-memory db");
conn.execute_batch(CREATE_SEARCH_FEEDBACK_TABLE)
.expect("create table");
conn
}
#[test]
fn test_record_and_retrieve_feedback() {
let conn = setup();
let fb = record_feedback(
&conn,
"rust async",
42,
FeedbackSignal::Useful,
Some(1),
Some(0.9),
"default",
)
.expect("record");
assert_eq!(fb.query, "rust async");
assert_eq!(fb.memory_id, 42);
assert_eq!(fb.signal, FeedbackSignal::Useful);
assert_eq!(fb.rank_position, Some(1));
assert!((fb.original_score.unwrap() - 0.9).abs() < 1e-5);
assert_eq!(fb.workspace, "default");
assert!(fb.id > 0);
}
#[test]
fn test_record_useful_signal() {
let conn = setup();
let fb = record_feedback(
&conn,
"search query",
10,
FeedbackSignal::Useful,
None,
None,
"ws1",
)
.expect("record useful");
assert_eq!(fb.signal, FeedbackSignal::Useful);
}
#[test]
fn test_record_irrelevant_signal() {
let conn = setup();
let fb = record_feedback(
&conn,
"another query",
20,
FeedbackSignal::Irrelevant,
Some(5),
Some(0.3),
"ws1",
)
.expect("record irrelevant");
assert_eq!(fb.signal, FeedbackSignal::Irrelevant);
assert_eq!(fb.rank_position, Some(5));
}
#[test]
fn test_stats_counts_and_ratios() {
let conn = setup();
record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws").unwrap();
record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws").unwrap();
record_feedback(&conn, "q", 3, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
let stats = feedback_stats(&conn, None).expect("stats");
assert_eq!(stats.total_feedback, 3);
assert_eq!(stats.useful_count, 2);
assert_eq!(stats.irrelevant_count, 1);
assert!((stats.useful_ratio - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_stats_workspace_filter() {
let conn = setup();
record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws_a").unwrap();
record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws_a").unwrap();
record_feedback(
&conn,
"q",
3,
FeedbackSignal::Irrelevant,
None,
None,
"ws_b",
)
.unwrap();
let stats_a = feedback_stats(&conn, Some("ws_a")).expect("stats_a");
assert_eq!(stats_a.total_feedback, 2);
assert_eq!(stats_a.useful_count, 2);
assert_eq!(stats_a.irrelevant_count, 0);
let stats_b = feedback_stats(&conn, Some("ws_b")).expect("stats_b");
assert_eq!(stats_b.total_feedback, 1);
assert_eq!(stats_b.useful_count, 0);
assert_eq!(stats_b.irrelevant_count, 1);
}
#[test]
fn test_boost_mostly_useful() {
let conn = setup();
for _ in 0..8 {
record_feedback(&conn, "q", 99, FeedbackSignal::Useful, None, None, "ws").unwrap();
}
record_feedback(&conn, "q", 99, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
let boosts = compute_feedback_boosts(&conn, &[99], None).expect("boosts");
assert_eq!(boosts.len(), 1);
assert!(
boosts[0].boost_factor > 1.0,
"expected boost > 1.0, got {}",
boosts[0].boost_factor
);
}
#[test]
fn test_boost_mostly_irrelevant() {
let conn = setup();
for _ in 0..8 {
record_feedback(&conn, "q", 77, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
}
record_feedback(&conn, "q", 77, FeedbackSignal::Useful, None, None, "ws").unwrap();
let boosts = compute_feedback_boosts(&conn, &[77], None).expect("boosts");
assert_eq!(boosts.len(), 1);
assert!(
boosts[0].boost_factor < 1.0,
"expected boost < 1.0, got {}",
boosts[0].boost_factor
);
}
#[test]
fn test_boost_no_feedback() {
let conn = setup();
let boosts = compute_feedback_boosts(&conn, &[999], None).expect("boosts");
assert_eq!(boosts.len(), 1);
assert_eq!(boosts[0].boost_factor, 1.0);
assert_eq!(boosts[0].signal_count, 0);
assert_eq!(boosts[0].confidence, 0.0);
}
#[test]
fn test_boost_smoothing_prevents_extremes() {
let conn = setup();
record_feedback(&conn, "q", 55, FeedbackSignal::Useful, None, None, "ws").unwrap();
let boosts = compute_feedback_boosts(&conn, &[55], None).expect("boosts");
let expected = 1.0 + 1.0 / 6.0;
assert!((boosts[0].boost_factor - expected).abs() < 1e-9);
assert!(boosts[0].boost_factor < 1.3);
}
#[test]
fn test_apply_boosts_modifies_scores() {
let boosts = vec![
FeedbackBoost {
memory_id: 1,
boost_factor: 1.5,
signal_count: 5,
confidence: 0.5,
},
FeedbackBoost {
memory_id: 2,
boost_factor: 0.8,
signal_count: 3,
confidence: 0.3,
},
];
let mut scores = vec![(1_i64, 0.6_f32), (2_i64, 0.7_f32), (3_i64, 0.4_f32)];
apply_feedback_boosts(&mut scores, &boosts);
assert!(
(scores[0].1 - 0.9_f32).abs() < 1e-5,
"score[0] = {}",
scores[0].1
);
assert!(
(scores[1].1 - 0.56_f32).abs() < 1e-4,
"score[1] = {}",
scores[1].1
);
assert!(
(scores[2].1 - 0.4_f32).abs() < 1e-5,
"score[2] = {}",
scores[2].1
);
}
#[test]
fn test_boost_clamping() {
let boosts_high = vec![FeedbackBoost {
memory_id: 10,
boost_factor: 5.0,
signal_count: 100,
confidence: 1.0,
}];
let mut scores_high = vec![(10_i64, 0.9_f32)];
apply_feedback_boosts(&mut scores_high, &boosts_high);
assert!(
(scores_high[0].1 - 2.0_f32).abs() < 1e-5,
"expected clamp to 2.0, got {}",
scores_high[0].1
);
let boosts_low = vec![FeedbackBoost {
memory_id: 20,
boost_factor: 0.1,
signal_count: 100,
confidence: 1.0,
}];
let mut scores_low = vec![(20_i64, 0.9_f32)];
apply_feedback_boosts(&mut scores_low, &boosts_low);
assert!(
(scores_low[0].1 - 0.5_f32).abs() < 1e-5,
"expected clamp to 0.5, got {}",
scores_low[0].1
);
}
#[test]
fn test_delete_feedback() {
let conn = setup();
let fb = record_feedback(
&conn,
"to delete",
1,
FeedbackSignal::Useful,
None,
None,
"ws",
)
.expect("record");
delete_feedback(&conn, fb.id).expect("delete");
let remaining = get_feedback_for_memory(&conn, 1).expect("get");
assert!(remaining.is_empty());
}
#[test]
fn test_delete_nonexistent_feedback() {
let conn = setup();
let result = delete_feedback(&conn, 9999);
assert!(matches!(result, Err(EngramError::NotFound(_))));
}
#[test]
fn test_query_similarity_weighting() {
let conn = setup();
record_feedback(
&conn,
"rust async runtime",
42,
FeedbackSignal::Useful,
None,
None,
"ws",
)
.unwrap();
record_feedback(
&conn,
"python web framework",
42,
FeedbackSignal::Irrelevant,
None,
None,
"ws",
)
.unwrap();
let boosts_rust =
compute_feedback_boosts(&conn, &[42], Some("rust async")).expect("boosts");
assert!(
boosts_rust[0].boost_factor > 1.0,
"expected boost > 1.0 with matching query, got {}",
boosts_rust[0].boost_factor
);
let boosts_python =
compute_feedback_boosts(&conn, &[42], Some("python web")).expect("boosts");
assert!(
boosts_python[0].boost_factor < 1.0,
"expected boost < 1.0 with mismatched query, got {}",
boosts_python[0].boost_factor
);
}
#[test]
fn test_get_feedback_for_query() {
let conn = setup();
record_feedback(
&conn,
"specific query",
1,
FeedbackSignal::Useful,
None,
None,
"ws",
)
.unwrap();
record_feedback(
&conn,
"specific query",
2,
FeedbackSignal::Irrelevant,
None,
None,
"ws",
)
.unwrap();
record_feedback(
&conn,
"other query",
3,
FeedbackSignal::Useful,
None,
None,
"ws",
)
.unwrap();
let rows = get_feedback_for_query(&conn, "specific query").expect("get");
assert_eq!(rows.len(), 2);
for r in &rows {
assert_eq!(r.query, "specific query");
}
}
#[test]
fn test_stats_top_memories() {
let conn = setup();
for _ in 0..3 {
record_feedback(&conn, "q", 1, FeedbackSignal::Useful, None, None, "ws").unwrap();
}
record_feedback(&conn, "q", 2, FeedbackSignal::Useful, None, None, "ws").unwrap();
for _ in 0..2 {
record_feedback(&conn, "q", 3, FeedbackSignal::Irrelevant, None, None, "ws").unwrap();
}
let stats = feedback_stats(&conn, None).unwrap();
assert_eq!(stats.top_useful_memories[0].0, 1);
assert_eq!(stats.top_useful_memories[0].1, 3);
assert_eq!(stats.top_irrelevant_memories[0].0, 3);
assert_eq!(stats.top_irrelevant_memories[0].1, 2);
}
#[test]
fn test_stats_avg_rank() {
let conn = setup();
record_feedback(&conn, "q", 1, FeedbackSignal::Useful, Some(1), None, "ws").unwrap();
record_feedback(&conn, "q", 2, FeedbackSignal::Useful, Some(3), None, "ws").unwrap();
record_feedback(
&conn,
"q",
3,
FeedbackSignal::Irrelevant,
Some(10),
None,
"ws",
)
.unwrap();
let stats = feedback_stats(&conn, None).unwrap();
assert!((stats.avg_useful_rank.unwrap() - 2.0).abs() < 1e-9);
assert!((stats.avg_irrelevant_rank.unwrap() - 10.0).abs() < 1e-9);
}
#[test]
fn test_compute_boosts_empty_ids() {
let conn = setup();
let boosts = compute_feedback_boosts(&conn, &[], None).expect("boosts");
assert!(boosts.is_empty());
}
}