Skip to main content

rab/agent/
compaction.rs

1use serde::Serialize;
2
3use crate::agent::session::SessionEntry;
4use yoagent::types::AgentMessage;
5
6// ── CompactionSettings ─────────────────────────────────────────────
7
8/// Per-session config for compaction behaviour.
9#[derive(Debug, Clone)]
10pub struct CompactionSettings {
11    pub enabled: bool,
12    /// Tokens to reserve for system prompt, tool defs, and the response.
13    pub reserve_tokens: u64,
14    /// Number of most-recent tokens to always keep (never summarised).
15    pub keep_recent_tokens: u64,
16}
17
18impl Default for CompactionSettings {
19    fn default() -> Self {
20        Self {
21            enabled: true,
22            reserve_tokens: 16_384,
23            keep_recent_tokens: 20_000,
24        }
25    }
26}
27
28// ── Compaction reason ──────────────────────────────────────────────
29
30/// Why compaction was triggered.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
32#[serde(rename_all = "lowercase")]
33pub enum CompactionReason {
34    /// User manually triggered `/compact`.
35    Manual,
36    /// Context usage exceeded the configured threshold.
37    Threshold,
38    /// Provider returned a context overflow error.
39    Overflow,
40}
41
42impl std::fmt::Display for CompactionReason {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            CompactionReason::Manual => write!(f, "manual"),
46            CompactionReason::Threshold => write!(f, "threshold"),
47            CompactionReason::Overflow => write!(f, "overflow"),
48        }
49    }
50}
51
52// ── Result types ───────────────────────────────────────────────────
53
54/// Result of prepare_compaction — what to summarise and what to keep.
55#[derive(Debug, Clone)]
56pub struct CompactionPreparation {
57    /// ID of the first entry to keep (everything before is summarised).
58    pub first_kept_entry_id: String,
59    /// Messages to summarise (will be replaced by a compaction entry).
60    pub messages_to_summarize: Vec<AgentMessage>,
61    /// Turn-prefix messages when splitting a single turn.
62    pub turn_prefix_messages: Vec<AgentMessage>,
63    /// Whether the cut point split a turn in half.
64    pub is_split_turn: bool,
65    /// Estimated total tokens before compaction.
66    pub tokens_before: u64,
67    /// Previous compaction summary (for incremental update).
68    pub previous_summary: Option<String>,
69}
70
71/// Result of compact() — ready to append to the session.
72#[derive(Debug, Clone, Serialize)]
73pub struct CompactionResult {
74    pub summary: String,
75    pub first_kept_entry_id: String,
76    pub tokens_before: u64,
77    /// Estimated context tokens immediately after compaction is applied.
78    pub estimated_tokens_after: u64,
79    /// File operation details (readFiles, modifiedFiles).
80    pub details: Option<serde_json::Value>,
81}
82
83// ── Default context windows ────────────────────────────────────────
84
85/// Known model context windows (in tokens).
86/// Falls back to 200_000 for unknown models.
87const MODEL_CONTEXT_WINDOWS: &[(&str, u64)] = &[
88    ("deepseek", 1_000_000),
89    ("claude", 200_000),
90    ("gpt-4", 128_000),
91    ("gpt-4o", 128_000),
92    ("gemini", 1_048_576),
93    ("sonnet", 200_000),
94    ("haiku", 200_000),
95];
96
97/// Look up the context window for a model name.
98pub fn get_model_context_window(model: &str) -> u64 {
99    let lower = model.to_lowercase();
100    for (prefix, window) in MODEL_CONTEXT_WINDOWS {
101        if lower.starts_with(prefix) {
102            return *window;
103        }
104    }
105    200_000
106}
107
108// ── Token estimation ───────────────────────────────────────────────
109
110/// Estimate token count for a single message (chars/4 heuristic, conservative).
111pub fn estimate_tokens(message: &AgentMessage) -> u64 {
112    use yoagent::types::Content;
113
114    let text = crate::agent::types::message_text(message);
115    let mut chars: usize = text.len();
116
117    if let AgentMessage::Llm(yoagent::types::Message::Assistant { content, .. }) = message {
118        // Account for thinking blocks and images in assistant messages (pi-compatible).
119        // text.len() is already counted via message_text above, so only add extra non-text content.
120        for c in content {
121            match c {
122                Content::Text { .. } => {
123                    // Already counted in message_text
124                }
125                Content::Thinking { thinking, .. } => {
126                    chars += thinking.len();
127                }
128                Content::ToolCall {
129                    name, arguments, ..
130                } => {
131                    chars += name.len();
132                    chars += serde_json::to_string(arguments).unwrap_or_default().len();
133                }
134                Content::Image { .. } => {
135                    // Pi estimates 4800 chars per image
136                    chars += 4800;
137                }
138            }
139        }
140    } else if let AgentMessage::Llm(yoagent::types::Message::User { content: c, .. }) = message {
141        // Account for images in user messages (pi-compatible)
142        for c in c {
143            if matches!(c, Content::Image { .. }) {
144                chars += 4800;
145            }
146        }
147    }
148
149    (chars as u64).div_ceil(4)
150}
151
152/// Estimate context tokens for a slice of messages.
153/// Uses recorded usage from the last non-aborted assistant message as the baseline,
154/// then adds estimated tokens for any messages after it.
155pub fn estimate_context_tokens(messages: &[AgentMessage]) -> u64 {
156    let mut last_usage_index = None;
157    for (i, msg) in messages.iter().enumerate().rev() {
158        if let Some(usage) = crate::agent::types::message_usage(msg) {
159            // Skip usage records that are all zeros (e.g. from test helpers)
160            if usage.input > 0 || usage.output > 0 || usage.cache_read > 0 {
161                last_usage_index = Some(i);
162                break;
163            }
164        }
165    }
166
167    if let Some(idx) = last_usage_index {
168        if let Some(usage) = crate::agent::types::message_usage(&messages[idx]) {
169            let usage_tokens = usage.input + usage.output + usage.cache_read;
170            let mut trailing = 0u64;
171            for msg in &messages[idx + 1..] {
172                trailing += estimate_tokens(msg);
173            }
174            usage_tokens + trailing
175        } else {
176            messages.iter().map(estimate_tokens).sum()
177        }
178    } else {
179        messages.iter().map(estimate_tokens).sum()
180    }
181}
182
183// ── shouldCompact ──────────────────────────────────────────────────
184
185/// Determine whether compaction should trigger.
186pub fn should_compact(
187    context_tokens: u64,
188    context_window: u64,
189    settings: &CompactionSettings,
190) -> bool {
191    if !settings.enabled {
192        return false;
193    }
194    context_tokens > context_window.saturating_sub(settings.reserve_tokens)
195}
196
197// ── Cut-point detection ────────────────────────────────────────────
198
199/// Find valid cut-point indices: user and assistant messages (never tool results).
200fn find_valid_cut_points(entries: &[SessionEntry], start: usize, end: usize) -> Vec<usize> {
201    let mut points = Vec::new();
202    for (i, entry) in entries.iter().enumerate().take(end).skip(start) {
203        match entry {
204            SessionEntry::Message(m) => {
205                if crate::agent::types::message_is_user(&m.message)
206                    || crate::agent::types::message_is_assistant(&m.message)
207                {
208                    points.push(i);
209                }
210            }
211            // Pi-compatible: branch_summary and custom_message are valid cut points
212            SessionEntry::BranchSummary(_) | SessionEntry::CustomMessage(_) => {
213                points.push(i);
214            }
215            SessionEntry::ThinkingLevelChange(_)
216            | SessionEntry::ModelChange(_)
217            | SessionEntry::ActiveToolsChange(_)
218            | SessionEntry::Custom(_)
219            | SessionEntry::Label(_)
220            | SessionEntry::SessionInfo(_)
221            | SessionEntry::Compaction(_)
222            | SessionEntry::Leaf(_) => {}
223        }
224    }
225    points
226}
227
228/// Find the user message that starts the turn containing `entry_index`.
229fn find_turn_start_index(
230    entries: &[SessionEntry],
231    entry_index: usize,
232    start: usize,
233) -> Option<usize> {
234    for i in (start..=entry_index).rev() {
235        match &entries[i] {
236            SessionEntry::Message(m) if crate::agent::types::message_is_user(&m.message) => {
237                return Some(i);
238            }
239            // Pi-compatible: branch_summary and custom_message start a turn
240            SessionEntry::BranchSummary(_) | SessionEntry::CustomMessage(_) => return Some(i),
241            _ => {}
242        }
243    }
244    None
245}
246
247/// Result of finding the cut point.
248struct CutPointResult {
249    first_kept_entry_index: usize,
250    turn_start_index: Option<usize>,
251    is_split_turn: bool,
252}
253
254/// Walk backwards from the end, accumulating estimated token sizes,
255/// and find where to cut.
256fn find_cut_point(
257    entries: &[SessionEntry],
258    start: usize,
259    end: usize,
260    keep_recent_tokens: u64,
261) -> CutPointResult {
262    let cut_points = find_valid_cut_points(entries, start, end);
263
264    if cut_points.is_empty() {
265        return CutPointResult {
266            first_kept_entry_index: start,
267            turn_start_index: None,
268            is_split_turn: false,
269        };
270    }
271
272    let mut accumulated = 0u64;
273    let mut cut_index = cut_points[0];
274
275    for i in (start..end).rev() {
276        let tokens = match &entries[i] {
277            SessionEntry::Message(m) => estimate_tokens(&m.message),
278            _ => continue,
279        };
280        accumulated += tokens;
281
282        if accumulated >= keep_recent_tokens {
283            // Find the closest valid cut point at or after this entry
284            for &cp in &cut_points {
285                if cp >= i {
286                    cut_index = cp;
287                    break;
288                }
289            }
290            break;
291        }
292    }
293
294    // Walk backward past non-message entries (label, info, etc.)
295    while cut_index > start {
296        match &entries[cut_index - 1] {
297            SessionEntry::Message(_) | SessionEntry::Compaction(_) => break,
298            _ => cut_index -= 1,
299        }
300    }
301
302    let cut_entry = &entries[cut_index];
303    let is_user_msg = matches!(cut_entry, SessionEntry::Message(m) if crate::agent::types::message_is_user(&m.message));
304    let turn_start = if is_user_msg {
305        None
306    } else {
307        find_turn_start_index(entries, cut_index, start)
308    };
309
310    CutPointResult {
311        first_kept_entry_index: cut_index,
312        turn_start_index: turn_start,
313        is_split_turn: !is_user_msg && turn_start.is_some(),
314    }
315}
316
317// ── prepareCompaction ──────────────────────────────────────────────
318
319/// Analyse the session branch and determine what should be compacted.
320///
321/// Returns `None` when the last entry is already a compaction (nothing new to do).
322pub fn prepare_compaction(
323    entries: &[SessionEntry],
324    settings: &CompactionSettings,
325) -> Option<CompactionPreparation> {
326    // Don't compact if no entries
327    if entries.is_empty() {
328        return None;
329    }
330    // Don't compact if the last entry is already a compaction
331    if let Some(SessionEntry::Compaction(_)) = entries.last() {
332        return None;
333    }
334
335    // Find previous compaction boundary
336    let mut prev_compaction_idx = None;
337    for (i, entry) in entries.iter().enumerate().rev() {
338        if matches!(entry, SessionEntry::Compaction(_)) {
339            prev_compaction_idx = Some(i);
340            break;
341        }
342    }
343
344    let mut previous_summary: Option<String> = None;
345    let boundary_start = if let Some(ci) = prev_compaction_idx {
346        if let SessionEntry::Compaction(c) = &entries[ci] {
347            previous_summary = Some(c.summary.clone());
348            // Find where the previous compaction's kept region starts
349            let kept_idx = entries.iter().position(|e| e.id() == c.first_kept_entry_id);
350            kept_idx.unwrap_or(ci + 1)
351        } else {
352            0
353        }
354    } else {
355        0
356    };
357
358    let boundary_end = entries.len();
359    let context_msgs: Vec<AgentMessage> = entries
360        .iter()
361        .filter_map(|e| match e {
362            SessionEntry::Message(m) => Some(m.message.clone()),
363            SessionEntry::BranchSummary(s) => Some(crate::agent::types::assistant_message(
364                format!("[Branch: from {}] {}", s.from_id, s.summary),
365            )),
366            SessionEntry::CustomMessage(c) => {
367                Some(crate::agent::types::assistant_message(format!(
368                    "[{}] {}",
369                    c.custom_type,
370                    serde_json::to_string(&c.content).unwrap_or_default()
371                )))
372            }
373            _ => None,
374        })
375        .collect();
376
377    let tokens_before = estimate_context_tokens(&context_msgs);
378
379    let cut = find_cut_point(
380        entries,
381        boundary_start,
382        boundary_end,
383        settings.keep_recent_tokens,
384    );
385
386    let first_kept = &entries[cut.first_kept_entry_index];
387    let first_kept_entry_id = first_kept.id().to_string();
388
389    let history_end = if cut.is_split_turn {
390        cut.turn_start_index.unwrap_or(cut.first_kept_entry_index)
391    } else {
392        cut.first_kept_entry_index
393    };
394
395    // Collect messages to summarise
396    let messages_to_summarize: Vec<AgentMessage> = entries[boundary_start..history_end]
397        .iter()
398        .filter_map(|e| match e {
399            SessionEntry::Message(m) => Some(m.message.clone()),
400            _ => None,
401        })
402        .collect();
403
404    // Turn prefix messages (when splitting a turn)
405    let turn_prefix_messages: Vec<AgentMessage> = if cut.is_split_turn {
406        entries[cut.turn_start_index.unwrap_or(0)..cut.first_kept_entry_index]
407            .iter()
408            .filter_map(|e| match e {
409                SessionEntry::Message(m) => Some(m.message.clone()),
410                _ => None,
411            })
412            .collect()
413    } else {
414        vec![]
415    };
416
417    if messages_to_summarize.is_empty() && turn_prefix_messages.is_empty() {
418        return None;
419    }
420
421    Some(CompactionPreparation {
422        first_kept_entry_id,
423        messages_to_summarize,
424        turn_prefix_messages,
425        is_split_turn: cut.is_split_turn,
426        tokens_before,
427        previous_summary,
428    })
429}
430
431// ── Summarization prompts ──────────────────────────────────────────
432
433const SUMMARIZATION_SYSTEM_PROMPT: &str = "You are a context summarization assistant. Your task is to read a conversation between a user and an AI assistant, then produce a structured summary following the exact format specified.\n\nDo NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.";
434
435const SUMMARIZATION_PROMPT: &str = "The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.\n\nUse this EXACT format:\n\n## Goal\n[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]\n\n## Constraints & Preferences\n- [Any constraints, preferences, or requirements mentioned by user]\n- [Or \"(none)\" if none were mentioned]\n\n## Progress\n### Done\n- [x] [Completed tasks/changes]\n\n### In Progress\n- [ ] [Current work]\n\n### Blocked\n- [Issues preventing progress, if any]\n\n## Key Decisions\n- **[Decision]**: [Brief rationale]\n\n## Next Steps\n1. [Ordered list of what should happen next]\n\n## Critical Context\n- [Any data, examples, or references needed to continue]\n- [Or \"(none)\" if not applicable]\n\nKeep each section concise. Preserve exact file paths, function names, and error messages.";
436
437const UPDATE_SUMMARIZATION_PROMPT: &str = "The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.\n\nUpdate the existing structured summary with new information. RULES:\n- PRESERVE all existing information from the previous summary\n- ADD new progress, decisions, and context from the new messages\n- UPDATE the Progress section: move items from \"In Progress\" to \"Done\" when completed\n- UPDATE \"Next Steps\" based on what was accomplished\n- PRESERVE exact file paths, function names, and error messages\n- If something is no longer relevant, you may remove it\n\nUse this EXACT format:\n\n## Goal\n[Preserve existing goals, add new ones if the task expanded]\n\n## Constraints & Preferences\n- [Preserve existing, add new ones discovered]\n\n## Progress\n### Done\n- [x] [Include previously done items AND newly completed items]\n\n### In Progress\n- [ ] [Current work - update based on progress]\n\n### Blocked\n- [Current blockers - remove if resolved]\n\n## Key Decisions\n- **[Decision]**: [Brief rationale] (preserve all previous, add new)\n\n## Next Steps\n1. [Update based on current state]\n\n## Critical Context\n- [Preserve important context, add new if needed]\n\nKeep each section concise. Preserve exact file paths, function names, and error messages.";
438
439const TURN_PREFIX_SUMMARIZATION_PROMPT: &str = r#"This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
440
441Summarize the prefix to provide context for the retained suffix:
442
443## Original Request
444[What did the user ask for?]
445
446## Early Progress
447- [Key decisions and work done]
448
449## Context for Suffix
450- [Information needed to understand the kept suffix]"#;
451
452// ── File operation extraction ──────────────────────────────────────
453
454/// File operations accumulator, matching pi's FileOperations / createFileOps.
455pub struct FileOps {
456    pub read: std::collections::HashSet<String>,
457    pub written: std::collections::HashSet<String>,
458    pub edited: std::collections::HashSet<String>,
459}
460
461impl FileOps {
462    pub fn new() -> Self {
463        Self {
464            read: std::collections::HashSet::new(),
465            written: std::collections::HashSet::new(),
466            edited: std::collections::HashSet::new(),
467        }
468    }
469
470    /// Extract file ops from a single assistant message (pi-compatible).
471    pub fn extract_from_message(&mut self, msg: &AgentMessage) {
472        if let AgentMessage::Llm(yoagent::types::Message::Assistant { content, .. }) = msg {
473            let tcs = crate::agent::types::content_tool_calls(content);
474            for (_, name, args) in &tcs {
475                // Pi only checks `path` field (not `file_path`)
476                let path = args
477                    .get("path")
478                    .and_then(|v| v.as_str())
479                    .map(|s| s.to_string());
480                let Some(p) = path else { continue };
481                match name.as_str() {
482                    "read" => {
483                        self.read.insert(p);
484                    }
485                    "write" => {
486                        self.written.insert(p);
487                    }
488                    "edit" => {
489                        self.edited.insert(p);
490                    }
491                    _ => {}
492                }
493            }
494        }
495    }
496
497    /// Compute sorted read-only and modified file lists (pi-compatible).
498    pub fn compute_lists(&self) -> (Vec<String>, Vec<String>) {
499        let modified: std::collections::HashSet<String> =
500            self.edited.union(&self.written).cloned().collect();
501        let mut read_only: Vec<String> = self.read.difference(&modified).cloned().collect();
502        read_only.sort();
503        let mut modified_sorted: Vec<String> = modified.into_iter().collect();
504        modified_sorted.sort();
505        (read_only, modified_sorted)
506    }
507
508    /// Serialize to JSON for compaction details (pi-compatible).
509    pub fn to_json_value(&self) -> Option<serde_json::Value> {
510        let (read_files, modified_files) = self.compute_lists();
511        if read_files.is_empty() && modified_files.is_empty() {
512            return None;
513        }
514        Some(serde_json::json!({
515            "readFiles": read_files,
516            "modifiedFiles": modified_files,
517        }))
518    }
519}
520
521impl Default for FileOps {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527/// Extract file operations from a list of messages (for compaction details).
528fn extract_file_ops(messages: &[AgentMessage]) -> Option<serde_json::Value> {
529    let mut ops = FileOps::new();
530    for msg in messages {
531        ops.extract_from_message(msg);
532    }
533    ops.to_json_value()
534}
535
536// ── compact ────────────────────────────────────────────────────────
537
538/// Execute compaction: send messages to the provider for summarisation
539/// and return the result ready to append to the session.
540///
541/// `model_config` should be the session's current model configuration.
542/// `thinking_level` controls whether the summarization uses reasoning mode.
543pub async fn compact(
544    preparation: &CompactionPreparation,
545    api_key: &str,
546    model: &str,
547    system_prompt_override: Option<&str>,
548    thinking_level: yoagent::types::ThinkingLevel,
549    model_config: Option<yoagent::provider::model::ModelConfig>,
550) -> Result<CompactionResult, String> {
551    // Serialize messages to summarise into a single text block
552    let mut conversation_text = String::new();
553    for msg in &preparation.messages_to_summarize {
554        conversation_text.push_str(&format_message_for_summary(msg));
555        conversation_text.push('\n');
556    }
557
558    // Build the summarisation prompt
559    let system = system_prompt_override.unwrap_or(SUMMARIZATION_SYSTEM_PROMPT);
560    let mut prompt = String::new();
561    if !conversation_text.is_empty() {
562        prompt.push_str("<conversation>\n");
563        prompt.push_str(&conversation_text);
564        prompt.push_str("\n</conversation>\n\n");
565    }
566
567    // Add previous summary if available (incremental update)
568    if let Some(ref prev) = preparation.previous_summary {
569        prompt.push_str(&format!(
570            "<previous-summary>\n{}\n</previous-summary>\n\n",
571            prev
572        ));
573    }
574
575    if preparation.is_split_turn && !preparation.turn_prefix_messages.is_empty() {
576        // Two-part summary: history + turn prefix
577        let mut history_text = String::new();
578        for msg in &preparation.turn_prefix_messages {
579            history_text.push_str(&format_message_for_summary(msg));
580            history_text.push('\n');
581        }
582        let turn_prompt = format!(
583            "{}\n\n<turn-prefix>\n{}\n</turn-prefix>\n\n{}",
584            prompt, history_text, TURN_PREFIX_SUMMARIZATION_PROMPT
585        );
586        prompt = turn_prompt;
587    } else if preparation.previous_summary.is_some() {
588        prompt.push_str(UPDATE_SUMMARIZATION_PROMPT);
589    } else {
590        prompt.push_str(SUMMARIZATION_PROMPT);
591    }
592
593    // Create a summarisation message
594    let summary_msg = crate::agent::types::user_message(&prompt);
595
596    // Get summary from provider via yoagent
597    let summary_text = summarize_text(
598        api_key,
599        model,
600        system,
601        &[summary_msg],
602        thinking_level,
603        model_config,
604    )
605    .await?;
606
607    // Extract file operations from messages being summarised
608    let mut all_messages = preparation.messages_to_summarize.clone();
609    all_messages.extend(preparation.turn_prefix_messages.clone());
610    let details = extract_file_ops(&all_messages);
611
612    // Estimate tokens after compaction:
613    //   summary text + kept messages (estimated via heuristic)
614    let summary_msg_est = (summary_text.len() as u64).div_ceil(4);
615    let kept_tokens = preparation
616        .tokens_before
617        .saturating_sub(
618            preparation
619                .messages_to_summarize
620                .iter()
621                .map(estimate_tokens)
622                .sum::<u64>(),
623        )
624        .saturating_sub(
625            preparation
626                .turn_prefix_messages
627                .iter()
628                .map(estimate_tokens)
629                .sum::<u64>(),
630        );
631    let estimated_tokens_after = summary_msg_est + kept_tokens;
632
633    // Build the result
634    Ok(CompactionResult {
635        summary: summary_text,
636        first_kept_entry_id: preparation.first_kept_entry_id.clone(),
637        tokens_before: preparation.tokens_before,
638        estimated_tokens_after,
639        details,
640    })
641}
642
643/// Call the provider for a simple text completion (no tools, no streaming).
644///
645/// Format a message for inclusion in the summarisation prompt.
646fn format_message_for_summary(msg: &AgentMessage) -> String {
647    let role_label = if crate::agent::types::message_is_user(msg) {
648        "User"
649    } else if crate::agent::types::message_is_assistant(msg) {
650        "Assistant"
651    } else {
652        "Tool Result"
653    };
654    let content = crate::agent::types::message_text(msg);
655    let mut result = format!("<{}>\n", role_label);
656    result.push_str(&content);
657
658    // Include tool calls for assistant messages
659    if crate::agent::types::message_tool_call_count(msg) > 0
660        && let AgentMessage::Llm(yoagent::types::Message::Assistant { content: c, .. }) = msg
661    {
662        let tcs = crate::agent::types::content_tool_calls(c);
663        if !tcs.is_empty() {
664            result.push_str("\n\nTool calls:\n");
665            for (_, name, args) in &tcs {
666                result.push_str(&format!(
667                    "  - {}: {}\n",
668                    name,
669                    serde_json::to_string(args).unwrap_or_default()
670                ));
671            }
672        }
673    }
674    result.push_str(&format!("\n</{}>", role_label));
675    result
676}
677
678// ── Tests ──────────────────────────────────────────────────────────
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use crate::agent::session::{CompactionEntry, MessageEntry};
684    use crate::agent::types::{assistant_message, tool_result_message, user_message};
685    use yoagent::types::{AgentMessage, Content, Message};
686
687    // ── get_model_context_window tests ──────────────────────────────
688
689    #[test]
690    fn test_context_window_known_model() {
691        assert_eq!(get_model_context_window("deepseek-v4-flash"), 1_000_000);
692        assert_eq!(get_model_context_window("claude-sonnet-4"), 200_000);
693        assert_eq!(get_model_context_window("gpt-4o"), 128_000);
694        assert_eq!(get_model_context_window("gemini-2.0-flash"), 1_048_576);
695    }
696
697    #[test]
698    fn test_context_window_unknown_model_falls_back() {
699        assert_eq!(get_model_context_window("unknown-model-42"), 200_000);
700    }
701
702    #[test]
703    fn test_context_window_case_insensitive() {
704        assert_eq!(get_model_context_window("DeepSeek-V4"), 1_000_000);
705        assert_eq!(get_model_context_window("CLAUDE-OPUS"), 200_000);
706    }
707
708    // ── estimate_tokens tests ───────────────────────────────────────
709
710    #[test]
711    fn test_estimate_tokens_empty_message() {
712        let msg = user_message("");
713        assert_eq!(estimate_tokens(&msg), 0);
714    }
715
716    #[test]
717    fn test_estimate_tokens_short_message() {
718        let msg = user_message("hello");
719        // 5 chars / 4 = 2 (div_ceil)
720        assert_eq!(estimate_tokens(&msg), 2);
721    }
722
723    #[test]
724    fn test_estimate_tokens_long_message() {
725        let text = "a".repeat(100);
726        let msg = user_message(&text);
727        // 100 / 4 = 25
728        assert_eq!(estimate_tokens(&msg), 25);
729    }
730
731    #[test]
732    fn test_estimate_tokens_tool_call_includes_arguments() {
733        let content = vec![
734            Content::Text {
735                text: "checking".into(),
736            },
737            Content::ToolCall {
738                id: "call1".into(),
739                name: "read".into(),
740                arguments: serde_json::json!({"path": "/tmp/file.txt"}),
741                provider_metadata: None,
742            },
743        ];
744        let msg = AgentMessage::Llm(Message::Assistant {
745            content,
746            stop_reason: yoagent::types::StopReason::Stop,
747            model: String::new(),
748            provider: String::new(),
749            usage: yoagent::types::Usage::default(),
750            timestamp: 0,
751            error_message: None,
752        });
753        let tokens = estimate_tokens(&msg);
754        // text "checking" (8) + name "read" (4) + args json length >= 17
755        assert!(tokens >= 8, "tokens={}", tokens);
756    }
757
758    // ── estimate_context_tokens tests ───────────────────────────────
759
760    #[test]
761    fn test_estimate_context_tokens_empty() {
762        assert_eq!(estimate_context_tokens(&[]), 0);
763    }
764
765    #[test]
766    fn test_estimate_context_tokens_no_usage_uses_heuristic() {
767        let msgs = vec![user_message("hello"), assistant_message("world")];
768        let tokens = estimate_context_tokens(&msgs);
769        // 5/4 + 5/4 = 2 + 2 = 4
770        assert_eq!(tokens, 4);
771    }
772
773    #[test]
774    fn test_estimate_context_tokens_with_usage_baseline() {
775        let msg_with_usage = AgentMessage::Llm(Message::Assistant {
776            content: vec![Content::Text {
777                text: "response".into(),
778            }],
779            stop_reason: yoagent::types::StopReason::Stop,
780            model: String::new(),
781            provider: String::new(),
782            usage: yoagent::types::Usage {
783                input: 100,
784                output: 50,
785                cache_read: 20,
786                cache_write: 0,
787                total_tokens: 0,
788            },
789            timestamp: 0,
790            error_message: None,
791        });
792        let msgs = vec![
793            user_message("hello"),
794            msg_with_usage,
795            user_message("follow-up"),
796        ];
797        let tokens = estimate_context_tokens(&msgs);
798        // usage: 100 + 50 + 20 = 170 + trailing "follow-up" (9/4=3) = 173
799        assert_eq!(tokens, 173);
800    }
801
802    // ── should_compact tests ────────────────────────────────────────
803
804    #[test]
805    fn test_should_compact_disabled() {
806        let settings = CompactionSettings {
807            enabled: false,
808            reserve_tokens: 16_384,
809            keep_recent_tokens: 20_000,
810        };
811        assert!(!should_compact(999_999, 1_000_000, &settings));
812    }
813
814    #[test]
815    fn test_should_compact_under_threshold() {
816        let settings = CompactionSettings::default();
817        assert!(!should_compact(100_000, 200_000, &settings));
818    }
819
820    #[test]
821    fn test_should_compact_at_threshold() {
822        let settings = CompactionSettings {
823            reserve_tokens: 10_000,
824            keep_recent_tokens: 20_000,
825            ..Default::default()
826        };
827        // context_tokens > context_window - reserve = 190_000
828        assert!(should_compact(190_001, 200_000, &settings));
829        assert!(!should_compact(190_000, 200_000, &settings));
830    }
831
832    #[test]
833    fn test_should_compact_exact_boundary() {
834        let settings = CompactionSettings {
835            enabled: true,
836            reserve_tokens: 0,
837            keep_recent_tokens: 0,
838        };
839        assert!(!should_compact(200_000, 200_000, &settings));
840        assert!(should_compact(200_001, 200_000, &settings));
841    }
842
843    // ── find_valid_cut_points (via prepare_compaction) ──────────────
844
845    /// Build a minimal session entry list for compaction testing.
846    fn make_msg_entry(content: &str) -> SessionEntry {
847        SessionEntry::Message(MessageEntry {
848            id: uuid::Uuid::new_v4().to_string(),
849            parent_id: None,
850            timestamp: String::new(),
851            message: user_message(content),
852        })
853    }
854
855    fn make_asst_entry(content: &str) -> SessionEntry {
856        SessionEntry::Message(MessageEntry {
857            id: uuid::Uuid::new_v4().to_string(),
858            parent_id: None,
859            timestamp: String::new(),
860            message: assistant_message(content),
861        })
862    }
863
864    fn make_compaction_entry(first_kept_id: &str) -> SessionEntry {
865        SessionEntry::Compaction(CompactionEntry {
866            id: uuid::Uuid::new_v4().to_string(),
867            parent_id: None,
868            timestamp: String::new(),
869            summary: "previous summary".into(),
870            first_kept_entry_id: first_kept_id.to_string(),
871            tokens_before: 1000,
872            details: None,
873            from_hook: None,
874        })
875    }
876
877    #[test]
878    fn test_prepare_compaction_empty_entries() {
879        let settings = CompactionSettings::default();
880        assert!(prepare_compaction(&[], &settings).is_none());
881    }
882
883    #[test]
884    fn test_prepare_compaction_last_entry_is_compaction() {
885        let entries = vec![make_msg_entry("hello"), make_compaction_entry("some-id")];
886        let settings = CompactionSettings::default();
887        assert!(prepare_compaction(&entries, &settings).is_none());
888    }
889
890    #[test]
891    fn test_prepare_compaction_returns_preparation() {
892        // Create enough entries that keep_recent_tokens forces a cut
893        let mut entries: Vec<SessionEntry> = (0..10)
894            .map(|i| {
895                make_msg_entry(&format!(
896                    "message {} with enough text to accumulate tokens",
897                    i
898                ))
899            })
900            .collect();
901        // Add some assistant messages too
902        for i in 0..5 {
903            entries.push(make_asst_entry(&format!("response {} with enough text", i)));
904        }
905
906        let settings = CompactionSettings {
907            enabled: true,
908            reserve_tokens: 100_000,
909            keep_recent_tokens: 2, // very small, will cut early
910        };
911        let result = prepare_compaction(&entries, &settings);
912        assert!(result.is_some(), "should return preparation");
913        let prep = result.unwrap();
914        assert!(!prep.messages_to_summarize.is_empty());
915        assert!(!prep.first_kept_entry_id.is_empty());
916        assert!(prep.tokens_before > 0);
917    }
918
919    #[test]
920    fn test_prepare_compaction_with_previous_compaction() {
921        let mut entries: Vec<SessionEntry> = vec![make_msg_entry("old message")];
922
923        // First compaction entry
924        let first_id = entries[0].id().to_string();
925        entries.push(make_compaction_entry(&first_id));
926
927        // New messages after compaction
928        entries.push(make_msg_entry("new message"));
929        entries.push(make_asst_entry("new response"));
930
931        let settings = CompactionSettings {
932            enabled: true,
933            reserve_tokens: 100_000,
934            keep_recent_tokens: 1,
935        };
936        let result = prepare_compaction(&entries, &settings);
937        assert!(result.is_some(), "should compact new messages");
938        let prep = result.unwrap();
939        assert!(prep.previous_summary.is_some());
940        assert_eq!(prep.previous_summary.as_deref(), Some("previous summary"));
941    }
942
943    // ── extract_file_ops tests ──────────────────────────────────────
944
945    fn make_asst_with_tool_call(name: &str, path: &str) -> AgentMessage {
946        AgentMessage::Llm(Message::Assistant {
947            content: vec![
948                Content::Text {
949                    text: "using tool".into(),
950                },
951                Content::ToolCall {
952                    id: "call-1".into(),
953                    name: name.into(),
954                    arguments: serde_json::json!({"path": path}),
955                    provider_metadata: None,
956                },
957            ],
958            stop_reason: yoagent::types::StopReason::ToolUse,
959            model: String::new(),
960            provider: String::new(),
961            usage: yoagent::types::Usage::default(),
962            timestamp: 0,
963            error_message: None,
964        })
965    }
966
967    #[test]
968    fn test_extract_file_ops_empty() {
969        assert!(extract_file_ops(&[]).is_none());
970    }
971
972    #[test]
973    fn test_extract_file_ops_no_tools() {
974        let msgs = vec![user_message("hello"), assistant_message("hi")];
975        assert!(extract_file_ops(&msgs).is_none());
976    }
977
978    #[test]
979    fn test_extract_file_ops_read_and_write() {
980        let msgs = vec![
981            make_asst_with_tool_call("read", "/tmp/a.txt"),
982            make_asst_with_tool_call("read", "/tmp/b.txt"),
983            make_asst_with_tool_call("write", "/tmp/a.txt"),
984        ];
985        let result = extract_file_ops(&msgs).unwrap();
986        let obj = result.as_object().unwrap();
987        let read: Vec<String> = serde_json::from_value(obj["readFiles"].clone()).unwrap();
988        let modified: Vec<String> = serde_json::from_value(obj["modifiedFiles"].clone()).unwrap();
989        // a.txt is both read and modified -> goes only in modified
990        assert_eq!(read, vec!["/tmp/b.txt".to_string()]);
991        assert_eq!(modified, vec!["/tmp/a.txt".to_string()]);
992    }
993
994    #[test]
995    fn test_extract_file_ops_deduplicates() {
996        let msgs = vec![
997            make_asst_with_tool_call("read", "/tmp/x.txt"),
998            make_asst_with_tool_call("read", "/tmp/x.txt"),
999        ];
1000        let result = extract_file_ops(&msgs).unwrap();
1001        let obj = result.as_object().unwrap();
1002        let read: Vec<String> = serde_json::from_value(obj["readFiles"].clone()).unwrap();
1003        assert_eq!(read.len(), 1);
1004    }
1005
1006    // ── format_message_for_summary tests ────────────────────────────
1007
1008    #[test]
1009    fn test_format_user_message() {
1010        let msg = user_message("hello world");
1011        let formatted = format_message_for_summary(&msg);
1012        assert!(formatted.contains("<User>"));
1013        assert!(formatted.contains("hello world"));
1014        assert!(formatted.contains("</User>"));
1015    }
1016
1017    #[test]
1018    fn test_format_assistant_message_with_tool_calls() {
1019        let msg = make_asst_with_tool_call("edit", "/tmp/f.py");
1020        let formatted = format_message_for_summary(&msg);
1021        assert!(formatted.contains("<Assistant>"));
1022        assert!(formatted.contains("using tool"));
1023        assert!(formatted.contains("Tool calls"));
1024        assert!(formatted.contains("edit"));
1025    }
1026
1027    #[test]
1028    fn test_format_tool_result_message() {
1029        let msg = tool_result_message("call-1", "bash", "command output", false);
1030        let formatted = format_message_for_summary(&msg);
1031        assert!(formatted.contains("Tool Result"));
1032        assert!(formatted.contains("command output"));
1033    }
1034}
1035
1036// ── Summarization helper (shared with branch_summary) ──
1037
1038/// Call yoagent's provider for a simple text completion (no tools, no streaming).
1039///
1040/// Uses the provided `model_config` (base URL, compat flags, etc.) and `thinking_level`
1041/// instead of hardcoded values. When `model_config` is None, falls back to the default
1042/// OpenCode Go endpoint for backward compatibility.
1043pub async fn summarize_text(
1044    api_key: &str,
1045    model: &str,
1046    system_prompt: &str,
1047    messages: &[AgentMessage],
1048    thinking_level: yoagent::types::ThinkingLevel,
1049    model_config: Option<yoagent::provider::model::ModelConfig>,
1050) -> Result<String, String> {
1051    use yoagent::provider::StreamProvider;
1052    use yoagent::provider::traits::StreamConfig;
1053
1054    let yoagent_messages: Vec<yoagent::types::Message> = messages
1055        .iter()
1056        .filter_map(|m| match m {
1057            AgentMessage::Llm(msg) => Some(msg.clone()),
1058            AgentMessage::Extension(_) => None,
1059        })
1060        .collect();
1061
1062    // Use provided model config, or fall back to hardcoded OpenCode Go for backward compat
1063    let model_config = model_config.unwrap_or_else(|| crate::agent::base_model_config(model));
1064
1065    let retry_config = yoagent::RetryConfig::default();
1066
1067    for attempt in 0..=retry_config.max_retries {
1068        let config = StreamConfig {
1069            model: model.to_string(),
1070            system_prompt: system_prompt.to_string(),
1071            messages: yoagent_messages.clone(),
1072            tools: vec![],
1073            thinking_level,
1074            api_key: api_key.to_string(),
1075            max_tokens: Some(2048),
1076            temperature: Some(0.3),
1077            model_config: Some(model_config.clone()),
1078            cache_config: yoagent::types::CacheConfig::default(),
1079        };
1080
1081        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1082        let cancel = tokio_util::sync::CancellationToken::new();
1083
1084        tokio::spawn(async move {
1085            let _ = yoagent::provider::OpenAiCompatProvider
1086                .stream(config, tx, cancel)
1087                .await;
1088        });
1089
1090        let mut text = String::new();
1091        let mut last_error: Option<String> = None;
1092
1093        while let Some(event) = rx.recv().await {
1094            match event {
1095                yoagent::provider::traits::StreamEvent::TextDelta { delta, .. } => {
1096                    text.push_str(&delta);
1097                }
1098                yoagent::provider::traits::StreamEvent::Done { message } => {
1099                    if let yoagent::types::Message::Assistant { content, .. } = &message {
1100                        for c in content {
1101                            if let yoagent::types::Content::Text { text: t } = c
1102                                && text.is_empty()
1103                            {
1104                                text = t.clone();
1105                            }
1106                        }
1107                    }
1108                    break;
1109                }
1110                yoagent::provider::traits::StreamEvent::Error { .. } => {
1111                    last_error = Some("Provider returned error".to_string());
1112                    break;
1113                }
1114                _ => {}
1115            }
1116        }
1117
1118        if let Some(err) = last_error {
1119            if attempt < retry_config.max_retries {
1120                let delay = retry_config.delay_for_attempt(attempt + 1);
1121                tokio::time::sleep(delay).await;
1122                continue;
1123            }
1124            return Err(err);
1125        }
1126        return Ok(text);
1127    }
1128
1129    unreachable!()
1130}