use chrono::Utc;
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use crate::error::{EngramError, Result};
pub const CREATE_COACTIVATION_EDGES_TABLE: &str = r#"
CREATE TABLE IF NOT EXISTS coactivation_edges (
from_id INTEGER NOT NULL,
to_id INTEGER NOT NULL,
strength REAL NOT NULL DEFAULT 0.0,
coactivation_count INTEGER NOT NULL DEFAULT 0,
last_coactivated TEXT NOT NULL,
PRIMARY KEY (from_id, to_id)
);
CREATE INDEX IF NOT EXISTS idx_coact_from ON coactivation_edges(from_id);
CREATE INDEX IF NOT EXISTS idx_coact_to ON coactivation_edges(to_id);
CREATE INDEX IF NOT EXISTS idx_coact_str ON coactivation_edges(strength DESC);
"#;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoactivationConfig {
pub learning_rate: f64,
pub decay_rate: f64,
pub min_strength: f64,
}
impl Default for CoactivationConfig {
fn default() -> Self {
Self {
learning_rate: 0.1,
decay_rate: 0.01,
min_strength: 0.01,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoactivationEdge {
pub from_id: i64,
pub to_id: i64,
pub strength: f64,
pub count: i64,
pub last_coactivated: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoactivationReport {
pub total_edges: i64,
pub avg_strength: f64,
pub strongest_pairs: Vec<(i64, i64, f64)>,
}
pub struct CoactivationTracker {
pub config: CoactivationConfig,
}
impl CoactivationTracker {
pub fn new() -> Self {
Self {
config: CoactivationConfig::default(),
}
}
pub fn with_config(config: CoactivationConfig) -> Self {
Self { config }
}
pub fn record_coactivation(
&self,
conn: &Connection,
memory_ids: &[i64],
_session_id: &str,
) -> Result<usize> {
let mut ids: Vec<i64> = memory_ids.to_vec();
ids.sort_unstable();
ids.dedup();
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let lr = self.config.learning_rate;
let mut updated = 0usize;
for i in 0..ids.len() {
for j in (i + 1)..ids.len() {
let from_id = ids[i];
let to_id = ids[j];
self.upsert_edge(conn, from_id, to_id, lr, &now)?;
updated += 1;
}
}
Ok(updated)
}
pub fn strengthen(&self, conn: &Connection, from_id: i64, to_id: i64) -> Result<f64> {
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
self.upsert_edge(conn, from_id, to_id, self.config.learning_rate, &now)?;
let strength: f64 = conn
.query_row(
"SELECT strength FROM coactivation_edges WHERE from_id = ?1 AND to_id = ?2",
params![from_id, to_id],
|row| row.get(0),
)
.map_err(EngramError::Database)?;
Ok(strength)
}
pub fn weaken_unused(
&self,
conn: &Connection,
decay_rate: f64,
min_age_days: u32,
) -> Result<usize> {
let cutoff = Utc::now() - chrono::Duration::days(min_age_days as i64);
let cutoff_str = cutoff.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let min_strength = self.config.min_strength;
let updated = conn
.execute(
"UPDATE coactivation_edges
SET strength = strength * (1.0 - ?1)
WHERE last_coactivated < ?2",
params![decay_rate, cutoff_str],
)
.map_err(EngramError::Database)?;
let deleted = conn
.execute(
"DELETE FROM coactivation_edges WHERE strength < ?1",
params![min_strength],
)
.map_err(EngramError::Database)?;
Ok(updated + deleted)
}
pub fn get_coactivation_graph(
&self,
conn: &Connection,
memory_id: i64,
) -> Result<Vec<CoactivationEdge>> {
let mut stmt = conn
.prepare(
"SELECT from_id, to_id, strength, coactivation_count, last_coactivated
FROM coactivation_edges
WHERE from_id = ?1 OR to_id = ?1
ORDER BY strength DESC",
)
.map_err(EngramError::Database)?;
let edges = stmt
.query_map(params![memory_id], |row| {
Ok(CoactivationEdge {
from_id: row.get(0)?,
to_id: row.get(1)?,
strength: row.get(2)?,
count: row.get(3)?,
last_coactivated: row.get(4)?,
})
})
.map_err(EngramError::Database)?
.collect::<rusqlite::Result<Vec<_>>>()
.map_err(EngramError::Database)?;
Ok(edges)
}
pub fn suggest_related(
&self,
conn: &Connection,
memory_id: i64,
top_k: usize,
) -> Result<Vec<(i64, f64)>> {
let mut stmt = conn
.prepare(
"SELECT
CASE WHEN from_id = ?1 THEN to_id ELSE from_id END AS neighbor,
strength
FROM coactivation_edges
WHERE from_id = ?1 OR to_id = ?1
ORDER BY strength DESC
LIMIT ?2",
)
.map_err(EngramError::Database)?;
let pairs = stmt
.query_map(params![memory_id, top_k as i64], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
})
.map_err(EngramError::Database)?
.collect::<rusqlite::Result<Vec<_>>>()
.map_err(EngramError::Database)?;
Ok(pairs)
}
pub fn report(&self, conn: &Connection) -> Result<CoactivationReport> {
let (total_edges, avg_strength): (i64, f64) = conn
.query_row(
"SELECT COUNT(*), COALESCE(AVG(strength), 0.0) FROM coactivation_edges",
[],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.map_err(EngramError::Database)?;
let mut stmt = conn
.prepare(
"SELECT from_id, to_id, strength
FROM coactivation_edges
ORDER BY strength DESC
LIMIT 10",
)
.map_err(EngramError::Database)?;
let strongest_pairs = stmt
.query_map([], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, f64>(2)?,
))
})
.map_err(EngramError::Database)?
.collect::<rusqlite::Result<Vec<_>>>()
.map_err(EngramError::Database)?;
Ok(CoactivationReport {
total_edges,
avg_strength,
strongest_pairs,
})
}
fn upsert_edge(
&self,
conn: &Connection,
from_id: i64,
to_id: i64,
lr: f64,
now: &str,
) -> Result<()> {
conn.execute(
"INSERT INTO coactivation_edges (from_id, to_id, strength, coactivation_count, last_coactivated)
VALUES (?1, ?2, MIN(1.0, ?3), 1, ?4)
ON CONFLICT (from_id, to_id) DO UPDATE SET
strength = MIN(1.0, strength + ?3 * (1.0 - strength)),
coactivation_count = coactivation_count + 1,
last_coactivated = ?4",
params![from_id, to_id, lr, now],
)
.map_err(EngramError::Database)?;
Ok(())
}
}
impl Default for CoactivationTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
fn setup_db() -> Connection {
let conn = Connection::open_in_memory().expect("open in-memory DB");
conn.execute_batch(CREATE_COACTIVATION_EDGES_TABLE)
.expect("create table");
conn
}
fn tracker() -> CoactivationTracker {
CoactivationTracker::new()
}
#[test]
fn test_record_coactivation_creates_edges() {
let conn = setup_db();
let t = tracker();
let n = t
.record_coactivation(&conn, &[1, 2, 3], "session-1")
.expect("record");
assert_eq!(n, 3, "should create one edge per unique pair");
let report = t.report(&conn).expect("report");
assert_eq!(report.total_edges, 3);
}
#[test]
fn test_strength_increases_with_repeated_coactivation() {
let conn = setup_db();
let t = tracker();
t.record_coactivation(&conn, &[10, 20], "s1")
.expect("first");
let s1 = get_strength(&conn, 10, 20);
t.record_coactivation(&conn, &[10, 20], "s2")
.expect("second");
let s2 = get_strength(&conn, 10, 20);
t.record_coactivation(&conn, &[10, 20], "s3")
.expect("third");
let s3 = get_strength(&conn, 10, 20);
assert!(s1 > 0.0, "first activation must produce positive strength");
assert!(s2 > s1, "second activation must increase strength");
assert!(s3 > s2, "third activation must increase strength further");
assert!(s3 <= 1.0, "strength must be capped at 1.0");
}
#[test]
fn test_weaken_unused_decays_and_prunes() {
let conn = setup_db();
let t = CoactivationTracker::with_config(CoactivationConfig {
learning_rate: 0.1,
decay_rate: 0.5, min_strength: 0.08, });
conn.execute(
"INSERT INTO coactivation_edges
(from_id, to_id, strength, coactivation_count, last_coactivated)
VALUES (100, 200, 0.10, 1, '2020-01-01T00:00:00.000Z')",
[],
)
.expect("insert old edge");
let now_str = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
conn.execute(
"INSERT INTO coactivation_edges
(from_id, to_id, strength, coactivation_count, last_coactivated)
VALUES (100, 300, 0.50, 5, ?1)",
params![now_str],
)
.expect("insert fresh edge");
let affected = t.weaken_unused(&conn, 0.5, 1).expect("weaken");
assert!(affected >= 1, "at least one edge should be affected");
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM coactivation_edges WHERE from_id=100 AND to_id=200",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 0, "sub-threshold edge should be deleted");
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM coactivation_edges WHERE from_id=100 AND to_id=300",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1, "fresh edge should survive");
}
#[test]
fn test_get_coactivation_graph_returns_sorted_neighbors() {
let conn = setup_db();
let t = tracker();
for _ in 0..3 {
t.record_coactivation(&conn, &[1, 2], "s").unwrap();
}
for _ in 0..5 {
t.record_coactivation(&conn, &[1, 3], "s").unwrap();
}
for _ in 0..1 {
t.record_coactivation(&conn, &[1, 4], "s").unwrap();
}
let graph = t.get_coactivation_graph(&conn, 1).expect("graph");
assert_eq!(graph.len(), 3, "node 1 has 3 neighbors");
for w in graph.windows(2) {
assert!(
w[0].strength >= w[1].strength,
"edges must be sorted by strength desc: {} >= {}",
w[0].strength,
w[1].strength
);
}
}
#[test]
fn test_suggest_related_returns_top_k() {
let conn = setup_db();
let t = tracker();
for (neighbor, times) in [(10i64, 1), (20, 3), (30, 5), (40, 2), (50, 4)] {
for _ in 0..times {
t.record_coactivation(&conn, &[1, neighbor], "s").unwrap();
}
}
let top3 = t.suggest_related(&conn, 1, 3).expect("suggest");
assert_eq!(top3.len(), 3, "must return exactly top_k results");
for w in top3.windows(2) {
assert!(w[0].1 >= w[1].1, "results must be sorted by strength desc");
}
assert_eq!(top3[0].0, 30, "strongest neighbor should be memory 30");
}
#[test]
fn test_report_stats() {
let conn = setup_db();
let t = tracker();
t.record_coactivation(&conn, &[1, 2, 3], "s1").unwrap();
let report = t.report(&conn).expect("report");
assert_eq!(report.total_edges, 3);
assert!(report.avg_strength > 0.0, "avg_strength must be positive");
assert!(
report.avg_strength <= 1.0,
"avg_strength must be at most 1.0"
);
assert!(
!report.strongest_pairs.is_empty(),
"strongest_pairs must not be empty"
);
assert!(
report.strongest_pairs.len() <= 10,
"strongest_pairs must have at most 10 entries"
);
}
#[test]
fn test_empty_graph() {
let conn = setup_db();
let t = tracker();
let graph = t.get_coactivation_graph(&conn, 999).expect("graph");
assert!(graph.is_empty(), "no neighbors for unknown memory");
let related = t.suggest_related(&conn, 999, 5).expect("suggest");
assert!(related.is_empty(), "no suggestions for unknown memory");
let report = t.report(&conn).expect("report");
assert_eq!(report.total_edges, 0);
assert_eq!(report.avg_strength, 0.0);
assert!(report.strongest_pairs.is_empty());
}
#[test]
fn test_strengthen_single_edge() {
let conn = setup_db();
let t = tracker();
let s1 = t.strengthen(&conn, 5, 6).expect("strengthen 1");
let s2 = t.strengthen(&conn, 5, 6).expect("strengthen 2");
assert!(s1 > 0.0);
assert!(s2 > s1, "repeated calls must increase strength");
}
#[test]
fn test_single_memory_no_self_loops() {
let conn = setup_db();
let t = tracker();
let n = t
.record_coactivation(&conn, &[42], "session-x")
.expect("record single");
assert_eq!(n, 0, "no pairs from a single memory");
let report = t.report(&conn).expect("report");
assert_eq!(report.total_edges, 0);
}
#[test]
fn test_coactivation_count_increments() {
let conn = setup_db();
let t = tracker();
for _ in 0..4 {
t.record_coactivation(&conn, &[7, 8], "s").unwrap();
}
let graph = t.get_coactivation_graph(&conn, 7).expect("graph");
assert_eq!(graph.len(), 1);
assert_eq!(graph[0].count, 4, "count should reflect 4 co-activations");
}
fn get_strength(conn: &Connection, from_id: i64, to_id: i64) -> f64 {
conn.query_row(
"SELECT strength FROM coactivation_edges WHERE from_id=?1 AND to_id=?2",
params![from_id, to_id],
|r| r.get(0),
)
.unwrap_or(0.0)
}
}