use std::collections::HashMap;
use sqlx::Row as _;
use zeph_db::DbPool;
use crate::types::MessageId;
const MAX_ACCESS_COUNT: f64 = 10_000.0;
pub struct AccessFrequencyCache {
pool: DbPool,
}
impl AccessFrequencyCache {
#[must_use]
pub fn new(pool: DbPool) -> Self {
Self { pool }
}
#[tracing::instrument(
name = "memory.five_signal.access_frequency.load",
skip(self, fact_ids),
fields(fact_count = fact_ids.len())
)]
pub async fn load_for_candidates(
&self,
session_id: &str,
fact_ids: &[MessageId],
) -> Result<HashMap<MessageId, f64>, crate::error::MemoryError> {
tracing::debug!("five_signal: loading access frequencies");
if fact_ids.is_empty() {
return Ok(HashMap::new());
}
let ids: Vec<i64> = fact_ids.iter().map(|id| id.0).collect();
let placeholders: String = ids
.iter()
.enumerate()
.map(|(i, _)| format!("?{}", i + 2))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT fact_id, COUNT(*) as cnt FROM fact_access_log \
WHERE session_id = ?1 AND fact_id IN ({placeholders}) \
GROUP BY fact_id"
);
let mut q = sqlx::query(&sql).bind(session_id);
for id in &ids {
q = q.bind(id);
}
let rows = q
.fetch_all(&self.pool)
.await
.map_err(|e| crate::error::MemoryError::Db(e.into()))?;
let counts: HashMap<i64, i64> = rows
.iter()
.map(|row| (row.get::<i64, _>("fact_id"), row.get::<i64, _>("cnt")))
.collect();
let normalized = fact_ids
.iter()
.map(|id| {
#[expect(clippy::cast_precision_loss)]
let raw = *counts.get(&id.0).unwrap_or(&0) as f64;
let score =
(1.0_f64 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
(*id, score)
})
.collect();
Ok(normalized)
}
#[tracing::instrument(
name = "memory.five_signal.access_frequency.log",
skip(self, fact_type, session_id),
fields(fact_id = fact_id.0)
)]
pub async fn log_access(&self, fact_id: MessageId, fact_type: &str, session_id: &str) {
tracing::debug!("five_signal: logging access");
let accessed_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
let res = sqlx::query(
"INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
VALUES (?1, ?2, ?3, ?4)",
)
.bind(fact_id.0)
.bind(fact_type)
.bind(session_id)
.bind(accessed_at)
.execute(&self.pool)
.await;
if let Err(e) = res {
tracing::warn!(
fact_id = fact_id.0,
error = %e,
"five_signal: failed to log fact access (non-fatal)"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn test_pool() -> DbPool {
crate::store::SqliteStore::with_pool_size(":memory:", 1)
.await
.expect("in-memory SQLite failed")
.pool()
.clone()
}
#[tokio::test]
async fn load_for_candidates_empty_returns_empty() {
let pool = test_pool().await;
let cache = AccessFrequencyCache::new(pool);
let result = cache.load_for_candidates("s1", &[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn load_for_candidates_no_rows_gives_zero_score() {
let pool = test_pool().await;
let cache = AccessFrequencyCache::new(pool);
let ids = vec![MessageId(1), MessageId(2)];
let scores = cache.load_for_candidates("s1", &ids).await.unwrap();
assert_eq!(scores.len(), 2);
assert!(scores[&MessageId(1)] < f64::EPSILON);
assert!(scores[&MessageId(2)] < f64::EPSILON);
}
#[tokio::test]
async fn load_for_candidates_higher_count_gives_higher_score() {
let pool = test_pool().await;
let cache = AccessFrequencyCache::new(pool.clone());
let session = "test-session";
sqlx::query(
"INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
VALUES (?1, 'episodic', ?2, 0)",
)
.bind(10_i64)
.bind(session)
.execute(&pool)
.await
.unwrap();
for _ in 0..5_u8 {
sqlx::query(
"INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
VALUES (?1, 'episodic', ?2, 0)",
)
.bind(20_i64)
.bind(session)
.execute(&pool)
.await
.unwrap();
}
let ids = vec![MessageId(10), MessageId(20)];
let scores = cache.load_for_candidates(session, &ids).await.unwrap();
let s10 = scores[&MessageId(10)];
let s20 = scores[&MessageId(20)];
assert!(
s20 > s10,
"higher access count must yield higher score: {s20} vs {s10}"
);
assert!(s10 > 0.0, "score for fact with 1 access must be > 0");
assert!(s20 <= 1.0, "score must be capped at 1.0");
}
#[tokio::test]
async fn load_for_candidates_ignores_other_sessions() {
let pool = test_pool().await;
let cache = AccessFrequencyCache::new(pool.clone());
sqlx::query(
"INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
VALUES (?1, 'episodic', ?2, 0)",
)
.bind(99_i64)
.bind("other-session")
.execute(&pool)
.await
.unwrap();
let ids = vec![MessageId(99)];
let scores = cache.load_for_candidates("my-session", &ids).await.unwrap();
assert!(
scores[&MessageId(99)] < f64::EPSILON,
"score must be 0 for different session"
);
}
#[test]
fn normalization_zero_count() {
let raw = 0.0_f64;
let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
assert!((score).abs() < 1e-9, "zero access → score 0.0");
}
#[test]
fn normalization_max_count() {
let raw = MAX_ACCESS_COUNT;
let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
assert!((score - 1.0).abs() < 1e-9, "max access → score 1.0");
}
#[test]
fn normalization_overflow_clamped() {
let raw = MAX_ACCESS_COUNT * 2.0;
let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
assert!((score - 1.0).abs() < 1e-9, "overflow is clamped to 1.0");
}
#[test]
fn normalization_monotone() {
let score_low = (1.0 + 10.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
let score_high =
(1.0 + 100.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
assert!(score_high > score_low);
}
}