Skip to main content

pi/
sdk.rs

1//! Stable SDK-facing API surface for embedding Pi as a library.
2//!
3//! This module is the supported entry point for external library consumers.
4//! Prefer importing from `pi::sdk` instead of deep internal modules.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use pi::sdk::{AgentEvent, Message, ToolDefinition};
10//!
11//! let _events: Vec<AgentEvent> = Vec::new();
12//! let _messages: Vec<Message> = Vec::new();
13//! let _tools: Vec<ToolDefinition> = Vec::new();
14//! ```
15//!
16//! Internal implementation types are intentionally not part of this surface.
17//!
18//! ```compile_fail
19//! use pi::sdk::RpcSharedState;
20//! ```
21
22use crate::app;
23use crate::auth::AuthStorage;
24use crate::cli::Cli;
25use crate::compaction::ResolvedCompactionSettings;
26use crate::models::default_models_path;
27use crate::provider::ThinkingBudgets;
28use crate::providers;
29use clap::Parser;
30use serde::{Deserialize, Serialize, de::DeserializeOwned};
31use serde_json::{Map, Value};
32use std::collections::HashMap;
33use std::io::{BufRead, BufReader, BufWriter, Write};
34use std::path::{Path, PathBuf};
35use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38
39pub use crate::agent::{
40    AbortHandle, AbortSignal, Agent, AgentConfig, AgentEvent, AgentSession, QueueMode,
41};
42pub use crate::config::Config;
43pub use crate::error::{Error, Result};
44pub use crate::extensions::{ExtensionManager, ExtensionPolicy, ExtensionRegion};
45pub use crate::model::{
46    AssistantMessage, ContentBlock, Cost, CustomMessage, ImageContent, Message, StopReason,
47    StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage, UserContent,
48    UserMessage,
49};
50pub use crate::models::{ModelEntry, ModelRegistry};
51pub use crate::provider::{
52    Context as ProviderContext, InputType, Model, ModelCost, Provider, StreamOptions,
53    ThinkingBudgets as ProviderThinkingBudgets, ToolDef,
54};
55pub use crate::session::Session;
56pub use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
57pub use crate::model::ThinkingLevel;
58
59/// Stable alias for model-exposed tool schema definitions.
60pub type ToolDefinition = ToolDef;
61
62// ============================================================================
63// Tool Factory Functions
64// ============================================================================
65
66use crate::tools::{
67    BashTool, EditTool, FindTool, GrepTool, HashlineEditTool, LsTool, ReadTool, WriteTool,
68};
69
70/// All built-in tool names.
71pub const BUILTIN_TOOL_NAMES: &[&str] = &[
72    "read",
73    "bash",
74    "edit",
75    "write",
76    "grep",
77    "find",
78    "ls",
79    "hashline_edit",
80];
81
82/// Create a read tool configured for `cwd`.
83pub fn create_read_tool(cwd: &Path) -> Box<dyn Tool> {
84    Box::new(ReadTool::new(cwd))
85}
86
87/// Create a bash tool configured for `cwd`.
88pub fn create_bash_tool(cwd: &Path) -> Box<dyn Tool> {
89    Box::new(BashTool::new(cwd))
90}
91
92/// Create an edit tool configured for `cwd`.
93pub fn create_edit_tool(cwd: &Path) -> Box<dyn Tool> {
94    Box::new(EditTool::new(cwd))
95}
96
97/// Create a write tool configured for `cwd`.
98pub fn create_write_tool(cwd: &Path) -> Box<dyn Tool> {
99    Box::new(WriteTool::new(cwd))
100}
101
102/// Create a grep tool configured for `cwd`.
103pub fn create_grep_tool(cwd: &Path) -> Box<dyn Tool> {
104    Box::new(GrepTool::new(cwd))
105}
106
107/// Create a find tool configured for `cwd`.
108pub fn create_find_tool(cwd: &Path) -> Box<dyn Tool> {
109    Box::new(FindTool::new(cwd))
110}
111
112/// Create an ls tool configured for `cwd`.
113pub fn create_ls_tool(cwd: &Path) -> Box<dyn Tool> {
114    Box::new(LsTool::new(cwd))
115}
116
117/// Create a hashline edit tool configured for `cwd`.
118pub fn create_hashline_edit_tool(cwd: &Path) -> Box<dyn Tool> {
119    Box::new(HashlineEditTool::new(cwd))
120}
121
122/// Create all built-in tools configured for `cwd`.
123pub fn create_all_tools(cwd: &Path) -> Vec<Box<dyn Tool>> {
124    vec![
125        create_read_tool(cwd),
126        create_bash_tool(cwd),
127        create_edit_tool(cwd),
128        create_write_tool(cwd),
129        create_grep_tool(cwd),
130        create_find_tool(cwd),
131        create_ls_tool(cwd),
132        create_hashline_edit_tool(cwd),
133    ]
134}
135
136/// Convert a [`Tool`] into its [`ToolDefinition`] schema.
137pub fn tool_to_definition(tool: &dyn Tool) -> ToolDefinition {
138    ToolDefinition {
139        name: tool.name().to_string(),
140        description: tool.description().to_string(),
141        parameters: tool.parameters(),
142    }
143}
144
145/// Return [`ToolDefinition`] schemas for all built-in tools.
146pub fn all_tool_definitions(cwd: &Path) -> Vec<ToolDefinition> {
147    create_all_tools(cwd)
148        .iter()
149        .map(|t| tool_to_definition(t.as_ref()))
150        .collect()
151}
152
153// ============================================================================
154// Streaming Callbacks and Tool Hooks
155// ============================================================================
156
157/// Opaque identifier for an event subscription.
158///
159/// Returned by [`AgentSessionHandle::subscribe`] and used to remove the
160/// listener via [`AgentSessionHandle::unsubscribe`].
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub struct SubscriptionId(u64);
163
164/// Callback invoked when a tool execution starts.
165///
166/// Arguments: `(tool_name, input_args)`.
167pub type OnToolStart = Arc<dyn Fn(&str, &Value) + Send + Sync>;
168
169/// Callback invoked when a tool execution ends.
170///
171/// Arguments: `(tool_name, output, is_error)`.
172pub type OnToolEnd = Arc<dyn Fn(&str, &ToolOutput, bool) + Send + Sync>;
173
174/// Callback invoked for every raw provider [`StreamEvent`].
175///
176/// This gives SDK consumers direct access to the low-level streaming protocol
177/// before events are wrapped into [`AgentEvent::MessageUpdate`].
178pub type OnStreamEvent = Arc<dyn Fn(&StreamEvent) + Send + Sync>;
179
180pub type EventSubscriber = Arc<dyn Fn(AgentEvent) + Send + Sync>;
181type EventSubscribers = HashMap<SubscriptionId, EventSubscriber>;
182
183/// Collection of session-level event listeners.
184///
185/// These are registered once and invoked for every prompt throughout the
186/// session lifetime, in contrast to per-prompt callbacks on
187/// [`AgentSessionHandle::prompt`].
188#[derive(Clone, Default)]
189pub struct EventListeners {
190    next_id: Arc<AtomicU64>,
191    subscribers: Arc<std::sync::Mutex<EventSubscribers>>,
192    pub on_tool_start: Option<OnToolStart>,
193    pub on_tool_end: Option<OnToolEnd>,
194    pub on_stream_event: Option<OnStreamEvent>,
195}
196
197impl EventListeners {
198    fn new() -> Self {
199        Self {
200            next_id: Arc::new(AtomicU64::new(1)),
201            subscribers: Arc::new(std::sync::Mutex::new(HashMap::new())),
202            on_tool_start: None,
203            on_tool_end: None,
204            on_stream_event: None,
205        }
206    }
207
208    /// Register a session-level event listener.
209    pub fn subscribe(&self, listener: EventSubscriber) -> SubscriptionId {
210        let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
211        let mut subs = self
212            .subscribers
213            .lock()
214            .unwrap_or_else(std::sync::PoisonError::into_inner);
215        subs.insert(id, listener);
216        id
217    }
218
219    /// Remove a previously registered listener.
220    pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
221        let mut subs = self
222            .subscribers
223            .lock()
224            .unwrap_or_else(std::sync::PoisonError::into_inner);
225        subs.remove(&id).is_some()
226    }
227
228    /// Dispatch an [`AgentEvent`] to all registered subscribers.
229    pub fn notify(&self, event: &AgentEvent) {
230        let listeners: Vec<_> = {
231            let subs = self
232                .subscribers
233                .lock()
234                .unwrap_or_else(std::sync::PoisonError::into_inner);
235            subs.values().cloned().collect()
236        };
237        for listener in listeners {
238            listener(event.clone());
239        }
240    }
241
242    /// Dispatch tool-start to the typed hook (if set).
243    pub fn notify_tool_start(&self, tool_name: &str, args: &Value) {
244        if let Some(cb) = &self.on_tool_start {
245            cb(tool_name, args);
246        }
247    }
248
249    /// Dispatch tool-end to the typed hook (if set).
250    pub fn notify_tool_end(&self, tool_name: &str, output: &ToolOutput, is_error: bool) {
251        if let Some(cb) = &self.on_tool_end {
252            cb(tool_name, output, is_error);
253        }
254    }
255
256    /// Dispatch a raw stream event (if hook is set).
257    pub fn notify_stream_event(&self, event: &StreamEvent) {
258        if let Some(cb) = &self.on_stream_event {
259            cb(event);
260        }
261    }
262}
263
264impl std::fmt::Debug for EventListeners {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        let count = self.subscribers.lock().map_or(0, |s| s.len());
267        let next_id = self.next_id.load(Ordering::Relaxed);
268        f.debug_struct("EventListeners")
269            .field("subscriber_count", &count)
270            .field("next_id", &next_id)
271            .field("has_on_tool_start", &self.on_tool_start.is_some())
272            .field("has_on_tool_end", &self.on_tool_end.is_some())
273            .field("has_on_stream_event", &self.on_stream_event.is_some())
274            .finish()
275    }
276}
277
278/// SDK session construction options.
279///
280/// These options provide the programmatic equivalent of the core CLI startup
281/// path used in `src/main.rs`.
282#[derive(Clone)]
283pub struct SessionOptions {
284    pub provider: Option<String>,
285    pub model: Option<String>,
286    pub api_key: Option<String>,
287    pub thinking: Option<crate::model::ThinkingLevel>,
288    pub system_prompt: Option<String>,
289    pub append_system_prompt: Option<String>,
290    pub enabled_tools: Option<Vec<String>>,
291    pub working_directory: Option<PathBuf>,
292    pub no_session: bool,
293    pub session_path: Option<PathBuf>,
294    pub session_dir: Option<PathBuf>,
295    pub extension_paths: Vec<PathBuf>,
296    pub extension_policy: Option<String>,
297    pub repair_policy: Option<String>,
298    pub include_cwd_in_prompt: bool,
299    pub max_tool_iterations: usize,
300
301    /// Session-level event listener invoked for every [`AgentEvent`].
302    ///
303    /// Unlike the per-prompt callback passed to [`AgentSessionHandle::prompt`],
304    /// this fires for all prompts throughout the session lifetime.
305    pub on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
306
307    /// Typed callback invoked when tool execution starts.
308    pub on_tool_start: Option<OnToolStart>,
309
310    /// Typed callback invoked when tool execution ends.
311    pub on_tool_end: Option<OnToolEnd>,
312
313    /// Callback for raw provider [`StreamEvent`]s.
314    pub on_stream_event: Option<OnStreamEvent>,
315}
316
317impl Default for SessionOptions {
318    fn default() -> Self {
319        Self {
320            provider: None,
321            model: None,
322            api_key: None,
323            thinking: None,
324            system_prompt: None,
325            append_system_prompt: None,
326            enabled_tools: None,
327            working_directory: None,
328            no_session: true,
329            session_path: None,
330            session_dir: None,
331            extension_paths: Vec::new(),
332            extension_policy: None,
333            repair_policy: None,
334            include_cwd_in_prompt: true,
335            max_tool_iterations: 50,
336            on_event: None,
337            on_tool_start: None,
338            on_tool_end: None,
339            on_stream_event: None,
340        }
341    }
342}
343
344/// Lightweight handle for programmatic embedding.
345///
346/// This wraps `AgentSession` and exposes high-level request methods while still
347/// allowing access to the underlying session when needed.
348///
349/// Session-level event listeners can be registered via [`Self::subscribe`] or
350/// by providing callbacks on [`SessionOptions`].  These fire for **every**
351/// prompt, in addition to the per-prompt `on_event` callback.
352pub struct AgentSessionHandle {
353    session: AgentSession,
354    listeners: EventListeners,
355}
356
357/// Snapshot of the current agent session state.
358#[derive(Debug, Clone, PartialEq, Eq)]
359pub struct AgentSessionState {
360    pub session_id: Option<String>,
361    pub provider: String,
362    pub model_id: String,
363    pub thinking_level: Option<crate::model::ThinkingLevel>,
364    pub save_enabled: bool,
365    pub message_count: usize,
366}
367
368/// Prompt completion payload returned by `SessionTransport`.
369#[derive(Debug, Clone)]
370pub enum SessionPromptResult {
371    InProcess(AssistantMessage),
372    RpcEvents(Vec<Value>),
373}
374
375/// Event wrapper used by the unified `SessionTransport` callback.
376#[derive(Debug, Clone)]
377pub enum SessionTransportEvent {
378    InProcess(AgentEvent),
379    Rpc(Value),
380}
381
382/// Unified session state snapshot across in-process and RPC transports.
383#[derive(Debug, Clone, PartialEq)]
384pub enum SessionTransportState {
385    InProcess(AgentSessionState),
386    Rpc(Box<RpcSessionState>),
387}
388
389/// Model metadata exposed by RPC APIs.
390#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
391#[serde(rename_all = "camelCase")]
392pub struct RpcModelInfo {
393    pub id: String,
394    pub name: String,
395    pub api: String,
396    pub provider: String,
397    #[serde(default)]
398    pub base_url: String,
399    #[serde(default)]
400    pub reasoning: bool,
401    #[serde(default)]
402    pub input: Vec<InputType>,
403    #[serde(default)]
404    pub context_window: u32,
405    #[serde(default)]
406    pub max_tokens: u32,
407    #[serde(default)]
408    pub cost: Option<ModelCost>,
409}
410
411/// Session state payload returned by RPC `get_state`.
412#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
413#[serde(rename_all = "camelCase")]
414#[allow(clippy::struct_excessive_bools)]
415pub struct RpcSessionState {
416    #[serde(default)]
417    pub model: Option<RpcModelInfo>,
418    #[serde(default)]
419    pub thinking_level: String,
420    #[serde(default)]
421    pub is_streaming: bool,
422    #[serde(default)]
423    pub is_compacting: bool,
424    #[serde(default)]
425    pub steering_mode: String,
426    #[serde(default)]
427    pub follow_up_mode: String,
428    #[serde(default)]
429    pub session_file: Option<String>,
430    #[serde(default)]
431    pub session_id: String,
432    #[serde(default)]
433    pub session_name: Option<String>,
434    #[serde(default)]
435    pub auto_compaction_enabled: bool,
436    #[serde(default)]
437    pub auto_retry_enabled: bool,
438    #[serde(default)]
439    pub message_count: usize,
440    #[serde(default)]
441    pub pending_message_count: usize,
442    #[serde(default)]
443    pub durability_mode: String,
444}
445
446/// Session-level token aggregates returned by RPC `get_session_stats`.
447#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
448#[serde(rename_all = "camelCase")]
449pub struct RpcTokenStats {
450    pub input: u64,
451    pub output: u64,
452    pub cache_read: u64,
453    pub cache_write: u64,
454    pub total: u64,
455}
456
457/// Session stats payload returned by RPC `get_session_stats`.
458#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
459#[serde(rename_all = "camelCase")]
460pub struct RpcSessionStats {
461    #[serde(default)]
462    pub session_file: Option<String>,
463    pub session_id: String,
464    pub user_messages: u64,
465    pub assistant_messages: u64,
466    pub tool_calls: u64,
467    pub tool_results: u64,
468    pub total_messages: u64,
469    pub tokens: RpcTokenStats,
470    pub cost: f64,
471}
472
473/// Result payload for `new_session` and `switch_session`.
474#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
475pub struct RpcCancelledResult {
476    pub cancelled: bool,
477}
478
479/// Result payload returned by RPC `cycle_model`.
480#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
481#[serde(rename_all = "camelCase")]
482pub struct RpcCycleModelResult {
483    pub model: RpcModelInfo,
484    pub thinking_level: crate::model::ThinkingLevel,
485    pub is_scoped: bool,
486}
487
488/// Result payload returned by RPC `cycle_thinking_level`.
489#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
490pub struct RpcThinkingLevelResult {
491    pub level: crate::model::ThinkingLevel,
492}
493
494/// Bash execution result returned by RPC `bash`.
495#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
496#[serde(rename_all = "camelCase")]
497pub struct RpcBashResult {
498    pub output: String,
499    pub exit_code: i32,
500    pub cancelled: bool,
501    pub truncated: bool,
502    pub full_output_path: Option<String>,
503}
504
505/// Compaction result returned by RPC `compact`.
506#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
507#[serde(rename_all = "camelCase")]
508pub struct RpcCompactionResult {
509    pub summary: String,
510    pub first_kept_entry_id: String,
511    pub tokens_before: u64,
512    #[serde(default)]
513    pub details: Value,
514}
515
516/// Result payload returned by RPC `fork`.
517#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
518pub struct RpcForkResult {
519    pub text: String,
520    pub cancelled: bool,
521}
522
523/// Forkable message entry returned by RPC `get_fork_messages`.
524#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
525#[serde(rename_all = "camelCase")]
526pub struct RpcForkMessage {
527    pub entry_id: String,
528    pub text: String,
529}
530
531/// Slash command metadata returned by RPC `get_commands`.
532#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
533pub struct RpcCommandInfo {
534    pub name: String,
535    #[serde(default)]
536    pub description: Option<String>,
537    pub source: String,
538    #[serde(default)]
539    pub location: Option<String>,
540    #[serde(default)]
541    pub path: Option<String>,
542}
543
544/// Export HTML response payload.
545#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
546pub struct RpcExportHtmlResult {
547    pub path: String,
548}
549
550/// Last-assistant-text response payload.
551#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
552pub struct RpcLastAssistantText {
553    pub text: Option<String>,
554}
555
556/// Extension UI response payload used by RPC `extension_ui_response`.
557#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
558#[serde(tag = "kind", rename_all = "snake_case")]
559pub enum RpcExtensionUiResponse {
560    Value { value: Value },
561    Confirmed { confirmed: bool },
562    Cancelled,
563}
564
565/// Process-boundary transport options for SDK callers that prefer RPC mode.
566#[derive(Debug, Clone)]
567pub struct RpcTransportOptions {
568    pub binary_path: PathBuf,
569    pub args: Vec<String>,
570    pub cwd: Option<PathBuf>,
571}
572
573impl Default for RpcTransportOptions {
574    fn default() -> Self {
575        Self {
576            binary_path: PathBuf::from("pi"),
577            args: vec!["--mode".to_string(), "rpc".to_string()],
578            cwd: None,
579        }
580    }
581}
582
583/// Subprocess-backed SDK transport for `pi --mode rpc`.
584pub struct RpcTransportClient {
585    child: Child,
586    stdin: BufWriter<ChildStdin>,
587    stdout: BufReader<ChildStdout>,
588    next_request_id: u64,
589}
590
591/// Unified adapter over in-process and subprocess-backed session control.
592pub enum SessionTransport {
593    InProcess(Box<AgentSessionHandle>),
594    RpcSubprocess(RpcTransportClient),
595}
596
597impl SessionTransport {
598    pub async fn in_process(options: SessionOptions) -> Result<Self> {
599        create_agent_session(options)
600            .await
601            .map(Box::new)
602            .map(Self::InProcess)
603    }
604
605    pub fn rpc_subprocess(options: RpcTransportOptions) -> Result<Self> {
606        RpcTransportClient::connect(options).map(Self::RpcSubprocess)
607    }
608
609    #[allow(clippy::missing_const_for_fn)]
610    pub fn as_in_process_mut(&mut self) -> Option<&mut AgentSessionHandle> {
611        match self {
612            Self::InProcess(handle) => Some(handle.as_mut()),
613            Self::RpcSubprocess(_) => None,
614        }
615    }
616
617    #[allow(clippy::missing_const_for_fn)]
618    pub fn as_rpc_mut(&mut self) -> Option<&mut RpcTransportClient> {
619        match self {
620            Self::InProcess(_) => None,
621            Self::RpcSubprocess(client) => Some(client),
622        }
623    }
624
625    /// Send one prompt over whichever transport is active.
626    ///
627    /// - In-process mode returns the final assistant message.
628    /// - RPC mode waits for `agent_end` and returns collected raw events.
629    pub async fn prompt(
630        &mut self,
631        input: impl Into<String>,
632        on_event: impl Fn(SessionTransportEvent) + Send + Sync + 'static,
633    ) -> Result<SessionPromptResult> {
634        let input = input.into();
635        let on_event = Arc::new(on_event);
636        match self {
637            Self::InProcess(handle) => {
638                let on_event = Arc::clone(&on_event);
639                let assistant = handle
640                    .prompt(input, move |event| {
641                        (on_event)(SessionTransportEvent::InProcess(event));
642                    })
643                    .await?;
644                Ok(SessionPromptResult::InProcess(assistant))
645            }
646            Self::RpcSubprocess(client) => {
647                let events = client.prompt(input).await?;
648                for event in events.iter().cloned() {
649                    (on_event)(SessionTransportEvent::Rpc(event));
650                }
651                Ok(SessionPromptResult::RpcEvents(events))
652            }
653        }
654    }
655
656    /// Return a state snapshot from the active transport.
657    pub async fn state(&mut self) -> Result<SessionTransportState> {
658        match self {
659            Self::InProcess(handle) => handle.state().await.map(SessionTransportState::InProcess),
660            Self::RpcSubprocess(client) => client
661                .get_state()
662                .await
663                .map(Box::new)
664                .map(SessionTransportState::Rpc),
665        }
666    }
667
668    /// Update provider/model for the active transport.
669    pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
670        match self {
671            Self::InProcess(handle) => handle.set_model(provider, model_id).await,
672            Self::RpcSubprocess(client) => {
673                let _ = client.set_model(provider, model_id).await?;
674                Ok(())
675            }
676        }
677    }
678
679    /// Shut down transport resources (best effort for in-process, explicit for RPC).
680    pub fn shutdown(&mut self) -> Result<()> {
681        match self {
682            Self::InProcess(_) => Ok(()),
683            Self::RpcSubprocess(client) => client.shutdown(),
684        }
685    }
686}
687
688impl RpcTransportClient {
689    pub fn connect(options: RpcTransportOptions) -> Result<Self> {
690        let mut command = Command::new(&options.binary_path);
691        command
692            .args(&options.args)
693            .stdin(Stdio::piped())
694            .stdout(Stdio::piped())
695            .stderr(Stdio::inherit());
696        if let Some(cwd) = options.cwd {
697            command.current_dir(cwd);
698        }
699
700        let mut child = command.spawn().map_err(|err| {
701            Error::config(format!(
702                "Failed to spawn RPC subprocess {}: {err}",
703                options.binary_path.display()
704            ))
705        })?;
706        let stdin = child
707            .stdin
708            .take()
709            .ok_or_else(|| Error::config("RPC subprocess stdin is not piped"))?;
710        let stdout = child
711            .stdout
712            .take()
713            .ok_or_else(|| Error::config("RPC subprocess stdout is not piped"))?;
714
715        Ok(Self {
716            child,
717            stdin: BufWriter::new(stdin),
718            stdout: BufReader::new(stdout),
719            next_request_id: 1,
720        })
721    }
722
723    pub async fn request(&mut self, command: &str, payload: Map<String, Value>) -> Result<Value> {
724        let request_id = self.next_request_id();
725        let mut command_payload = Map::new();
726        command_payload.insert("type".to_string(), Value::String(command.to_string()));
727        command_payload.insert("id".to_string(), Value::String(request_id.clone()));
728        command_payload.extend(payload);
729
730        self.write_json_line(&Value::Object(command_payload))?;
731        self.wait_for_response(&request_id, command)
732    }
733
734    fn parse_response_data<T: DeserializeOwned>(data: Value, command: &str) -> Result<T> {
735        serde_json::from_value(data).map_err(|err| {
736            Error::api(format!(
737                "Failed to decode RPC `{command}` response payload: {err}"
738            ))
739        })
740    }
741
742    async fn request_typed<T: DeserializeOwned>(
743        &mut self,
744        command: &str,
745        payload: Map<String, Value>,
746    ) -> Result<T> {
747        let data = self.request(command, payload).await?;
748        Self::parse_response_data(data, command)
749    }
750
751    async fn request_no_data(&mut self, command: &str, payload: Map<String, Value>) -> Result<()> {
752        let _ = self.request(command, payload).await?;
753        Ok(())
754    }
755
756    pub async fn steer(&mut self, message: impl Into<String>) -> Result<()> {
757        let mut payload = Map::new();
758        payload.insert("message".to_string(), Value::String(message.into()));
759        self.request_no_data("steer", payload).await
760    }
761
762    pub async fn follow_up(&mut self, message: impl Into<String>) -> Result<()> {
763        let mut payload = Map::new();
764        payload.insert("message".to_string(), Value::String(message.into()));
765        self.request_no_data("follow_up", payload).await
766    }
767
768    pub async fn abort(&mut self) -> Result<()> {
769        self.request_no_data("abort", Map::new()).await
770    }
771
772    pub async fn new_session(
773        &mut self,
774        parent_session: Option<&Path>,
775    ) -> Result<RpcCancelledResult> {
776        let mut payload = Map::new();
777        if let Some(parent_session) = parent_session {
778            payload.insert(
779                "parentSession".to_string(),
780                Value::String(parent_session.display().to_string()),
781            );
782        }
783        self.request_typed("new_session", payload).await
784    }
785
786    pub async fn get_state(&mut self) -> Result<RpcSessionState> {
787        self.request_typed("get_state", Map::new()).await
788    }
789
790    pub async fn get_session_stats(&mut self) -> Result<RpcSessionStats> {
791        self.request_typed("get_session_stats", Map::new()).await
792    }
793
794    pub async fn get_messages(&mut self) -> Result<Vec<Value>> {
795        #[derive(Deserialize)]
796        struct MessagesPayload {
797            messages: Vec<Value>,
798        }
799        let payload: MessagesPayload = self.request_typed("get_messages", Map::new()).await?;
800        Ok(payload.messages)
801    }
802
803    pub async fn get_available_models(&mut self) -> Result<Vec<RpcModelInfo>> {
804        #[derive(Deserialize)]
805        struct ModelsPayload {
806            models: Vec<RpcModelInfo>,
807        }
808        let payload: ModelsPayload = self
809            .request_typed("get_available_models", Map::new())
810            .await?;
811        Ok(payload.models)
812    }
813
814    pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<RpcModelInfo> {
815        let mut payload = Map::new();
816        payload.insert("provider".to_string(), Value::String(provider.to_string()));
817        payload.insert("modelId".to_string(), Value::String(model_id.to_string()));
818        self.request_typed("set_model", payload).await
819    }
820
821    pub async fn cycle_model(&mut self) -> Result<Option<RpcCycleModelResult>> {
822        self.request_typed("cycle_model", Map::new()).await
823    }
824
825    pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
826        let mut payload = Map::new();
827        payload.insert("level".to_string(), Value::String(level.to_string()));
828        self.request_no_data("set_thinking_level", payload).await
829    }
830
831    pub async fn cycle_thinking_level(&mut self) -> Result<Option<RpcThinkingLevelResult>> {
832        self.request_typed("cycle_thinking_level", Map::new()).await
833    }
834
835    pub async fn set_steering_mode(&mut self, mode: &str) -> Result<()> {
836        let mut payload = Map::new();
837        payload.insert("mode".to_string(), Value::String(mode.to_string()));
838        self.request_no_data("set_steering_mode", payload).await
839    }
840
841    pub async fn set_follow_up_mode(&mut self, mode: &str) -> Result<()> {
842        let mut payload = Map::new();
843        payload.insert("mode".to_string(), Value::String(mode.to_string()));
844        self.request_no_data("set_follow_up_mode", payload).await
845    }
846
847    pub async fn set_auto_compaction(&mut self, enabled: bool) -> Result<()> {
848        let mut payload = Map::new();
849        payload.insert("enabled".to_string(), Value::Bool(enabled));
850        self.request_no_data("set_auto_compaction", payload).await
851    }
852
853    pub async fn set_auto_retry(&mut self, enabled: bool) -> Result<()> {
854        let mut payload = Map::new();
855        payload.insert("enabled".to_string(), Value::Bool(enabled));
856        self.request_no_data("set_auto_retry", payload).await
857    }
858
859    pub async fn abort_retry(&mut self) -> Result<()> {
860        self.request_no_data("abort_retry", Map::new()).await
861    }
862
863    pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
864        let mut payload = Map::new();
865        payload.insert("name".to_string(), Value::String(name.into()));
866        self.request_no_data("set_session_name", payload).await
867    }
868
869    pub async fn get_last_assistant_text(&mut self) -> Result<Option<String>> {
870        let payload: RpcLastAssistantText = self
871            .request_typed("get_last_assistant_text", Map::new())
872            .await?;
873        Ok(payload.text)
874    }
875
876    pub async fn export_html(&mut self, output_path: Option<&Path>) -> Result<RpcExportHtmlResult> {
877        let mut payload = Map::new();
878        if let Some(path) = output_path {
879            payload.insert(
880                "outputPath".to_string(),
881                Value::String(path.display().to_string()),
882            );
883        }
884        self.request_typed("export_html", payload).await
885    }
886
887    pub async fn bash(&mut self, command: impl Into<String>) -> Result<RpcBashResult> {
888        let mut payload = Map::new();
889        payload.insert("command".to_string(), Value::String(command.into()));
890        self.request_typed("bash", payload).await
891    }
892
893    pub async fn abort_bash(&mut self) -> Result<()> {
894        self.request_no_data("abort_bash", Map::new()).await
895    }
896
897    pub async fn compact(&mut self) -> Result<RpcCompactionResult> {
898        self.compact_with_instructions(None).await
899    }
900
901    pub async fn compact_with_instructions(
902        &mut self,
903        custom_instructions: Option<&str>,
904    ) -> Result<RpcCompactionResult> {
905        let mut payload = Map::new();
906        if let Some(custom_instructions) = custom_instructions {
907            payload.insert(
908                "customInstructions".to_string(),
909                Value::String(custom_instructions.to_string()),
910            );
911        }
912        self.request_typed("compact", payload).await
913    }
914
915    pub async fn switch_session(&mut self, session_path: &Path) -> Result<RpcCancelledResult> {
916        let mut payload = Map::new();
917        payload.insert(
918            "sessionPath".to_string(),
919            Value::String(session_path.display().to_string()),
920        );
921        self.request_typed("switch_session", payload).await
922    }
923
924    pub async fn fork(&mut self, entry_id: impl Into<String>) -> Result<RpcForkResult> {
925        let mut payload = Map::new();
926        payload.insert("entryId".to_string(), Value::String(entry_id.into()));
927        self.request_typed("fork", payload).await
928    }
929
930    pub async fn get_fork_messages(&mut self) -> Result<Vec<RpcForkMessage>> {
931        #[derive(Deserialize)]
932        struct ForkMessagesPayload {
933            messages: Vec<RpcForkMessage>,
934        }
935        let payload: ForkMessagesPayload =
936            self.request_typed("get_fork_messages", Map::new()).await?;
937        Ok(payload.messages)
938    }
939
940    pub async fn get_commands(&mut self) -> Result<Vec<RpcCommandInfo>> {
941        #[derive(Deserialize)]
942        struct CommandsPayload {
943            commands: Vec<RpcCommandInfo>,
944        }
945        let payload: CommandsPayload = self.request_typed("get_commands", Map::new()).await?;
946        Ok(payload.commands)
947    }
948
949    pub async fn extension_ui_response(
950        &mut self,
951        request_id: &str,
952        response: RpcExtensionUiResponse,
953    ) -> Result<bool> {
954        #[derive(Deserialize)]
955        struct ExtensionUiResolvedPayload {
956            resolved: bool,
957        }
958
959        let mut payload = Map::new();
960        payload.insert(
961            "requestId".to_string(),
962            Value::String(request_id.to_string()),
963        );
964
965        match response {
966            RpcExtensionUiResponse::Value { value } => {
967                payload.insert("value".to_string(), value);
968            }
969            RpcExtensionUiResponse::Confirmed { confirmed } => {
970                payload.insert("confirmed".to_string(), Value::Bool(confirmed));
971            }
972            RpcExtensionUiResponse::Cancelled => {
973                payload.insert("cancelled".to_string(), Value::Bool(true));
974            }
975        }
976
977        let response: Option<ExtensionUiResolvedPayload> =
978            self.request_typed("extension_ui_response", payload).await?;
979        Ok(response.is_none_or(|payload| payload.resolved))
980    }
981
982    pub async fn prompt(&mut self, message: impl Into<String>) -> Result<Vec<Value>> {
983        self.prompt_with_options(message, None, None).await
984    }
985
986    pub async fn prompt_with_options(
987        &mut self,
988        message: impl Into<String>,
989        images: Option<Vec<ImageContent>>,
990        streaming_behavior: Option<&str>,
991    ) -> Result<Vec<Value>> {
992        let request_id = self.next_request_id();
993        let mut payload = Map::new();
994        payload.insert("type".to_string(), Value::String("prompt".to_string()));
995        payload.insert("id".to_string(), Value::String(request_id.clone()));
996        payload.insert("message".to_string(), Value::String(message.into()));
997        if let Some(images) = images {
998            payload.insert(
999                "images".to_string(),
1000                serde_json::to_value(images).map_err(|err| Error::Json(Box::new(err)))?,
1001            );
1002        }
1003        if let Some(streaming_behavior) = streaming_behavior {
1004            payload.insert(
1005                "streamingBehavior".to_string(),
1006                Value::String(streaming_behavior.to_string()),
1007            );
1008        }
1009        let payload = Value::Object(payload);
1010        self.write_json_line(&payload)?;
1011
1012        let mut saw_ack = false;
1013        let mut events = Vec::new();
1014        loop {
1015            let item = self.read_json_line()?;
1016            let item_type = item.get("type").and_then(Value::as_str);
1017            if item_type == Some("response") {
1018                if item.get("id").and_then(Value::as_str) != Some(request_id.as_str()) {
1019                    continue;
1020                }
1021                let success = item
1022                    .get("success")
1023                    .and_then(Value::as_bool)
1024                    .unwrap_or(false);
1025                if !success {
1026                    return Err(rpc_error_from_response(&item, "prompt"));
1027                }
1028                saw_ack = true;
1029                continue;
1030            }
1031
1032            if saw_ack {
1033                let reached_end = item_type == Some("agent_end");
1034                events.push(item);
1035                if reached_end {
1036                    return Ok(events);
1037                }
1038            }
1039        }
1040    }
1041
1042    pub fn shutdown(&mut self) -> Result<()> {
1043        if self
1044            .child
1045            .try_wait()
1046            .map_err(|err| Error::Io(Box::new(err)))?
1047            .is_none()
1048        {
1049            self.child.kill().map_err(|err| Error::Io(Box::new(err)))?;
1050        }
1051        let _ = self.child.wait();
1052        Ok(())
1053    }
1054
1055    fn next_request_id(&mut self) -> String {
1056        let id = format!("rpc-{}", self.next_request_id);
1057        self.next_request_id = self.next_request_id.saturating_add(1);
1058        id
1059    }
1060
1061    fn write_json_line(&mut self, payload: &Value) -> Result<()> {
1062        let encoded = serde_json::to_string(payload).map_err(|err| Error::Json(Box::new(err)))?;
1063        self.stdin
1064            .write_all(encoded.as_bytes())
1065            .map_err(|err| Error::Io(Box::new(err)))?;
1066        self.stdin
1067            .write_all(b"\n")
1068            .map_err(|err| Error::Io(Box::new(err)))?;
1069        self.stdin.flush().map_err(|err| Error::Io(Box::new(err)))?;
1070        Ok(())
1071    }
1072
1073    fn read_json_line(&mut self) -> Result<Value> {
1074        let mut line = String::new();
1075        let read = self
1076            .stdout
1077            .read_line(&mut line)
1078            .map_err(|err| Error::Io(Box::new(err)))?;
1079        if read == 0 {
1080            return Err(Error::api(
1081                "RPC subprocess exited before sending a response",
1082            ));
1083        }
1084        serde_json::from_str(line.trim_end()).map_err(|err| Error::Json(Box::new(err)))
1085    }
1086
1087    fn wait_for_response(&mut self, request_id: &str, command: &str) -> Result<Value> {
1088        loop {
1089            let item = self.read_json_line()?;
1090            let Some(item_type) = item.get("type").and_then(Value::as_str) else {
1091                continue;
1092            };
1093            if item_type != "response" {
1094                continue;
1095            }
1096            if item.get("id").and_then(Value::as_str) != Some(request_id) {
1097                continue;
1098            }
1099            if item.get("command").and_then(Value::as_str) != Some(command) {
1100                continue;
1101            }
1102
1103            let success = item
1104                .get("success")
1105                .and_then(Value::as_bool)
1106                .unwrap_or(false);
1107            if success {
1108                return Ok(item.get("data").cloned().unwrap_or(Value::Null));
1109            }
1110            return Err(rpc_error_from_response(&item, command));
1111        }
1112    }
1113}
1114
1115impl Drop for RpcTransportClient {
1116    fn drop(&mut self) {
1117        let _ = self.shutdown();
1118    }
1119}
1120
1121fn rpc_error_from_response(response: &Value, command: &str) -> Error {
1122    let error = response
1123        .get("error")
1124        .and_then(Value::as_str)
1125        .unwrap_or("RPC command failed");
1126    Error::api(format!("RPC {command} failed: {error}"))
1127}
1128
1129impl AgentSessionHandle {
1130    /// Create a handle from a pre-built `AgentSession` with custom listeners.
1131    ///
1132    /// This is useful for tests and advanced embedding scenarios where
1133    /// the full `create_agent_session()` flow is not needed.
1134    pub const fn from_session_with_listeners(
1135        session: AgentSession,
1136        listeners: EventListeners,
1137    ) -> Self {
1138        Self { session, listeners }
1139    }
1140
1141    /// Send one user prompt through the agent loop.
1142    ///
1143    /// The `on_event` callback receives events for this prompt only.
1144    /// Session-level listeners registered via [`Self::subscribe`] or
1145    /// [`SessionOptions`] callbacks also fire for every event.
1146    pub async fn prompt(
1147        &mut self,
1148        input: impl Into<String>,
1149        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1150    ) -> Result<AssistantMessage> {
1151        let combined = self.make_combined_callback(on_event);
1152        self.session.run_text(input.into(), combined).await
1153    }
1154
1155    /// Send one user prompt through the agent loop with an explicit abort signal.
1156    pub async fn prompt_with_abort(
1157        &mut self,
1158        input: impl Into<String>,
1159        abort_signal: AbortSignal,
1160        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1161    ) -> Result<AssistantMessage> {
1162        let combined = self.make_combined_callback(on_event);
1163        self.session
1164            .run_text_with_abort(input.into(), Some(abort_signal), combined)
1165            .await
1166    }
1167
1168    /// Continue the current agent loop without adding a new user prompt.
1169    ///
1170    /// This is useful for retry/continuation flows where session history or
1171    /// injected messages should drive the next turn without synthesizing a new
1172    /// user message through [`Self::prompt`].
1173    pub async fn continue_turn(
1174        &mut self,
1175        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1176    ) -> Result<AssistantMessage> {
1177        let combined = self.make_combined_callback(on_event);
1178        self.session
1179            .sync_runtime_selection_from_session_header()
1180            .await?;
1181        self.session
1182            .agent
1183            .run_continue_with_abort(None, combined)
1184            .await
1185    }
1186
1187    /// Continue the current agent loop with an explicit abort signal.
1188    pub async fn continue_turn_with_abort(
1189        &mut self,
1190        abort_signal: AbortSignal,
1191        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1192    ) -> Result<AssistantMessage> {
1193        let combined = self.make_combined_callback(on_event);
1194        self.session
1195            .sync_runtime_selection_from_session_header()
1196            .await?;
1197        self.session
1198            .agent
1199            .run_continue_with_abort(Some(abort_signal), combined)
1200            .await
1201    }
1202
1203    /// Create a new abort handle/signal pair for prompt cancellation.
1204    pub fn new_abort_handle() -> (AbortHandle, AbortSignal) {
1205        AbortHandle::new()
1206    }
1207
1208    /// Register a session-level event listener.
1209    ///
1210    /// The listener fires for every [`AgentEvent`] across all future prompts
1211    /// until removed via [`Self::unsubscribe`].
1212    ///
1213    /// Returns a [`SubscriptionId`] that can be used to remove the listener.
1214    pub fn subscribe(
1215        &self,
1216        listener: impl Fn(AgentEvent) + Send + Sync + 'static,
1217    ) -> SubscriptionId {
1218        self.listeners.subscribe(Arc::new(listener))
1219    }
1220
1221    /// Remove a previously registered event listener.
1222    ///
1223    /// Returns `true` if the listener was found and removed.
1224    pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
1225        self.listeners.unsubscribe(id)
1226    }
1227
1228    /// Access the session-level event listeners.
1229    pub const fn listeners(&self) -> &EventListeners {
1230        &self.listeners
1231    }
1232
1233    /// Mutable access to session-level event listeners.
1234    ///
1235    /// Allows updating typed hooks (`on_tool_start`, `on_tool_end`,
1236    /// `on_stream_event`) after session creation.
1237    pub const fn listeners_mut(&mut self) -> &mut EventListeners {
1238        &mut self.listeners
1239    }
1240
1241    // -----------------------------------------------------------------
1242    // Extensions & Capability Policy
1243    // -----------------------------------------------------------------
1244
1245    /// Whether this session has extensions loaded.
1246    pub const fn has_extensions(&self) -> bool {
1247        self.session.extensions.is_some()
1248    }
1249
1250    /// Return a reference to the extension manager (if extensions are loaded).
1251    pub fn extension_manager(&self) -> Option<&ExtensionManager> {
1252        self.session
1253            .extensions
1254            .as_ref()
1255            .map(ExtensionRegion::manager)
1256    }
1257
1258    /// Return a reference to the extension region (if extensions are loaded).
1259    ///
1260    /// The region wraps the extension manager with lifecycle management.
1261    pub const fn extension_region(&self) -> Option<&ExtensionRegion> {
1262        self.session.extensions.as_ref()
1263    }
1264
1265    // -----------------------------------------------------------------
1266    // Provider & Model
1267    // -----------------------------------------------------------------
1268
1269    /// Return the active provider/model pair.
1270    pub fn model(&self) -> (String, String) {
1271        let provider = self.session.agent.provider();
1272        (provider.name().to_string(), provider.model_id().to_string())
1273    }
1274
1275    /// Update the active provider/model pair and persist it to session metadata.
1276    pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
1277        self.session.set_provider_model(provider, model_id).await
1278    }
1279
1280    /// Return the currently configured thinking level.
1281    pub const fn thinking_level(&self) -> Option<crate::model::ThinkingLevel> {
1282        self.session.agent.stream_options().thinking_level
1283    }
1284
1285    /// Alias for thinking level access, matching the SDK naming style.
1286    pub const fn thinking(&self) -> Option<crate::model::ThinkingLevel> {
1287        self.thinking_level()
1288    }
1289
1290    /// Update thinking level and persist it to session metadata.
1291    pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
1292        let cx = crate::agent_cx::AgentCx::for_request();
1293        let (effective_level, changed) = {
1294            let mut guard = self
1295                .session
1296                .session
1297                .lock(cx.cx())
1298                .await
1299                .map_err(|e| Error::session(e.to_string()))?;
1300            let (provider_id, model_id) = guard
1301                .effective_model_for_current_path()
1302                .unwrap_or_else(|| self.model());
1303            let effective_level =
1304                self.session
1305                    .clamp_thinking_level_for_model(&provider_id, &model_id, level);
1306            let level_string = effective_level.to_string();
1307            let changed = guard.effective_thinking_level_for_current_path().as_deref()
1308                != Some(level_string.as_str());
1309            guard.set_model_header(None, None, Some(level_string.clone()));
1310            if changed {
1311                guard.append_thinking_level_change(level_string);
1312            }
1313            (effective_level, changed)
1314        };
1315        self.session.agent.stream_options_mut().thinking_level = Some(effective_level);
1316        if changed {
1317            self.session.persist_session().await
1318        } else {
1319            Ok(())
1320        }
1321    }
1322
1323    /// Return all model messages for the current session path.
1324    pub async fn messages(&self) -> Result<Vec<Message>> {
1325        let cx = crate::agent_cx::AgentCx::for_request();
1326        let guard = self
1327            .session
1328            .session
1329            .lock(cx.cx())
1330            .await
1331            .map_err(|e| Error::session(e.to_string()))?;
1332        Ok(guard.to_messages_for_current_path())
1333    }
1334
1335    /// Return a lightweight state snapshot.
1336    pub async fn state(&self) -> Result<AgentSessionState> {
1337        let (provider, model_id) = self.model();
1338        let thinking_level = self.thinking_level();
1339        let save_enabled = self.session.save_enabled();
1340        let cx = crate::agent_cx::AgentCx::for_request();
1341        let guard = self
1342            .session
1343            .session
1344            .lock(cx.cx())
1345            .await
1346            .map_err(|e| Error::session(e.to_string()))?;
1347        let session_id = Some(guard.header.id.clone());
1348        let message_count = guard.to_messages_for_current_path().len();
1349
1350        Ok(AgentSessionState {
1351            session_id,
1352            provider,
1353            model_id,
1354            thinking_level,
1355            save_enabled,
1356            message_count,
1357        })
1358    }
1359
1360    /// Trigger an immediate compaction pass (if compaction is enabled).
1361    pub async fn compact(
1362        &mut self,
1363        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1364    ) -> Result<()> {
1365        self.session.compact_now(on_event).await
1366    }
1367
1368    /// Access the underlying `AgentSession`.
1369    pub const fn session(&self) -> &AgentSession {
1370        &self.session
1371    }
1372
1373    /// Mutable access to the underlying `AgentSession`.
1374    pub const fn session_mut(&mut self) -> &mut AgentSession {
1375        &mut self.session
1376    }
1377
1378    /// Consume the handle and return the inner `AgentSession`.
1379    pub fn into_inner(self) -> AgentSession {
1380        self.session
1381    }
1382
1383    /// Build a combined callback that fans out to the per-prompt callback,
1384    /// session-level subscribers, and typed hooks.
1385    fn make_combined_callback(
1386        &self,
1387        per_prompt: impl Fn(AgentEvent) + Send + Sync + 'static,
1388    ) -> impl Fn(AgentEvent) + Send + Sync + 'static {
1389        let listeners = self.listeners.clone();
1390        move |event: AgentEvent| {
1391            // Typed tool hooks — fire before generic listeners.
1392            match &event {
1393                AgentEvent::ToolExecutionStart {
1394                    tool_name, args, ..
1395                } => {
1396                    listeners.notify_tool_start(tool_name, args);
1397                }
1398                AgentEvent::ToolExecutionEnd {
1399                    tool_name,
1400                    result,
1401                    is_error,
1402                    ..
1403                } => {
1404                    listeners.notify_tool_end(tool_name, result, *is_error);
1405                }
1406                AgentEvent::MessageUpdate {
1407                    assistant_message_event,
1408                    ..
1409                } => {
1410                    // Forward raw stream events from the nested
1411                    // `AssistantMessageEvent` when possible.
1412                    if let Some(stream_ev) =
1413                        stream_event_from_assistant_message_event(assistant_message_event)
1414                    {
1415                        listeners.notify_stream_event(&stream_ev);
1416                    }
1417                }
1418                _ => {}
1419            }
1420
1421            // Session-level generic subscribers.
1422            listeners.notify(&event);
1423
1424            // Per-prompt callback.
1425            per_prompt(event);
1426        }
1427    }
1428}
1429
1430/// Extract a raw [`StreamEvent`] equivalent from an [`AssistantMessageEvent`].
1431///
1432/// This lets the typed `on_stream_event` hook fire with the low-level provider
1433/// protocol event rather than the wrapped agent-level event.
1434fn stream_event_from_assistant_message_event(
1435    event: &crate::model::AssistantMessageEvent,
1436) -> Option<StreamEvent> {
1437    use crate::model::AssistantMessageEvent as AME;
1438    match event {
1439        AME::TextStart { content_index, .. } => Some(StreamEvent::TextStart {
1440            content_index: *content_index,
1441        }),
1442        AME::TextDelta {
1443            content_index,
1444            delta,
1445            ..
1446        } => Some(StreamEvent::TextDelta {
1447            content_index: *content_index,
1448            delta: delta.clone(),
1449        }),
1450        AME::TextEnd {
1451            content_index,
1452            content,
1453            ..
1454        } => Some(StreamEvent::TextEnd {
1455            content_index: *content_index,
1456            content: content.clone(),
1457        }),
1458        AME::ThinkingStart { content_index, .. } => Some(StreamEvent::ThinkingStart {
1459            content_index: *content_index,
1460        }),
1461        AME::ThinkingDelta {
1462            content_index,
1463            delta,
1464            ..
1465        } => Some(StreamEvent::ThinkingDelta {
1466            content_index: *content_index,
1467            delta: delta.clone(),
1468        }),
1469        AME::ThinkingEnd {
1470            content_index,
1471            content,
1472            ..
1473        } => Some(StreamEvent::ThinkingEnd {
1474            content_index: *content_index,
1475            content: content.clone(),
1476        }),
1477        AME::ToolCallStart { content_index, .. } => Some(StreamEvent::ToolCallStart {
1478            content_index: *content_index,
1479        }),
1480        AME::ToolCallDelta {
1481            content_index,
1482            delta,
1483            ..
1484        } => Some(StreamEvent::ToolCallDelta {
1485            content_index: *content_index,
1486            delta: delta.clone(),
1487        }),
1488        AME::ToolCallEnd {
1489            content_index,
1490            tool_call,
1491            ..
1492        } => Some(StreamEvent::ToolCallEnd {
1493            content_index: *content_index,
1494            tool_call: tool_call.clone(),
1495        }),
1496        AME::Done { reason, message } => Some(StreamEvent::Done {
1497            reason: *reason,
1498            message: (**message).clone(),
1499        }),
1500        AME::Error { reason, error } => Some(StreamEvent::Error {
1501            reason: *reason,
1502            error: (**error).clone(),
1503        }),
1504        AME::Start { .. } => None,
1505    }
1506}
1507
1508fn resolve_path_for_cwd(path: &Path, cwd: &Path) -> PathBuf {
1509    if path.is_absolute() {
1510        path.to_path_buf()
1511    } else {
1512        cwd.join(path)
1513    }
1514}
1515
1516fn build_stream_options_with_optional_key(
1517    config: &Config,
1518    api_key: Option<String>,
1519    selection: &app::ModelSelection,
1520    session: &Session,
1521) -> StreamOptions {
1522    let mut options = StreamOptions {
1523        api_key,
1524        headers: selection.model_entry.headers.clone(),
1525        session_id: Some(session.header.id.clone()),
1526        thinking_level: Some(selection.thinking_level),
1527        ..Default::default()
1528    };
1529
1530    if let Some(budgets) = &config.thinking_budgets {
1531        let defaults = ThinkingBudgets::default();
1532        options.thinking_budgets = Some(ThinkingBudgets {
1533            minimal: budgets.minimal.unwrap_or(defaults.minimal),
1534            low: budgets.low.unwrap_or(defaults.low),
1535            medium: budgets.medium.unwrap_or(defaults.medium),
1536            high: budgets.high.unwrap_or(defaults.high),
1537            xhigh: budgets.xhigh.unwrap_or(defaults.xhigh),
1538        });
1539    }
1540
1541    options
1542}
1543
1544/// Create a fully configured embeddable agent session.
1545///
1546/// This is the programmatic entrypoint for non-CLI consumers that want to run
1547/// Pi sessions in-process.
1548#[allow(clippy::too_many_lines)]
1549pub async fn create_agent_session(options: SessionOptions) -> Result<AgentSessionHandle> {
1550    let process_cwd =
1551        std::env::current_dir().map_err(|e| Error::config(format!("cwd lookup failed: {e}")))?;
1552    let cwd = options.working_directory.as_deref().map_or_else(
1553        || process_cwd.clone(),
1554        |path| resolve_path_for_cwd(path, &process_cwd),
1555    );
1556
1557    let mut cli = Cli::try_parse_from(["pi"])
1558        .map_err(|e| Error::validation(format!("CLI init failed: {e}")))?;
1559    cli.no_session = options.no_session;
1560    cli.provider = options.provider.clone();
1561    cli.model = options.model.clone();
1562    cli.api_key = options.api_key.clone();
1563    cli.system_prompt = options.system_prompt.clone();
1564    cli.append_system_prompt = options.append_system_prompt.clone();
1565    cli.hide_cwd_in_prompt = !options.include_cwd_in_prompt;
1566    cli.thinking = options.thinking.map(|t| t.to_string());
1567    cli.session = options
1568        .session_path
1569        .as_ref()
1570        .map(|p| p.to_string_lossy().to_string());
1571    cli.session_dir = options
1572        .session_dir
1573        .as_ref()
1574        .map(|p| p.to_string_lossy().to_string());
1575    if let Some(enabled_tools) = &options.enabled_tools {
1576        if enabled_tools.is_empty() {
1577            cli.no_tools = true;
1578        } else {
1579            cli.no_tools = false;
1580            cli.tools = enabled_tools.join(",");
1581        }
1582    }
1583
1584    let config = Config::load()?;
1585
1586    let mut auth = AuthStorage::load_async(Config::auth_path()).await?;
1587    auth.refresh_expired_oauth_tokens().await?;
1588
1589    let global_dir = Config::global_dir();
1590    let package_dir = Config::package_dir();
1591    let models_path = default_models_path(&global_dir);
1592    let model_registry = ModelRegistry::load(&auth, Some(models_path));
1593
1594    let mut session = Session::new(&cli, &config).await?;
1595    let scoped_patterns = if let Some(models_arg) = &cli.models {
1596        app::parse_models_arg(models_arg)
1597    } else {
1598        config.enabled_models.clone().unwrap_or_default()
1599    };
1600    let scoped_models = if scoped_patterns.is_empty() {
1601        Vec::new()
1602    } else {
1603        app::resolve_model_scope(&scoped_patterns, &model_registry, cli.api_key.is_some())
1604    };
1605
1606    let selection = app::select_model_and_thinking(
1607        &cli,
1608        &config,
1609        &session,
1610        &model_registry,
1611        &scoped_models,
1612        &global_dir,
1613    )
1614    .map_err(|err| Error::validation(err.to_string()))?;
1615    app::update_session_for_selection(&mut session, &selection);
1616
1617    let enabled_tools_owned = cli
1618        .enabled_tools()
1619        .into_iter()
1620        .map(str::to_string)
1621        .collect::<Vec<_>>();
1622    let enabled_tools = enabled_tools_owned
1623        .iter()
1624        .map(String::as_str)
1625        .collect::<Vec<_>>();
1626
1627    let system_prompt = app::build_system_prompt(
1628        &cli,
1629        &cwd,
1630        &enabled_tools,
1631        None,
1632        &global_dir,
1633        &package_dir,
1634        std::env::var_os("PI_TEST_MODE").is_some(),
1635        options.include_cwd_in_prompt,
1636    )
1637    .map_err(|err| Error::validation(err.to_string()))?;
1638
1639    let provider = providers::create_provider(&selection.model_entry, None)
1640        .map_err(|e| Error::provider("sdk", e.to_string()))?;
1641
1642    let api_key = app::resolve_api_key(&auth, &cli, &selection.model_entry)
1643        .map_err(|err| Error::validation(err.to_string()))?;
1644
1645    let stream_options =
1646        build_stream_options_with_optional_key(&config, api_key, &selection, &session);
1647
1648    let agent_config = AgentConfig {
1649        system_prompt: Some(system_prompt),
1650        max_tool_iterations: options.max_tool_iterations,
1651        stream_options,
1652        block_images: config.image_block_images(),
1653        fail_closed_hooks: config.fail_closed_hooks(),
1654    };
1655
1656    let tools = ToolRegistry::new(&enabled_tools, &cwd, Some(&config));
1657    let session_arc = Arc::new(asupersync::sync::Mutex::new(session));
1658
1659    let context_window_tokens = if selection.model_entry.model.context_window == 0 {
1660        ResolvedCompactionSettings::default().context_window_tokens
1661    } else {
1662        selection.model_entry.model.context_window
1663    };
1664    let compaction_settings = ResolvedCompactionSettings {
1665        enabled: config.compaction_enabled(),
1666        reserve_tokens: config.compaction_reserve_tokens(),
1667        keep_recent_tokens: config.compaction_keep_recent_tokens(),
1668        context_window_tokens,
1669    };
1670
1671    let mut agent_session = AgentSession::new(
1672        Agent::new(provider, tools, agent_config),
1673        Arc::clone(&session_arc),
1674        !cli.no_session,
1675        compaction_settings,
1676    );
1677
1678    if !options.extension_paths.is_empty() {
1679        let extension_paths = options
1680            .extension_paths
1681            .iter()
1682            .map(|path| resolve_path_for_cwd(path, &cwd))
1683            .collect::<Vec<_>>();
1684        let resolved_ext_policy =
1685            config.resolve_extension_policy_with_metadata(options.extension_policy.as_deref());
1686        let resolved_repair_policy =
1687            config.resolve_repair_policy_with_metadata(options.repair_policy.as_deref());
1688
1689        agent_session
1690            .enable_extensions_with_policy(
1691                &enabled_tools,
1692                &cwd,
1693                Some(&config),
1694                &extension_paths,
1695                Some(resolved_ext_policy.policy),
1696                Some(resolved_repair_policy.effective_mode),
1697                None,
1698            )
1699            .await?;
1700    }
1701
1702    agent_session.set_model_registry(model_registry.clone());
1703    agent_session.set_auth_storage(auth);
1704
1705    let history = {
1706        let cx = crate::agent_cx::AgentCx::for_request();
1707        let guard = session_arc
1708            .lock(cx.cx())
1709            .await
1710            .map_err(|e| Error::session(e.to_string()))?;
1711        guard.to_messages_for_current_path()
1712    };
1713    if !history.is_empty() {
1714        agent_session.agent.replace_messages(history);
1715    }
1716
1717    let mut listeners = EventListeners::new();
1718    if let Some(on_event) = options.on_event {
1719        listeners.subscribe(on_event);
1720    }
1721    listeners.on_tool_start = options.on_tool_start;
1722    listeners.on_tool_end = options.on_tool_end;
1723    listeners.on_stream_event = options.on_stream_event;
1724
1725    Ok(AgentSessionHandle {
1726        session: agent_session,
1727        listeners,
1728    })
1729}
1730
1731#[cfg(test)]
1732mod tests {
1733    use super::*;
1734    use asupersync::runtime::RuntimeBuilder;
1735    use asupersync::runtime::reactor::create_reactor;
1736    use asupersync::sync::Mutex as AsyncMutex;
1737    use std::sync::{Arc, Mutex};
1738    use tempfile::tempdir;
1739
1740    fn run_async<F>(future: F) -> F::Output
1741    where
1742        F: std::future::Future,
1743    {
1744        let reactor = create_reactor().expect("create reactor");
1745        let runtime = RuntimeBuilder::current_thread()
1746            .with_reactor(reactor)
1747            .build()
1748            .expect("build runtime");
1749        runtime.block_on(future)
1750    }
1751
1752    #[test]
1753    fn create_agent_session_default_succeeds() {
1754        let tmp = tempdir().expect("tempdir");
1755        let options = SessionOptions {
1756            working_directory: Some(tmp.path().to_path_buf()),
1757            no_session: true,
1758            ..SessionOptions::default()
1759        };
1760
1761        let handle = run_async(create_agent_session(options)).expect("create session");
1762        let provider = handle.session().agent.provider();
1763        assert!(!provider.name().is_empty());
1764        assert!(!provider.model_id().is_empty());
1765        assert_eq!(handle.model().0, provider.name());
1766        assert_eq!(handle.model().1, provider.model_id());
1767    }
1768
1769    #[test]
1770    fn create_agent_session_respects_provider_model_and_clamps_thinking() {
1771        let tmp = tempdir().expect("tempdir");
1772        let options = SessionOptions {
1773            provider: Some("openai".to_string()),
1774            model: Some("gpt-4o".to_string()),
1775            api_key: Some("dummy-key".to_string()),
1776            thinking: Some(crate::model::ThinkingLevel::Low),
1777            working_directory: Some(tmp.path().to_path_buf()),
1778            no_session: true,
1779            ..SessionOptions::default()
1780        };
1781
1782        let handle = run_async(create_agent_session(options)).expect("create session");
1783        let provider = handle.session().agent.provider();
1784        assert_eq!(provider.name(), "openai");
1785        assert_eq!(provider.model_id(), "gpt-4o");
1786        assert_eq!(
1787            handle.session().agent.stream_options().thinking_level,
1788            Some(crate::model::ThinkingLevel::Off)
1789        );
1790    }
1791
1792    #[test]
1793    fn create_agent_session_no_session_keeps_ephemeral_state() {
1794        let tmp = tempdir().expect("tempdir");
1795        let options = SessionOptions {
1796            working_directory: Some(tmp.path().to_path_buf()),
1797            no_session: true,
1798            ..SessionOptions::default()
1799        };
1800
1801        let handle = run_async(create_agent_session(options)).expect("create session");
1802        assert!(!handle.session().save_enabled());
1803
1804        let path_is_none = run_async(async {
1805            let cx = crate::agent_cx::AgentCx::for_request();
1806            let guard = handle
1807                .session()
1808                .session
1809                .lock(cx.cx())
1810                .await
1811                .expect("lock session");
1812            guard.path.is_none()
1813        });
1814        assert!(path_is_none);
1815    }
1816
1817    #[test]
1818    fn from_session_with_listeners_set_model_switches_provider_model() {
1819        let dir = tempdir().expect("tempdir");
1820        let auth_path = dir.path().join("auth.json");
1821        let mut auth = AuthStorage::load(auth_path).expect("load auth");
1822        auth.set(
1823            "anthropic",
1824            crate::auth::AuthCredential::ApiKey {
1825                key: "anthropic-key".to_string(),
1826            },
1827        );
1828        auth.set(
1829            "openai",
1830            crate::auth::AuthCredential::ApiKey {
1831                key: "openai-key".to_string(),
1832            },
1833        );
1834
1835        let registry = ModelRegistry::load(&auth, None);
1836        let entry = registry
1837            .find("anthropic", "claude-sonnet-4-5")
1838            .expect("anthropic model in registry");
1839        let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
1840        let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
1841        let agent = Agent::new(
1842            provider,
1843            tools,
1844            AgentConfig {
1845                system_prompt: None,
1846                max_tool_iterations: 50,
1847                stream_options: StreamOptions::default(),
1848                block_images: false,
1849                fail_closed_hooks: false,
1850            },
1851        );
1852
1853        let mut session = Session::in_memory();
1854        session.header.provider = Some("anthropic".to_string());
1855        session.header.model_id = Some("claude-sonnet-4-5".to_string());
1856
1857        let mut agent_session = AgentSession::new(
1858            agent,
1859            Arc::new(AsyncMutex::new(session)),
1860            false,
1861            ResolvedCompactionSettings::default(),
1862        );
1863        agent_session.set_model_registry(registry);
1864        agent_session.set_auth_storage(auth);
1865
1866        let mut handle =
1867            AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
1868        run_async(handle.set_model("openai", "gpt-4o")).expect("set model");
1869        let provider = handle.session().agent.provider();
1870        assert_eq!(provider.name(), "openai");
1871        assert_eq!(provider.model_id(), "gpt-4o");
1872    }
1873
1874    #[test]
1875    fn create_agent_session_set_thinking_level_clamps_and_dedupes_history() {
1876        let tmp = tempdir().expect("tempdir");
1877        let options = SessionOptions {
1878            provider: Some("openai".to_string()),
1879            model: Some("gpt-4o".to_string()),
1880            api_key: Some("dummy-key".to_string()),
1881            working_directory: Some(tmp.path().to_path_buf()),
1882            no_session: true,
1883            ..SessionOptions::default()
1884        };
1885
1886        let mut handle = run_async(create_agent_session(options)).expect("create session");
1887        run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1888            .expect("set thinking");
1889        run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1890            .expect("reapply thinking");
1891
1892        assert_eq!(
1893            handle.session().agent.stream_options().thinking_level,
1894            Some(crate::model::ThinkingLevel::Off)
1895        );
1896
1897        let thinking_changes = run_async(async {
1898            let cx = crate::agent_cx::AgentCx::for_request();
1899            let guard = handle
1900                .session()
1901                .session
1902                .lock(cx.cx())
1903                .await
1904                .expect("lock session");
1905            assert_eq!(guard.header.thinking_level.as_deref(), Some("off"));
1906            guard
1907                .entries_for_current_path()
1908                .iter()
1909                .filter(|entry| {
1910                    matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
1911                })
1912                .count()
1913        });
1914        assert_eq!(thinking_changes, 1);
1915    }
1916
1917    #[test]
1918    fn from_session_with_listeners_set_thinking_level_uses_session_header_target() {
1919        let dir = tempdir().expect("tempdir");
1920        let auth_path = dir.path().join("auth.json");
1921        let auth = crate::auth::AuthStorage::load(auth_path).expect("load auth");
1922        let mut registry = ModelRegistry::load(&auth, None);
1923        registry.merge_entries(vec![ModelEntry {
1924            model: Model {
1925                id: "plain-model".to_string(),
1926                name: "Plain Model".to_string(),
1927                api: "openai-completions".to_string(),
1928                provider: "acme".to_string(),
1929                base_url: "https://example.invalid/v1".to_string(),
1930                reasoning: false,
1931                input: vec![InputType::Text],
1932                cost: ModelCost {
1933                    input: 0.0,
1934                    output: 0.0,
1935                    cache_read: 0.0,
1936                    cache_write: 0.0,
1937                },
1938                context_window: 128_000,
1939                max_tokens: 8_192,
1940                headers: HashMap::new(),
1941            },
1942            api_key: None,
1943            headers: HashMap::new(),
1944            auth_header: false,
1945            compat: None,
1946            oauth_config: None,
1947        }]);
1948        let entry = registry
1949            .find("anthropic", "claude-sonnet-4-5")
1950            .expect("anthropic model in registry");
1951        let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
1952        let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
1953        let agent = Agent::new(
1954            provider,
1955            tools,
1956            AgentConfig {
1957                system_prompt: None,
1958                max_tool_iterations: 50,
1959                stream_options: StreamOptions::default(),
1960                block_images: false,
1961                fail_closed_hooks: false,
1962            },
1963        );
1964
1965        let mut session = Session::in_memory();
1966        session.header.provider = Some("acme".to_string());
1967        session.header.model_id = Some("plain-model".to_string());
1968
1969        let mut agent_session = AgentSession::new(
1970            agent,
1971            Arc::new(AsyncMutex::new(session)),
1972            false,
1973            ResolvedCompactionSettings::default(),
1974        );
1975        agent_session.set_model_registry(registry);
1976
1977        let mut handle =
1978            AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
1979        run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1980            .expect("set thinking");
1981
1982        assert_eq!(
1983            handle.session().agent.stream_options().thinking_level,
1984            Some(crate::model::ThinkingLevel::Off)
1985        );
1986        assert_eq!(handle.model().0, "anthropic");
1987        assert_eq!(handle.model().1, "claude-sonnet-4-5");
1988    }
1989
1990    #[test]
1991    fn compact_without_history_is_noop() {
1992        let tmp = tempdir().expect("tempdir");
1993        let options = SessionOptions {
1994            working_directory: Some(tmp.path().to_path_buf()),
1995            no_session: true,
1996            ..SessionOptions::default()
1997        };
1998
1999        let mut handle = run_async(create_agent_session(options)).expect("create session");
2000        let events = Arc::new(Mutex::new(Vec::new()));
2001        let events_for_callback = Arc::clone(&events);
2002        run_async(handle.compact(move |event| {
2003            events_for_callback
2004                .lock()
2005                .expect("compact callback lock")
2006                .push(event);
2007        }))
2008        .expect("compact");
2009
2010        assert!(
2011            events
2012                .lock()
2013                .unwrap_or_else(std::sync::PoisonError::into_inner)
2014                .is_empty(),
2015            "expected no compaction lifecycle events for empty session"
2016        );
2017    }
2018
2019    #[test]
2020    fn resolve_path_for_cwd_uses_cwd_for_relative_paths() {
2021        let cwd = Path::new("/tmp/pi-sdk-cwd");
2022        assert_eq!(
2023            resolve_path_for_cwd(Path::new("relative/file.txt"), cwd),
2024            PathBuf::from("/tmp/pi-sdk-cwd/relative/file.txt")
2025        );
2026        assert_eq!(
2027            resolve_path_for_cwd(Path::new("/etc/hosts"), cwd),
2028            PathBuf::from("/etc/hosts")
2029        );
2030    }
2031
2032    // =====================================================================
2033    // EventListeners tests
2034    // =====================================================================
2035
2036    #[test]
2037    fn event_listeners_subscribe_and_notify() {
2038        let listeners = EventListeners::new();
2039        let received = Arc::new(Mutex::new(Vec::new()));
2040
2041        let recv_clone = Arc::clone(&received);
2042        let id = listeners.subscribe(Arc::new(move |event| {
2043            recv_clone
2044                .lock()
2045                .unwrap_or_else(std::sync::PoisonError::into_inner)
2046                .push(event);
2047        }));
2048
2049        let event = AgentEvent::AgentStart {
2050            session_id: "test-123".into(),
2051        };
2052        listeners.notify(&event);
2053
2054        let events = received
2055            .lock()
2056            .unwrap_or_else(std::sync::PoisonError::into_inner);
2057        assert_eq!(events.len(), 1);
2058
2059        // Verify unsubscribe
2060        drop(events);
2061        assert!(listeners.unsubscribe(id));
2062        listeners.notify(&AgentEvent::AgentStart {
2063            session_id: "test-456".into(),
2064        });
2065        assert_eq!(
2066            received
2067                .lock()
2068                .unwrap_or_else(std::sync::PoisonError::into_inner)
2069                .len(),
2070            1
2071        );
2072    }
2073
2074    #[test]
2075    fn event_listeners_unsubscribe_nonexistent_returns_false() {
2076        let listeners = EventListeners::new();
2077        assert!(!listeners.unsubscribe(SubscriptionId(999)));
2078    }
2079
2080    #[test]
2081    fn event_listeners_multiple_subscribers() {
2082        let listeners = EventListeners::new();
2083        let count_a = Arc::new(Mutex::new(0u32));
2084        let count_b = Arc::new(Mutex::new(0u32));
2085
2086        let ca = Arc::clone(&count_a);
2087        listeners.subscribe(Arc::new(move |_| {
2088            *ca.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2089        }));
2090
2091        let cb = Arc::clone(&count_b);
2092        listeners.subscribe(Arc::new(move |_| {
2093            *cb.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2094        }));
2095
2096        listeners.notify(&AgentEvent::AgentStart {
2097            session_id: "s".into(),
2098        });
2099
2100        assert_eq!(
2101            *count_a
2102                .lock()
2103                .unwrap_or_else(std::sync::PoisonError::into_inner),
2104            1
2105        );
2106        assert_eq!(
2107            *count_b
2108                .lock()
2109                .unwrap_or_else(std::sync::PoisonError::into_inner),
2110            1
2111        );
2112    }
2113
2114    #[test]
2115    fn event_listeners_tool_hooks_fire() {
2116        let listeners = EventListeners::new();
2117        let starts = Arc::new(Mutex::new(Vec::new()));
2118        let ends = Arc::new(Mutex::new(Vec::new()));
2119
2120        let s = Arc::clone(&starts);
2121        let mut listeners = listeners;
2122        listeners.on_tool_start = Some(Arc::new(move |name, args| {
2123            s.lock()
2124                .expect("lock")
2125                .push((name.to_string(), args.clone()));
2126        }));
2127
2128        let e = Arc::clone(&ends);
2129        listeners.on_tool_end = Some(Arc::new(move |name, _output, is_error| {
2130            e.lock()
2131                .unwrap_or_else(std::sync::PoisonError::into_inner)
2132                .push((name.to_string(), is_error));
2133        }));
2134
2135        let args = serde_json::json!({"path": "/foo"});
2136        listeners.notify_tool_start("bash", &args);
2137        let output = ToolOutput {
2138            content: vec![ContentBlock::Text(TextContent::new("ok"))],
2139            details: None,
2140            is_error: false,
2141        };
2142        listeners.notify_tool_end("bash", &output, false);
2143
2144        {
2145            let s = starts
2146                .lock()
2147                .unwrap_or_else(std::sync::PoisonError::into_inner);
2148            assert_eq!(s.len(), 1);
2149            assert_eq!(s[0].0, "bash");
2150            drop(s);
2151        }
2152
2153        {
2154            let e = ends
2155                .lock()
2156                .unwrap_or_else(std::sync::PoisonError::into_inner);
2157            assert_eq!(e.len(), 1);
2158            assert_eq!(e[0].0, "bash");
2159            assert!(!e[0].1);
2160            drop(e);
2161        }
2162    }
2163
2164    #[test]
2165    fn event_listeners_stream_event_hook_fires() {
2166        let mut listeners = EventListeners::new();
2167        let received = Arc::new(Mutex::new(Vec::new()));
2168
2169        let r = Arc::clone(&received);
2170        listeners.on_stream_event = Some(Arc::new(move |ev| {
2171            r.lock()
2172                .unwrap_or_else(std::sync::PoisonError::into_inner)
2173                .push(format!("{ev:?}"));
2174        }));
2175
2176        let event = StreamEvent::TextDelta {
2177            content_index: 0,
2178            delta: "hello".to_string(),
2179        };
2180        listeners.notify_stream_event(&event);
2181
2182        assert_eq!(
2183            received
2184                .lock()
2185                .unwrap_or_else(std::sync::PoisonError::into_inner)
2186                .len(),
2187            1
2188        );
2189    }
2190
2191    #[test]
2192    fn session_options_on_event_wired_into_listeners() {
2193        let received = Arc::new(Mutex::new(Vec::new()));
2194        let r = Arc::clone(&received);
2195        let tmp = tempdir().expect("tempdir");
2196
2197        let options = SessionOptions {
2198            working_directory: Some(tmp.path().to_path_buf()),
2199            no_session: true,
2200            on_event: Some(Arc::new(move |event| {
2201                r.lock()
2202                    .unwrap_or_else(std::sync::PoisonError::into_inner)
2203                    .push(format!("{event:?}"));
2204            })),
2205            ..SessionOptions::default()
2206        };
2207
2208        let handle = run_async(create_agent_session(options)).expect("create session");
2209        // Verify the listener was registered
2210        let count = handle
2211            .listeners()
2212            .subscribers
2213            .lock()
2214            .unwrap_or_else(std::sync::PoisonError::into_inner)
2215            .len();
2216        assert_eq!(
2217            count, 1,
2218            "on_event from SessionOptions should register one subscriber"
2219        );
2220    }
2221
2222    #[test]
2223    fn subscribe_unsubscribe_on_handle() {
2224        let tmp = tempdir().expect("tempdir");
2225        let options = SessionOptions {
2226            working_directory: Some(tmp.path().to_path_buf()),
2227            no_session: true,
2228            ..SessionOptions::default()
2229        };
2230
2231        let handle = run_async(create_agent_session(options)).expect("create session");
2232        let id = handle.subscribe(|_event| {});
2233        assert_eq!(
2234            handle
2235                .listeners()
2236                .subscribers
2237                .lock()
2238                .unwrap_or_else(std::sync::PoisonError::into_inner)
2239                .len(),
2240            1
2241        );
2242
2243        assert!(handle.unsubscribe(id));
2244        assert_eq!(
2245            handle
2246                .listeners()
2247                .subscribers
2248                .lock()
2249                .unwrap_or_else(std::sync::PoisonError::into_inner)
2250                .len(),
2251            0
2252        );
2253
2254        // Double unsubscribe returns false
2255        assert!(!handle.unsubscribe(id));
2256    }
2257
2258    #[test]
2259    fn stream_event_from_assistant_message_event_converts_text_delta() {
2260        use crate::model::AssistantMessageEvent as AME;
2261
2262        let partial = Arc::new(AssistantMessage {
2263            content: Vec::new(),
2264            api: String::new(),
2265            provider: String::new(),
2266            model: String::new(),
2267            usage: Usage::default(),
2268            stop_reason: StopReason::Stop,
2269            error_message: None,
2270            timestamp: 0,
2271        });
2272        let ame = AME::TextDelta {
2273            content_index: 2,
2274            delta: "chunk".to_string(),
2275            partial,
2276        };
2277        let result = stream_event_from_assistant_message_event(&ame);
2278        assert!(result.is_some());
2279        match result.unwrap() {
2280            StreamEvent::TextDelta {
2281                content_index,
2282                delta,
2283            } => {
2284                assert_eq!(content_index, 2);
2285                assert_eq!(delta, "chunk");
2286            }
2287            other => unreachable!("expected TextDelta, got {other:?}"),
2288        }
2289    }
2290
2291    #[test]
2292    fn stream_event_from_assistant_message_event_start_returns_none() {
2293        use crate::model::AssistantMessageEvent as AME;
2294
2295        let partial = Arc::new(AssistantMessage {
2296            content: Vec::new(),
2297            api: String::new(),
2298            provider: String::new(),
2299            model: String::new(),
2300            usage: Usage::default(),
2301            stop_reason: StopReason::Stop,
2302            error_message: None,
2303            timestamp: 0,
2304        });
2305        let ame = AME::Start { partial };
2306        assert!(stream_event_from_assistant_message_event(&ame).is_none());
2307    }
2308
2309    #[test]
2310    fn event_listeners_debug_impl() {
2311        let listeners = EventListeners::new();
2312        let debug = format!("{listeners:?}");
2313        assert!(debug.contains("subscriber_count"));
2314        assert!(debug.contains("has_on_tool_start"));
2315    }
2316
2317    // =====================================================================
2318    // Extension convenience method tests
2319    // =====================================================================
2320
2321    #[test]
2322    fn has_extensions_false_by_default() {
2323        let tmp = tempdir().expect("tempdir");
2324        let options = SessionOptions {
2325            working_directory: Some(tmp.path().to_path_buf()),
2326            no_session: true,
2327            ..SessionOptions::default()
2328        };
2329
2330        let handle = run_async(create_agent_session(options)).expect("create session");
2331        assert!(
2332            !handle.has_extensions(),
2333            "session without extension_paths should have no extensions"
2334        );
2335        assert!(handle.extension_manager().is_none());
2336        assert!(handle.extension_region().is_none());
2337    }
2338
2339    // =====================================================================
2340    // Tool factory function tests
2341    // =====================================================================
2342
2343    #[test]
2344    fn create_read_tool_has_correct_name() {
2345        let tmp = tempdir().expect("tempdir");
2346        let tool = super::create_read_tool(tmp.path());
2347        assert_eq!(tool.name(), "read");
2348        assert!(!tool.description().is_empty());
2349        let params = tool.parameters();
2350        assert!(params.is_object(), "parameters should be a JSON object");
2351    }
2352
2353    #[test]
2354    fn create_bash_tool_has_correct_name() {
2355        let tmp = tempdir().expect("tempdir");
2356        let tool = super::create_bash_tool(tmp.path());
2357        assert_eq!(tool.name(), "bash");
2358        assert!(!tool.description().is_empty());
2359    }
2360
2361    #[test]
2362    fn create_edit_tool_has_correct_name() {
2363        let tmp = tempdir().expect("tempdir");
2364        let tool = super::create_edit_tool(tmp.path());
2365        assert_eq!(tool.name(), "edit");
2366    }
2367
2368    #[test]
2369    fn create_write_tool_has_correct_name() {
2370        let tmp = tempdir().expect("tempdir");
2371        let tool = super::create_write_tool(tmp.path());
2372        assert_eq!(tool.name(), "write");
2373    }
2374
2375    #[test]
2376    fn create_grep_tool_has_correct_name() {
2377        let tmp = tempdir().expect("tempdir");
2378        let tool = super::create_grep_tool(tmp.path());
2379        assert_eq!(tool.name(), "grep");
2380    }
2381
2382    #[test]
2383    fn create_find_tool_has_correct_name() {
2384        let tmp = tempdir().expect("tempdir");
2385        let tool = super::create_find_tool(tmp.path());
2386        assert_eq!(tool.name(), "find");
2387    }
2388
2389    #[test]
2390    fn create_ls_tool_has_correct_name() {
2391        let tmp = tempdir().expect("tempdir");
2392        let tool = super::create_ls_tool(tmp.path());
2393        assert_eq!(tool.name(), "ls");
2394    }
2395
2396    #[test]
2397    fn create_all_tools_returns_eight() {
2398        let tmp = tempdir().expect("tempdir");
2399        let tools = super::create_all_tools(tmp.path());
2400        assert_eq!(tools.len(), 8, "should create all 8 built-in tools");
2401
2402        let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2403        for expected in BUILTIN_TOOL_NAMES {
2404            assert!(names.contains(expected), "missing tool: {expected}");
2405        }
2406    }
2407
2408    #[test]
2409    fn tool_to_definition_preserves_schema() {
2410        let tmp = tempdir().expect("tempdir");
2411        let tool = super::create_read_tool(tmp.path());
2412        let def = super::tool_to_definition(tool.as_ref());
2413        assert_eq!(def.name, "read");
2414        assert!(!def.description.is_empty());
2415        assert!(def.parameters.is_object());
2416        assert!(
2417            def.parameters.get("properties").is_some(),
2418            "schema should have properties"
2419        );
2420    }
2421
2422    #[test]
2423    fn all_tool_definitions_returns_eight_schemas() {
2424        let tmp = tempdir().expect("tempdir");
2425        let defs = super::all_tool_definitions(tmp.path());
2426        assert_eq!(defs.len(), 8);
2427
2428        for def in &defs {
2429            assert!(!def.name.is_empty());
2430            assert!(!def.description.is_empty());
2431            assert!(def.parameters.is_object());
2432        }
2433    }
2434
2435    #[test]
2436    fn builtin_tool_names_matches_create_all() {
2437        let tmp = tempdir().expect("tempdir");
2438        let tools = super::create_all_tools(tmp.path());
2439        let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2440        assert_eq!(
2441            names.as_slice(),
2442            BUILTIN_TOOL_NAMES,
2443            "create_all_tools order should match BUILTIN_TOOL_NAMES"
2444        );
2445    }
2446
2447    #[test]
2448    fn tool_registry_from_factory_tools() {
2449        let tmp = tempdir().expect("tempdir");
2450        let tools = super::create_all_tools(tmp.path());
2451        let registry = ToolRegistry::from_tools(tools);
2452        assert!(registry.get("read").is_some());
2453        assert!(registry.get("bash").is_some());
2454        assert!(registry.get("nonexistent").is_none());
2455    }
2456}