use crate::error::Result;
use crate::types::*;
use rusqlite::{params, Connection, OptionalExtension};
pub fn store_impression(conn: &Connection, imp: &NewImpression) -> Result<ImpressionId> {
let now = crate::db::now();
conn.execute(
"INSERT INTO impressions (domain, observation, valence, timestamp)
VALUES (?1, ?2, ?3, ?4)",
params![imp.domain, imp.observation, imp.valence, now],
)?;
Ok(ImpressionId(conn.last_insert_rowid()))
}
pub fn get_impressions_by_domain(
conn: &Connection,
domain: &str,
limit: u32,
) -> Result<Vec<Impression>> {
let mut stmt = conn.prepare(
"SELECT id, domain, observation, valence, timestamp
FROM impressions WHERE domain = ?1
ORDER BY timestamp DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![domain, limit], |row| {
Ok(Impression {
id: ImpressionId(row.get(0)?),
domain: row.get(1)?,
observation: row.get(2)?,
valence: row.get(3)?,
timestamp: row.get(4)?,
})
})?;
Ok(rows.filter_map(|r| r.ok()).collect())
}
pub fn count_impressions_by_domain(conn: &Connection, domain: &str) -> Result<u64> {
let count: i64 = conn.query_row(
"SELECT count(*) FROM impressions WHERE domain = ?1",
[domain],
|row| row.get(0),
)?;
Ok(count as u64)
}
pub fn store_preference(
conn: &Connection,
domain: &str,
preference: &str,
confidence: f32,
) -> Result<PreferenceId> {
let now = crate::db::now();
conn.execute(
"INSERT INTO preferences (domain, preference, confidence, evidence_count, first_observed, last_reinforced)
VALUES (?1, ?2, ?3, 1, ?4, ?4)",
params![domain, preference, confidence, now],
)?;
Ok(PreferenceId(conn.last_insert_rowid()))
}
pub fn get_preference(conn: &Connection, id: PreferenceId) -> Result<Preference> {
conn.query_row(
"SELECT id, domain, preference, confidence, evidence_count, first_observed, last_reinforced
FROM preferences WHERE id = ?1",
[id.0],
map_preference,
)
.optional()?
.ok_or_else(|| crate::error::AlayaError::NotFound(format!("preference {}", id.0)))
}
pub fn get_preferences(conn: &Connection, domain: Option<&str>) -> Result<Vec<Preference>> {
let (sql, param): (&str, Option<&str>) = match domain {
Some(d) => (
"SELECT id, domain, preference, confidence, evidence_count, first_observed, last_reinforced
FROM preferences WHERE domain = ?1 ORDER BY confidence DESC",
Some(d),
),
None => (
"SELECT id, domain, preference, confidence, evidence_count, first_observed, last_reinforced
FROM preferences ORDER BY confidence DESC",
None,
),
};
let mut stmt = conn.prepare(sql)?;
let rows = if let Some(d) = param {
stmt.query_map([d], map_preference)?
} else {
stmt.query_map([], map_preference)?
};
Ok(rows.filter_map(|r| r.ok()).collect())
}
fn map_preference(row: &rusqlite::Row<'_>) -> rusqlite::Result<Preference> {
Ok(Preference {
id: PreferenceId(row.get(0)?),
domain: row.get(1)?,
preference: row.get(2)?,
confidence: row.get(3)?,
evidence_count: row.get(4)?,
first_observed: row.get(5)?,
last_reinforced: row.get(6)?,
})
}
pub fn reinforce_preference(
conn: &Connection,
id: PreferenceId,
additional_evidence: u32,
) -> Result<()> {
let now = crate::db::now();
conn.execute(
"UPDATE preferences SET evidence_count = evidence_count + ?2,
last_reinforced = ?3,
confidence = MIN(1.0, confidence + 0.1 * ?2)
WHERE id = ?1",
params![id.0, additional_evidence, now],
)?;
Ok(())
}
pub fn decay_preferences(conn: &Connection, now: i64, half_life_secs: i64) -> Result<u64> {
let changed = conn.execute(
"UPDATE preferences SET confidence = confidence * 0.95
WHERE (?1 - last_reinforced) > ?2 AND confidence > 0.01",
rusqlite::params![now, half_life_secs],
)?;
Ok(changed as u64)
}
pub fn prune_weak_preferences(conn: &Connection, min_confidence: f32) -> Result<u64> {
let deleted = conn.execute(
"DELETE FROM preferences WHERE confidence < ?1",
[min_confidence],
)?;
Ok(deleted as u64)
}
pub fn prune_old_impressions(conn: &Connection, max_age_secs: i64) -> Result<u64> {
let now = crate::db::now();
let cutoff = now - max_age_secs;
let deleted = conn.execute("DELETE FROM impressions WHERE timestamp < ?1", [cutoff])?;
Ok(deleted as u64)
}
pub fn count_preferences(conn: &Connection) -> Result<u64> {
let count: i64 = conn.query_row("SELECT count(*) FROM preferences", [], |row| row.get(0))?;
Ok(count as u64)
}
pub fn count_impressions(conn: &Connection) -> Result<u64> {
let count: i64 = conn.query_row("SELECT count(*) FROM impressions", [], |row| row.get(0))?;
Ok(count as u64)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::open_memory_db;
#[test]
fn test_impressions_crud() {
let conn = open_memory_db().unwrap();
let id = store_impression(
&conn,
&NewImpression {
domain: "communication".to_string(),
observation: "user prefers short answers".to_string(),
valence: 1.0,
},
)
.unwrap();
assert_eq!(id.0, 1);
let imps = get_impressions_by_domain(&conn, "communication", 10).unwrap();
assert_eq!(imps.len(), 1);
assert_eq!(imps[0].observation, "user prefers short answers");
}
#[test]
fn test_preferences_crud() {
let conn = open_memory_db().unwrap();
let id = store_preference(&conn, "style", "concise answers", 0.7).unwrap();
let prefs = get_preferences(&conn, Some("style")).unwrap();
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].preference, "concise answers");
reinforce_preference(&conn, id, 2).unwrap();
let prefs = get_preferences(&conn, Some("style")).unwrap();
assert_eq!(prefs[0].evidence_count, 3);
assert!(prefs[0].confidence > 0.7);
}
#[test]
fn test_count_impressions_by_domain() {
let conn = open_memory_db().unwrap();
store_impression(
&conn,
&NewImpression {
domain: "style".to_string(),
observation: "concise".to_string(),
valence: 0.8,
},
)
.unwrap();
store_impression(
&conn,
&NewImpression {
domain: "style".to_string(),
observation: "brief".to_string(),
valence: 0.7,
},
)
.unwrap();
store_impression(
&conn,
&NewImpression {
domain: "other".to_string(),
observation: "something".to_string(),
valence: 0.5,
},
)
.unwrap();
assert_eq!(count_impressions_by_domain(&conn, "style").unwrap(), 2);
assert_eq!(count_impressions_by_domain(&conn, "other").unwrap(), 1);
assert_eq!(
count_impressions_by_domain(&conn, "nonexistent").unwrap(),
0
);
}
#[test]
fn test_decay_preferences() {
let conn = open_memory_db().unwrap();
store_preference(&conn, "style", "concise", 0.8).unwrap();
let future = crate::db::now() + 100_000;
let decayed = decay_preferences(&conn, future, 3600).unwrap(); assert_eq!(decayed, 1);
let prefs = get_preferences(&conn, Some("style")).unwrap();
assert!(prefs[0].confidence < 0.8, "confidence should have decayed");
}
#[test]
fn test_decay_preferences_skips_fresh() {
let conn = open_memory_db().unwrap();
store_preference(&conn, "style", "concise", 0.8).unwrap();
let now = crate::db::now();
let decayed = decay_preferences(&conn, now, 999_999_999).unwrap();
assert_eq!(decayed, 0);
}
#[test]
fn test_prune_weak_preferences() {
let conn = open_memory_db().unwrap();
store_preference(&conn, "style", "concise", 0.8).unwrap();
store_preference(&conn, "other", "weak", 0.02).unwrap();
let pruned = prune_weak_preferences(&conn, 0.05).unwrap();
assert_eq!(pruned, 1);
let remaining = get_preferences(&conn, None).unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].preference, "concise");
}
#[test]
fn test_prune_old_impressions() {
let conn = open_memory_db().unwrap();
store_impression(
&conn,
&NewImpression {
domain: "style".to_string(),
observation: "recent".to_string(),
valence: 0.8,
},
)
.unwrap();
let pruned = prune_old_impressions(&conn, 999_999_999).unwrap();
assert_eq!(pruned, 0);
let _pruned = prune_old_impressions(&conn, 1).unwrap();
}
#[test]
fn test_get_preferences_all_domains() {
let conn = open_memory_db().unwrap();
store_preference(&conn, "style", "concise", 0.8).unwrap();
store_preference(&conn, "code", "modular", 0.7).unwrap();
let all = get_preferences(&conn, None).unwrap();
assert_eq!(all.len(), 2);
assert!(all[0].confidence >= all[1].confidence);
}
#[test]
fn test_count_impressions_and_preferences() {
let conn = open_memory_db().unwrap();
assert_eq!(count_impressions(&conn).unwrap(), 0);
assert_eq!(count_preferences(&conn).unwrap(), 0);
store_impression(
&conn,
&NewImpression {
domain: "a".to_string(),
observation: "b".to_string(),
valence: 0.5,
},
)
.unwrap();
assert_eq!(count_impressions(&conn).unwrap(), 1);
store_preference(&conn, "a", "b", 0.5).unwrap();
assert_eq!(count_preferences(&conn).unwrap(), 1);
}
#[test]
fn test_get_preference_by_id() {
let conn = open_memory_db().unwrap();
let id = store_preference(&conn, "ui", "dark mode", 0.85).unwrap();
let pref = get_preference(&conn, id).unwrap();
assert_eq!(pref.domain, "ui");
assert_eq!(pref.preference, "dark mode");
assert!((pref.confidence - 0.85).abs() < 0.01);
assert_eq!(pref.evidence_count, 1);
}
#[test]
fn test_get_preference_not_found() {
let conn = open_memory_db().unwrap();
let result = get_preference(&conn, PreferenceId(999));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::AlayaError::NotFound(_)
));
}
#[test]
fn test_reinforce_preference_increases_confidence() {
let conn = open_memory_db().unwrap();
let id = store_preference(&conn, "coding", "modular code", 0.5).unwrap();
reinforce_preference(&conn, id, 3).unwrap();
let pref = get_preference(&conn, id).unwrap();
assert_eq!(pref.evidence_count, 4, "evidence_count should be 1 + 3");
assert!(
pref.confidence > 0.5,
"confidence should increase after reinforce, got {}",
pref.confidence
);
assert!((pref.confidence - 0.8).abs() < 0.01);
}
#[test]
fn test_reinforce_preference_caps_at_one() {
let conn = open_memory_db().unwrap();
let id = store_preference(&conn, "ui", "dark theme", 0.95).unwrap();
reinforce_preference(&conn, id, 10).unwrap(); let pref = get_preference(&conn, id).unwrap();
assert!(
pref.confidence <= 1.0,
"confidence should be capped at 1.0, got {}",
pref.confidence
);
}
#[test]
fn test_get_impressions_by_domain_ordering() {
let conn = open_memory_db().unwrap();
store_impression(
&conn,
&NewImpression {
domain: "x".to_string(),
observation: "old".to_string(),
valence: 0.5,
},
)
.unwrap();
conn.execute(
"UPDATE impressions SET timestamp = 100 WHERE observation = 'old'",
[],
)
.unwrap();
store_impression(
&conn,
&NewImpression {
domain: "x".to_string(),
observation: "new".to_string(),
valence: 0.8,
},
)
.unwrap();
let imps = get_impressions_by_domain(&conn, "x", 10).unwrap();
assert_eq!(imps.len(), 2);
assert_eq!(imps[0].observation, "new");
assert_eq!(imps[1].observation, "old");
}
#[test]
fn test_get_impressions_by_domain_limit() {
let conn = open_memory_db().unwrap();
for i in 0..5 {
store_impression(
&conn,
&NewImpression {
domain: "test".to_string(),
observation: format!("obs {i}"),
valence: 0.5,
},
)
.unwrap();
}
let imps = get_impressions_by_domain(&conn, "test", 3).unwrap();
assert_eq!(imps.len(), 3, "should respect limit of 3");
}
#[test]
fn test_prune_weak_preferences_empty_db() {
let conn = open_memory_db().unwrap();
let pruned = prune_weak_preferences(&conn, 0.1).unwrap();
assert_eq!(pruned, 0, "pruning empty db should return 0");
}
#[test]
fn test_decay_preferences_low_confidence_skipped() {
let conn = open_memory_db().unwrap();
store_preference(&conn, "style", "minimal", 0.01).unwrap();
let future = crate::db::now() + 200_000;
let decayed = decay_preferences(&conn, future, 3600).unwrap();
assert_eq!(
decayed, 0,
"confidence at 0.01 should not be decayed (already at floor)"
);
}
#[test]
fn test_count_impressions_by_domain_empty() {
let conn = open_memory_db().unwrap();
let count = count_impressions_by_domain(&conn, "nonexistent").unwrap();
assert_eq!(count, 0);
}
}