use crate::error::Result;
use crate::vector::confidence::now_unix;
use rusqlite::Connection;
use std::collections::{HashSet, VecDeque};
use tracing::debug;
const MAX_DEPTH: usize = 2;
const ATTENUATION: f64 = 0.5;
pub fn propagate(conn: &Connection, origin_id: i64, base_penalty: f64) -> Result<()> {
let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
let mut visited: HashSet<i64> = HashSet::new();
visited.insert(origin_id);
queue.push_back((origin_id, 0));
while let Some((current_id, depth)) = queue.pop_front() {
if depth >= MAX_DEPTH {
continue;
}
let next_depth = depth + 1;
let actual_penalty = base_penalty * ATTENUATION.powi(next_depth as i32);
let dependents = dependents_of(conn, current_id)?;
for dep_id in dependents {
if visited.contains(&dep_id) {
continue;
}
visited.insert(dep_id);
apply_penalty(conn, dep_id, actual_penalty)?;
debug!(
dep_id,
depth = next_depth,
actual_penalty,
"ripple penalty applied"
);
queue.push_back((dep_id, next_depth));
}
}
Ok(())
}
pub fn add_dependency(
conn: &Connection,
source_id: i64,
target_id: i64,
dep_type: &str,
) -> Result<()> {
conn.execute(
"INSERT INTO kg_dependencies (source_id, target_id, dep_type, created_at) \
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![source_id, target_id, dep_type, now_unix()],
)?;
Ok(())
}
fn dependents_of(conn: &Connection, target_id: i64) -> Result<Vec<i64>> {
let mut stmt = conn.prepare(
"SELECT source_id FROM kg_dependencies \
WHERE target_id = ?1 AND dep_type = 'depends_on'",
)?;
let ids = stmt
.query_map([target_id], |r| r.get::<_, i64>(0))?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(ids)
}
fn apply_penalty(conn: &Connection, entity_id: i64, penalty: f64) -> Result<()> {
let old_conf: f64 = conn.query_row(
"SELECT COALESCE(confidence, 1.0) FROM kg_entities WHERE id = ?1",
[entity_id],
|r| r.get(0),
)?;
let raw_conf = old_conf - penalty;
let new_conf = raw_conf.clamp(0.0, 1.0);
conn.execute(
"UPDATE kg_entities SET confidence = ?1 WHERE id = ?2",
rusqlite::params![new_conf, entity_id],
)?;
conn.execute(
"INSERT INTO kg_confidence_log \
(entity_id, old_value, new_value, reason, created_at) \
VALUES (?1, ?2, ?3, 'ripple', ?4)",
rusqlite::params![entity_id, old_conf, new_conf, now_unix()],
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::ensure_schema;
fn setup() -> Connection {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
conn
}
fn insert_entity(conn: &Connection, name: &str) -> i64 {
conn.execute(
"INSERT INTO kg_entities (entity_type, name, confidence) VALUES ('t', ?1, 1.0)",
[name],
)
.unwrap();
conn.last_insert_rowid()
}
fn get_confidence(conn: &Connection, id: i64) -> f64 {
conn.query_row(
"SELECT confidence FROM kg_entities WHERE id = ?1",
[id],
|r| r.get(0),
)
.unwrap()
}
#[test]
fn add_dependency_inserts_row() {
let conn = setup();
let a = insert_entity(&conn, "A");
let b = insert_entity(&conn, "B");
add_dependency(&conn, a, b, "depends_on").unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM kg_dependencies", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn ripple_penalises_direct_dependent() {
let conn = setup();
let b = insert_entity(&conn, "B");
let a = insert_entity(&conn, "A");
add_dependency(&conn, a, b, "depends_on").unwrap();
propagate(&conn, b, 0.4).unwrap();
let conf_a = get_confidence(&conn, a);
assert!((conf_a - 0.8).abs() < 1e-9, "expected 0.8, got {conf_a}");
}
#[test]
fn ripple_attenuates_at_depth_two() {
let conn = setup();
let c_node = insert_entity(&conn, "C"); let b = insert_entity(&conn, "B"); let a = insert_entity(&conn, "A");
add_dependency(&conn, b, a, "depends_on").unwrap();
add_dependency(&conn, c_node, b, "depends_on").unwrap();
propagate(&conn, a, 0.4).unwrap();
let conf_b = get_confidence(&conn, b);
let conf_c = get_confidence(&conn, c_node);
assert!((conf_b - 0.8).abs() < 1e-9, "expected 0.8, got {conf_b}");
assert!((conf_c - 0.9).abs() < 1e-9, "expected 0.9, got {conf_c}");
}
#[test]
fn ripple_logs_changes() {
let conn = setup();
let b = insert_entity(&conn, "B");
let a = insert_entity(&conn, "A");
add_dependency(&conn, a, b, "depends_on").unwrap();
propagate(&conn, b, 0.2).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM kg_confidence_log WHERE reason = 'ripple'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn ripple_stops_at_max_depth() {
let conn = setup();
let d = insert_entity(&conn, "D");
let c_node = insert_entity(&conn, "C");
let b = insert_entity(&conn, "B");
let a = insert_entity(&conn, "A");
add_dependency(&conn, b, a, "depends_on").unwrap();
add_dependency(&conn, c_node, b, "depends_on").unwrap();
add_dependency(&conn, d, c_node, "depends_on").unwrap();
propagate(&conn, a, 0.4).unwrap();
let conf_d = get_confidence(&conn, d);
assert!((conf_d - 1.0).abs() < 1e-9, "D should be unaffected");
}
#[test]
fn apply_penalty_clamps_to_zero_when_penalty_exceeds_confidence() {
let conn = setup();
let a = insert_entity(&conn, "A");
apply_penalty(&conn, a, 1.5).unwrap();
let conf = get_confidence(&conn, a);
assert!(conf >= 0.0, "confidence must not be negative, got {conf}");
assert!((conf - 0.0).abs() < 1e-9, "expected 0.0, got {conf}");
}
}