use std::collections::HashMap;
use anyhow::{Context, Result};
use rusqlite::{params, Connection, OptionalExtension, Transaction};
use serde::{Deserialize, Serialize};
use crate::memory::chunks::with_connection;
use crate::memory::config::MemoryConfig;
use crate::memory::score::extract::EntityKind;
use crate::memory::score::resolver::CanonicalEntity;
use crate::memory::score::signals::ScoreSignals;
fn entity_is_user(_entity: &CanonicalEntity) -> bool {
false
}
fn canonical_id_is_user(_canonical_id: &str) -> bool {
false
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ScoreRow {
pub chunk_id: String,
pub total: f32,
pub signals: ScoreSignals,
pub dropped: bool,
pub reason: Option<String>,
pub computed_at_ms: i64,
#[serde(default)]
pub llm_importance_reason: Option<String>,
}
pub fn upsert_score(config: &MemoryConfig, row: &ScoreRow) -> Result<()> {
with_connection(config, |conn| {
upsert_score_on_connection(conn, row)?;
Ok(())
})
}
pub fn upsert_score_tx(tx: &Transaction<'_>, row: &ScoreRow) -> Result<()> {
tx.execute(
SCORE_UPSERT_SQL,
params![
row.chunk_id,
row.total,
row.signals.token_count,
row.signals.unique_words,
row.signals.metadata_weight,
row.signals.source_weight,
row.signals.interaction,
row.signals.entity_density,
i32::from(row.dropped),
row.reason,
row.computed_at_ms,
],
)?;
Ok(())
}
const SCORE_UPSERT_SQL: &str = "INSERT OR REPLACE INTO mem_tree_score (
chunk_id, total,
token_count_signal, unique_words_signal,
metadata_weight, source_weight, interaction_weight, entity_density,
dropped, reason, computed_at_ms
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)";
fn upsert_score_on_connection(conn: &Connection, row: &ScoreRow) -> Result<()> {
conn.execute(
SCORE_UPSERT_SQL,
params![
row.chunk_id,
row.total,
row.signals.token_count,
row.signals.unique_words,
row.signals.metadata_weight,
row.signals.source_weight,
row.signals.interaction,
row.signals.entity_density,
i32::from(row.dropped),
row.reason,
row.computed_at_ms,
],
)?;
Ok(())
}
pub fn get_score(config: &MemoryConfig, chunk_id: &str) -> Result<Option<ScoreRow>> {
with_connection(config, |conn| {
conn.query_row(
"SELECT chunk_id, total,
token_count_signal, unique_words_signal,
metadata_weight, source_weight, interaction_weight, entity_density,
dropped, reason, computed_at_ms
FROM mem_tree_score WHERE chunk_id = ?1",
params![chunk_id],
|row| {
Ok(ScoreRow {
chunk_id: row.get(0)?,
total: row.get(1)?,
signals: ScoreSignals {
token_count: row.get(2)?,
unique_words: row.get(3)?,
metadata_weight: row.get(4)?,
source_weight: row.get(5)?,
interaction: row.get(6)?,
entity_density: row.get(7)?,
llm_importance: 0.0,
},
dropped: row.get::<_, i32>(8)? != 0,
reason: row.get(9)?,
computed_at_ms: row.get(10)?,
llm_importance_reason: None,
})
},
)
.optional()
.map_err(anyhow::Error::from)
})
}
const MAX_FETCH_BATCH: usize = 500;
pub fn get_scores_batch(
config: &MemoryConfig,
chunk_ids: &[String],
) -> Result<HashMap<String, f32>> {
if chunk_ids.is_empty() {
return Ok(HashMap::new());
}
with_connection(config, |conn| {
let mut out: HashMap<String, f32> = HashMap::with_capacity(chunk_ids.len());
for window in chunk_ids.chunks(MAX_FETCH_BATCH) {
let placeholders = (1..=window.len())
.map(|i| format!("?{i}"))
.collect::<Vec<_>>()
.join(",");
let sql = format!(
"SELECT chunk_id, total FROM mem_tree_score
WHERE chunk_id IN ({placeholders})"
);
let mut stmt = conn.prepare(&sql).context("prepare get_scores_batch")?;
let params: Vec<&dyn rusqlite::ToSql> =
window.iter().map(|id| id as &dyn rusqlite::ToSql).collect();
let rows = stmt
.query_map(params.as_slice(), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, f32>(1)?))
})
.context("query get_scores_batch")?;
for row in rows {
let (chunk_id, total) = row.context("decode get_scores_batch row")?;
out.insert(chunk_id, total);
}
}
Ok(out)
})
}
const ENTITY_INDEX_UPSERT_SQL: &str = "INSERT OR REPLACE INTO mem_tree_entity_index (
entity_id, node_id, node_kind, entity_kind, surface,
score, timestamp_ms, tree_id, is_user
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)";
pub fn index_entity(
config: &MemoryConfig,
entity: &CanonicalEntity,
node_id: &str,
node_kind: &str,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<()> {
let is_user = entity_is_user(entity);
with_connection(config, |conn| {
conn.execute(
ENTITY_INDEX_UPSERT_SQL,
params![
entity.canonical_id,
node_id,
node_kind,
entity.kind.as_str(),
entity.surface,
entity.score,
timestamp_ms,
tree_id,
is_user as i32,
],
)?;
Ok(())
})
}
pub fn index_entities(
config: &MemoryConfig,
entities: &[CanonicalEntity],
node_id: &str,
node_kind: &str,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<usize> {
if entities.is_empty() {
return Ok(0);
}
with_connection(config, |conn| {
let tx = conn.unchecked_transaction()?;
{
let mut stmt = tx.prepare(ENTITY_INDEX_UPSERT_SQL)?;
for e in entities {
stmt.execute(params![
e.canonical_id,
node_id,
node_kind,
e.kind.as_str(),
e.surface,
e.score,
timestamp_ms,
tree_id,
entity_is_user(e) as i32,
])?;
}
}
tx.commit()?;
Ok(entities.len())
})
}
pub fn clear_entity_index_for_node(config: &MemoryConfig, node_id: &str) -> Result<usize> {
with_connection(config, |conn| {
let n = conn.execute(
"DELETE FROM mem_tree_entity_index WHERE node_id = ?1",
params![node_id],
)?;
Ok(n)
})
}
pub fn clear_entity_index_for_node_tx(tx: &Transaction<'_>, node_id: &str) -> Result<usize> {
let n = tx.execute(
"DELETE FROM mem_tree_entity_index WHERE node_id = ?1",
params![node_id],
)?;
Ok(n)
}
pub fn index_summary_entity_ids_tx(
tx: &Transaction<'_>,
entity_ids: &[String],
node_id: &str,
score: f32,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<usize> {
if entity_ids.is_empty() {
return Ok(0);
}
let mut stmt = tx.prepare(ENTITY_INDEX_UPSERT_SQL)?;
for canonical_id in entity_ids {
let entity_kind = match canonical_id.split_once(':') {
Some((kind, _)) => kind,
None => canonical_id.as_str(),
};
stmt.execute(params![
canonical_id,
node_id,
"summary",
entity_kind,
canonical_id,
score,
timestamp_ms,
tree_id,
canonical_id_is_user(canonical_id) as i32,
])?;
}
Ok(entity_ids.len())
}
pub fn index_entities_tx(
tx: &Transaction<'_>,
entities: &[CanonicalEntity],
node_id: &str,
node_kind: &str,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<usize> {
if entities.is_empty() {
return Ok(0);
}
let mut stmt = tx.prepare(ENTITY_INDEX_UPSERT_SQL)?;
for e in entities {
stmt.execute(params![
e.canonical_id,
node_id,
node_kind,
e.kind.as_str(),
e.surface,
e.score,
timestamp_ms,
tree_id,
entity_is_user(e) as i32,
])?;
}
Ok(entities.len())
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EntityHit {
pub entity_id: String,
pub node_id: String,
pub node_kind: String,
pub entity_kind: EntityKind,
pub surface: String,
pub score: f32,
pub timestamp_ms: i64,
pub tree_id: Option<String>,
#[serde(default)]
pub is_user: bool,
}
pub fn lookup_entity(
config: &MemoryConfig,
entity_id: &str,
limit: Option<usize>,
) -> Result<Vec<EntityHit>> {
let limit = limit.unwrap_or(100).min(i64::MAX as usize) as i64;
with_connection(config, |conn| {
let mut stmt = conn.prepare(
"SELECT entity_id, node_id, node_kind, entity_kind, surface,
score, timestamp_ms, tree_id, is_user
FROM mem_tree_entity_index
WHERE entity_id = ?1
ORDER BY timestamp_ms DESC
LIMIT ?2",
)?;
let rows = stmt
.query_map(params![entity_id, limit], |row| {
let kind_s: String = row.get(3)?;
let entity_kind = EntityKind::parse(&kind_s).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
3,
rusqlite::types::Type::Text,
e.into(),
)
})?;
let is_user_int: i32 = row.get(8)?;
Ok(EntityHit {
entity_id: row.get(0)?,
node_id: row.get(1)?,
node_kind: row.get(2)?,
entity_kind,
surface: row.get(4)?,
score: row.get(5)?,
timestamp_ms: row.get(6)?,
tree_id: row.get(7)?,
is_user: is_user_int != 0,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
})
}
pub fn list_entity_ids_for_node(config: &MemoryConfig, node_id: &str) -> Result<Vec<String>> {
with_connection(config, |conn| {
let mut stmt = conn.prepare(
"SELECT DISTINCT entity_id
FROM mem_tree_entity_index
WHERE node_id = ?1
ORDER BY score DESC, timestamp_ms DESC, entity_id ASC",
)?;
let rows = stmt
.query_map(params![node_id], |row| row.get::<_, String>(0))?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
})
}
pub fn count_entity_index(config: &MemoryConfig) -> Result<u64> {
with_connection(config, |conn| {
let n: i64 = conn.query_row("SELECT COUNT(*) FROM mem_tree_entity_index", [], |r| {
r.get(0)
})?;
Ok(n.max(0) as u64)
})
}
pub fn count_scores(config: &MemoryConfig) -> Result<u64> {
with_connection(config, |conn| {
let n: i64 = conn.query_row("SELECT COUNT(*) FROM mem_tree_score", [], |r| r.get(0))?;
Ok(n.max(0) as u64)
})
}
#[cfg(test)]
#[path = "store_tests.rs"]
mod tests;