use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
use crate::middleware::{AgentState, Middleware, Result};
const DEFAULT_SUMMARY_PROMPT: &str =
"Summarize the following conversation concisely, preserving all key information, decisions, and context needed to continue the conversation:\n\n";
pub struct SummarizationMiddleware {
model: Arc<dyn BaseChatModel>,
pub max_messages: usize,
pub keep_recent: usize,
pub summary_prompt: String,
}
impl SummarizationMiddleware {
pub fn new(model: Arc<dyn BaseChatModel>) -> Self {
Self {
model,
max_messages: 20,
keep_recent: 5,
summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
}
}
pub fn with_max_messages(mut self, max_messages: usize) -> Self {
self.max_messages = max_messages;
self
}
pub fn with_keep_recent(mut self, keep_recent: usize) -> Self {
self.keep_recent = keep_recent;
self
}
pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
self.summary_prompt = prompt.into();
self
}
}
#[async_trait]
impl Middleware for SummarizationMiddleware {
fn name(&self) -> &str {
"summarization"
}
async fn before_model(&self, state: &mut AgentState) -> Result<()> {
let messages = match state.get("messages").and_then(|v| v.as_array()) {
Some(arr) => arr.clone(),
None => return Ok(()),
};
let msg_count = messages.len();
if msg_count <= self.max_messages {
return Ok(());
}
let split_point = msg_count.saturating_sub(self.keep_recent);
let old_messages = &messages[..split_point];
let recent_messages = &messages[split_point..];
let mut conversation_text = String::new();
for msg_value in old_messages {
let role = msg_value
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let content = msg_value
.get("content")
.and_then(|v| v.as_str())
.unwrap_or("");
conversation_text.push_str(&format!("{role}: {content}\n"));
}
let prompt_text = format!("{}{}", self.summary_prompt, conversation_text);
let summarization_messages = vec![Message::Human(HumanMessage::new(&prompt_text))];
let summary_result = self
.model
._generate(&summarization_messages, None)
.await
.map_err(|e| {
crate::agent::DeepAgentError::MiddlewareError(format!(
"Summarization model call failed: {e}"
))
})?;
let summary_text = summary_result
.generations
.first()
.map(|g| g.message.content().text())
.unwrap_or_default();
let summary_msg = json!({
"type": "system",
"content": format!("## Conversation Summary\n{summary_text}")
});
let mut new_messages = vec![summary_msg];
new_messages.extend(recent_messages.iter().cloned());
state["messages"] = Value::Array(new_messages);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::FakeMessagesListChatModel;
use cognis_core::messages::AIMessage;
fn make_state(n: usize) -> Value {
let messages: Vec<Value> = (0..n)
.map(|i| {
json!({
"type": "human",
"content": format!("Message {i}")
})
})
.collect();
json!({ "messages": messages })
}
#[tokio::test]
async fn test_messages_under_limit_not_summarized() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("This should not be called"),
)]));
let mw = SummarizationMiddleware::new(model)
.with_max_messages(10)
.with_keep_recent(3);
let mut state = make_state(5); mw.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 5); }
#[tokio::test]
async fn test_messages_over_limit_get_summarized() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("Here is the summary of the conversation."),
)]));
let mw = SummarizationMiddleware::new(model)
.with_max_messages(5)
.with_keep_recent(2);
let mut state = make_state(8); mw.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 3);
let summary = &messages[0];
assert_eq!(summary["type"].as_str().unwrap(), "system");
assert!(summary["content"]
.as_str()
.unwrap()
.contains("summary of the conversation"));
assert_eq!(messages[1]["content"].as_str().unwrap(), "Message 6");
assert_eq!(messages[2]["content"].as_str().unwrap(), "Message 7");
}
#[tokio::test]
async fn test_recent_messages_preserved_intact() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("Summary content"),
)]));
let mw = SummarizationMiddleware::new(model)
.with_max_messages(3)
.with_keep_recent(2);
let mut state = json!({
"messages": [
{"type": "human", "content": "old1"},
{"type": "human", "content": "old2"},
{"type": "human", "content": "old3"},
{"type": "human", "content": "recent1"},
{"type": "human", "content": "recent2"}
]
});
mw.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[1]["content"].as_str().unwrap(), "recent1");
assert_eq!(messages[2]["content"].as_str().unwrap(), "recent2");
}
#[tokio::test]
async fn test_summarization_middleware_name() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("x"),
)]));
let mw = SummarizationMiddleware::new(model);
assert_eq!(mw.name(), "summarization");
}
#[tokio::test]
async fn test_no_messages_key_is_noop() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("x"),
)]));
let mw = SummarizationMiddleware::new(model);
let mut state = json!({});
mw.before_model(&mut state).await.unwrap();
assert!(state.get("messages").is_none());
}
#[tokio::test]
async fn test_exact_limit_not_summarized() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("should not be called"),
)]));
let mw = SummarizationMiddleware::new(model)
.with_max_messages(5)
.with_keep_recent(2);
let mut state = make_state(5); mw.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 5); }
#[tokio::test]
async fn test_custom_summary_prompt() {
let model = Arc::new(FakeMessagesListChatModel::new(vec![Message::Ai(
AIMessage::new("Custom summary"),
)]));
let mw = SummarizationMiddleware::new(model)
.with_max_messages(2)
.with_keep_recent(1)
.with_summary_prompt("Please provide a brief summary:\n");
let mut state = make_state(4);
mw.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 2);
}
}