use std::collections::BTreeSet;
use std::sync::Arc;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use crate::memory::chunks::with_connection;
use crate::memory::config::MemoryConfig;
use crate::memory::score::extract::EntityExtractor;
use crate::memory::score::resolver::canonicalise;
use crate::memory::score::store::index_summary_entity_ids_tx;
use crate::memory::tree::hydrate::hydrate_inputs;
use crate::memory::tree::registry::new_summary_id;
use crate::memory::tree::store::{self, Buffer, SummaryNode, Tree};
use crate::memory::tree::summarise::{fallback_summary, Summariser, SummaryContext, SummaryInput};
const MAX_CASCADE_DEPTH: u32 = 32;
#[derive(Clone)]
pub enum LabelStrategy {
ExtractFromContent(Arc<dyn EntityExtractor>),
UnionFromChildren,
Empty,
}
impl std::fmt::Debug for LabelStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ExtractFromContent(ex) => write!(f, "ExtractFromContent({})", ex.name()),
Self::UnionFromChildren => f.write_str("UnionFromChildren"),
Self::Empty => f.write_str("Empty"),
}
}
}
async fn resolve_labels(
strategy: &LabelStrategy,
inputs: &[SummaryInput],
summary_content: &str,
) -> Result<(Vec<String>, Vec<String>)> {
match strategy {
LabelStrategy::ExtractFromContent(extractor) => {
let extracted = extractor
.extract(summary_content)
.await
.context("seal-time extractor failed")?;
let canonical = canonicalise(&extracted);
let mut entities: Vec<String> = canonical
.into_iter()
.map(|c| c.canonical_id)
.collect::<BTreeSet<_>>()
.into_iter()
.collect();
entities.sort();
let mut topics: Vec<String> = extracted
.topics
.into_iter()
.map(|t| t.label)
.collect::<BTreeSet<_>>()
.into_iter()
.collect();
topics.sort();
Ok((entities, topics))
}
LabelStrategy::UnionFromChildren => {
let mut entities: BTreeSet<String> = BTreeSet::new();
let mut topics: BTreeSet<String> = BTreeSet::new();
for inp in inputs {
entities.extend(inp.entities.iter().cloned());
topics.extend(inp.topics.iter().cloned());
}
Ok((entities.into_iter().collect(), topics.into_iter().collect()))
}
LabelStrategy::Empty => Ok((Vec::new(), Vec::new())),
}
}
#[derive(Clone, Debug)]
pub struct LeafRef {
pub chunk_id: String,
pub token_count: u32,
pub timestamp: DateTime<Utc>,
pub content: String,
pub entities: Vec<String>,
pub topics: Vec<String>,
pub score: f32,
}
pub async fn append_leaf(
config: &MemoryConfig,
tree: &Tree,
leaf: &LeafRef,
summariser: &dyn Summariser,
strategy: &LabelStrategy,
) -> Result<Vec<String>> {
append_to_buffer(
config,
&tree.id,
0,
&leaf.chunk_id,
leaf.token_count as i64,
leaf.timestamp,
)?;
cascade_all_from(config, tree, 0, None, summariser, strategy).await
}
pub fn append_leaf_deferred(config: &MemoryConfig, tree: &Tree, leaf: &LeafRef) -> Result<bool> {
append_to_buffer(
config,
&tree.id,
0,
&leaf.chunk_id,
leaf.token_count as i64,
leaf.timestamp,
)?;
let buf = store::get_buffer(config, &tree.id, 0)?;
Ok(should_seal(config, &buf))
}
pub fn append_to_buffer(
config: &MemoryConfig,
tree_id: &str,
level: u32,
item_id: &str,
token_delta: i64,
item_ts: DateTime<Utc>,
) -> Result<()> {
with_connection(config, |conn| {
let tx = conn.unchecked_transaction()?;
let mut buf = store::get_buffer_conn(&tx, tree_id, level)?;
if buf.item_ids.iter().any(|existing| existing == item_id) {
return Ok(()); }
buf.item_ids.push(item_id.to_string());
buf.token_sum = buf.token_sum.saturating_add(token_delta);
buf.oldest_at = match buf.oldest_at {
Some(existing) => Some(existing.min(item_ts)),
None => Some(item_ts),
};
store::upsert_buffer_tx(&tx, &buf)?;
tx.commit()?;
Ok(())
})
}
pub async fn cascade_all_from(
config: &MemoryConfig,
tree: &Tree,
start_level: u32,
force_now: Option<DateTime<Utc>>,
summariser: &dyn Summariser,
strategy: &LabelStrategy,
) -> Result<Vec<String>> {
let mut sealed_ids: Vec<String> = Vec::new();
let mut level: u32 = start_level;
let mut first_iteration = true;
for _ in 0..MAX_CASCADE_DEPTH {
let buf = store::get_buffer(config, &tree.id, level)?;
let forced = first_iteration && force_now.is_some();
first_iteration = false;
if !forced && !should_seal(config, &buf) {
break;
}
if buf.is_empty() {
break;
}
let summary_id = seal_one_level(config, tree, &buf, summariser, strategy).await?;
sealed_ids.push(summary_id);
level += 1;
}
Ok(sealed_ids)
}
pub(crate) fn should_seal(config: &MemoryConfig, buf: &Buffer) -> bool {
if buf.is_empty() {
return false;
}
if buf.level == 0 {
buf.token_sum >= config.tree.input_token_budget as i64
} else {
(buf.item_ids.len() as u32) >= config.tree.summary_fanout
}
}
pub(crate) async fn seal_one_level(
config: &MemoryConfig,
tree: &Tree,
buf: &Buffer,
summariser: &dyn Summariser,
strategy: &LabelStrategy,
) -> Result<String> {
let level = buf.level;
let target_level = level + 1;
let inputs = hydrate_inputs(config, level, &buf.item_ids)?;
if inputs.is_empty() {
anyhow::bail!(
"refused to seal empty buffer tree_id={} level={}",
tree.id,
level
);
}
let time_range_start = inputs
.iter()
.map(|i| i.time_range_start)
.min()
.unwrap_or_else(Utc::now);
let time_range_end = inputs
.iter()
.map(|i| i.time_range_end)
.max()
.unwrap_or_else(Utc::now);
let score = inputs
.iter()
.map(|i| i.score)
.fold(f32::NEG_INFINITY, f32::max)
.max(0.0);
let budget = config.tree.output_token_budget;
let ctx = SummaryContext {
tree_id: &tree.id,
tree_kind: tree.kind,
target_level,
token_budget: budget,
};
let output = match summariser.summarise(&inputs, &ctx).await {
Ok(o) if !o.content.trim().is_empty() => o,
_ => fallback_summary(&inputs, budget),
};
let (node_entities, node_topics) = resolve_labels(strategy, &inputs, &output.content).await?;
let now = Utc::now();
let summary_id = new_summary_id(target_level);
let node = SummaryNode {
id: summary_id.clone(),
tree_id: tree.id.clone(),
tree_kind: tree.kind,
level: target_level,
parent_id: None,
child_ids: buf.item_ids.clone(),
content: output.content,
token_count: output.token_count,
entities: node_entities,
topics: node_topics,
time_range_start,
time_range_end,
score,
sealed_at: now,
deleted: false,
embedding: None,
doc_id: None,
version_ms: None,
};
let signature = crate::memory::chunks::tree_active_signature(config);
let tree_id = tree.id.clone();
let summary_id_for_tx = summary_id.clone();
with_connection(config, move |conn| {
let tx = conn.unchecked_transaction()?;
let current_max: u32 = tx
.query_row(
"SELECT max_level FROM mem_tree_trees WHERE id = ?1",
rusqlite::params![&tree_id],
|r| r.get::<_, i64>(0),
)
.map(|n| n.max(0) as u32)
.context("Failed to read current max_level for tree")?;
store::insert_summary_tx(&tx, &node, &signature)?;
index_summary_entity_ids_tx(
&tx,
&node.entities,
&node.id,
node.score,
now.timestamp_millis(),
Some(&tree_id),
)?;
for child_id in &node.child_ids {
if level == 0 {
tx.execute(
"UPDATE mem_tree_chunks SET parent_summary_id = ?1
WHERE id = ?2 AND parent_summary_id IS NULL",
rusqlite::params![&summary_id_for_tx, child_id],
)
.context("Failed to backlink chunk to parent summary")?;
} else {
tx.execute(
"UPDATE mem_tree_summaries SET parent_id = ?1
WHERE id = ?2 AND parent_id IS NULL",
rusqlite::params![&summary_id_for_tx, child_id],
)
.context("Failed to backlink summary to parent summary")?;
}
}
store::clear_buffer_tx(&tx, &tree_id, level)?;
let mut parent = store::get_buffer_conn(&tx, &tree_id, target_level)?;
parent.item_ids.push(summary_id_for_tx.clone());
parent.token_sum = parent.token_sum.saturating_add(node.token_count as i64);
parent.oldest_at = match parent.oldest_at {
Some(existing) => Some(existing.min(time_range_start)),
None => Some(time_range_start),
};
store::upsert_buffer_tx(&tx, &parent)?;
if target_level > current_max {
store::update_tree_after_seal_tx(&tx, &tree_id, &summary_id_for_tx, target_level, now)?;
} else {
store::refresh_last_sealed_tx(&tx, &tree_id, now)?;
}
tx.commit()?;
Ok(())
})?;
Ok(summary_id)
}
#[cfg(test)]
#[path = "bucket_seal_label_tests.rs"]
mod label_tests;
#[cfg(test)]
#[path = "bucket_seal_tests.rs"]
mod tests;