use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use crate::memory::chunks::approx_token_count;
use crate::memory::tree::store::TreeKind;
#[derive(Clone, Debug)]
pub struct SummaryInput {
pub id: String,
pub content: String,
pub token_count: u32,
pub entities: Vec<String>,
pub topics: Vec<String>,
pub time_range_start: DateTime<Utc>,
pub time_range_end: DateTime<Utc>,
pub score: f32,
}
#[derive(Clone, Debug)]
pub struct SummaryContext<'a> {
pub tree_id: &'a str,
pub tree_kind: TreeKind,
pub target_level: u32,
pub token_budget: u32,
}
#[derive(Clone, Debug, Default)]
pub struct SummaryOutput {
pub content: String,
pub token_count: u32,
pub entities: Vec<String>,
pub topics: Vec<String>,
}
#[async_trait]
pub trait Summariser: Send + Sync {
fn name(&self) -> &str {
"summariser"
}
async fn summarise(
&self,
inputs: &[SummaryInput],
ctx: &SummaryContext<'_>,
) -> Result<SummaryOutput>;
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ConcatSummariser;
impl ConcatSummariser {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Summariser for ConcatSummariser {
fn name(&self) -> &str {
"concat"
}
async fn summarise(
&self,
inputs: &[SummaryInput],
ctx: &SummaryContext<'_>,
) -> Result<SummaryOutput> {
Ok(fallback_summary(inputs, ctx.token_budget))
}
}
pub fn fallback_summary(inputs: &[SummaryInput], budget: u32) -> SummaryOutput {
const PROVENANCE_PREFIX: &str = "— ";
let mut order: Vec<&SummaryInput> = inputs.iter().collect();
order.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut parts: Vec<String> = Vec::with_capacity(order.len());
for inp in order {
let trimmed = inp.content.trim();
if trimmed.is_empty() {
continue;
}
parts.push(format!("{PROVENANCE_PREFIX}{trimmed}"));
}
let joined = parts.join("\n\n");
let (content, token_count) = clamp_to_budget(&joined, budget);
SummaryOutput {
content,
token_count,
entities: Vec::new(),
topics: Vec::new(),
}
}
pub fn clamp_to_budget(text: &str, budget: u32) -> (String, u32) {
let initial = approx_token_count(text);
if initial <= budget {
return (text.to_string(), initial);
}
let char_ceiling = (budget as usize).saturating_mul(4);
let truncated: String = text.chars().take(char_ceiling).collect();
let tokens = approx_token_count(&truncated);
(truncated, tokens)
}
#[cfg(test)]
#[path = "summarise_tests.rs"]
mod tests;