Skip to main content

imp_core/
imp_session.rs

1//! High-level session API for driving imp programmatically.
2//!
3//! `ImpSession` is the primary public interface for embedding imp in other
4//! Rust programs, building custom UIs, or driving agents from orchestrators.
5//! It wires together config, auth, model resolution, agent construction,
6//! session persistence, and the event stream — eliminating the boilerplate
7//! that each run mode (interactive, print, headless, RPC) otherwise
8//! duplicates.
9//!
10//! # Example
11//!
12//! ```no_run
13//! use imp_core::imp_session::{ImpSession, SessionOptions, SessionChoice};
14//!
15//! # async fn example() -> imp_core::Result<()> {
16//! let mut session = ImpSession::create(SessionOptions {
17//!     cwd: std::env::current_dir()?,
18//!     ..Default::default()
19//! }).await?;
20//!
21//! session.prompt("What files are in the current directory?").await?;
22//!
23//! while let Some(event) = session.recv_event().await {
24//!     println!("{event:?}");
25//! }
26//! # Ok(())
27//! # }
28//! ```
29
30use std::collections::VecDeque;
31use std::path::PathBuf;
32use std::sync::Arc;
33
34use tokio::sync::mpsc;
35use tokio::task::JoinHandle;
36
37use imp_llm::auth::{ApiKey, AuthStore};
38use imp_llm::model::{ModelMeta, ModelRegistry};
39use imp_llm::providers::create_provider;
40use imp_llm::{Model, ThinkingLevel};
41
42use crate::agent::{Agent, AgentCommand, AgentEvent, AgentHandle};
43use crate::builder::AgentBuilder;
44use crate::config::{AgentMode, Config};
45use crate::error::{Error, Result};
46use crate::session::{SessionCheckpointRecord, SessionEntry, SessionManager};
47use crate::storage;
48use crate::system_prompt::{Fact, TaskContext};
49use crate::ui::UserInterface;
50
51// ── Options ─────────────────────────────────────────────────────
52
53/// How to initialize the session file.
54#[derive(Debug, Clone, Default)]
55pub enum SessionChoice {
56    /// Fresh session, persisted to disk.
57    #[default]
58    New,
59    /// No persistence.
60    InMemory,
61    /// Continue the most recent session for the working directory.
62    Continue,
63    /// Open a specific session file.
64    Open(PathBuf),
65}
66
67use crate::tools::LuaToolLoader;
68
69/// Configuration for creating an `ImpSession`.
70///
71/// All fields have sensible defaults — only `cwd` is typically required.
72pub struct SessionOptions {
73    /// Working directory. Tools resolve paths relative to this.
74    pub cwd: PathBuf,
75
76    /// Prebuilt model override for deterministic tests or embedded callers.
77    /// When set, ImpSession skips runtime model/provider/auth resolution.
78    pub model_override: Option<Model>,
79
80    /// Model hint — alias ("sonnet") or full ID. Resolved against the
81    /// model registry. Falls back to config, then "sonnet".
82    pub model: Option<String>,
83
84    /// Provider override. Usually auto-detected from the model.
85    pub provider: Option<String>,
86
87    /// Runtime API key override (not persisted).
88    pub api_key: Option<String>,
89
90    /// Thinking level override.
91    pub thinking: Option<ThinkingLevel>,
92
93    /// Agent mode (full, worker, orchestrator, …).
94    pub mode: Option<AgentMode>,
95
96    /// Maximum turns before the agent stops.
97    pub max_turns: Option<u32>,
98
99    /// Max output tokens per response.
100    pub max_tokens: Option<u32>,
101
102    /// Replace the assembled system prompt entirely.
103    pub system_prompt: Option<String>,
104
105    /// Skip native tool registration.
106    pub no_tools: bool,
107
108    /// Session persistence strategy.
109    pub session: SessionChoice,
110
111    /// Task context for headless / unit mode.
112    pub task: Option<TaskContext>,
113
114    /// Task-specific facts to inject into the system prompt.
115    pub facts: Vec<Fact>,
116
117    /// Lua extension loader. Called after native tools are registered.
118    /// The binary crate typically provides this; library callers can
119    /// pass `None` to skip Lua extensions.
120    pub lua_loader: Option<LuaToolLoader>,
121
122    /// Custom UI implementation. Defaults to `NullInterface`.
123    pub ui: Option<Arc<dyn UserInterface>>,
124
125    /// Path to auth.json. Defaults to `~/.config/imp/auth.json`.
126    pub auth_path: Option<PathBuf>,
127
128    /// Pre-assembled context messages injected before the first prompt.
129    /// Built by `context_prefill::assemble_context()` at dispatch time.
130    /// The agent starts with these files already in its cached prefix.
131    pub context_prefill: Vec<imp_llm::Message>,
132}
133
134impl Default for SessionOptions {
135    fn default() -> Self {
136        Self {
137            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
138            model_override: None,
139            model: None,
140            provider: None,
141            api_key: None,
142            thinking: None,
143            mode: None,
144            max_turns: None,
145            max_tokens: None,
146            system_prompt: None,
147            no_tools: false,
148            session: SessionChoice::default(),
149            task: None,
150            facts: Vec::new(),
151            lua_loader: None,
152            ui: None,
153            auth_path: None,
154            context_prefill: Vec::new(),
155        }
156    }
157}
158
159#[derive(Debug, Clone)]
160pub struct RuntimeConnectionIntent<'a> {
161    pub model_hint: Option<&'a str>,
162    pub config_model: Option<&'a str>,
163    pub provider_override: Option<&'a str>,
164    pub api_key_override_present: bool,
165}
166
167#[derive(Debug, Clone, PartialEq, Eq)]
168pub struct ResolvedRuntimeConnection {
169    pub model_id: String,
170    pub provider_name: String,
171}
172
173/// Resolve the model-first runtime connection (model id + provider route/surface)
174/// shared by CLI and session startup.
175pub fn resolve_runtime_connection(
176    intent: RuntimeConnectionIntent<'_>,
177    auth_store: &AuthStore,
178    registry: &ModelRegistry,
179) -> std::result::Result<ResolvedRuntimeConnection, String> {
180    let model_hint = intent
181        .model_hint
182        .or(intent.config_model)
183        .unwrap_or("sonnet");
184
185    let meta = registry
186        .resolve_meta(model_hint, intent.provider_override)
187        .ok_or_else(|| format!("Unknown model: {model_hint}"))?;
188
189    let provider_name = intent
190        .provider_override
191        .unwrap_or(&meta.provider)
192        .to_string();
193
194    if let Some(oauth_route) = auth_preferred_oauth_route(
195        intent.provider_override,
196        intent.api_key_override_present,
197        auth_store,
198        registry,
199        &meta,
200        &provider_name,
201    ) {
202        return Ok(oauth_route);
203    }
204
205    Ok(ResolvedRuntimeConnection {
206        model_id: meta.id.clone(),
207        provider_name,
208    })
209}
210
211// ── ImpSession ──────────────────────────────────────────────────
212
213/// A fully wired agent session.
214///
215/// Manages the lifecycle of a single agent: config resolution, model
216/// selection, session persistence, and the event/command channels.
217pub struct ImpSession {
218    agent: Option<Agent>,
219    handle: AgentHandle,
220    session_mgr: SessionManager,
221    config: Config,
222    model: Model,
223    auth_store: AuthStore,
224    model_registry: ModelRegistry,
225    cwd: PathBuf,
226    /// Task handle for the currently running agent loop, if any.
227    agent_task: Option<JoinHandle<(Agent, Result<()>)>>,
228    completed_run_result: Option<Result<()>>,
229    pending_persistence_errors: VecDeque<String>,
230    /// Context prefill messages, injected once before the first prompt.
231    context_prefill: Vec<imp_llm::Message>,
232    context_prefill_injected: bool,
233}
234
235impl ImpSession {
236    /// Create a new session by resolving config, auth, model, and tools.
237    ///
238    /// This is the main factory — mirrors pi's `createAgentSession()`.
239    pub async fn create(options: SessionOptions) -> Result<Self> {
240        let cwd = options.cwd.clone();
241
242        let _ = storage::reconcile_legacy_into_global_root();
243
244        // 1. Load config (user + project, merged)
245        let mut config = Config::resolve(&Config::user_config_dir(), Some(&cwd))?;
246
247        // Apply option overrides
248        if let Some(thinking) = options.thinking {
249            config.thinking = Some(thinking);
250        }
251        if let Some(mode) = options.mode {
252            config.mode = mode;
253        }
254
255        // 2. Resolve auth
256        let auth_path = options
257            .auth_path
258            .clone()
259            .or_else(storage::existing_global_auth_path)
260            .unwrap_or_else(storage::global_auth_path);
261        let mut auth_store =
262            AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path));
263
264        if let Some(ref key) = options.api_key {
265            // We'll set this after we know the provider name
266            // Store it temporarily
267            let _ = key; // handled below
268        }
269
270        // 3. Resolve model + provider route
271        let model_registry = ModelRegistry::with_builtins();
272        let (model, _provider_name, api_key) = if let Some(model) = options.model_override.as_ref()
273        {
274            (
275                clone_model(model),
276                model.meta.provider.clone(),
277                String::new(),
278            )
279        } else {
280            let runtime_connection = resolve_runtime_connection(
281                RuntimeConnectionIntent {
282                    model_hint: options.model.as_deref(),
283                    config_model: config.model.as_deref(),
284                    provider_override: options.provider.as_deref(),
285                    api_key_override_present: options.api_key.is_some(),
286                },
287                &auth_store,
288                &model_registry,
289            )
290            .map_err(Error::Config)?;
291
292            let meta = model_registry
293                .resolve_meta(
294                    &runtime_connection.model_id,
295                    Some(&runtime_connection.provider_name),
296                )
297                .ok_or_else(|| {
298                    Error::Config(format!(
299                        "Unknown model/provider route: {} via {}",
300                        runtime_connection.model_id, runtime_connection.provider_name
301                    ))
302                })?;
303
304            let provider_name = runtime_connection.provider_name.clone();
305
306            if let Some(ref key) = options.api_key {
307                auth_store.set_runtime_key(&provider_name, key.clone());
308            }
309
310            let provider = create_provider(&provider_name)
311                .ok_or_else(|| Error::Config(format!("Unknown provider: {provider_name}")))?;
312
313            let api_key = resolve_api_key(&mut auth_store, &provider_name).await?;
314            (
315                Model {
316                    meta,
317                    provider: Arc::from(provider),
318                },
319                provider_name,
320                api_key,
321            )
322        };
323
324        // 5. Build agent
325        let mut builder =
326            AgentBuilder::new(config.clone(), cwd.clone(), clone_model(&model), api_key);
327
328        if let Some(task) = &options.task {
329            builder = builder.task(task.clone());
330        }
331        if !options.facts.is_empty() {
332            builder = builder.facts(options.facts.clone());
333        }
334        if let Some(prompt) = &options.system_prompt {
335            builder = builder.system_prompt(prompt.clone());
336        }
337        if let Some(lua_loader) = options.lua_loader {
338            builder = builder.lua_tool_loader(move |policy, tools| lua_loader(policy, tools));
339        }
340
341        let (mut agent, handle) = builder.build()?;
342
343        if options.no_tools {
344            agent.tools.retain(|_| false);
345        }
346
347        if options.no_tools {
348            agent.thinking_level = config.thinking.unwrap_or(ThinkingLevel::Off);
349            if let Some(max_turns) = options.max_turns.or(config.max_turns) {
350                agent.max_turns = max_turns;
351            }
352            if let Some(max_tokens) = options.max_tokens.or(config.max_tokens) {
353                agent.max_tokens = Some(max_tokens);
354            }
355        } else {
356            if let Some(max_turns) = options.max_turns {
357                agent.max_turns = max_turns;
358            }
359            if let Some(max_tokens) = options.max_tokens {
360                agent.max_tokens = Some(max_tokens);
361            }
362        }
363        if let Some(ui) = &options.ui {
364            agent.ui = Arc::clone(ui);
365        }
366
367        // 6. Set up session persistence
368        let session_dir = storage::global_sessions_dir();
369        let session_mgr = match options.session {
370            SessionChoice::New => SessionManager::new(&cwd, &session_dir)?,
371            SessionChoice::InMemory => SessionManager::in_memory(),
372            SessionChoice::Continue => SessionManager::continue_recent(&cwd, &session_dir)?
373                .unwrap_or_else(|| SessionManager::new(&cwd, &session_dir).unwrap()),
374            SessionChoice::Open(ref path) => SessionManager::open(path)?,
375        };
376
377        Ok(Self {
378            agent: Some(agent),
379            handle,
380            session_mgr,
381            config,
382            model,
383            auth_store,
384            model_registry,
385            cwd,
386            context_prefill: options.context_prefill,
387            context_prefill_injected: false,
388            agent_task: None,
389            completed_run_result: None,
390            pending_persistence_errors: VecDeque::new(),
391        })
392    }
393
394    // ── Prompting ───────────────────────────────────────────────
395
396    /// Send a prompt and run the agent loop.
397    ///
398    /// The agent runs on a background task. Use [`recv_event`] to consume
399    /// events, and [`steer`] / [`follow_up`] / [`cancel`] to control it.
400    ///
401    /// Returns an error if the agent is already running.
402    pub async fn prompt(&mut self, text: &str) -> Result<()> {
403        if self.agent_task.is_some() {
404            return Err(Error::Config(
405                "Agent is already running. Cancel or wait for it to finish.".into(),
406            ));
407        }
408
409        self.completed_run_result = None;
410        self.pending_persistence_errors.clear();
411
412        // Persist user message to session
413        let msg_id = uuid::Uuid::new_v4().to_string();
414        let _ = self.session_mgr.append(SessionEntry::Message {
415            id: msg_id,
416            parent_id: None,
417            message: imp_llm::Message::user(text),
418        });
419
420        // Load prior messages from session history into agent
421        let mut agent = self
422            .agent
423            .take()
424            .ok_or_else(|| Error::Config("Agent already consumed".into()))?;
425
426        let mut history: Vec<imp_llm::Message> = self.session_mgr.get_active_messages();
427
428        // The prompt was already appended to session history so resume/tree state
429        // is correct, but Agent::run() will push the active prompt itself. Remove
430        // the just-appended trailing user message to avoid duplicating it in the
431        // model context for this run.
432        if matches!(
433            history.last(),
434            Some(imp_llm::Message::User(user))
435                if matches!(
436                    user.content.as_slice(),
437                    [imp_llm::ContentBlock::Text { text: last_text }] if last_text == text
438                )
439        ) {
440            history.pop();
441        }
442
443        // Inject context prefill (once, before the first prompt). These messages
444        // form the cached prefix: file contents the agent needs, assembled at
445        // dispatch time by context_prefill::assemble_context(). Subsequent turns
446        // get cache_read on this prefix instead of re-reading files.
447        if !self.context_prefill_injected && !self.context_prefill.is_empty() {
448            for msg in &self.context_prefill {
449                history.push(msg.clone());
450            }
451            // Assistant acknowledgment to maintain user/assistant alternation
452            history.push(imp_llm::Message::Assistant(imp_llm::AssistantMessage {
453                content: vec![imp_llm::ContentBlock::Text {
454                    text: "Context loaded. Ready to work.".into(),
455                }],
456                usage: None,
457                stop_reason: imp_llm::StopReason::EndTurn,
458                timestamp: imp_llm::now(),
459            }));
460            self.context_prefill_injected = true;
461        }
462
463        // Replace agent messages with session history. Agent::run() will append
464        // the active prompt as the next user message.
465        agent.messages = history;
466
467        let prompt = text.to_string();
468        let task = tokio::spawn(async move {
469            let result = agent.run(prompt).await;
470            (agent, result)
471        });
472        self.agent_task = Some(task);
473
474        Ok(())
475    }
476
477    /// Send a prompt and block until the agent finishes.
478    ///
479    /// Events are still emitted via [`recv_event`], but this method
480    /// does not return until the agent loop completes.
481    pub async fn prompt_and_wait(&mut self, text: &str) -> Result<()> {
482        self.prompt(text).await?;
483        self.wait().await
484    }
485
486    /// Wait for the running agent to finish.
487    pub async fn wait(&mut self) -> Result<()> {
488        if let Some(task) = self.agent_task.take() {
489            let (agent, result) = task
490                .await
491                .map_err(|e| Error::Config(format!("Agent task panicked: {e}")))?;
492            self.agent = Some(agent);
493            self.completed_run_result = Some(result);
494            self.drain_pending_events_for_persistence();
495        }
496
497        if let Some(result) = self.completed_run_result.take() {
498            return result;
499        }
500
501        Ok(())
502    }
503
504    /// Interrupt the agent: delivered after the current tool finishes,
505    /// remaining queued tools are skipped.
506    pub async fn steer(&self, text: &str) -> Result<()> {
507        self.handle
508            .command_tx
509            .send(AgentCommand::Steer(text.into()))
510            .await
511            .map_err(|_| Error::Config("Agent not running".into()))
512    }
513
514    /// Follow-up: delivered only after the agent finishes all current work.
515    pub async fn follow_up(&self, text: &str) -> Result<()> {
516        self.handle
517            .command_tx
518            .send(AgentCommand::FollowUp(text.into()))
519            .await
520            .map_err(|_| Error::Config("Agent not running".into()))
521    }
522
523    /// Cancel the current agent run.
524    pub async fn cancel(&self) -> Result<()> {
525        self.handle
526            .command_tx
527            .send(AgentCommand::Cancel)
528            .await
529            .map_err(|_| Error::Config("Agent not running".into()))
530    }
531
532    /// Force-abort the current agent task when graceful cancellation does not finish.
533    pub fn abort(&mut self) {
534        if let Some(task) = self.agent_task.take() {
535            task.abort();
536            self.completed_run_result = Some(Err(Error::Cancelled));
537        }
538    }
539
540    // ── Events ──────────────────────────────────────────────────
541
542    /// Receive the next event from the agent.
543    ///
544    /// Returns `None` when the agent has finished and all events have
545    /// been consumed.
546    pub async fn recv_event(&mut self) -> Option<AgentEvent> {
547        if let Some(error) = self.take_persistence_error() {
548            return Some(AgentEvent::Error { error });
549        }
550
551        if self.agent_task.is_none() && self.completed_run_result.is_some() {
552            return None;
553        }
554
555        let event = self.handle.event_rx.recv().await?;
556        let events = self.persist_event_entries(&event);
557
558        if matches!(event, AgentEvent::AgentEnd { .. }) {
559            if let Some(task) = self.agent_task.take() {
560                match task.await {
561                    Ok((agent, result)) => {
562                        self.agent = Some(agent);
563                        self.completed_run_result = Some(result);
564                    }
565                    Err(join_error) => {
566                        self.push_persistence_error(
567                            events,
568                            format!("agent task panicked: {join_error}"),
569                        );
570                    }
571                }
572            }
573        }
574
575        Some(event)
576    }
577
578    /// Get mutable access to the raw event receiver.
579    ///
580    /// Use this when you need `select!` or other channel combinators.
581    pub fn event_rx(&mut self) -> &mut mpsc::Receiver<AgentEvent> {
582        &mut self.handle.event_rx
583    }
584
585    // ── Model ───────────────────────────────────────────────────
586
587    /// Switch the model for subsequent prompts.
588    ///
589    /// The change takes effect on the next `prompt()` call.
590    pub async fn set_model(&mut self, hint: &str) -> Result<()> {
591        let meta = self
592            .model_registry
593            .resolve_meta(hint, None)
594            .ok_or_else(|| Error::Config(format!("Unknown model: {hint}")))?;
595
596        let provider_name = meta.provider.clone();
597        let provider = create_provider(&provider_name)
598            .ok_or_else(|| Error::Config(format!("Unknown provider: {provider_name}")))?;
599        let api_key = resolve_api_key(&mut self.auth_store, &provider_name).await?;
600
601        self.model = Model {
602            meta,
603            provider: Arc::from(provider),
604        };
605
606        // If we still have the agent (not currently running), update it
607        if let Some(ref mut agent) = self.agent {
608            agent.model = clone_model(&self.model);
609            agent.api_key = api_key;
610        }
611
612        Ok(())
613    }
614
615    /// Set the thinking level for subsequent prompts.
616    pub fn set_thinking(&mut self, level: ThinkingLevel) {
617        self.config.thinking = Some(level);
618        if let Some(ref mut agent) = self.agent {
619            agent.thinking_level = level;
620        }
621    }
622
623    // ── Accessors ───────────────────────────────────────────────
624
625    /// The current model.
626    pub fn model(&self) -> &Model {
627        &self.model
628    }
629
630    /// The resolved config.
631    pub fn config(&self) -> &Config {
632        &self.config
633    }
634
635    /// The session manager (tree, entries, persistence).
636    pub fn session_manager(&self) -> &SessionManager {
637        &self.session_mgr
638    }
639
640    /// Mutable access to the session manager.
641    pub fn session_manager_mut(&mut self) -> &mut SessionManager {
642        &mut self.session_mgr
643    }
644
645    /// The working directory.
646    pub fn cwd(&self) -> &PathBuf {
647        &self.cwd
648    }
649
650    /// The auth store (for checking credentials, OAuth status, etc).
651    pub fn auth_store(&self) -> &AuthStore {
652        &self.auth_store
653    }
654
655    /// Mutable access to the auth store.
656    pub fn auth_store_mut(&mut self) -> &mut AuthStore {
657        &mut self.auth_store
658    }
659
660    /// The model registry.
661    pub fn model_registry(&self) -> &ModelRegistry {
662        &self.model_registry
663    }
664
665    /// Whether the agent is currently running a prompt.
666    pub fn is_running(&self) -> bool {
667        self.agent_task.is_some()
668    }
669
670    /// Get the raw command sender for advanced use cases.
671    pub fn command_tx(&self) -> &mpsc::Sender<AgentCommand> {
672        &self.handle.command_tx
673    }
674
675    fn persist_event_entries(&mut self, event: &AgentEvent) -> Vec<&'static str> {
676        let persisted = match self
677            .session_mgr
678            .persist_agent_event_entries(&self.model, event)
679        {
680            Ok(persisted) => persisted,
681            Err(error) => {
682                self.push_persistence_error(
683                    Vec::new(),
684                    format!("failed to persist agent event entries: {error}"),
685                );
686                Vec::new()
687            }
688        };
689
690        if let Some(agent) = self.agent.as_ref() {
691            if let Err(error) =
692                persist_checkpoint_records(&mut self.session_mgr, &agent.checkpoint_state)
693            {
694                self.push_persistence_error(
695                    persisted.clone(),
696                    format!("failed to persist checkpoint records: {error}"),
697                );
698            }
699        }
700
701        persisted
702    }
703
704    fn drain_pending_events_for_persistence(&mut self) {
705        while let Ok(event) = self.handle.event_rx.try_recv() {
706            self.persist_event_entries(&event);
707        }
708    }
709
710    fn push_persistence_error(&mut self, persisted: Vec<&'static str>, error: String) {
711        let prefix = if persisted.is_empty() {
712            "session persistence warning".to_string()
713        } else {
714            format!("session persistence warning after {}", persisted.join(", "))
715        };
716        self.pending_persistence_errors
717            .push_back(format!("{prefix}: {error}"));
718    }
719
720    fn take_persistence_error(&mut self) -> Option<String> {
721        self.pending_persistence_errors.pop_front()
722    }
723}
724// ── Helpers ─────────────────────────────────────────────────────
725
726/// Resolve the API key for a provider, handling OAuth refresh.
727async fn resolve_api_key(auth_store: &mut AuthStore, provider: &str) -> Result<ApiKey> {
728    let result = match provider {
729        "openai-codex" => auth_store.resolve_chatgpt_oauth().await,
730        "anthropic" | "kimi-code" => auth_store.resolve_with_refresh(provider).await,
731        _ => auth_store.resolve(provider),
732    };
733    result.map_err(|e| Error::Config(format!("Auth failed for {provider}: {e}")))
734}
735
736fn auth_preferred_oauth_route(
737    provider_override: Option<&str>,
738    api_key_override_present: bool,
739    auth_store: &AuthStore,
740    registry: &ModelRegistry,
741    meta: &ModelMeta,
742    provider_name: &str,
743) -> Option<ResolvedRuntimeConnection> {
744    if should_use_openai_chatgpt_route(
745        provider_override,
746        api_key_override_present,
747        auth_store,
748        registry,
749        &meta.id,
750        provider_name,
751    ) {
752        return Some(ResolvedRuntimeConnection {
753            model_id: meta.id.clone(),
754            provider_name: "openai-codex".to_string(),
755        });
756    }
757
758    if should_use_kimi_code_route(
759        provider_override,
760        api_key_override_present,
761        auth_store,
762        registry,
763        meta,
764        provider_name,
765    ) {
766        return Some(ResolvedRuntimeConnection {
767            model_id: "kimi2.6".to_string(),
768            provider_name: "kimi-code".to_string(),
769        });
770    }
771
772    None
773}
774fn should_use_openai_chatgpt_route(
775    provider_override: Option<&str>,
776    api_key_override_present: bool,
777    auth_store: &AuthStore,
778    registry: &ModelRegistry,
779    model_id: &str,
780    provider_name: &str,
781) -> bool {
782    let provider_allows_fallback = match provider_override {
783        None => true,
784        Some("openai") => true,
785        Some(_) => false,
786    };
787
788    provider_allows_fallback
789        && !api_key_override_present
790        && provider_name == "openai"
791        && auth_store.resolve_api_key_only("openai").is_err()
792        && (auth_store.get_oauth("openai").is_some()
793            || auth_store.get_oauth("openai-codex").is_some())
794        && codex_supports_model(registry, model_id)
795}
796
797fn should_use_kimi_code_route(
798    provider_override: Option<&str>,
799    api_key_override_present: bool,
800    auth_store: &AuthStore,
801    registry: &ModelRegistry,
802    meta: &ModelMeta,
803    provider_name: &str,
804) -> bool {
805    let provider_allows_fallback = match provider_override {
806        None => true,
807        Some("moonshot") => true,
808        Some("kimi-code") => true,
809        Some(_) => false,
810    };
811
812    provider_allows_fallback
813        && !api_key_override_present
814        && provider_name == "moonshot"
815        && auth_store.resolve_api_key_only("moonshot").is_err()
816        && auth_store.get_oauth("kimi-code").is_some()
817        && registry.find("kimi2.6").is_some()
818        && is_kimi_moonshot_model(&meta.id)
819}
820
821fn is_kimi_moonshot_model(model_id: &str) -> bool {
822    matches!(
823        model_id,
824        "kimi-k2.6"
825            | "kimi-k2.5"
826            | "kimi-k2-0905-preview"
827            | "kimi-k2-turbo-preview"
828            | "kimi-k2-thinking"
829            | "kimi-k2-thinking-turbo"
830    )
831}
832fn clone_model(model: &Model) -> Model {
833    Model {
834        meta: model.meta.clone(),
835        provider: Arc::clone(&model.provider),
836    }
837}
838
839fn persist_checkpoint_records(
840    session_mgr: &mut SessionManager,
841    checkpoint_state: &crate::tools::CheckpointState,
842) -> Result<Vec<String>> {
843    let existing: std::collections::HashSet<String> = session_mgr
844        .checkpoint_records()
845        .into_iter()
846        .map(|record| record.checkpoint_id)
847        .collect();
848
849    let mut persisted = Vec::new();
850    for record in checkpoint_state.checkpoints() {
851        if existing.contains(&record.id) {
852            continue;
853        }
854        session_mgr.append_checkpoint_record(SessionCheckpointRecord {
855            version: crate::session::CHECKPOINT_RECORD_VERSION,
856            checkpoint_id: record.id.clone(),
857            created_at: record.created_at,
858            label: record.label.clone(),
859            files: record
860                .files
861                .iter()
862                .map(|path| path.to_string_lossy().to_string())
863                .collect(),
864        })?;
865        persisted.push(record.id);
866    }
867
868    Ok(persisted)
869}
870
871fn codex_supports_model(_registry: &ModelRegistry, model_id: &str) -> bool {
872    imp_llm::model::builtin_openai_codex_models()
873        .iter()
874        .any(|m| m.id == model_id)
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use imp_llm::{
881        auth::{ApiKey, AuthStore},
882        model::{Capabilities, ModelPricing},
883        provider::{Context, Provider, RequestOptions},
884        AssistantMessage, ContentBlock, ModelMeta, StopReason, StreamEvent, Usage,
885    };
886    use serde_json::json;
887    use tempfile::TempDir;
888
889    struct NoopProvider {
890        models: Vec<ModelMeta>,
891    }
892
893    struct SingleResponseProvider {
894        models: Vec<ModelMeta>,
895        events: std::sync::Mutex<Option<Vec<imp_llm::Result<StreamEvent>>>>,
896    }
897
898    #[async_trait::async_trait]
899    impl Provider for NoopProvider {
900        fn stream(
901            &self,
902            _model: &Model,
903            _context: Context,
904            _options: RequestOptions,
905            _api_key: &str,
906        ) -> std::pin::Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
907        {
908            Box::pin(futures::stream::empty())
909        }
910
911        async fn resolve_auth(&self, _auth: &AuthStore) -> imp_llm::Result<ApiKey> {
912            Ok(String::new())
913        }
914
915        fn id(&self) -> &str {
916            "noop"
917        }
918
919        fn models(&self) -> &[ModelMeta] {
920            &self.models
921        }
922    }
923
924    #[async_trait::async_trait]
925    impl Provider for SingleResponseProvider {
926        fn stream(
927            &self,
928            _model: &Model,
929            _context: Context,
930            _options: RequestOptions,
931            _api_key: &str,
932        ) -> std::pin::Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
933        {
934            let events = self
935                .events
936                .lock()
937                .expect("single response provider lock")
938                .take()
939                .unwrap_or_default();
940            Box::pin(futures::stream::iter(events))
941        }
942
943        async fn resolve_auth(&self, _auth: &AuthStore) -> imp_llm::Result<ApiKey> {
944            Ok(String::new())
945        }
946
947        fn id(&self) -> &str {
948            "single-response"
949        }
950
951        fn models(&self) -> &[ModelMeta] {
952            &self.models
953        }
954    }
955
956    fn test_model() -> Model {
957        let meta = ModelMeta {
958            id: "test-model".into(),
959            provider: "test-provider".into(),
960            name: "Test Model".into(),
961            context_window: 8192,
962            max_output_tokens: 2048,
963            pricing: ModelPricing {
964                input_per_mtok: 2.0,
965                output_per_mtok: 4.0,
966                cache_read_per_mtok: 0.5,
967                cache_write_per_mtok: 1.0,
968            },
969            capabilities: Capabilities {
970                reasoning: false,
971                images: false,
972                tool_use: true,
973            },
974        };
975        Model {
976            meta: meta.clone(),
977            provider: Arc::new(NoopProvider { models: vec![meta] }),
978        }
979    }
980
981    fn test_model_with_events(events: Vec<imp_llm::Result<StreamEvent>>) -> Model {
982        let meta = ModelMeta {
983            id: "test-model".into(),
984            provider: "test-provider".into(),
985            name: "Test Model".into(),
986            context_window: 8192,
987            max_output_tokens: 2048,
988            pricing: ModelPricing {
989                input_per_mtok: 2.0,
990                output_per_mtok: 4.0,
991                cache_read_per_mtok: 0.5,
992                cache_write_per_mtok: 1.0,
993            },
994            capabilities: Capabilities {
995                reasoning: false,
996                images: false,
997                tool_use: true,
998            },
999        };
1000        Model {
1001            meta: meta.clone(),
1002            provider: Arc::new(SingleResponseProvider {
1003                models: vec![meta],
1004                events: std::sync::Mutex::new(Some(events)),
1005            }),
1006        }
1007    }
1008
1009    fn test_assistant_message(timestamp: u64, usage: Option<Usage>) -> AssistantMessage {
1010        AssistantMessage {
1011            content: vec![ContentBlock::Text {
1012                text: "done".into(),
1013            }],
1014            usage,
1015            stop_reason: StopReason::EndTurn,
1016            timestamp,
1017        }
1018    }
1019
1020    #[test]
1021    fn session_options_default_is_sensible() {
1022        let opts = SessionOptions::default();
1023        assert!(opts.model.is_none());
1024        assert!(opts.max_tokens.is_none());
1025        assert!(!opts.no_tools);
1026        assert!(matches!(opts.session, SessionChoice::New));
1027    }
1028
1029    #[test]
1030    fn resolve_runtime_connection_prefers_openai_chatgpt_route_when_oauth_exists() {
1031        let dir = tempfile::tempdir().unwrap();
1032        let auth_path = dir.path().join("auth.json");
1033        let mut auth_store = AuthStore::new(auth_path);
1034        auth_store
1035            .store(
1036                "openai",
1037                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1038                    access_token: "oauth-token".into(),
1039                    refresh_token: "refresh-token".into(),
1040                    expires_at: imp_llm::now() + 3600,
1041                }),
1042            )
1043            .unwrap();
1044        let registry = ModelRegistry::with_builtins();
1045
1046        let resolved = resolve_runtime_connection(
1047            RuntimeConnectionIntent {
1048                model_hint: Some("gpt-5.4"),
1049                config_model: None,
1050                provider_override: Some("openai"),
1051                api_key_override_present: false,
1052            },
1053            &auth_store,
1054            &registry,
1055        )
1056        .unwrap();
1057
1058        assert_eq!(resolved.model_id, "gpt-5.4");
1059        assert_eq!(resolved.provider_name, "openai-codex");
1060    }
1061
1062    #[test]
1063    fn resolve_runtime_connection_respects_forced_non_openai_provider() {
1064        let auth_path = PathBuf::from("/tmp/nonexistent-auth.json");
1065        let auth_store = AuthStore::new(auth_path);
1066        let registry = ModelRegistry::with_builtins();
1067
1068        let resolved = resolve_runtime_connection(
1069            RuntimeConnectionIntent {
1070                model_hint: Some("gpt-5.4"),
1071                config_model: None,
1072                provider_override: Some("anthropic"),
1073                api_key_override_present: false,
1074            },
1075            &auth_store,
1076            &registry,
1077        )
1078        .unwrap();
1079
1080        assert_eq!(resolved.provider_name, "anthropic");
1081    }
1082
1083    #[test]
1084    fn resolve_runtime_connection_does_not_switch_when_model_is_not_codex_supported() {
1085        let dir = tempfile::tempdir().unwrap();
1086        let auth_path = dir.path().join("auth.json");
1087        let mut auth_store = AuthStore::new(auth_path);
1088        auth_store
1089            .store(
1090                "openai",
1091                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1092                    access_token: "oauth-token".into(),
1093                    refresh_token: "refresh-token".into(),
1094                    expires_at: imp_llm::now() + 3600,
1095                }),
1096            )
1097            .unwrap();
1098        let registry = ModelRegistry::with_builtins();
1099
1100        let resolved = resolve_runtime_connection(
1101            RuntimeConnectionIntent {
1102                model_hint: Some("gpt-4o"),
1103                config_model: None,
1104                provider_override: Some("openai"),
1105                api_key_override_present: false,
1106            },
1107            &auth_store,
1108            &registry,
1109        )
1110        .unwrap();
1111
1112        assert_eq!(resolved.model_id, "gpt-4o");
1113        assert_eq!(resolved.provider_name, "openai");
1114    }
1115
1116    #[test]
1117    fn resolve_runtime_connection_does_not_switch_when_api_key_override_is_present() {
1118        let dir = tempfile::tempdir().unwrap();
1119        let auth_path = dir.path().join("auth.json");
1120        let mut auth_store = AuthStore::new(auth_path);
1121        auth_store
1122            .store(
1123                "openai",
1124                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1125                    access_token: "oauth-token".into(),
1126                    refresh_token: "refresh-token".into(),
1127                    expires_at: imp_llm::now() + 3600,
1128                }),
1129            )
1130            .unwrap();
1131        let registry = ModelRegistry::with_builtins();
1132
1133        let resolved = resolve_runtime_connection(
1134            RuntimeConnectionIntent {
1135                model_hint: Some("gpt-5.4"),
1136                config_model: None,
1137                provider_override: None,
1138                api_key_override_present: true,
1139            },
1140            &auth_store,
1141            &registry,
1142        )
1143        .unwrap();
1144
1145        assert_eq!(resolved.model_id, "gpt-5.4");
1146        assert_eq!(resolved.provider_name, "openai");
1147    }
1148
1149    #[test]
1150    fn resolve_runtime_connection_prefers_kimi_code_route_when_oauth_exists_without_api_key() {
1151        let dir = tempfile::tempdir().unwrap();
1152        let auth_path = dir.path().join("auth.json");
1153        let mut auth_store = AuthStore::new(auth_path);
1154        auth_store
1155            .store(
1156                "kimi-code",
1157                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1158                    access_token: "oauth-token".into(),
1159                    refresh_token: "refresh-token".into(),
1160                    expires_at: imp_llm::now() + 3600,
1161                }),
1162            )
1163            .unwrap();
1164        let registry = ModelRegistry::with_builtins();
1165
1166        let resolved = resolve_runtime_connection(
1167            RuntimeConnectionIntent {
1168                model_hint: Some("kimi"),
1169                config_model: None,
1170                provider_override: None,
1171                api_key_override_present: false,
1172            },
1173            &auth_store,
1174            &registry,
1175        )
1176        .unwrap();
1177
1178        assert_eq!(resolved.model_id, "kimi2.6");
1179        assert_eq!(resolved.provider_name, "kimi-code");
1180    }
1181
1182    #[test]
1183    fn resolve_runtime_connection_keeps_moonshot_kimi_when_api_key_exists() {
1184        let dir = tempfile::tempdir().unwrap();
1185        let auth_path = dir.path().join("auth.json");
1186        let mut auth_store = AuthStore::new(auth_path);
1187        auth_store
1188            .store(
1189                "moonshot",
1190                imp_llm::auth::StoredCredential::ApiKey {
1191                    key: "sk-moonshot".into(),
1192                },
1193            )
1194            .unwrap();
1195        auth_store
1196            .store(
1197                "kimi-code",
1198                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1199                    access_token: "oauth-token".into(),
1200                    refresh_token: "refresh-token".into(),
1201                    expires_at: imp_llm::now() + 3600,
1202                }),
1203            )
1204            .unwrap();
1205        let registry = ModelRegistry::with_builtins();
1206
1207        let resolved = resolve_runtime_connection(
1208            RuntimeConnectionIntent {
1209                model_hint: Some("kimi"),
1210                config_model: None,
1211                provider_override: None,
1212                api_key_override_present: false,
1213            },
1214            &auth_store,
1215            &registry,
1216        )
1217        .unwrap();
1218
1219        assert_eq!(resolved.model_id, "kimi-k2.6");
1220        assert_eq!(resolved.provider_name, "moonshot");
1221    }
1222
1223    #[tokio::test]
1224    async fn no_tools_session_surfaces_auth_failure_instead_of_empty_api_key() {
1225        let tmp = TempDir::new().unwrap();
1226        let cwd = tmp.path().join("project");
1227        let auth_path = tmp.path().join("auth.json");
1228        std::fs::create_dir_all(&cwd).unwrap();
1229
1230        let result = ImpSession::create(SessionOptions {
1231            cwd: cwd.clone(),
1232            auth_path: Some(auth_path),
1233            provider: Some("openai-codex".into()),
1234            model: Some("gpt-5.4".into()),
1235            no_tools: true,
1236            session: SessionChoice::InMemory,
1237            ..Default::default()
1238        })
1239        .await;
1240
1241        match result {
1242            Ok(_) => panic!("missing auth should fail clearly"),
1243            Err(Error::Config(message)) => {
1244                assert!(message.contains("Auth failed for openai-codex"));
1245                assert!(!message.contains("Incorrect API key provided: ''"));
1246            }
1247            Err(other) => panic!("expected config error, got {other:?}"),
1248        }
1249    }
1250
1251    #[tokio::test]
1252    async fn no_tools_session_builds_assembled_system_prompt_when_task_present() {
1253        let tmp = TempDir::new().unwrap();
1254        let cwd = tmp.path().join("project");
1255        let auth_path = tmp.path().join("auth.json");
1256        std::fs::create_dir_all(&cwd).unwrap();
1257
1258        let mut auth_store = AuthStore::new(auth_path.clone());
1259        auth_store
1260            .store(
1261                "openai",
1262                imp_llm::auth::StoredCredential::OAuth(imp_llm::auth::OAuthCredential {
1263                    access_token: "oauth-token".into(),
1264                    refresh_token: "refresh-token".into(),
1265                    expires_at: imp_llm::now() + 3600,
1266                }),
1267            )
1268            .unwrap();
1269
1270        let session = ImpSession::create(SessionOptions {
1271            cwd: cwd.clone(),
1272            auth_path: Some(auth_path),
1273            provider: Some("openai".into()),
1274            model: Some("gpt-5.4".into()),
1275            no_tools: true,
1276            session: SessionChoice::InMemory,
1277            task: Some(TaskContext {
1278                title: "Test task".into(),
1279                description: "Verify headless prompt assembly".into(),
1280                acceptance: Some("Prompt includes task guidance".into()),
1281                verify: None,
1282                notes: None,
1283                attempts: vec![],
1284                dependencies: vec![],
1285                decisions: vec![],
1286                context_paths: vec![],
1287                constraints: vec![],
1288            }),
1289            ..Default::default()
1290        })
1291        .await
1292        .expect("no-tools session should build with saved auth");
1293
1294        let prompt = session
1295            .agent
1296            .as_ref()
1297            .expect("agent present")
1298            .system_prompt
1299            .clone();
1300        assert!(!prompt.trim().is_empty());
1301        assert!(prompt.contains("Test task"));
1302        assert!(prompt.contains("Verify headless prompt assembly"));
1303    }
1304
1305    #[tokio::test]
1306    async fn recv_event_returns_none_after_agent_end_even_if_sender_is_still_owned() {
1307        let tmp = TempDir::new().unwrap();
1308        let cwd = tmp.path().join("project");
1309        let (agent, handle) = Agent::new(
1310            clone_model(&test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1311                message: AssistantMessage {
1312                    content: vec![ContentBlock::Text {
1313                        text: "done".into(),
1314                    }],
1315                    usage: None,
1316                    stop_reason: StopReason::EndTurn,
1317                    timestamp: 1,
1318                },
1319            })])),
1320            cwd.clone(),
1321        );
1322
1323        let mut session = ImpSession {
1324            agent: Some(agent),
1325            handle,
1326            session_mgr: SessionManager::in_memory(),
1327            config: Config::default(),
1328            model: test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1329                message: AssistantMessage {
1330                    content: vec![ContentBlock::Text {
1331                        text: "done".into(),
1332                    }],
1333                    usage: None,
1334                    stop_reason: StopReason::EndTurn,
1335                    timestamp: 1,
1336                },
1337            })]),
1338            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1339            model_registry: ModelRegistry::with_builtins(),
1340            cwd,
1341            agent_task: None,
1342            completed_run_result: None,
1343            pending_persistence_errors: VecDeque::new(),
1344            context_prefill: Vec::new(),
1345            context_prefill_injected: false,
1346        };
1347
1348        session.prompt("latest").await.unwrap();
1349        while let Some(event) = session.recv_event().await {
1350            if matches!(event, AgentEvent::AgentEnd { .. }) {
1351                break;
1352            }
1353        }
1354
1355        let next = tokio::time::timeout(std::time::Duration::from_secs(1), session.recv_event())
1356            .await
1357            .expect("recv_event should not hang after agent end");
1358        assert!(next.is_none());
1359
1360        session.wait().await.unwrap();
1361    }
1362
1363    #[tokio::test]
1364    async fn abort_marks_wait_as_cancelled() {
1365        let tmp = TempDir::new().unwrap();
1366        let cwd = tmp.path().join("project");
1367        let (agent, handle) = Agent::new(
1368            test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1369                message: AssistantMessage {
1370                    content: vec![ContentBlock::Text {
1371                        text: "done".into(),
1372                    }],
1373                    usage: None,
1374                    stop_reason: StopReason::EndTurn,
1375                    timestamp: 1,
1376                },
1377            })]),
1378            cwd.clone(),
1379        );
1380        let mut session = ImpSession {
1381            agent: Some(agent),
1382            handle,
1383            session_mgr: SessionManager::in_memory(),
1384            config: Config::default(),
1385            model: test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1386                message: AssistantMessage {
1387                    content: vec![ContentBlock::Text {
1388                        text: "done".into(),
1389                    }],
1390                    usage: None,
1391                    stop_reason: StopReason::EndTurn,
1392                    timestamp: 1,
1393                },
1394            })]),
1395            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1396            model_registry: ModelRegistry::with_builtins(),
1397            cwd,
1398            agent_task: Some(tokio::spawn(async move {
1399                tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1400                (
1401                    Agent::new(
1402                        test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1403                            message: AssistantMessage {
1404                                content: vec![ContentBlock::Text {
1405                                    text: "done".into(),
1406                                }],
1407                                usage: None,
1408                                stop_reason: StopReason::EndTurn,
1409                                timestamp: 1,
1410                            },
1411                        })]),
1412                        PathBuf::from("/tmp"),
1413                    )
1414                    .0,
1415                    Ok(()),
1416                )
1417            })),
1418            completed_run_result: None,
1419            pending_persistence_errors: VecDeque::new(),
1420            context_prefill: Vec::new(),
1421            context_prefill_injected: false,
1422        };
1423
1424        session.abort();
1425        let result = session.wait().await;
1426        assert!(matches!(result, Err(Error::Cancelled)));
1427    }
1428
1429    #[tokio::test]
1430    async fn prompt_uses_session_history_without_duplicate_active_prompt() {
1431        let tmp = TempDir::new().unwrap();
1432        let cwd = tmp.path().join("project");
1433        let session_dir = tmp.path().join("sessions");
1434        let model = test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1435            message: AssistantMessage {
1436                content: vec![ContentBlock::Text {
1437                    text: "done".into(),
1438                }],
1439                usage: None,
1440                stop_reason: StopReason::EndTurn,
1441                timestamp: 42,
1442            },
1443        })]);
1444        let mut session_mgr = SessionManager::new(&cwd, &session_dir).unwrap();
1445        session_mgr
1446            .append(SessionEntry::Message {
1447                id: "existing-user".into(),
1448                parent_id: None,
1449                message: imp_llm::Message::user("earlier"),
1450            })
1451            .unwrap();
1452
1453        let (agent, handle) = Agent::new(clone_model(&model), cwd.clone());
1454        let mut session = ImpSession {
1455            agent: Some(agent),
1456            handle,
1457            session_mgr,
1458            config: Config::default(),
1459            model,
1460            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1461            model_registry: ModelRegistry::with_builtins(),
1462            cwd,
1463            agent_task: None,
1464            completed_run_result: None,
1465            pending_persistence_errors: VecDeque::new(),
1466            context_prefill: Vec::new(),
1467            context_prefill_injected: false,
1468        };
1469
1470        session.prompt("latest").await.unwrap();
1471        while let Some(event) = session.recv_event().await {
1472            if matches!(event, AgentEvent::AgentEnd { .. }) {
1473                break;
1474            }
1475        }
1476        session.wait().await.unwrap();
1477
1478        let messages: Vec<_> = session.session_mgr.get_active_messages();
1479        assert_eq!(messages.len(), 3);
1480        match &messages[0] {
1481            imp_llm::Message::User(user) => match user.content.as_slice() {
1482                [ContentBlock::Text { text }] => assert_eq!(text, "earlier"),
1483                other => panic!("unexpected user content: {other:?}"),
1484            },
1485            other => panic!("unexpected message: {other:?}"),
1486        }
1487        match &messages[1] {
1488            imp_llm::Message::User(user) => match user.content.as_slice() {
1489                [ContentBlock::Text { text }] => assert_eq!(text, "latest"),
1490                other => panic!("unexpected user content: {other:?}"),
1491            },
1492            other => panic!("unexpected message: {other:?}"),
1493        }
1494        match &messages[2] {
1495            imp_llm::Message::Assistant(assistant) => match assistant.content.as_slice() {
1496                [ContentBlock::Text { text }] => assert_eq!(text, "done"),
1497                other => panic!("unexpected assistant content: {other:?}"),
1498            },
1499            other => panic!("unexpected message: {other:?}"),
1500        }
1501    }
1502
1503    #[tokio::test]
1504    async fn prompt_uses_compacted_active_history_for_follow_up_turns() {
1505        let tmp = TempDir::new().unwrap();
1506        let cwd = tmp.path().join("project");
1507        let session_dir = tmp.path().join("sessions");
1508        let model = test_model_with_events(vec![Ok(StreamEvent::MessageEnd {
1509            message: AssistantMessage {
1510                content: vec![ContentBlock::Text {
1511                    text: "follow-up done".into(),
1512                }],
1513                usage: None,
1514                stop_reason: StopReason::EndTurn,
1515                timestamp: 99,
1516            },
1517        })]);
1518        let mut session_mgr = SessionManager::new(&cwd, &session_dir).unwrap();
1519        session_mgr
1520            .append(SessionEntry::Message {
1521                id: "u1".into(),
1522                parent_id: None,
1523                message: imp_llm::Message::user("older request"),
1524            })
1525            .unwrap();
1526        session_mgr
1527            .append(SessionEntry::Message {
1528                id: "a1".into(),
1529                parent_id: None,
1530                message: imp_llm::Message::Assistant(AssistantMessage {
1531                    content: vec![ContentBlock::Text {
1532                        text: "older answer".into(),
1533                    }],
1534                    usage: None,
1535                    stop_reason: StopReason::EndTurn,
1536                    timestamp: 1,
1537                }),
1538            })
1539            .unwrap();
1540        session_mgr
1541            .append(SessionEntry::Message {
1542                id: "u2".into(),
1543                parent_id: None,
1544                message: imp_llm::Message::user("recent request"),
1545            })
1546            .unwrap();
1547        session_mgr
1548            .append(SessionEntry::Compaction {
1549                id: "c1".into(),
1550                parent_id: None,
1551                summary: "[CONTEXT COMPACTION] compacted summary".into(),
1552                first_kept_id: "u2".into(),
1553                tokens_before: 100,
1554                tokens_after: 40,
1555            })
1556            .unwrap();
1557
1558        let (agent, handle) = Agent::new(clone_model(&model), cwd.clone());
1559        let mut session = ImpSession {
1560            agent: Some(agent),
1561            handle,
1562            session_mgr,
1563            config: Config::default(),
1564            model,
1565            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1566            model_registry: ModelRegistry::with_builtins(),
1567            cwd,
1568            agent_task: None,
1569            completed_run_result: None,
1570            pending_persistence_errors: VecDeque::new(),
1571            context_prefill: Vec::new(),
1572            context_prefill_injected: false,
1573        };
1574
1575        session.prompt("new follow-up").await.unwrap();
1576        while let Some(event) = session.recv_event().await {
1577            if matches!(event, AgentEvent::AgentEnd { .. }) {
1578                break;
1579            }
1580        }
1581        session.wait().await.unwrap();
1582
1583        let messages = session.session_mgr.get_active_messages();
1584        assert_eq!(messages.len(), 4);
1585        match &messages[0] {
1586            imp_llm::Message::User(user) => match user.content.as_slice() {
1587                [ContentBlock::Text { text }] => assert!(text.contains("CONTEXT COMPACTION")),
1588                other => panic!("unexpected summary content: {other:?}"),
1589            },
1590            other => panic!("unexpected message: {other:?}"),
1591        }
1592        match &messages[1] {
1593            imp_llm::Message::User(user) => match user.content.as_slice() {
1594                [ContentBlock::Text { text }] => assert_eq!(text, "recent request"),
1595                other => panic!("unexpected recent user content: {other:?}"),
1596            },
1597            other => panic!("unexpected message: {other:?}"),
1598        }
1599        match &messages[2] {
1600            imp_llm::Message::User(user) => match user.content.as_slice() {
1601                [ContentBlock::Text { text }] => assert_eq!(text, "new follow-up"),
1602                other => panic!("unexpected follow-up content: {other:?}"),
1603            },
1604            other => panic!("unexpected message: {other:?}"),
1605        }
1606    }
1607
1608    #[test]
1609    fn persist_event_entries_writes_assistant_and_canonical_usage() {
1610        let tmp = TempDir::new().unwrap();
1611        let cwd = tmp.path().join("project");
1612        let session_dir = tmp.path().join("sessions");
1613        let model = test_model();
1614        let session_mgr = SessionManager::new(&cwd, &session_dir).unwrap();
1615        let (_agent, handle) = Agent::new(clone_model(&model), cwd.clone());
1616
1617        let mut session = ImpSession {
1618            agent: None,
1619            handle,
1620            session_mgr,
1621            config: Config::default(),
1622            model,
1623            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1624            model_registry: ModelRegistry::with_builtins(),
1625            cwd,
1626            agent_task: None,
1627            completed_run_result: None,
1628            pending_persistence_errors: VecDeque::new(),
1629            context_prefill: Vec::new(),
1630            context_prefill_injected: false,
1631        };
1632
1633        let message = test_assistant_message(
1634            123,
1635            Some(Usage {
1636                input_tokens: 1_000,
1637                output_tokens: 250,
1638                cache_read_tokens: 100,
1639                cache_write_tokens: 50,
1640            }),
1641        );
1642
1643        let persisted = session.persist_event_entries(&AgentEvent::TurnEnd {
1644            index: 2,
1645            message: message.clone(),
1646            mana_review: crate::mana_review::TurnManaReview::no_change(2),
1647        });
1648
1649        assert_eq!(persisted, vec!["assistant message", "canonical usage"]);
1650
1651        let usage_records = session.session_mgr.usage_records();
1652        assert_eq!(usage_records.len(), 1);
1653        let record = &usage_records[0];
1654        assert_eq!(record.turn_index, Some(2));
1655        assert_eq!(record.provider.as_deref(), Some("test-provider"));
1656        assert_eq!(record.model.as_deref(), Some("test-model"));
1657        assert!(record.request_id.starts_with("assistant:"));
1658        assert!(record.assistant_message_id.is_some());
1659        let cost = record.cost.as_ref().unwrap();
1660        assert!((cost.input - 0.002).abs() < 1e-12);
1661        assert!((cost.output - 0.001).abs() < 1e-12);
1662        assert!((cost.cache_read - 0.00005).abs() < 1e-12);
1663        assert!((cost.cache_write - 0.00005).abs() < 1e-12);
1664        assert!((cost.total - 0.0031).abs() < 1e-12);
1665    }
1666
1667    #[test]
1668    fn persist_event_entries_skips_usage_record_when_usage_missing() {
1669        let tmp = TempDir::new().unwrap();
1670        let cwd = tmp.path().join("project");
1671        let session_dir = tmp.path().join("sessions");
1672        let model = test_model();
1673        let session_mgr = SessionManager::new(&cwd, &session_dir).unwrap();
1674        let (_agent, handle) = Agent::new(clone_model(&model), cwd.clone());
1675
1676        let mut session = ImpSession {
1677            agent: None,
1678            handle,
1679            session_mgr,
1680            config: Config::default(),
1681            model,
1682            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1683            model_registry: ModelRegistry::with_builtins(),
1684            cwd,
1685            agent_task: None,
1686            completed_run_result: None,
1687            pending_persistence_errors: VecDeque::new(),
1688            context_prefill: Vec::new(),
1689            context_prefill_injected: false,
1690        };
1691
1692        let persisted = session.persist_event_entries(&AgentEvent::TurnEnd {
1693            index: 0,
1694            message: test_assistant_message(456, None),
1695            mana_review: crate::mana_review::TurnManaReview::no_change(0),
1696        });
1697
1698        assert_eq!(persisted, vec!["assistant message"]);
1699        assert!(session.session_mgr.usage_records().is_empty());
1700    }
1701
1702    #[test]
1703    fn persist_event_entries_writes_tool_results() {
1704        let tmp = TempDir::new().unwrap();
1705        let cwd = tmp.path().join("project");
1706        let session_dir = tmp.path().join("sessions");
1707        let model = test_model();
1708        let session_mgr = SessionManager::new(&cwd, &session_dir).unwrap();
1709        let (agent, handle) = Agent::new(clone_model(&model), cwd.clone());
1710        std::fs::create_dir_all(&cwd).unwrap();
1711        let file = cwd.join("tracked.rs");
1712        std::fs::write(&file, "original").unwrap();
1713        let checkpoint = agent
1714            .checkpoint_state
1715            .snapshot_paths(
1716                std::slice::from_ref(&file),
1717                Some("before tool result".into()),
1718            )
1719            .unwrap()
1720            .unwrap();
1721        std::fs::write(&file, "modified").unwrap();
1722
1723        let mut session = ImpSession {
1724            agent: Some(agent),
1725            handle,
1726            session_mgr,
1727            config: Config::default(),
1728            model,
1729            auth_store: AuthStore::new(tmp.path().join("auth.json")),
1730            model_registry: ModelRegistry::with_builtins(),
1731            cwd,
1732            agent_task: None,
1733            completed_run_result: None,
1734            pending_persistence_errors: VecDeque::new(),
1735            context_prefill: Vec::new(),
1736            context_prefill_injected: false,
1737        };
1738
1739        let persisted = session.persist_event_entries(&AgentEvent::ToolExecutionEnd {
1740            tool_call_id: "call-1".into(),
1741            result: imp_llm::ToolResultMessage {
1742                tool_call_id: "call-1".into(),
1743                tool_name: "bash".into(),
1744                content: vec![ContentBlock::Text { text: "ok".into() }],
1745                is_error: false,
1746                details: json!({"exit_code": 0}),
1747                timestamp: 999,
1748            },
1749        });
1750
1751        assert_eq!(persisted, vec!["tool result"]);
1752        assert!(session.session_mgr.entries().iter().any(|entry| matches!(
1753            entry,
1754            SessionEntry::Message {
1755                message: imp_llm::Message::ToolResult(_),
1756                ..
1757            }
1758        )));
1759        let checkpoints = session.session_mgr.checkpoint_records();
1760        assert_eq!(checkpoints.len(), 1);
1761        assert_eq!(checkpoints[0].checkpoint_id, checkpoint.id);
1762        let restored = session
1763            .session_mgr
1764            .restore_checkpoint(
1765                session
1766                    .agent
1767                    .as_ref()
1768                    .expect("agent retained for persistence test")
1769                    .checkpoint_state
1770                    .as_ref(),
1771                &checkpoints[0].checkpoint_id,
1772            )
1773            .unwrap();
1774        assert_eq!(restored, vec![file.clone()]);
1775        assert_eq!(std::fs::read_to_string(&file).unwrap(), "original");
1776    }
1777}