synaptic_middleware/
summarization.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError};
5
6use crate::{AgentMiddleware, ModelRequest};
7
8pub struct SummarizationMiddleware {
15 model: Arc<dyn ChatModel>,
16 max_tokens: usize,
17 token_counter: Box<dyn Fn(&Message) -> usize + Send + Sync>,
18}
19
20impl SummarizationMiddleware {
21 pub fn new(
27 model: Arc<dyn ChatModel>,
28 max_tokens: usize,
29 token_counter: impl Fn(&Message) -> usize + Send + Sync + 'static,
30 ) -> Self {
31 Self {
32 model,
33 max_tokens,
34 token_counter: Box::new(token_counter),
35 }
36 }
37}
38
39#[async_trait]
40impl AgentMiddleware for SummarizationMiddleware {
41 async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
42 let total: usize = request
43 .messages
44 .iter()
45 .map(|m| (self.token_counter)(m))
46 .sum();
47 if total <= self.max_tokens {
48 return Ok(());
49 }
50
51 let half_budget = self.max_tokens / 2;
54 let mut keep_from = request.messages.len();
55 let mut kept_tokens = 0;
56 for (i, msg) in request.messages.iter().enumerate().rev() {
57 let t = (self.token_counter)(msg);
58 if kept_tokens + t > half_budget {
59 break;
60 }
61 kept_tokens += t;
62 keep_from = i;
63 }
64
65 if keep_from == 0 {
66 return Ok(());
68 }
69
70 let to_summarize: Vec<_> = request.messages[..keep_from].to_vec();
71
72 let summary_prompt = Message::human(
74 "Summarize the following conversation concisely, preserving key facts and context:\n\n"
75 .to_string()
76 + &to_summarize
77 .iter()
78 .map(|m| format!("{}: {}", m.role(), m.content()))
79 .collect::<Vec<_>>()
80 .join("\n"),
81 );
82
83 let summary_req = ChatRequest::new(vec![
84 Message::system("You are a conversation summarizer. Output a brief summary."),
85 summary_prompt,
86 ]);
87
88 let summary_resp = self.model.chat(summary_req).await?;
89 let summary_text = summary_resp.message.content().to_string();
90
91 let mut new_messages = vec![Message::system(format!(
93 "[Previous conversation summary]: {summary_text}"
94 ))];
95 new_messages.extend_from_slice(&request.messages[keep_from..]);
96 request.messages = new_messages;
97
98 Ok(())
99 }
100}