Skip to main content

ds_api/conversation/
summarizer.rs

1//! Conversation summarizer trait and built-in implementations.
2//!
3//! The [`AUTO_SUMMARY_TAG`][crate::raw::request::message::AUTO_SUMMARY_TAG] constant
4//! in [`Message`] defines the single source of
5//! truth for identifying auto-generated summary messages.
6//!
7//! # Trait
8//!
9//! [`Summarizer`] is an async trait with two methods:
10//! - [`should_summarize`][Summarizer::should_summarize] — synchronous check on the current history.
11//! - [`summarize`][Summarizer::summarize] — async, may perform an API call; mutates history in-place.
12//!
13//! # Built-in implementations
14//!
15//! | Type | Strategy |
16//! |---|---|
17//! | [`LlmSummarizer`] | Calls DeepSeek to produce a semantic summary; **default** for `DeepseekAgent`. |
18//! | [`SlidingWindowSummarizer`] | Keeps the last N messages and silently drops the rest; no API call. |
19
20use std::pin::Pin;
21
22use futures::Future;
23
24use crate::api::{ApiClient, ApiRequest};
25use crate::error::ApiError;
26use crate::raw::request::message::{Message, Role};
27
28// ── Trait ────────────────────────────────────────────────────────────────────
29
30/// Decides when and how to compress conversation history.
31///
32/// Both methods receive an immutable or mutable slice of the current history.
33/// Implementors are free to count tokens, count turns, check wall-clock time,
34/// or use any other heuristic.
35///
36/// The trait is object-safe via `BoxFuture`; you can store it as
37/// `Box<dyn Summarizer>` without `async_trait`.
38///
39/// # Implementing a custom summarizer
40///
41/// ```no_run
42/// use std::pin::Pin;
43/// use ds_api::conversation::Summarizer;
44/// use ds_api::error::ApiError;
45/// use ds_api::raw::request::message::Message;
46///
47/// /// Drops all history older than `max_turns` turns.  No API call needed.
48/// struct TurnLimitSummarizer { max_turns: usize }
49///
50/// impl Summarizer for TurnLimitSummarizer {
51///     fn should_summarize(&self, history: &[Message]) -> bool {
52///         history.len() > self.max_turns
53///     }
54///
55///     fn summarize<'a>(
56///         &'a self,
57///         history: &'a mut Vec<Message>,
58///     ) -> Pin<Box<dyn std::future::Future<Output = Result<(), ApiError>> + Send + 'a>> {
59///         Box::pin(async move {
60///             if history.len() > self.max_turns {
61///                 let drop_count = history.len() - self.max_turns;
62///                 history.drain(0..drop_count);
63///             }
64///             Ok(())
65///         })
66///     }
67/// }
68///
69/// // Use it with an agent:
70/// use ds_api::DeepseekAgent;
71/// let agent = DeepseekAgent::new("sk-...")
72///     .with_summarizer(TurnLimitSummarizer { max_turns: 20 });
73/// ```
74pub trait Summarizer: Send + Sync {
75    /// Return `true` if the history should be summarized before the next API turn.
76    ///
77    /// This is called synchronously on every user-input push; keep it cheap.
78    fn should_summarize(&self, history: &[Message]) -> bool;
79
80    /// Compress `history` in-place, returning an error only for unrecoverable failures.
81    ///
82    /// On success the history must be shorter (or at most the same length) than before.
83    /// Implementations must **not** remove messages whose role is [`Role::System`] and
84    /// whose `name` field is not `Some("[auto-summary]")` — those are user-provided
85    /// system prompts and must be preserved.
86    fn summarize<'a>(
87        &'a self,
88        history: &'a mut Vec<Message>,
89    ) -> Pin<Box<dyn Future<Output = Result<(), ApiError>> + Send + 'a>>;
90}
91
92// ── Helpers ───────────────────────────────────────────────────────────────────
93
94/// Estimate the token count of a slice of messages using a fast character heuristic.
95///
96/// ASCII characters count as 1 char ≈ 0.25 tokens; CJK / multibyte characters are
97/// counted as 4 chars ≈ 1 token.  System messages whose `name` is `[auto-summary]`
98/// are included in the estimate; other system messages (user-provided prompts) are
99/// excluded because they are permanent and we cannot remove them anyway.
100pub(crate) fn estimate_tokens(history: &[Message]) -> usize {
101    history
102        .iter()
103        .filter(|m| {
104            // Always exclude permanent system prompts from the token estimate;
105            // we can't remove them so counting them would trigger summarization
106            // that can never actually free those tokens.
107            if matches!(m.role, Role::System) {
108                // auto-summary placeholders are replaceable → count them
109                m.is_auto_summary()
110            } else {
111                true
112            }
113        })
114        .filter_map(|m| m.content.as_deref())
115        .map(|s| {
116            s.chars()
117                .map(|c| if c.is_ascii() { 1usize } else { 4 })
118                .sum::<usize>()
119        })
120        .sum::<usize>()
121        / 4
122}
123
124/// Partition `history` into (system_prompts, rest), where system prompts are
125/// permanent user-provided system messages (role=System, name≠"[auto-summary]").
126///
127/// Returns the indices of permanent system messages so callers can re-inject
128/// them after compressing the rest.
129fn extract_system_prompts(history: &mut Vec<Message>) -> Vec<Message> {
130    let mut prompts = Vec::new();
131    let mut i = 0;
132    while i < history.len() {
133        let m = &history[i];
134        let is_permanent_system = matches!(m.role, Role::System) && !m.is_auto_summary();
135        if is_permanent_system {
136            prompts.push(history.remove(i));
137            // don't increment i — the next element shifted into position i
138        } else {
139            i += 1;
140        }
141    }
142    prompts
143}
144
145// ── LlmSummarizer ─────────────────────────────────────────────────────────────
146
147/// Summarizes older conversation turns by asking DeepSeek to write a concise
148/// prose summary, then replaces the compressed turns with a single
149/// `Role::System` message containing that summary.
150///
151/// # Trigger
152///
153/// Fires when the estimated token count of the **compressible** portion of the
154/// history (everything except permanent system prompts) exceeds `token_threshold`.
155///
156/// # Behavior
157///
158/// 1. Permanent `Role::System` messages (user-provided via `with_system_prompt`)
159///    are extracted and re-prepended after summarization — they are never lost.
160/// 2. Any previous `[auto-summary]` system message is included in the text sent
161///    to the model so the new summary is cumulative.
162/// 3. The `retain_last` most recent non-system turns are kept verbatim; everything
163///    older is replaced by the LLM-generated summary.
164/// 4. If the API call fails the history is left **unchanged** and the error is
165///    returned so the caller can decide whether to abort or continue.
166///
167/// # Example
168///
169/// ```no_run
170/// use ds_api::{DeepseekAgent, ApiClient};
171/// use ds_api::conversation::LlmSummarizer;
172///
173/// let summarizer = LlmSummarizer::new(ApiClient::new("sk-..."));
174/// let agent = DeepseekAgent::new("sk-...")
175///     .with_summarizer(summarizer);
176/// ```
177#[derive(Clone)]
178pub struct LlmSummarizer {
179    /// Client used exclusively for summary API calls (can share the agent's token).
180    client: ApiClient,
181    /// Model used for the summarization API call.  Defaults to `"deepseek-chat"`.
182    pub(crate) model: String,
183    /// Estimated token count above which summarization is triggered.
184    pub(crate) token_threshold: usize,
185    /// Number of most-recent non-system messages to retain verbatim.
186    pub(crate) retain_last: usize,
187}
188
189impl LlmSummarizer {
190    /// Create with default thresholds: trigger at ~60 000 tokens, retain last 10 turns.
191    ///
192    /// The summarization call uses `"deepseek-chat"` by default.  Override with
193    /// [`with_model`][LlmSummarizer::with_model] — useful when the agent is
194    /// pointed at an OpenAI-compatible provider and you want the summarizer to
195    /// use the same model.
196    pub fn new(client: ApiClient) -> Self {
197        Self {
198            client,
199            model: "deepseek-chat".to_string(),
200            token_threshold: 60_000,
201            retain_last: 10,
202        }
203    }
204
205    /// Builder: set the model used for the summarization API call.
206    ///
207    /// ```no_run
208    /// use ds_api::{ApiClient, LlmSummarizer};
209    ///
210    /// let summarizer = LlmSummarizer::new(ApiClient::new("sk-..."))
211    ///     .with_model("gpt-4o-mini");
212    /// ```
213    pub fn with_model(mut self, model: impl Into<String>) -> Self {
214        self.model = model.into();
215        self
216    }
217
218    /// Builder: set a custom token threshold.
219    pub fn token_threshold(mut self, n: usize) -> Self {
220        self.token_threshold = n;
221        self
222    }
223
224    /// Builder: set how many recent messages to keep verbatim.
225    pub fn retain_last(mut self, n: usize) -> Self {
226        self.retain_last = n;
227        self
228    }
229}
230
231impl Summarizer for LlmSummarizer {
232    fn should_summarize(&self, history: &[Message]) -> bool {
233        estimate_tokens(history) >= self.token_threshold
234    }
235
236    fn summarize<'a>(
237        &'a self,
238        history: &'a mut Vec<Message>,
239    ) -> Pin<Box<dyn Future<Output = Result<(), ApiError>> + Send + 'a>> {
240        Box::pin(async move {
241            // ── 1. Extract permanent system prompts ──────────────────────────
242            let system_prompts = extract_system_prompts(history);
243
244            // ── 2. Split off the tail we want to keep verbatim ───────────────
245            let retain = self.retain_last.min(history.len());
246            let mut split = history.len().saturating_sub(retain);
247
248            while split < history.len() {
249                let current_is_tool = matches!(history[split].role, Role::Tool);
250                let prev_is_call = if split > 0 {
251                    // 检查前一条是否包含 tool_calls
252                    history[split - 1]
253                        .tool_calls
254                        .as_ref()
255                        .map_or(false, |tc| !tc.is_empty())
256                } else {
257                    false
258                };
259
260                if current_is_tool || prev_is_call {
261                    split += 1;
262                } else {
263                    break;
264                }
265            }
266
267            let tail: Vec<Message> = history.drain(split..).collect();
268
269            // history now contains only the "old" turns (including any previous
270            // [auto-summary] message).
271
272            if history.is_empty() {
273                // Nothing old enough to summarize — just restore everything.
274                history.extend(tail);
275                // re-prepend system prompts
276                for (i, p) in system_prompts.into_iter().enumerate() {
277                    history.insert(i, p);
278                }
279                return Ok(());
280            }
281
282            // ── 3. Build a prompt asking the model for a summary ─────────────
283            //
284            // We format the old turns as a readable transcript and ask for a
285            // concise summary that preserves the most important facts and decisions.
286            let mut transcript = String::new();
287            for msg in &*history {
288                // skip the old auto-summary header line if present — the content
289                // itself is still useful context for the new summary
290                let role_label = match msg.role {
291                    Role::User => "User",
292                    Role::Assistant => "Assistant",
293                    Role::System => "System",
294                    Role::Tool => "Tool",
295                };
296
297                let content_text = msg.content.clone().unwrap_or_else(|| {
298                    msg.tool_calls
299                        .as_ref()
300                        .map(|calls| format!("[Calls Tools: {:?}]", calls))
301                        .unwrap_or_default()
302                });
303
304                if !content_text.is_empty() {
305                    transcript.push_str(&format!("{role_label}: {content_text}\n"));
306                }
307            }
308
309            let summarize_prompt = format!(
310                "Below is a conversation transcript. Write a concise summary (a few sentences \
311                 to a short paragraph) that captures the key context, decisions, and facts \
312                 established so far. The summary will replace the original transcript and be \
313                 read by the same AI assistant as a memory aid — be precise and neutral.\n\n\
314                 Transcript:\n{transcript}"
315            );
316
317            let req = ApiRequest::builder()
318                .with_model(self.model.clone())
319                .add_message(Message::new(Role::User, &summarize_prompt))
320                .max_tokens(512);
321
322            let response = self.client.send(req).await?;
323
324            let summary_text = response
325                .choices
326                .into_iter()
327                .next()
328                .and_then(|c| c.message.content)
329                .unwrap_or_else(|| transcript.clone());
330
331            // ── 4. Replace old turns with the summary message ────────────────
332            history.clear();
333
334            history.push(Message::auto_summary(format!(
335                "Summary of the conversation so far:\n{summary_text}"
336            )));
337
338            // ── 5. Re-attach the verbatim tail and system prompts ────────────
339            history.extend(tail);
340
341            for (i, p) in system_prompts.into_iter().enumerate() {
342                history.insert(i, p);
343            }
344
345            Ok(())
346        })
347    }
348}
349
350// ── SlidingWindowSummarizer ───────────────────────────────────────────────────
351
352/// Keeps only the most recent `window` messages and silently discards everything
353/// older.  No API call is made.
354///
355/// Use this when you want predictable, zero-cost context management and are
356/// comfortable with the model losing access to earlier turns.
357///
358/// Permanent `Role::System` messages are always preserved regardless of `window`.
359///
360/// # Example
361///
362/// ```no_run
363/// use ds_api::DeepseekAgent;
364/// use ds_api::conversation::SlidingWindowSummarizer;
365///
366/// // Keep the last 20 non-system messages; trigger summarization above 30.
367/// let agent = DeepseekAgent::new("sk-...")
368///     .with_summarizer(
369///         SlidingWindowSummarizer::new(20)
370///             .trigger_at(30)
371///     );
372/// ```
373#[derive(Debug, Clone)]
374pub struct SlidingWindowSummarizer {
375    /// Maximum number of non-system messages to retain after summarization.
376    pub(crate) window: usize,
377    /// Number of non-system messages above which summarization is triggered.
378    /// Defaults to `window + 1` (trigger as soon as the window is exceeded by one).
379    pub(crate) trigger_at: Option<usize>,
380}
381
382impl SlidingWindowSummarizer {
383    /// Create a summarizer that retains at most `window` non-system messages.
384    ///
385    /// Summarization triggers as soon as the non-system message count exceeds
386    /// `window`.  Use [`trigger_at`][Self::trigger_at] to set a larger trigger
387    /// threshold so the window only slides after a certain amount of growth.
388    pub fn new(window: usize) -> Self {
389        Self {
390            window,
391            trigger_at: None,
392        }
393    }
394
395    /// Builder: set the non-system message count that triggers summarization.
396    ///
397    /// Must be greater than `window`; if set to a value ≤ `window` it is
398    /// silently clamped to `window + 1`.
399    ///
400    /// # Example
401    ///
402    /// ```no_run
403    /// use ds_api::conversation::SlidingWindowSummarizer;
404    ///
405    /// // Retain 20 turns but only start trimming after reaching 40.
406    /// let s = SlidingWindowSummarizer::new(20).trigger_at(40);
407    /// ```
408    pub fn trigger_at(mut self, n: usize) -> Self {
409        self.trigger_at = Some(n.max(self.window + 1));
410        self
411    }
412}
413
414impl Summarizer for SlidingWindowSummarizer {
415    fn should_summarize(&self, history: &[Message]) -> bool {
416        let non_system = history
417            .iter()
418            .filter(|m| !matches!(m.role, Role::System))
419            .count();
420        let threshold = self.trigger_at.unwrap_or(self.window + 1);
421        non_system >= threshold
422    }
423
424    fn summarize<'a>(
425        &'a self,
426        history: &'a mut Vec<Message>,
427    ) -> Pin<Box<dyn Future<Output = Result<(), ApiError>> + Send + 'a>> {
428        Box::pin(async move {
429            // Extract and preserve permanent system prompts.
430            let system_prompts = extract_system_prompts(history);
431
432            // Remove any previous auto-summary messages — they're irrelevant
433            // for a pure sliding window.
434            history.retain(|m| !m.is_auto_summary());
435
436            // Keep only the last `window` non-system messages.
437            if history.len() > self.window {
438                let drop = history.len() - self.window;
439                history.drain(0..drop);
440            }
441
442            // Re-prepend the permanent system prompts at the front.
443            for (i, p) in system_prompts.into_iter().enumerate() {
444                history.insert(i, p);
445            }
446
447            Ok(())
448        })
449    }
450}
451
452// ── Tests ─────────────────────────────────────────────────────────────────────
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    fn msg(role: Role, text: &str) -> Message {
459        Message::new(role, text)
460    }
461
462    fn system_prompt(text: &str) -> Message {
463        // A permanent system prompt — no [auto-summary] name tag.
464        Message::new(Role::System, text)
465    }
466
467    // ── estimate_tokens ───────────────────────────────────────────────────────
468
469    #[test]
470    fn estimate_tokens_excludes_permanent_system() {
471        let history = vec![
472            system_prompt("You are a helpful assistant."),
473            msg(Role::User, "Hello"),         // 5 chars → 1 token
474            msg(Role::Assistant, "Hi there"), // 8 chars → 2 tokens
475        ];
476        // Only the User + Assistant messages should contribute.
477        let est = estimate_tokens(&history);
478        assert!(est > 0);
479        // "Hello" + "Hi there" = 13 chars / 4 = 3 tokens
480        assert_eq!(est, 3);
481    }
482
483    #[test]
484    fn estimate_tokens_includes_auto_summary() {
485        let summary = Message::auto_summary("Some prior summary text.");
486
487        let history = vec![summary];
488        let est = estimate_tokens(&history);
489        assert!(est > 0);
490    }
491
492    // ── SlidingWindowSummarizer ───────────────────────────────────────────────
493
494    #[tokio::test]
495    async fn sliding_window_trims_to_window() {
496        let mut history = vec![
497            system_prompt("system"),
498            msg(Role::User, "a"),
499            msg(Role::Assistant, "b"),
500            msg(Role::User, "c"),
501            msg(Role::Assistant, "d"),
502            msg(Role::User, "e"),
503        ];
504
505        let s = SlidingWindowSummarizer::new(2);
506        assert!(s.should_summarize(&history));
507        s.summarize(&mut history).await.unwrap();
508
509        // system prompt preserved
510        assert!(
511            history
512                .iter()
513                .any(|m| matches!(m.role, Role::System) && m.content.as_deref() == Some("system"))
514        );
515
516        // at most window non-system messages remain
517        let non_sys: Vec<_> = history
518            .iter()
519            .filter(|m| !matches!(m.role, Role::System))
520            .collect();
521        assert_eq!(non_sys.len(), 2);
522
523        // the retained messages are the most recent ones
524        assert_eq!(non_sys[0].content.as_deref(), Some("d"));
525        assert_eq!(non_sys[1].content.as_deref(), Some("e"));
526    }
527
528    #[tokio::test]
529    async fn sliding_window_preserves_multiple_system_prompts() {
530        let mut p1 = system_prompt("prompt one");
531        let mut p2 = system_prompt("prompt two");
532        // Give them something to distinguish them from auto-summary
533        p1.name = None;
534        p2.name = None;
535
536        let mut history = vec![
537            p1.clone(),
538            p2.clone(),
539            msg(Role::User, "1"),
540            msg(Role::User, "2"),
541            msg(Role::User, "3"),
542        ];
543
544        let s = SlidingWindowSummarizer::new(1);
545        s.summarize(&mut history).await.unwrap();
546
547        let sys_msgs: Vec<_> = history
548            .iter()
549            .filter(|m| matches!(m.role, Role::System))
550            .collect();
551        assert_eq!(sys_msgs.len(), 2);
552        assert_eq!(sys_msgs[0].content.as_deref(), Some("prompt one"));
553        assert_eq!(sys_msgs[1].content.as_deref(), Some("prompt two"));
554    }
555
556    #[tokio::test]
557    async fn sliding_window_removes_old_auto_summary() {
558        let auto = Message::auto_summary("old summary");
559
560        let mut history = vec![
561            system_prompt("permanent"),
562            auto,
563            msg(Role::User, "a"),
564            msg(Role::User, "b"),
565            msg(Role::User, "c"),
566        ];
567
568        let s = SlidingWindowSummarizer::new(2);
569        s.summarize(&mut history).await.unwrap();
570
571        // old auto-summary should be gone
572        assert!(!history.iter().any(|m| m.is_auto_summary()));
573
574        // permanent system prompt preserved
575        assert!(
576            history
577                .iter()
578                .any(|m| m.content.as_deref() == Some("permanent"))
579        );
580    }
581
582    #[tokio::test]
583    async fn sliding_window_noop_when_within_window() {
584        let mut history = vec![msg(Role::User, "a"), msg(Role::Assistant, "b")];
585
586        let s = SlidingWindowSummarizer::new(4);
587        assert!(!s.should_summarize(&history));
588        s.summarize(&mut history).await.unwrap();
589        assert_eq!(history.len(), 2);
590    }
591
592    // ── should_summarize ─────────────────────────────────────────────────────
593
594    #[test]
595    fn should_summarize_triggers_at_window_exceeded() {
596        let history = vec![
597            msg(Role::User, "a"),
598            msg(Role::User, "b"),
599            msg(Role::User, "c"),
600        ];
601        let s = SlidingWindowSummarizer::new(2);
602        assert!(s.should_summarize(&history));
603
604        let short = vec![msg(Role::User, "only")];
605        assert!(!s.should_summarize(&short));
606    }
607}