use crate::error::Result;
use crate::types::*;
use rusqlite::{params, Connection, OptionalExtension};
pub fn init_strength(conn: &Connection, node: NodeRef) -> Result<()> {
let now = crate::db::now();
conn.execute(
"INSERT OR IGNORE INTO node_strengths (node_type, node_id, storage_strength, retrieval_strength, access_count, last_accessed)
VALUES (?1, ?2, 0.5, 1.0, 1, ?3)",
params![node.type_str(), node.id(), now],
)?;
Ok(())
}
#[allow(dead_code)]
pub fn get_strength(conn: &Connection, node: NodeRef) -> Result<NodeStrength> {
Ok(conn
.query_row(
"SELECT storage_strength, retrieval_strength, access_count, last_accessed
FROM node_strengths WHERE node_type = ?1 AND node_id = ?2",
params![node.type_str(), node.id()],
|row| {
Ok(NodeStrength {
node,
storage_strength: row.get(0)?,
retrieval_strength: row.get(1)?,
access_count: row.get(2)?,
last_accessed: row.get(3)?,
})
},
)
.optional()?
.unwrap_or(NodeStrength {
node,
storage_strength: 0.5,
retrieval_strength: 0.5,
access_count: 0,
last_accessed: 0,
}))
}
pub fn on_access(conn: &Connection, node: NodeRef) -> Result<()> {
let now = crate::db::now();
conn.execute(
"INSERT INTO node_strengths (node_type, node_id, storage_strength, retrieval_strength, access_count, last_accessed)
VALUES (?1, ?2, 0.6, 1.0, 1, ?3)
ON CONFLICT(node_type, node_id) DO UPDATE SET
storage_strength = MIN(1.0, storage_strength + 0.05 * (1.0 - storage_strength)),
retrieval_strength = 1.0,
access_count = access_count + 1,
last_accessed = ?3",
params![node.type_str(), node.id(), now],
)?;
Ok(())
}
#[allow(dead_code)]
pub fn boost_retrieval(conn: &Connection, node: NodeRef, factor: f32) -> Result<()> {
conn.execute(
"UPDATE node_strengths SET retrieval_strength = MIN(1.0, retrieval_strength * ?3)
WHERE node_type = ?1 AND node_id = ?2",
params![node.type_str(), node.id(), factor],
)?;
Ok(())
}
pub fn suppress_retrieval(conn: &Connection, node: NodeRef, factor: f32) -> Result<()> {
conn.execute(
"UPDATE node_strengths SET retrieval_strength = retrieval_strength * ?3
WHERE node_type = ?1 AND node_id = ?2",
params![node.type_str(), node.id(), factor],
)?;
Ok(())
}
pub fn decay_all_retrieval(conn: &Connection, decay_factor: f32) -> Result<u64> {
crate::decay::apply_multiplicative_sql(
conn,
"node_strengths",
"retrieval_strength",
decay_factor as f64,
)
}
pub fn find_archivable(
conn: &Connection,
storage_thresh: f32,
retrieval_thresh: f32,
) -> Result<Vec<NodeRef>> {
let mut stmt = conn.prepare(
"SELECT node_type, node_id FROM node_strengths
WHERE storage_strength < ?1 AND retrieval_strength < ?2",
)?;
let rows = stmt.query_map(params![storage_thresh, retrieval_thresh], |row| {
let ntype: String = row.get(0)?;
let nid: i64 = row.get(1)?;
Ok((ntype, nid))
})?;
Ok(rows
.filter_map(|r| r.ok())
.filter_map(|(t, id)| NodeRef::from_parts(&t, id))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::open_memory_db;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_on_access_ss_bounded(access_count in 1u32..100) {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
for _ in 0..access_count {
on_access(&conn, node).unwrap();
}
let s = get_strength(&conn, node).unwrap();
prop_assert!(s.storage_strength >= 0.0, "SS below 0: {}", s.storage_strength);
prop_assert!(s.storage_strength <= 1.0, "SS above 1: {}", s.storage_strength);
prop_assert!(s.retrieval_strength >= 0.0, "RS below 0: {}", s.retrieval_strength);
prop_assert!(s.retrieval_strength <= 1.0, "RS above 1: {}", s.retrieval_strength);
}
#[test]
fn prop_suppress_keeps_rs_non_negative(factor in 0.0f32..1.0f32) {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
suppress_retrieval(&conn, node, factor).unwrap();
let s = get_strength(&conn, node).unwrap();
prop_assert!(s.retrieval_strength >= 0.0, "RS should be >= 0, got {}", s.retrieval_strength);
}
#[test]
fn prop_decay_all_keeps_rs_non_negative(factor in 0.0f32..1.0f32) {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
decay_all_retrieval(&conn, factor).unwrap();
let s = get_strength(&conn, node).unwrap();
prop_assert!(s.retrieval_strength >= 0.0, "RS should be >= 0, got {}", s.retrieval_strength);
}
}
#[test]
fn test_init_and_access() {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
let s = get_strength(&conn, node).unwrap();
assert_eq!(s.access_count, 1);
assert!((s.retrieval_strength - 1.0).abs() < 0.01);
on_access(&conn, node).unwrap();
let s = get_strength(&conn, node).unwrap();
assert_eq!(s.access_count, 2);
assert!(s.storage_strength > 0.5);
}
#[test]
fn test_suppress_and_decay() {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
suppress_retrieval(&conn, node, 0.5).unwrap();
let s = get_strength(&conn, node).unwrap();
assert!((s.retrieval_strength - 0.5).abs() < 0.01);
decay_all_retrieval(&conn, 0.9).unwrap();
let s = get_strength(&conn, node).unwrap();
assert!(s.retrieval_strength < 0.5);
}
#[test]
fn test_boost_retrieval() {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
suppress_retrieval(&conn, node, 0.5).unwrap();
let before = get_strength(&conn, node).unwrap();
assert!((before.retrieval_strength - 0.5).abs() < 0.01);
boost_retrieval(&conn, node, 1.5).unwrap();
let after = get_strength(&conn, node).unwrap();
assert!(after.retrieval_strength > before.retrieval_strength);
assert!((after.retrieval_strength - 0.75).abs() < 0.01);
}
#[test]
fn test_boost_retrieval_clamps_at_one() {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(1));
init_strength(&conn, node).unwrap();
boost_retrieval(&conn, node, 2.0).unwrap();
let s = get_strength(&conn, node).unwrap();
assert!((s.retrieval_strength - 1.0).abs() < 0.01);
}
#[test]
fn test_find_archivable() {
let conn = open_memory_db().unwrap();
let node1 = NodeRef::Episode(EpisodeId(1));
let node2 = NodeRef::Episode(EpisodeId(2));
init_strength(&conn, node1).unwrap();
init_strength(&conn, node2).unwrap();
suppress_retrieval(&conn, node1, 0.01).unwrap();
conn.execute(
"UPDATE node_strengths SET storage_strength = 0.05 WHERE node_id = 1",
[],
)
.unwrap();
let archivable = find_archivable(&conn, 0.1, 0.05).unwrap();
assert_eq!(archivable.len(), 1);
assert_eq!(archivable[0], node1);
}
#[test]
fn test_find_archivable_empty() {
let conn = open_memory_db().unwrap();
let archivable = find_archivable(&conn, 0.1, 0.05).unwrap();
assert!(archivable.is_empty());
}
#[test]
fn test_get_strength_default_for_untracked() {
let conn = open_memory_db().unwrap();
let node = NodeRef::Episode(EpisodeId(999));
let s = get_strength(&conn, node).unwrap();
assert_eq!(s.access_count, 0);
assert!((s.storage_strength - 0.5).abs() < 0.01);
assert!((s.retrieval_strength - 0.5).abs() < 0.01);
}
}