Skip to main content

rab/agent/
agent_session.rs

1use crate::agent::branch_summary::{collect_entries_for_branch_summary, generate_branch_summary};
2use crate::agent::compaction::{
3    self, CompactionReason, CompactionResult, CompactionSettings, compact, prepare_compaction,
4};
5use crate::agent::extension::Extension;
6use crate::agent::session::SessionManager;
7use crate::agent::types::{message_text, user_message};
8use std::sync::Arc;
9
10use crate::provider::ProviderRegistry;
11use yoagent::types::AgentMessage;
12
13// ── Compaction lifecycle events ─────────────────────────────────────
14
15/// Events emitted during the compaction lifecycle.
16/// Matches pi's `compaction_start` / `compaction_end` event semantics.
17#[derive(Debug, Clone)]
18pub enum CompactionEvent {
19    /// Compaction has started with the given reason.
20    Start { reason: CompactionReason },
21    /// Compaction completed successfully.
22    End {
23        reason: CompactionReason,
24        result: CompactionResult,
25        aborted: bool,
26        will_retry: bool,
27        error_message: Option<String>,
28    },
29}
30
31/// Callback for compaction lifecycle events.
32pub type CompactionEventCallback = Box<dyn Fn(&CompactionEvent) + Send + Sync>;
33
34/// Bridges the agent loop events and session persistence.
35///
36/// Handles:
37/// - Event-driven message persistence (persist tool results as they arrive)
38/// - Automatic model/thinking/tool change detection and persistence
39pub struct AgentSession {
40    /// The session manager (owns Session + flush logic).
41    mgr: SessionManager,
42    /// Last known model for change detection.
43    last_model: Option<(String, String)>,
44    /// Last known thinking level for change detection.
45    last_thinking_level: String,
46    /// Last known active tool names for change detection.
47    last_active_tools: Option<Vec<String>>,
48    /// Compaction settings (default: enabled).
49    compaction_settings: CompactionSettings,
50    /// Model context window in tokens (for shouldCompact check).
51    context_window: u64,
52    /// Model name to use for compaction LLM calls.
53    model_name: String,
54    /// API key for compaction LLM calls.
55    compaction_api_key: Option<String>,
56    /// Model configuration for compaction LLM calls (base URL, compat flags, etc.).
57    model_config: Option<yoagent::provider::model::ModelConfig>,
58    /// Current thinking level from the session (for compaction summarization).
59    thinking_level: yoagent::types::ThinkingLevel,
60    /// Registered extensions (for compaction hooks).
61    extensions: Vec<Box<dyn Extension>>,
62    /// Lifecycle event listeners.
63    event_listeners: Vec<CompactionEventCallback>,
64    /// Whether overflow recovery has already been attempted (prevents loops).
65    overflow_recovery_attempted: bool,
66    /// Cancellation token for in-progress compaction (pi-compatible abort).
67    compaction_cancel: crate::agent::extension::Cancel,
68    /// Provider registry for resolving model cost configs per message (pi-style).
69    registry: Option<Arc<ProviderRegistry>>,
70}
71
72impl AgentSession {
73    /// Create a new AgentSession from a SessionManager (pi-compatible: keeps the manager).
74    pub fn new(mgr: SessionManager) -> Self {
75        // Snapshot current metadata from the session context for change detection.
76        let ctx = mgr.session().build_context();
77
78        // If the session has no thinking level change entries, set last_thinking_level
79        // to empty so the first on_thinking_level_change always detects a change.
80        let has_thinking_entries = !mgr
81            .session()
82            .find_entries("thinking_level_change")
83            .is_empty();
84        let last_thinking_level = if has_thinking_entries {
85            ctx.thinking_level
86        } else {
87            String::new()
88        };
89
90        Self {
91            mgr,
92            last_model: ctx.model,
93            last_thinking_level,
94            last_active_tools: ctx.active_tool_names,
95            compaction_settings: CompactionSettings::default(),
96            context_window: 200_000,
97            model_name: String::new(),
98            compaction_api_key: None,
99            model_config: None,
100            thinking_level: yoagent::types::ThinkingLevel::Off,
101            extensions: Vec::new(),
102            event_listeners: Vec::new(),
103            overflow_recovery_attempted: false,
104            compaction_cancel: crate::agent::extension::Cancel::new(),
105            registry: None,
106        }
107    }
108
109    // ── Static factory methods ─────────────────────────────────
110
111    /// Create a new persisted session.
112    pub fn create(cwd: &std::path::Path, session_dir: Option<&std::path::Path>) -> Self {
113        Self::new(SessionManager::create(cwd, session_dir))
114    }
115
116    /// Open a specific session file.
117    pub fn open(
118        path: &std::path::Path,
119        session_dir: Option<&std::path::Path>,
120        cwd_override: Option<&std::path::Path>,
121    ) -> Self {
122        Self::new(SessionManager::open(path, session_dir, cwd_override))
123    }
124
125    /// Create an in-memory session (no persistence).
126    pub fn in_memory(cwd: &std::path::Path) -> Self {
127        Self::new(SessionManager::in_memory(cwd))
128    }
129
130    /// Continue most recent session or create new.
131    pub fn continue_recent(cwd: &std::path::Path, session_dir: Option<&std::path::Path>) -> Self {
132        Self::new(SessionManager::continue_recent(cwd, session_dir))
133    }
134
135    /// Fork a session from another project directory.
136    pub fn fork_from(
137        source_path: &std::path::Path,
138        target_cwd: &std::path::Path,
139        session_dir: Option<&std::path::Path>,
140        options: Option<&crate::agent::session::NewSessionOptions>,
141    ) -> std::io::Result<Self> {
142        SessionManager::fork_from(source_path, target_cwd, session_dir, options).map(Self::new)
143    }
144
145    /// Configure compaction with API key, model, context window, and model config.
146    pub fn set_compaction_config(
147        &mut self,
148        api_key: String,
149        model_name: &str,
150        context_window: u64,
151        model_config: Option<yoagent::provider::model::ModelConfig>,
152    ) {
153        self.compaction_api_key = Some(api_key);
154        self.model_name = model_name.to_string();
155        self.context_window = context_window;
156        self.model_config = model_config;
157    }
158
159    /// Enable or disable auto-compaction.
160    pub fn set_auto_compact(&mut self, enabled: bool) {
161        self.compaction_settings.enabled = enabled;
162    }
163
164    /// Set the provider registry for per-message cost computation (pi-style).
165    pub fn set_registry(&mut self, registry: Arc<ProviderRegistry>) {
166        self.registry = Some(registry);
167    }
168
169    /// Sync the thinking level from the session context.
170    /// Should be called after the session context changes.
171    pub fn sync_thinking_level(&mut self) {
172        let ctx = self.mgr.session().build_context();
173        let level_str = ctx.thinking_level.to_lowercase();
174        self.thinking_level = match level_str.as_str() {
175            "off" => yoagent::types::ThinkingLevel::Off,
176            "minimal" => yoagent::types::ThinkingLevel::Minimal,
177            "low" => yoagent::types::ThinkingLevel::Low,
178            "medium" => yoagent::types::ThinkingLevel::Medium,
179            "high" => yoagent::types::ThinkingLevel::High,
180            _ => yoagent::types::ThinkingLevel::Off,
181        };
182    }
183
184    /// Get the current compaction settings (mutable, for modification).
185    pub fn compaction_settings_mut(&mut self) -> &mut CompactionSettings {
186        &mut self.compaction_settings
187    }
188
189    /// Get the current compaction settings.
190    pub fn compaction_settings(&self) -> &CompactionSettings {
191        &self.compaction_settings
192    }
193
194    /// Set the list of extensions (for compaction hooks).
195    pub fn set_extensions(&mut self, extensions: Vec<Box<dyn Extension>>) {
196        self.extensions = extensions;
197    }
198
199    /// Abort any in-progress compaction (matching pi's `abortCompaction()`).
200    /// The cancellation will be picked up by extension hooks on their next
201    /// `cancel.is_cancelled()` check.
202    pub fn abort_compaction(&self) {
203        self.compaction_cancel.cancel();
204    }
205
206    /// Register a compaction lifecycle event listener.
207    pub fn on_compaction_event(&mut self, callback: CompactionEventCallback) {
208        self.event_listeners.push(callback);
209    }
210
211    /// Emit a compaction event to all registered listeners.
212    fn emit_compaction_event(&self, event: &CompactionEvent) {
213        for listener in &self.event_listeners {
214            listener(event);
215        }
216    }
217
218    /// Reset overflow recovery state (called when starting a new turn).
219    /// Pi-compatible: reset overflow recovery when a user message arrives
220    /// (matches pi's _overflowRecoveryAttempted reset in message_start for user role).
221    pub fn reset_overflow_recovery(&mut self) {
222        self.overflow_recovery_attempted = false;
223        self.compaction_cancel = crate::agent::extension::Cancel::new();
224    }
225
226    /// Check if a provider error indicates context overflow.
227    /// Matches pi's context overflow detection patterns.
228    pub fn is_context_overflow_error(msg: &AgentMessage) -> bool {
229        let text = message_text(msg);
230        let lower = text.to_lowercase();
231        // Pi-compatible: detect HTTP 413, "prompt too long", "context_length_exceeded", etc.
232        lower.contains("413")
233            || lower.contains("request_too_large")
234            || lower.contains("prompt too long")
235            || lower.contains("context_length_exceeded")
236            || lower.contains("context overflow")
237            || lower.contains("max context length")
238            || lower.contains("exceeded max tokens")
239            || lower.contains("maximum context length")
240    }
241
242    // ── Accessors ─────────────────────────────────────────────────
243
244    /// Borrow the underlying session manager.
245    /// Borrow the underlying Session.
246    pub fn session(&self) -> &crate::agent::session::Session {
247        self.mgr.session()
248    }
249
250    /// Mutably borrow the underlying Session.
251    pub fn session_mut(&mut self) -> &mut crate::agent::session::Session {
252        self.mgr.session_mut()
253    }
254
255    /// Consume and return the inner Session.
256    pub fn into_session(self) -> crate::agent::session::Session {
257        self.mgr.into_session()
258    }
259
260    /// Flush is handled automatically by `SessionManager` on every `append_message`.
261    /// Call this to force an early flush (e.g. before saving state externally).
262    pub fn ensure_flushed(&mut self) {
263        self.mgr.ensure_flushed();
264    }
265
266    // ── App-level accessors ────────────────────────────────────
267
268    pub fn cwd(&self) -> &std::path::Path {
269        self.mgr.cwd()
270    }
271
272    pub fn session_dir(&self) -> &std::path::Path {
273        self.mgr.session_dir()
274    }
275
276    pub fn is_persisted(&self) -> bool {
277        self.mgr.is_persisted()
278    }
279
280    pub fn session_id(&self) -> String {
281        self.mgr.session().session_id()
282    }
283
284    pub fn session_file(&self) -> Option<std::path::PathBuf> {
285        self.mgr.session().session_file()
286    }
287
288    pub fn session_name(&self) -> Option<String> {
289        self.mgr.session().session_name()
290    }
291
292    // ── Model / thinking / tool change tracking ─────────────────
293
294    /// Persist a model change if it differs from the last known model.
295    /// Pi-compatible: writes immediately to the session.
296    pub fn on_model_change(&mut self, provider: &str, model_id: &str) -> bool {
297        let new = (provider.to_string(), model_id.to_string());
298        if self.last_model.as_ref() != Some(&new) {
299            self.mgr
300                .session_mut()
301                .append_model_change(provider, model_id);
302            self.last_model = Some(new);
303            true
304        } else {
305            false
306        }
307    }
308
309    /// Persist a thinking level change if it differs from the last known level.
310    /// Pi-compatible: writes immediately to the session.
311    pub fn on_thinking_level_change(&mut self, level: &str) -> bool {
312        if self.last_thinking_level != level {
313            self.mgr.session_mut().append_thinking_level_change(level);
314            self.last_thinking_level = level.to_string();
315            true
316        } else {
317            false
318        }
319    }
320
321    /// Persist an active tools change if it differs from the last known set.
322    /// Pi-compatible: writes immediately to the session.
323    pub fn on_active_tools_change(&mut self, tools: &[String]) -> bool {
324        let tools_vec = tools.to_vec();
325        if self.last_active_tools.as_ref() != Some(&tools_vec) {
326            self.mgr
327                .session_mut()
328                .append_active_tools_change(&tools_vec);
329            self.last_active_tools = Some(tools_vec);
330            true
331        } else {
332            false
333        }
334    }
335
336    // ── User message submission ───────────────────────────────────
337
338    /// Reset the session (creates a new empty session) and clear
339    /// all tracked state so the new session starts fresh.
340    pub fn new_session(&mut self) {
341        self.mgr.new_session(None);
342        self.last_model = None;
343        self.last_thinking_level = String::new();
344        self.last_active_tools = None;
345        self.compaction_cancel = crate::agent::extension::Cancel::new();
346    }
347
348    /// Append a user message to the session (pi-compatible: persists immediately).
349    /// Returns the entry id.
350    pub fn send_user_message(&mut self, content: &str) -> String {
351        let msg = user_message(content);
352        self.mgr.append_message(&msg)
353    }
354
355    /// Append a user message (pre-constructed) to the session.
356    /// Returns the entry id.
357    pub fn send_user_message_obj(&mut self, msg: &AgentMessage) -> String {
358        self.mgr.append_message(msg)
359    }
360
361    // ── Event-driven persistence ──────────────────────────────────
362
363    /// Process an agent event for automatic persistence (pi-compatible).
364    ///
365    /// Pi persists every message (user, assistant, tool result, custom) immediately
366    /// on `message_end`, not deferred to `agent_end`. Extension messages use
367    /// `custom_message` entries (excluded from LLM context); all others use regular
368    /// `message` entries.
369    ///
370    /// Call this from your agent event handler.
371    pub fn on_agent_event(&mut self, event: &yoagent::types::AgentEvent) {
372        // Pi-compatible: persist every message immediately on message_end
373        if let yoagent::types::AgentEvent::MessageEnd { message } = event {
374            // Pi-compatible: reset overflow recovery when a user message arrives
375            // (matches pi's _overflowRecoveryAttempted reset in message_start for user role).
376            if crate::agent::types::message_is_user(message) {
377                self.reset_overflow_recovery();
378            }
379            // Pi-compatible: persist every message immediately on message_end.
380            // Extension messages use custom_message entries (excluded from LLM context);
381            // all others use regular messages.
382            if crate::agent::types::message_is_extension(message) {
383                self.persist_extension_message(message);
384            } else {
385                self.mgr.append_message(message);
386            }
387        }
388    }
389
390    // ── Compaction ────────────────────────────────────────────────
391
392    /// Check if compaction should run and execute it if needed.
393    /// Should be called after the agent finishes a turn (after on_agent_end).
394    /// Returns `true` if compaction was performed.
395    pub async fn check_auto_compact(&mut self) -> Result<bool, String> {
396        Ok(self
397            ._run_compaction(CompactionReason::Threshold, None, false)
398            .await?
399            .is_some())
400    }
401
402    /// Run compaction after a context overflow error.
403    /// If `will_retry` is true, the agent turn will be retried after compaction.
404    /// Returns `Ok(true)` if compaction was performed, `Ok(false)` if recovery already attempted.
405    pub async fn check_overflow_compact(&mut self, will_retry: bool) -> Result<bool, String> {
406        if self.overflow_recovery_attempted {
407            return Ok(false);
408        }
409        self.overflow_recovery_attempted = true;
410        Ok(self
411            ._run_compaction(CompactionReason::Overflow, None, will_retry)
412            .await?
413            .is_some())
414    }
415
416    /// Run compaction manually (ignores auto-compact setting).
417    /// Returns the compaction summary text, or an error message.
418    pub async fn run_manual_compact(
419        &mut self,
420        custom_instructions: Option<&str>,
421    ) -> Result<String, String> {
422        let result = self
423            ._run_compaction(CompactionReason::Manual, custom_instructions, false)
424            .await?;
425        Ok(result.map(|r| r.summary).unwrap_or_default())
426    }
427
428    /// Internal: run compaction with the given reason, emitting lifecycle events.
429    /// Returns the CompactionResult if compaction was performed, or None if skipped.
430    async fn _run_compaction(
431        &mut self,
432        reason: CompactionReason,
433        custom_instructions: Option<&str>,
434        will_retry: bool,
435    ) -> Result<Option<CompactionResult>, String> {
436        // For threshold compaction, check if auto-compact is enabled
437        if reason == CompactionReason::Threshold && !self.compaction_settings.enabled {
438            return Ok(None);
439        }
440
441        if self.compaction_api_key.is_none() || self.model_name.is_empty() {
442            return Ok(None);
443        }
444
445        // Create a fresh cancellation token for this compaction run
446        // (pi-compatible: matches AbortController per compaction call)
447        self.compaction_cancel = crate::agent::extension::Cancel::new();
448        let cancel = self.compaction_cancel.clone();
449
450        // Emit compaction_start
451        self.emit_compaction_event(&CompactionEvent::Start { reason });
452
453        // Check for cancellation before proceeding
454        if cancel.is_cancelled() {
455            return Ok(None);
456        }
457
458        let entries = self.mgr.get_entries();
459
460        // Check threshold for auto-compact
461        if reason == CompactionReason::Threshold {
462            let context_msgs = self.mgr.session().build_context().messages;
463            let context_tokens = compaction::estimate_context_tokens(&context_msgs);
464            if !compaction::should_compact(
465                context_tokens,
466                self.context_window,
467                &self.compaction_settings,
468            ) {
469                return Ok(None);
470            }
471        }
472
473        let Some(prep) = prepare_compaction(&entries, &self.compaction_settings) else {
474            return Ok(None);
475        };
476
477        // Extension hooks: before_compact
478        let mut from_hook = false;
479        let mut hook_summary: Option<String> = None;
480        let mut hook_details: Option<serde_json::Value> = None;
481
482        for ext in &self.extensions {
483            if cancel.is_cancelled() {
484                break;
485            }
486            if let Some(result) = ext.before_compact(
487                &prep.first_kept_entry_id,
488                prep.tokens_before,
489                &reason.to_string(),
490                &cancel,
491            ) {
492                if result.cancel {
493                    self.emit_compaction_event(&CompactionEvent::End {
494                        reason,
495                        aborted: true,
496                        will_retry: false,
497                        error_message: Some("Compaction cancelled by extension".to_string()),
498                        result: CompactionResult {
499                            summary: String::new(),
500                            first_kept_entry_id: prep.first_kept_entry_id.clone(),
501                            tokens_before: prep.tokens_before,
502                            estimated_tokens_after: 0,
503                            details: None,
504                        },
505                    });
506                    return Ok(None);
507                }
508                if result.summary.is_some() {
509                    hook_summary = result.summary;
510                    hook_details = result.details;
511                    from_hook = true;
512                    break;
513                }
514            }
515        }
516
517        let result = if let Some(summary) = hook_summary {
518            // Extension provided custom summary
519            CompactionResult {
520                summary,
521                first_kept_entry_id: prep.first_kept_entry_id.clone(),
522                tokens_before: prep.tokens_before,
523                estimated_tokens_after: 0, // will be computed after append
524                details: hook_details,
525            }
526        } else {
527            // Call provider for summarization
528            let api_key = self.compaction_api_key.as_ref().unwrap();
529            compact(
530                &prep,
531                api_key,
532                &self.model_name,
533                custom_instructions,
534                self.thinking_level,
535                self.model_config.clone(),
536            )
537            .await?
538        };
539
540        // Append the compaction entry to the session
541        self.mgr.session_mut().append_compaction(
542            &result.summary,
543            &result.first_kept_entry_id,
544            result.tokens_before,
545            result.details.clone(),
546            Some(from_hook),
547        );
548
549        // Compute estimated tokens after compaction
550        let context_after = self.mgr.session().build_context().messages;
551        let estimated_tokens_after = compaction::estimate_context_tokens(&context_after);
552
553        let final_result = CompactionResult {
554            estimated_tokens_after,
555            ..result
556        };
557
558        // Extension hooks: after_compact
559        for ext in &self.extensions {
560            if cancel.is_cancelled() {
561                break;
562            }
563            ext.after_compact(
564                &final_result.summary,
565                &final_result.first_kept_entry_id,
566                final_result.tokens_before,
567                final_result.estimated_tokens_after,
568                from_hook,
569                &reason.to_string(),
570                &cancel,
571            );
572        }
573
574        // Emit compaction_end
575        self.emit_compaction_event(&CompactionEvent::End {
576            reason,
577            result: final_result.clone(),
578            aborted: false,
579            will_retry,
580            error_message: None,
581        });
582
583        Ok(Some(final_result))
584    }
585
586    // ── Branch summarization ───────────────────────────────────────
587
588    /// Summarise the abandoned branch when navigating to a different node.
589    ///
590    /// Collects entries between `old_leaf_id` and the common ancestor with
591    /// `target_id`, summarises them via the provider, and appends a
592    /// `BranchSummaryEntry` to the session.
593    ///
594    /// Returns the summary text, or an error message.
595    pub async fn summarize_branch_navigation(
596        &mut self,
597        old_leaf_id: Option<&str>,
598        target_id: &str,
599    ) -> Result<String, String> {
600        if self.compaction_api_key.is_none() || self.model_name.is_empty() {
601            return Err("No provider configured for summarization".to_string());
602        }
603
604        let (entries, _common_ancestor) =
605            collect_entries_for_branch_summary(self.session(), old_leaf_id, target_id);
606
607        if entries.is_empty() {
608            return Err("No abandoned entries to summarize".to_string());
609        }
610
611        let api_key = self.compaction_api_key.as_ref().unwrap();
612        generate_branch_summary(
613            self.mgr.session_mut(),
614            &entries,
615            target_id,
616            api_key,
617            &self.model_name,
618            self.thinking_level,
619            self.model_config.clone(),
620        )
621        .await
622    }
623
624    /// Move the leaf pointer to an earlier entry (starts a new branch).
625    /// Optionally summarizes the abandoned path if a provider is configured.
626    /// Returns the branch summary text if summarization was performed.
627    pub async fn set_branch(&mut self, branch_from_id: &str) -> Result<Option<String>, String> {
628        let old_leaf = self.mgr.session().get_leaf_id();
629
630        let summary = if self.compaction_api_key.is_some()
631            && !self.model_name.is_empty()
632            && let Some(ref old) = old_leaf
633            && old != branch_from_id
634        {
635            // Summarize the abandoned path
636            match self
637                .summarize_branch_navigation(Some(old), branch_from_id)
638                .await
639            {
640                Ok(s) => Some(s),
641                Err(e) => {
642                    // Non-fatal: still allow the branch move
643                    eprintln!("Warning: branch summarization failed: {}", e);
644                    None
645                }
646            }
647        } else {
648            None
649        };
650
651        self.mgr
652            .session_mut()
653            .set_leaf_id(Some(branch_from_id))
654            .map_err(|e| format!("Failed to set branch: {}", e))?;
655
656        Ok(summary)
657    }
658
659    /// Persist a tool result message (public so the agent loop can persist crash-safely).
660    /// Deduplicates by tool_call_id.
661    /// Persist an Extension message as a `custom_message` session entry (pi-compatible).
662    /// Extension messages are NOT persisted as regular messages — they use the
663    /// `custom_message` entry type which supports `custom_type`, `display`, and `details`.
664    pub fn persist_extension_message(&mut self, msg: &AgentMessage) {
665        let Some(kind) = crate::agent::types::message_extension_kind(msg) else {
666            return;
667        };
668        let text = crate::agent::types::message_extension_text(msg)
669            .unwrap_or_else(|| crate::agent::types::message_text(msg));
670        let content = serde_json::json!({"text": text});
671        self.mgr
672            .session_mut()
673            .append_custom_message_entry(kind, content, true, None);
674    }
675}