1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
//! Summarization middleware — context window management.
//!
//! Mirrors Python `langchain.agents.middleware.summarization`.
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::Message;
use super::types::{AgentMiddleware, AgentState};
/// How to specify context size thresholds.
#[derive(Debug, Clone)]
pub enum ContextSize {
/// Fraction of the model's max context (0.0 to 1.0).
Fraction(f64),
/// Absolute number of tokens.
Tokens(usize),
/// Number of messages.
Messages(usize),
}
/// Configuration for summarization behavior.
#[derive(Debug, Clone)]
pub struct SummarizationConfig {
/// When to trigger summarization.
pub trigger: ContextSize,
/// How much context to preserve after summarization.
pub keep: ContextSize,
/// System prompt for the summarization model.
pub summary_prompt: String,
}
impl Default for SummarizationConfig {
fn default() -> Self {
Self {
trigger: ContextSize::Fraction(0.75),
keep: ContextSize::Messages(10),
summary_prompt: "You are a conversation summarizer. Summarize the following conversation \
concisely but thoroughly. Preserve all key information including: decisions made, \
tool results, important facts, user preferences, and any pending tasks or open questions. \
Focus on information that would be needed to continue the conversation effectively."
.into(),
}
}
}
/// Middleware that summarizes conversation history when token limits are approached.
///
/// When the conversation exceeds the configured trigger threshold, the middleware
/// splits messages into "to summarize" and "to keep" portions. If a model is
/// provided, it uses the LLM to generate a proper summary. Otherwise, it falls
/// back to concatenating message text.
#[derive(Default)]
pub struct SummarizationMiddleware {
pub config: SummarizationConfig,
/// Optional model to use for generating summaries via LLM.
pub model: Option<Arc<dyn BaseChatModel>>,
}
impl SummarizationMiddleware {
pub fn new(config: SummarizationConfig) -> Self {
Self {
config,
model: None,
}
}
/// Set the model to use for LLM-based summarization.
pub fn with_model(mut self, model: Arc<dyn BaseChatModel>) -> Self {
self.model = Some(model);
self
}
/// Determine the split point between messages to summarize and messages to keep,
/// ensuring AI/Tool message pairs are not split apart.
fn compute_keep_boundary(&self, messages: &[Message], keep_count: usize) -> usize {
if messages.len() <= keep_count {
return 0;
}
let mut boundary = messages.len() - keep_count;
// Walk backward from the boundary to avoid splitting an AI+Tool pair.
// If the message at the boundary is a Tool message, include the preceding
// AI message that triggered it in the "keep" portion.
while boundary > 0 {
if let Message::Tool(_) = &messages[boundary] {
// The tool response belongs with its preceding AI message
boundary -= 1;
} else {
break;
}
}
boundary
}
/// Build a concatenated summary from messages (fallback when no model is available).
fn fallback_summarize(&self, messages: &[Message]) -> String {
let summary_text: String = messages
.iter()
.map(|m| {
let role = m.message_type().as_str();
let content = m.content().text();
format!("{}: {}", role, content)
})
.collect::<Vec<_>>()
.join("\n");
summary_text
}
/// Check whether summarization should trigger based on the current state.
fn should_trigger(&self, state: &AgentState) -> bool {
match &self.config.trigger {
ContextSize::Messages(max) => state.messages.len() > *max,
ContextSize::Tokens(max) => {
// Estimate: ~4 chars per token
let est_tokens: usize = state
.messages
.iter()
.map(|m| m.content().text().len() / 4)
.sum();
est_tokens > *max
}
ContextSize::Fraction(frac) => {
// Without exact model context info, use a heuristic:
// estimate total chars and compare against a reasonable threshold.
// Assume ~100k chars as a typical context window (~25k tokens).
let total_chars: usize = state
.messages
.iter()
.map(|m| m.content().text().len())
.sum();
let threshold = (100_000.0 * frac) as usize;
total_chars > threshold
}
}
}
/// Determine the number of messages to keep.
fn keep_count(&self) -> usize {
match &self.config.keep {
ContextSize::Messages(n) => *n,
ContextSize::Tokens(_) | ContextSize::Fraction(_) => 10,
}
}
}
#[async_trait]
impl AgentMiddleware for SummarizationMiddleware {
fn name(&self) -> &str {
"SummarizationMiddleware"
}
async fn before_model(&self, state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
if !self.should_trigger(state) {
return Ok(None);
}
let keep_count = self.keep_count();
// Preserve the last N messages, respecting AI/Tool pairs
let boundary = self.compute_keep_boundary(&state.messages, keep_count);
if boundary == 0 {
return Ok(None);
}
let to_summarize = &state.messages[..boundary];
let to_keep = &state.messages[boundary..];
// Generate summary using LLM if available, otherwise fallback
let summary_text = if let Some(model) = &self.model {
// Build messages for the summarization LLM call
let mut summarize_messages = vec![Message::system(&self.config.summary_prompt)];
// Include the conversation to summarize as a human message
let conversation_text = self.fallback_summarize(to_summarize);
summarize_messages.push(Message::human(format!(
"Please summarize the following conversation:\n\n{}",
conversation_text
)));
match model.invoke_messages(&summarize_messages, None).await {
Ok(ai_msg) => ai_msg.base.content.text(),
Err(_) => {
// Fall back to concatenation if the LLM call fails
self.fallback_summarize(to_summarize)
}
}
} else {
self.fallback_summarize(to_summarize)
};
// Replace older messages with a summary system message + kept messages
let summary_msg = Message::system(format!(
"[Summary of previous conversation]\n{}",
summary_text
));
let mut new_messages = vec![summary_msg];
new_messages.extend_from_slice(to_keep);
let mut updates = HashMap::new();
updates.insert("messages".into(), serde_json::to_value(&new_messages)?);
Ok(Some(updates))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_summarization_default() {
let mw = SummarizationMiddleware::default();
assert_eq!(mw.name(), "SummarizationMiddleware");
}
#[test]
fn test_context_size_messages() {
let size = ContextSize::Messages(20);
match size {
ContextSize::Messages(n) => assert_eq!(n, 20),
_ => panic!("Expected Messages variant"),
}
}
}