use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::ir::Message;
use entelix_core::{ExecutionContext, Result};
use crate::buffer::BufferMemory;
use crate::summary::SummaryMemory;
#[async_trait]
pub trait Summarizer: Send + Sync + 'static {
async fn summarize(&self, messages: Vec<Message>, ctx: &ExecutionContext) -> Result<String>;
}
pub struct ConsolidatingBufferMemory {
buffer: Arc<BufferMemory>,
summary: Arc<SummaryMemory>,
summarizer: Arc<dyn Summarizer>,
}
impl ConsolidatingBufferMemory {
pub fn new(
buffer: Arc<BufferMemory>,
summary: Arc<SummaryMemory>,
summarizer: Arc<dyn Summarizer>,
) -> Self {
Self {
buffer,
summary,
summarizer,
}
}
pub const fn buffer(&self) -> &Arc<BufferMemory> {
&self.buffer
}
pub const fn summary(&self) -> &Arc<SummaryMemory> {
&self.summary
}
pub async fn append(&self, ctx: &ExecutionContext, message: Message) -> Result<()> {
self.buffer.append(ctx, message).await?;
if !self.buffer.should_consolidate(ctx).await? {
return Ok(());
}
let messages = self.buffer.messages(ctx).await?;
let summary_text = self.summarizer.summarize(messages, ctx).await?;
self.summary.append(ctx, &summary_text).await?;
self.buffer.clear(ctx).await?;
self.buffer.mark_consolidated_now();
Ok(())
}
pub async fn messages(&self, ctx: &ExecutionContext) -> Result<Vec<Message>> {
self.buffer.messages(ctx).await
}
pub async fn current_summary(&self, ctx: &ExecutionContext) -> Result<Option<String>> {
self.summary.get(ctx).await
}
pub async fn clear(&self, ctx: &ExecutionContext) -> Result<()> {
self.buffer.clear(ctx).await?;
self.summary.clear(ctx).await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::consolidation::{ConsolidationPolicy, OnMessageCount};
use crate::namespace::Namespace;
use crate::store::InMemoryStore;
use entelix_core::TenantId;
use std::sync::atomic::{AtomicUsize, Ordering};
struct StubSummarizer {
calls: Arc<AtomicUsize>,
reply: Result<String>,
}
impl StubSummarizer {
fn ok(reply: &str) -> (Self, Arc<AtomicUsize>) {
let calls = Arc::new(AtomicUsize::new(0));
(
Self {
calls: calls.clone(),
reply: Ok(reply.to_owned()),
},
calls,
)
}
fn err(msg: &str) -> Self {
Self {
calls: Arc::new(AtomicUsize::new(0)),
reply: Err(entelix_core::Error::config(msg.to_owned())),
}
}
}
#[async_trait]
impl Summarizer for StubSummarizer {
async fn summarize(
&self,
_messages: Vec<Message>,
_ctx: &ExecutionContext,
) -> Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
match &self.reply {
Ok(s) => Ok(s.clone()),
Err(e) => Err(clone_error(e)),
}
}
}
fn clone_error(e: &entelix_core::Error) -> entelix_core::Error {
match e {
entelix_core::Error::Config(c) => entelix_core::Error::config(c.to_string()),
other => entelix_core::Error::config(format!("{other}")),
}
}
fn make_buffer(max_turns: usize, policy: Arc<dyn ConsolidationPolicy>) -> Arc<BufferMemory> {
Arc::new(
BufferMemory::new(
Arc::new(InMemoryStore::<Vec<Message>>::new()),
Namespace::new(TenantId::new("t")).with_scope("conv"),
max_turns,
)
.with_consolidation_policy(policy),
)
}
fn make_summary() -> Arc<SummaryMemory> {
Arc::new(SummaryMemory::new(
Arc::new(InMemoryStore::<String>::new()),
Namespace::new(TenantId::new("t")).with_scope("conv"),
))
}
#[tokio::test]
async fn append_does_not_consolidate_below_threshold() {
let buf = make_buffer(10, Arc::new(OnMessageCount::new(5)));
let sum = make_summary();
let (summariser, calls) = StubSummarizer::ok("summary");
let mem = ConsolidatingBufferMemory::new(buf, sum.clone(), Arc::new(summariser));
let ctx = ExecutionContext::new();
for i in 0..3 {
mem.append(&ctx, Message::user(format!("m{i}")))
.await
.unwrap();
}
assert_eq!(calls.load(Ordering::SeqCst), 0);
assert_eq!(mem.messages(&ctx).await.unwrap().len(), 3);
assert!(mem.current_summary(&ctx).await.unwrap().is_none());
}
#[tokio::test]
async fn append_consolidates_when_threshold_reached() {
let buf = make_buffer(10, Arc::new(OnMessageCount::new(3)));
let sum = make_summary();
let (summariser, calls) = StubSummarizer::ok("compressed");
let mem = ConsolidatingBufferMemory::new(
Arc::clone(&buf),
Arc::clone(&sum),
Arc::new(summariser),
);
let ctx = ExecutionContext::new();
for i in 0..3 {
mem.append(&ctx, Message::user(format!("m{i}")))
.await
.unwrap();
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(mem.messages(&ctx).await.unwrap().len(), 0);
let summary = mem.current_summary(&ctx).await.unwrap().unwrap();
assert_eq!(summary, "compressed");
assert!(buf.last_consolidated_at().is_some());
}
#[tokio::test]
async fn summariser_failure_preserves_buffer() {
let buf = make_buffer(10, Arc::new(OnMessageCount::new(2)));
let sum = make_summary();
let summariser = StubSummarizer::err("summariser down");
let mem = ConsolidatingBufferMemory::new(
Arc::clone(&buf),
Arc::clone(&sum),
Arc::new(summariser),
);
let ctx = ExecutionContext::new();
mem.append(&ctx, Message::user("a")).await.unwrap();
let err = mem.append(&ctx, Message::user("b")).await.unwrap_err();
assert!(matches!(err, entelix_core::Error::Config(_)));
assert_eq!(mem.messages(&ctx).await.unwrap().len(), 2);
assert!(mem.current_summary(&ctx).await.unwrap().is_none());
assert!(buf.last_consolidated_at().is_none());
}
}