1use crate::app;
23use crate::auth::AuthStorage;
24use crate::cli::Cli;
25use crate::compaction::ResolvedCompactionSettings;
26use crate::models::default_models_path;
27use crate::provider::ThinkingBudgets;
28use crate::providers;
29use clap::Parser;
30use serde::{Deserialize, Serialize, de::DeserializeOwned};
31use serde_json::{Map, Value};
32use std::collections::HashMap;
33use std::io::{BufRead, BufReader, BufWriter, Write};
34use std::path::{Path, PathBuf};
35use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38
39pub use crate::agent::{
40 AbortHandle, AbortSignal, Agent, AgentConfig, AgentEvent, AgentSession, QueueMode,
41};
42pub use crate::config::Config;
43pub use crate::error::{Error, Result};
44pub use crate::extensions::{ExtensionManager, ExtensionPolicy, ExtensionRegion};
45pub use crate::model::ThinkingLevel;
46pub use crate::model::{
47 AssistantMessage, ContentBlock, Cost, CustomMessage, ImageContent, Message, StopReason,
48 StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage, UserContent,
49 UserMessage,
50};
51pub use crate::models::{ModelEntry, ModelRegistry};
52pub use crate::provider::{
53 Context as ProviderContext, InputType, Model, ModelCost, Provider, StreamOptions,
54 ThinkingBudgets as ProviderThinkingBudgets, ToolDef,
55};
56pub use crate::session::Session;
57pub use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
58
59pub type ToolDefinition = ToolDef;
61
62use crate::tools::{
67 BashTool, EditTool, FindTool, GrepTool, HashlineEditTool, LsTool, ReadTool, WriteTool,
68};
69
70pub const BUILTIN_TOOL_NAMES: &[&str] = &[
72 "read",
73 "bash",
74 "edit",
75 "write",
76 "grep",
77 "find",
78 "ls",
79 "hashline_edit",
80];
81
82pub fn create_read_tool(cwd: &Path) -> Box<dyn Tool> {
84 Box::new(ReadTool::new(cwd))
85}
86
87pub fn create_bash_tool(cwd: &Path) -> Box<dyn Tool> {
89 Box::new(BashTool::new(cwd))
90}
91
92pub fn create_edit_tool(cwd: &Path) -> Box<dyn Tool> {
94 Box::new(EditTool::new(cwd))
95}
96
97pub fn create_write_tool(cwd: &Path) -> Box<dyn Tool> {
99 Box::new(WriteTool::new(cwd))
100}
101
102pub fn create_grep_tool(cwd: &Path) -> Box<dyn Tool> {
104 Box::new(GrepTool::new(cwd))
105}
106
107pub fn create_find_tool(cwd: &Path) -> Box<dyn Tool> {
109 Box::new(FindTool::new(cwd))
110}
111
112pub fn create_ls_tool(cwd: &Path) -> Box<dyn Tool> {
114 Box::new(LsTool::new(cwd))
115}
116
117pub fn create_hashline_edit_tool(cwd: &Path) -> Box<dyn Tool> {
119 Box::new(HashlineEditTool::new(cwd))
120}
121
122pub fn create_all_tools(cwd: &Path) -> Vec<Box<dyn Tool>> {
124 vec![
125 create_read_tool(cwd),
126 create_bash_tool(cwd),
127 create_edit_tool(cwd),
128 create_write_tool(cwd),
129 create_grep_tool(cwd),
130 create_find_tool(cwd),
131 create_ls_tool(cwd),
132 create_hashline_edit_tool(cwd),
133 ]
134}
135
136pub fn tool_to_definition(tool: &dyn Tool) -> ToolDefinition {
138 ToolDefinition {
139 name: tool.name().to_string(),
140 description: tool.description().to_string(),
141 parameters: tool.parameters(),
142 }
143}
144
145pub fn all_tool_definitions(cwd: &Path) -> Vec<ToolDefinition> {
147 create_all_tools(cwd)
148 .iter()
149 .map(|t| tool_to_definition(t.as_ref()))
150 .collect()
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub struct SubscriptionId(u64);
163
164pub type OnToolStart = Arc<dyn Fn(&str, &Value) + Send + Sync>;
168
169pub type OnToolEnd = Arc<dyn Fn(&str, &ToolOutput, bool) + Send + Sync>;
173
174pub type OnStreamEvent = Arc<dyn Fn(&StreamEvent) + Send + Sync>;
179
180pub type EventSubscriber = Arc<dyn Fn(AgentEvent) + Send + Sync>;
181type EventSubscribers = HashMap<SubscriptionId, EventSubscriber>;
182
183#[derive(Clone, Default)]
189pub struct EventListeners {
190 next_id: Arc<AtomicU64>,
191 subscribers: Arc<std::sync::Mutex<EventSubscribers>>,
192 pub on_tool_start: Option<OnToolStart>,
193 pub on_tool_end: Option<OnToolEnd>,
194 pub on_stream_event: Option<OnStreamEvent>,
195}
196
197impl EventListeners {
198 fn new() -> Self {
199 Self {
200 next_id: Arc::new(AtomicU64::new(1)),
201 subscribers: Arc::new(std::sync::Mutex::new(HashMap::new())),
202 on_tool_start: None,
203 on_tool_end: None,
204 on_stream_event: None,
205 }
206 }
207
208 pub fn subscribe(&self, listener: EventSubscriber) -> SubscriptionId {
210 let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
211 let mut subs = self
212 .subscribers
213 .lock()
214 .unwrap_or_else(std::sync::PoisonError::into_inner);
215 subs.insert(id, listener);
216 id
217 }
218
219 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
221 let mut subs = self
222 .subscribers
223 .lock()
224 .unwrap_or_else(std::sync::PoisonError::into_inner);
225 subs.remove(&id).is_some()
226 }
227
228 pub fn notify(&self, event: &AgentEvent) {
230 let listeners: Vec<_> = {
231 let subs = self
232 .subscribers
233 .lock()
234 .unwrap_or_else(std::sync::PoisonError::into_inner);
235 subs.values().cloned().collect()
236 };
237 for listener in listeners {
238 listener(event.clone());
239 }
240 }
241
242 pub fn notify_tool_start(&self, tool_name: &str, args: &Value) {
244 if let Some(cb) = &self.on_tool_start {
245 cb(tool_name, args);
246 }
247 }
248
249 pub fn notify_tool_end(&self, tool_name: &str, output: &ToolOutput, is_error: bool) {
251 if let Some(cb) = &self.on_tool_end {
252 cb(tool_name, output, is_error);
253 }
254 }
255
256 pub fn notify_stream_event(&self, event: &StreamEvent) {
258 if let Some(cb) = &self.on_stream_event {
259 cb(event);
260 }
261 }
262}
263
264impl std::fmt::Debug for EventListeners {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 let count = self.subscribers.lock().map_or(0, |s| s.len());
267 let next_id = self.next_id.load(Ordering::Relaxed);
268 f.debug_struct("EventListeners")
269 .field("subscriber_count", &count)
270 .field("next_id", &next_id)
271 .field("has_on_tool_start", &self.on_tool_start.is_some())
272 .field("has_on_tool_end", &self.on_tool_end.is_some())
273 .field("has_on_stream_event", &self.on_stream_event.is_some())
274 .finish()
275 }
276}
277
278#[derive(Clone)]
283pub struct SessionOptions {
284 pub provider: Option<String>,
285 pub model: Option<String>,
286 pub api_key: Option<String>,
287 pub thinking: Option<crate::model::ThinkingLevel>,
288 pub system_prompt: Option<String>,
289 pub append_system_prompt: Option<String>,
290 pub enabled_tools: Option<Vec<String>>,
291 pub working_directory: Option<PathBuf>,
292 pub no_session: bool,
293 pub session_path: Option<PathBuf>,
294 pub session_dir: Option<PathBuf>,
295 pub extension_paths: Vec<PathBuf>,
296 pub extension_policy: Option<String>,
297 pub repair_policy: Option<String>,
298 pub include_cwd_in_prompt: bool,
299 pub max_tool_iterations: usize,
300
301 pub tool_factory: Option<Arc<dyn ToolFactory>>,
318
319 pub on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
324
325 pub on_tool_start: Option<OnToolStart>,
327
328 pub on_tool_end: Option<OnToolEnd>,
330
331 pub on_stream_event: Option<OnStreamEvent>,
333}
334
335impl Default for SessionOptions {
336 fn default() -> Self {
337 Self {
338 provider: None,
339 model: None,
340 api_key: None,
341 thinking: None,
342 system_prompt: None,
343 append_system_prompt: None,
344 enabled_tools: None,
345 working_directory: None,
346 no_session: true,
347 session_path: None,
348 session_dir: None,
349 extension_paths: Vec::new(),
350 extension_policy: None,
351 repair_policy: None,
352 include_cwd_in_prompt: true,
353 max_tool_iterations: crate::agent::resolved_max_tool_iterations_default(),
354 tool_factory: None,
355 on_event: None,
356 on_tool_start: None,
357 on_tool_end: None,
358 on_stream_event: None,
359 }
360 }
361}
362
363pub trait ToolFactory: Send + Sync {
376 fn create_tool_registry(&self, enabled: &[&str], cwd: &Path, config: &Config) -> ToolRegistry;
385}
386
387pub fn default_tool_registry(enabled: &[&str], cwd: &Path, config: &Config) -> ToolRegistry {
395 ToolRegistry::new(enabled, cwd, Some(config))
396}
397
398pub struct AgentSessionHandle {
407 session: AgentSession,
408 listeners: EventListeners,
409}
410
411#[derive(Debug, Clone, PartialEq, Eq)]
413pub struct AgentSessionState {
414 pub session_id: Option<String>,
415 pub provider: String,
416 pub model_id: String,
417 pub thinking_level: Option<crate::model::ThinkingLevel>,
418 pub save_enabled: bool,
419 pub message_count: usize,
420}
421
422#[derive(Debug, Clone)]
424pub enum SessionPromptResult {
425 InProcess(AssistantMessage),
426 RpcEvents(Vec<Value>),
427}
428
429#[derive(Debug, Clone)]
431pub enum SessionTransportEvent {
432 InProcess(AgentEvent),
433 Rpc(Value),
434}
435
436#[derive(Debug, Clone, PartialEq)]
438pub enum SessionTransportState {
439 InProcess(AgentSessionState),
440 Rpc(Box<RpcSessionState>),
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
445#[serde(rename_all = "camelCase")]
446pub struct RpcModelInfo {
447 pub id: String,
448 pub name: String,
449 pub api: String,
450 pub provider: String,
451 #[serde(default)]
452 pub base_url: String,
453 #[serde(default)]
454 pub reasoning: bool,
455 #[serde(default)]
456 pub input: Vec<InputType>,
457 #[serde(default)]
458 pub context_window: u32,
459 #[serde(default)]
460 pub max_tokens: u32,
461 #[serde(default)]
462 pub cost: Option<ModelCost>,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
467#[serde(rename_all = "camelCase")]
468#[allow(clippy::struct_excessive_bools)]
469pub struct RpcSessionState {
470 #[serde(default)]
471 pub model: Option<RpcModelInfo>,
472 #[serde(default)]
473 pub thinking_level: String,
474 #[serde(default)]
475 pub is_streaming: bool,
476 #[serde(default)]
477 pub is_compacting: bool,
478 #[serde(default)]
479 pub steering_mode: String,
480 #[serde(default)]
481 pub follow_up_mode: String,
482 #[serde(default)]
483 pub session_file: Option<String>,
484 #[serde(default)]
485 pub session_id: String,
486 #[serde(default)]
487 pub session_name: Option<String>,
488 #[serde(default)]
489 pub auto_compaction_enabled: bool,
490 #[serde(default)]
491 pub auto_retry_enabled: bool,
492 #[serde(default)]
493 pub message_count: usize,
494 #[serde(default)]
495 pub pending_message_count: usize,
496 #[serde(default)]
497 pub durability_mode: String,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
502#[serde(rename_all = "camelCase")]
503pub struct RpcTokenStats {
504 pub input: u64,
505 pub output: u64,
506 pub cache_read: u64,
507 pub cache_write: u64,
508 pub total: u64,
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
513#[serde(rename_all = "camelCase")]
514pub struct RpcSessionStats {
515 #[serde(default)]
516 pub session_file: Option<String>,
517 pub session_id: String,
518 pub user_messages: u64,
519 pub assistant_messages: u64,
520 pub tool_calls: u64,
521 pub tool_results: u64,
522 pub total_messages: u64,
523 pub tokens: RpcTokenStats,
524 pub cost: f64,
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
529pub struct RpcCancelledResult {
530 pub cancelled: bool,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
535#[serde(rename_all = "camelCase")]
536pub struct RpcCycleModelResult {
537 pub model: RpcModelInfo,
538 pub thinking_level: crate::model::ThinkingLevel,
539 pub is_scoped: bool,
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
544pub struct RpcThinkingLevelResult {
545 pub level: crate::model::ThinkingLevel,
546}
547
548#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
550#[serde(rename_all = "camelCase")]
551pub struct RpcBashResult {
552 pub output: String,
553 pub exit_code: i32,
554 pub cancelled: bool,
555 pub truncated: bool,
556 pub full_output_path: Option<String>,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
561#[serde(rename_all = "camelCase")]
562pub struct RpcCompactionResult {
563 pub summary: String,
564 pub first_kept_entry_id: String,
565 pub tokens_before: u64,
566 #[serde(default)]
567 pub details: Value,
568}
569
570#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
572pub struct RpcForkResult {
573 pub text: String,
574 pub cancelled: bool,
575}
576
577#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
579#[serde(rename_all = "camelCase")]
580pub struct RpcForkMessage {
581 pub entry_id: String,
582 pub text: String,
583}
584
585#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
587pub struct RpcCommandInfo {
588 pub name: String,
589 #[serde(default)]
590 pub description: Option<String>,
591 pub source: String,
592 #[serde(default)]
593 pub location: Option<String>,
594 #[serde(default)]
595 pub path: Option<String>,
596}
597
598#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
600pub struct RpcExportHtmlResult {
601 pub path: String,
602}
603
604#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
606pub struct RpcLastAssistantText {
607 pub text: Option<String>,
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
612#[serde(tag = "kind", rename_all = "snake_case")]
613pub enum RpcExtensionUiResponse {
614 Value { value: Value },
615 Confirmed { confirmed: bool },
616 Cancelled,
617}
618
619#[derive(Debug, Clone)]
621pub struct RpcTransportOptions {
622 pub binary_path: PathBuf,
623 pub args: Vec<String>,
624 pub cwd: Option<PathBuf>,
625}
626
627impl Default for RpcTransportOptions {
628 fn default() -> Self {
629 Self {
630 binary_path: PathBuf::from("pi"),
631 args: vec!["--mode".to_string(), "rpc".to_string()],
632 cwd: None,
633 }
634 }
635}
636
637pub struct RpcTransportClient {
639 child: Child,
640 stdin: BufWriter<ChildStdin>,
641 stdout: BufReader<ChildStdout>,
642 next_request_id: u64,
643}
644
645pub enum SessionTransport {
647 InProcess(Box<AgentSessionHandle>),
648 RpcSubprocess(RpcTransportClient),
649}
650
651impl SessionTransport {
652 pub async fn in_process(options: SessionOptions) -> Result<Self> {
653 create_agent_session(options)
654 .await
655 .map(Box::new)
656 .map(Self::InProcess)
657 }
658
659 pub fn rpc_subprocess(options: RpcTransportOptions) -> Result<Self> {
660 RpcTransportClient::connect(options).map(Self::RpcSubprocess)
661 }
662
663 #[allow(clippy::missing_const_for_fn)]
664 pub fn as_in_process_mut(&mut self) -> Option<&mut AgentSessionHandle> {
665 match self {
666 Self::InProcess(handle) => Some(handle.as_mut()),
667 Self::RpcSubprocess(_) => None,
668 }
669 }
670
671 #[allow(clippy::missing_const_for_fn)]
672 pub fn as_rpc_mut(&mut self) -> Option<&mut RpcTransportClient> {
673 match self {
674 Self::InProcess(_) => None,
675 Self::RpcSubprocess(client) => Some(client),
676 }
677 }
678
679 pub async fn prompt(
684 &mut self,
685 input: impl Into<String>,
686 on_event: impl Fn(SessionTransportEvent) + Send + Sync + 'static,
687 ) -> Result<SessionPromptResult> {
688 let input = input.into();
689 let on_event = Arc::new(on_event);
690 match self {
691 Self::InProcess(handle) => {
692 let on_event = Arc::clone(&on_event);
693 let assistant = handle
694 .prompt(input, move |event| {
695 (on_event)(SessionTransportEvent::InProcess(event));
696 })
697 .await?;
698 Ok(SessionPromptResult::InProcess(assistant))
699 }
700 Self::RpcSubprocess(client) => {
701 let events = client.prompt(input).await?;
702 for event in events.iter().cloned() {
703 (on_event)(SessionTransportEvent::Rpc(event));
704 }
705 Ok(SessionPromptResult::RpcEvents(events))
706 }
707 }
708 }
709
710 pub async fn state(&mut self) -> Result<SessionTransportState> {
712 match self {
713 Self::InProcess(handle) => handle.state().await.map(SessionTransportState::InProcess),
714 Self::RpcSubprocess(client) => client
715 .get_state()
716 .await
717 .map(Box::new)
718 .map(SessionTransportState::Rpc),
719 }
720 }
721
722 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
724 match self {
725 Self::InProcess(handle) => handle.set_model(provider, model_id).await,
726 Self::RpcSubprocess(client) => {
727 let _ = client.set_model(provider, model_id).await?;
728 Ok(())
729 }
730 }
731 }
732
733 pub fn shutdown(&mut self) -> Result<()> {
735 match self {
736 Self::InProcess(_) => Ok(()),
737 Self::RpcSubprocess(client) => client.shutdown(),
738 }
739 }
740}
741
742impl RpcTransportClient {
743 pub fn connect(options: RpcTransportOptions) -> Result<Self> {
744 let mut command = Command::new(&options.binary_path);
745 command
746 .args(&options.args)
747 .stdin(Stdio::piped())
748 .stdout(Stdio::piped())
749 .stderr(Stdio::inherit());
750 if let Some(cwd) = options.cwd {
751 command.current_dir(cwd);
752 }
753
754 let mut child = command.spawn().map_err(|err| {
755 Error::config(format!(
756 "Failed to spawn RPC subprocess {}: {err}",
757 options.binary_path.display()
758 ))
759 })?;
760 let stdin = child
761 .stdin
762 .take()
763 .ok_or_else(|| Error::config("RPC subprocess stdin is not piped"))?;
764 let stdout = child
765 .stdout
766 .take()
767 .ok_or_else(|| Error::config("RPC subprocess stdout is not piped"))?;
768
769 Ok(Self {
770 child,
771 stdin: BufWriter::new(stdin),
772 stdout: BufReader::new(stdout),
773 next_request_id: 1,
774 })
775 }
776
777 #[allow(
778 clippy::unused_async,
779 reason = "SDK RPC transport keeps an async public API"
780 )]
781 pub async fn request(&mut self, command: &str, payload: Map<String, Value>) -> Result<Value> {
782 let request_id = self.next_request_id();
783 let mut command_payload = Map::new();
784 command_payload.insert("type".to_string(), Value::String(command.to_string()));
785 command_payload.insert("id".to_string(), Value::String(request_id.clone()));
786 command_payload.extend(payload);
787
788 self.write_json_line(&Value::Object(command_payload))?;
789 self.wait_for_response(&request_id, command)
790 }
791
792 fn parse_response_data<T: DeserializeOwned>(data: Value, command: &str) -> Result<T> {
793 serde_json::from_value(data).map_err(|err| {
794 Error::api(format!(
795 "Failed to decode RPC `{command}` response payload: {err}"
796 ))
797 })
798 }
799
800 async fn request_typed<T: DeserializeOwned>(
801 &mut self,
802 command: &str,
803 payload: Map<String, Value>,
804 ) -> Result<T> {
805 let data = self.request(command, payload).await?;
806 Self::parse_response_data(data, command)
807 }
808
809 async fn request_no_data(&mut self, command: &str, payload: Map<String, Value>) -> Result<()> {
810 let _ = self.request(command, payload).await?;
811 Ok(())
812 }
813
814 pub async fn steer(&mut self, message: impl Into<String>) -> Result<()> {
815 let mut payload = Map::new();
816 payload.insert("message".to_string(), Value::String(message.into()));
817 self.request_no_data("steer", payload).await
818 }
819
820 pub async fn follow_up(&mut self, message: impl Into<String>) -> Result<()> {
821 let mut payload = Map::new();
822 payload.insert("message".to_string(), Value::String(message.into()));
823 self.request_no_data("follow_up", payload).await
824 }
825
826 pub async fn abort(&mut self) -> Result<()> {
827 self.request_no_data("abort", Map::new()).await
828 }
829
830 pub async fn new_session(
831 &mut self,
832 parent_session: Option<&Path>,
833 ) -> Result<RpcCancelledResult> {
834 let mut payload = Map::new();
835 if let Some(parent_session) = parent_session {
836 payload.insert(
837 "parentSession".to_string(),
838 Value::String(parent_session.display().to_string()),
839 );
840 }
841 self.request_typed("new_session", payload).await
842 }
843
844 pub async fn get_state(&mut self) -> Result<RpcSessionState> {
845 self.request_typed("get_state", Map::new()).await
846 }
847
848 pub async fn get_session_stats(&mut self) -> Result<RpcSessionStats> {
849 self.request_typed("get_session_stats", Map::new()).await
850 }
851
852 pub async fn get_messages(&mut self) -> Result<Vec<Value>> {
853 #[derive(Deserialize)]
854 struct MessagesPayload {
855 messages: Vec<Value>,
856 }
857 let payload: MessagesPayload = self.request_typed("get_messages", Map::new()).await?;
858 Ok(payload.messages)
859 }
860
861 pub async fn get_available_models(&mut self) -> Result<Vec<RpcModelInfo>> {
862 #[derive(Deserialize)]
863 struct ModelsPayload {
864 models: Vec<RpcModelInfo>,
865 }
866 let payload: ModelsPayload = self
867 .request_typed("get_available_models", Map::new())
868 .await?;
869 Ok(payload.models)
870 }
871
872 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<RpcModelInfo> {
873 let mut payload = Map::new();
874 payload.insert("provider".to_string(), Value::String(provider.to_string()));
875 payload.insert("modelId".to_string(), Value::String(model_id.to_string()));
876 self.request_typed("set_model", payload).await
877 }
878
879 pub async fn cycle_model(&mut self) -> Result<Option<RpcCycleModelResult>> {
880 self.request_typed("cycle_model", Map::new()).await
881 }
882
883 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
884 let mut payload = Map::new();
885 payload.insert("level".to_string(), Value::String(level.to_string()));
886 self.request_no_data("set_thinking_level", payload).await
887 }
888
889 pub async fn cycle_thinking_level(&mut self) -> Result<Option<RpcThinkingLevelResult>> {
890 self.request_typed("cycle_thinking_level", Map::new()).await
891 }
892
893 pub async fn set_steering_mode(&mut self, mode: &str) -> Result<()> {
894 let mut payload = Map::new();
895 payload.insert("mode".to_string(), Value::String(mode.to_string()));
896 self.request_no_data("set_steering_mode", payload).await
897 }
898
899 pub async fn set_follow_up_mode(&mut self, mode: &str) -> Result<()> {
900 let mut payload = Map::new();
901 payload.insert("mode".to_string(), Value::String(mode.to_string()));
902 self.request_no_data("set_follow_up_mode", payload).await
903 }
904
905 pub async fn set_auto_compaction(&mut self, enabled: bool) -> Result<()> {
906 let mut payload = Map::new();
907 payload.insert("enabled".to_string(), Value::Bool(enabled));
908 self.request_no_data("set_auto_compaction", payload).await
909 }
910
911 pub async fn set_auto_retry(&mut self, enabled: bool) -> Result<()> {
912 let mut payload = Map::new();
913 payload.insert("enabled".to_string(), Value::Bool(enabled));
914 self.request_no_data("set_auto_retry", payload).await
915 }
916
917 pub async fn abort_retry(&mut self) -> Result<()> {
918 self.request_no_data("abort_retry", Map::new()).await
919 }
920
921 pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
922 let mut payload = Map::new();
923 payload.insert("name".to_string(), Value::String(name.into()));
924 self.request_no_data("set_session_name", payload).await
925 }
926
927 pub async fn get_last_assistant_text(&mut self) -> Result<Option<String>> {
928 let payload: RpcLastAssistantText = self
929 .request_typed("get_last_assistant_text", Map::new())
930 .await?;
931 Ok(payload.text)
932 }
933
934 pub async fn export_html(&mut self, output_path: Option<&Path>) -> Result<RpcExportHtmlResult> {
935 let mut payload = Map::new();
936 if let Some(path) = output_path {
937 payload.insert(
938 "outputPath".to_string(),
939 Value::String(path.display().to_string()),
940 );
941 }
942 self.request_typed("export_html", payload).await
943 }
944
945 pub async fn bash(&mut self, command: impl Into<String>) -> Result<RpcBashResult> {
946 let mut payload = Map::new();
947 payload.insert("command".to_string(), Value::String(command.into()));
948 self.request_typed("bash", payload).await
949 }
950
951 pub async fn abort_bash(&mut self) -> Result<()> {
952 self.request_no_data("abort_bash", Map::new()).await
953 }
954
955 pub async fn compact(&mut self) -> Result<RpcCompactionResult> {
956 self.compact_with_instructions(None).await
957 }
958
959 pub async fn compact_with_instructions(
960 &mut self,
961 custom_instructions: Option<&str>,
962 ) -> Result<RpcCompactionResult> {
963 let mut payload = Map::new();
964 if let Some(custom_instructions) = custom_instructions {
965 payload.insert(
966 "customInstructions".to_string(),
967 Value::String(custom_instructions.to_string()),
968 );
969 }
970 self.request_typed("compact", payload).await
971 }
972
973 pub async fn switch_session(&mut self, session_path: &Path) -> Result<RpcCancelledResult> {
974 let mut payload = Map::new();
975 payload.insert(
976 "sessionPath".to_string(),
977 Value::String(session_path.display().to_string()),
978 );
979 self.request_typed("switch_session", payload).await
980 }
981
982 pub async fn fork(&mut self, entry_id: impl Into<String>) -> Result<RpcForkResult> {
983 let mut payload = Map::new();
984 payload.insert("entryId".to_string(), Value::String(entry_id.into()));
985 self.request_typed("fork", payload).await
986 }
987
988 pub async fn get_fork_messages(&mut self) -> Result<Vec<RpcForkMessage>> {
989 #[derive(Deserialize)]
990 struct ForkMessagesPayload {
991 messages: Vec<RpcForkMessage>,
992 }
993 let payload: ForkMessagesPayload =
994 self.request_typed("get_fork_messages", Map::new()).await?;
995 Ok(payload.messages)
996 }
997
998 pub async fn get_commands(&mut self) -> Result<Vec<RpcCommandInfo>> {
999 #[derive(Deserialize)]
1000 struct CommandsPayload {
1001 commands: Vec<RpcCommandInfo>,
1002 }
1003 let payload: CommandsPayload = self.request_typed("get_commands", Map::new()).await?;
1004 Ok(payload.commands)
1005 }
1006
1007 pub async fn extension_ui_response(
1008 &mut self,
1009 request_id: &str,
1010 response: RpcExtensionUiResponse,
1011 ) -> Result<bool> {
1012 #[derive(Deserialize)]
1013 struct ExtensionUiResolvedPayload {
1014 resolved: bool,
1015 }
1016
1017 let mut payload = Map::new();
1018 payload.insert(
1019 "requestId".to_string(),
1020 Value::String(request_id.to_string()),
1021 );
1022
1023 match response {
1024 RpcExtensionUiResponse::Value { value } => {
1025 payload.insert("value".to_string(), value);
1026 }
1027 RpcExtensionUiResponse::Confirmed { confirmed } => {
1028 payload.insert("confirmed".to_string(), Value::Bool(confirmed));
1029 }
1030 RpcExtensionUiResponse::Cancelled => {
1031 payload.insert("cancelled".to_string(), Value::Bool(true));
1032 }
1033 }
1034
1035 let response: Option<ExtensionUiResolvedPayload> =
1036 self.request_typed("extension_ui_response", payload).await?;
1037 Ok(response.is_none_or(|payload| payload.resolved))
1038 }
1039
1040 pub async fn prompt(&mut self, message: impl Into<String>) -> Result<Vec<Value>> {
1041 self.prompt_with_options(message, None, None).await
1042 }
1043
1044 #[allow(
1045 clippy::unused_async,
1046 reason = "SDK RPC transport keeps an async public API"
1047 )]
1048 pub async fn prompt_with_options(
1049 &mut self,
1050 message: impl Into<String>,
1051 images: Option<Vec<ImageContent>>,
1052 streaming_behavior: Option<&str>,
1053 ) -> Result<Vec<Value>> {
1054 let request_id = self.next_request_id();
1055 let mut payload = Map::new();
1056 payload.insert("type".to_string(), Value::String("prompt".to_string()));
1057 payload.insert("id".to_string(), Value::String(request_id.clone()));
1058 payload.insert("message".to_string(), Value::String(message.into()));
1059 if let Some(images) = images {
1060 payload.insert(
1061 "images".to_string(),
1062 serde_json::to_value(images).map_err(|err| Error::Json(Box::new(err)))?,
1063 );
1064 }
1065 if let Some(streaming_behavior) = streaming_behavior {
1066 payload.insert(
1067 "streamingBehavior".to_string(),
1068 Value::String(streaming_behavior.to_string()),
1069 );
1070 }
1071 let payload = Value::Object(payload);
1072 self.write_json_line(&payload)?;
1073
1074 let mut saw_ack = false;
1075 let mut events = Vec::new();
1076 loop {
1077 let item = self.read_json_line()?;
1078 let item_type = item.get("type").and_then(Value::as_str);
1079 if item_type == Some("response") {
1080 if item.get("id").and_then(Value::as_str) != Some(request_id.as_str()) {
1081 continue;
1082 }
1083 let success = item
1084 .get("success")
1085 .and_then(Value::as_bool)
1086 .unwrap_or(false);
1087 if !success {
1088 return Err(rpc_error_from_response(&item, "prompt"));
1089 }
1090 saw_ack = true;
1091 continue;
1092 }
1093
1094 if saw_ack {
1095 let reached_end = item_type == Some("agent_end");
1096 events.push(item);
1097 if reached_end {
1098 return Ok(events);
1099 }
1100 }
1101 }
1102 }
1103
1104 pub fn shutdown(&mut self) -> Result<()> {
1105 if self
1106 .child
1107 .try_wait()
1108 .map_err(|err| Error::Io(Box::new(err)))?
1109 .is_none()
1110 {
1111 self.child.kill().map_err(|err| Error::Io(Box::new(err)))?;
1112 }
1113 let _ = self.child.wait();
1114 Ok(())
1115 }
1116
1117 fn next_request_id(&mut self) -> String {
1118 let id = format!("rpc-{}", self.next_request_id);
1119 self.next_request_id = self.next_request_id.saturating_add(1);
1120 id
1121 }
1122
1123 fn write_json_line(&mut self, payload: &Value) -> Result<()> {
1124 let encoded = serde_json::to_string(payload).map_err(|err| Error::Json(Box::new(err)))?;
1125 self.stdin
1126 .write_all(encoded.as_bytes())
1127 .map_err(|err| Error::Io(Box::new(err)))?;
1128 self.stdin
1129 .write_all(b"\n")
1130 .map_err(|err| Error::Io(Box::new(err)))?;
1131 self.stdin.flush().map_err(|err| Error::Io(Box::new(err)))?;
1132 Ok(())
1133 }
1134
1135 fn read_json_line(&mut self) -> Result<Value> {
1136 let mut line = String::new();
1137 let read = self
1138 .stdout
1139 .read_line(&mut line)
1140 .map_err(|err| Error::Io(Box::new(err)))?;
1141 if read == 0 {
1142 return Err(Error::api(
1143 "RPC subprocess exited before sending a response",
1144 ));
1145 }
1146 serde_json::from_str(line.trim_end()).map_err(|err| Error::Json(Box::new(err)))
1147 }
1148
1149 fn wait_for_response(&mut self, request_id: &str, command: &str) -> Result<Value> {
1150 loop {
1151 let item = self.read_json_line()?;
1152 let Some(item_type) = item.get("type").and_then(Value::as_str) else {
1153 continue;
1154 };
1155 if item_type != "response" {
1156 continue;
1157 }
1158 if item.get("id").and_then(Value::as_str) != Some(request_id) {
1159 continue;
1160 }
1161 if item.get("command").and_then(Value::as_str) != Some(command) {
1162 continue;
1163 }
1164
1165 let success = item
1166 .get("success")
1167 .and_then(Value::as_bool)
1168 .unwrap_or(false);
1169 if success {
1170 return Ok(item.get("data").cloned().unwrap_or(Value::Null));
1171 }
1172 return Err(rpc_error_from_response(&item, command));
1173 }
1174 }
1175}
1176
1177impl Drop for RpcTransportClient {
1178 fn drop(&mut self) {
1179 let _ = self.shutdown();
1180 }
1181}
1182
1183fn rpc_error_from_response(response: &Value, command: &str) -> Error {
1184 let error = response
1185 .get("error")
1186 .and_then(Value::as_str)
1187 .unwrap_or("RPC command failed");
1188 Error::api(format!("RPC {command} failed: {error}"))
1189}
1190
1191impl AgentSessionHandle {
1192 pub const fn from_session_with_listeners(
1197 session: AgentSession,
1198 listeners: EventListeners,
1199 ) -> Self {
1200 Self { session, listeners }
1201 }
1202
1203 pub async fn prompt(
1209 &mut self,
1210 input: impl Into<String>,
1211 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1212 ) -> Result<AssistantMessage> {
1213 let combined = self.make_combined_callback(on_event);
1214 self.session.run_text(input.into(), combined).await
1215 }
1216
1217 pub async fn prompt_with_abort(
1219 &mut self,
1220 input: impl Into<String>,
1221 abort_signal: AbortSignal,
1222 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1223 ) -> Result<AssistantMessage> {
1224 let combined = self.make_combined_callback(on_event);
1225 self.session
1226 .run_text_with_abort(input.into(), Some(abort_signal), combined)
1227 .await
1228 }
1229
1230 pub async fn continue_turn(
1236 &mut self,
1237 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1238 ) -> Result<AssistantMessage> {
1239 let combined = self.make_combined_callback(on_event);
1240 self.session
1241 .sync_runtime_selection_from_session_header()
1242 .await?;
1243 self.session
1244 .agent
1245 .run_continue_with_abort(None, combined)
1246 .await
1247 }
1248
1249 pub async fn continue_turn_with_abort(
1251 &mut self,
1252 abort_signal: AbortSignal,
1253 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1254 ) -> Result<AssistantMessage> {
1255 let combined = self.make_combined_callback(on_event);
1256 self.session
1257 .sync_runtime_selection_from_session_header()
1258 .await?;
1259 self.session
1260 .agent
1261 .run_continue_with_abort(Some(abort_signal), combined)
1262 .await
1263 }
1264
1265 pub fn new_abort_handle() -> (AbortHandle, AbortSignal) {
1267 AbortHandle::new()
1268 }
1269
1270 pub fn subscribe(
1277 &self,
1278 listener: impl Fn(AgentEvent) + Send + Sync + 'static,
1279 ) -> SubscriptionId {
1280 self.listeners.subscribe(Arc::new(listener))
1281 }
1282
1283 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
1287 self.listeners.unsubscribe(id)
1288 }
1289
1290 pub const fn listeners(&self) -> &EventListeners {
1292 &self.listeners
1293 }
1294
1295 pub const fn listeners_mut(&mut self) -> &mut EventListeners {
1300 &mut self.listeners
1301 }
1302
1303 pub const fn has_extensions(&self) -> bool {
1309 self.session.extensions.is_some()
1310 }
1311
1312 pub fn extension_manager(&self) -> Option<&ExtensionManager> {
1314 self.session
1315 .extensions
1316 .as_ref()
1317 .map(ExtensionRegion::manager)
1318 }
1319
1320 pub const fn extension_region(&self) -> Option<&ExtensionRegion> {
1324 self.session.extensions.as_ref()
1325 }
1326
1327 pub fn model(&self) -> (String, String) {
1333 let provider = self.session.agent.provider();
1334 (provider.name().to_string(), provider.model_id().to_string())
1335 }
1336
1337 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
1339 self.session.set_provider_model(provider, model_id).await
1340 }
1341
1342 pub const fn thinking_level(&self) -> Option<crate::model::ThinkingLevel> {
1344 self.session.agent.stream_options().thinking_level
1345 }
1346
1347 pub const fn thinking(&self) -> Option<crate::model::ThinkingLevel> {
1349 self.thinking_level()
1350 }
1351
1352 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
1354 let cx = crate::agent_cx::AgentCx::for_request();
1355 let (effective_level, changed) = {
1356 let mut guard = self
1357 .session
1358 .session
1359 .lock(cx.cx())
1360 .await
1361 .map_err(|e| Error::session(e.to_string()))?;
1362 let (provider_id, model_id) = guard
1363 .effective_model_for_current_path()
1364 .unwrap_or_else(|| self.model());
1365 let effective_level =
1366 self.session
1367 .clamp_thinking_level_for_model(&provider_id, &model_id, level);
1368 let level_string = effective_level.to_string();
1369 let changed = guard.effective_thinking_level_for_current_path().as_deref()
1370 != Some(level_string.as_str());
1371 guard.set_model_header(None, None, Some(level_string.clone()));
1372 if changed {
1373 guard.append_thinking_level_change(level_string);
1374 }
1375 (effective_level, changed)
1376 };
1377 self.session.agent.stream_options_mut().thinking_level = Some(effective_level);
1378 if changed {
1379 self.session.persist_session().await
1380 } else {
1381 Ok(())
1382 }
1383 }
1384
1385 pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
1391 let name = name.into();
1392 let cx = crate::agent_cx::AgentCx::for_request();
1393 {
1394 let mut guard = self
1395 .session
1396 .session
1397 .lock(cx.cx())
1398 .await
1399 .map_err(|e| Error::session(e.to_string()))?;
1400 guard.append_session_info(Some(name));
1401 }
1402 self.session.persist_session().await
1403 }
1404
1405 pub const fn max_tokens(&self) -> Option<u32> {
1413 self.session.agent.stream_options().max_tokens
1414 }
1415
1416 pub const fn set_max_tokens(&mut self, max_tokens: Option<u32>) {
1422 self.session.agent.stream_options_mut().max_tokens = max_tokens;
1423 }
1424
1425 pub async fn messages(&self) -> Result<Vec<Message>> {
1427 let cx = crate::agent_cx::AgentCx::for_request();
1428 let guard = self
1429 .session
1430 .session
1431 .lock(cx.cx())
1432 .await
1433 .map_err(|e| Error::session(e.to_string()))?;
1434 Ok(guard.to_messages_for_current_path())
1435 }
1436
1437 pub async fn state(&self) -> Result<AgentSessionState> {
1439 let (provider, model_id) = self.model();
1440 let thinking_level = self.thinking_level();
1441 let save_enabled = self.session.save_enabled();
1442 let cx = crate::agent_cx::AgentCx::for_request();
1443 let guard = self
1444 .session
1445 .session
1446 .lock(cx.cx())
1447 .await
1448 .map_err(|e| Error::session(e.to_string()))?;
1449 let session_id = Some(guard.header.id.clone());
1450 let message_count = guard.to_messages_for_current_path().len();
1451
1452 Ok(AgentSessionState {
1453 session_id,
1454 provider,
1455 model_id,
1456 thinking_level,
1457 save_enabled,
1458 message_count,
1459 })
1460 }
1461
1462 pub async fn compact(
1464 &mut self,
1465 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1466 ) -> Result<()> {
1467 self.session.compact_now(on_event).await
1468 }
1469
1470 pub const fn session(&self) -> &AgentSession {
1472 &self.session
1473 }
1474
1475 pub const fn session_mut(&mut self) -> &mut AgentSession {
1477 &mut self.session
1478 }
1479
1480 pub fn into_inner(self) -> AgentSession {
1482 self.session
1483 }
1484
1485 fn make_combined_callback(
1488 &self,
1489 per_prompt: impl Fn(AgentEvent) + Send + Sync + 'static,
1490 ) -> impl Fn(AgentEvent) + Send + Sync + 'static {
1491 let listeners = self.listeners.clone();
1492 move |event: AgentEvent| {
1493 match &event {
1495 AgentEvent::ToolExecutionStart {
1496 tool_name, args, ..
1497 } => {
1498 listeners.notify_tool_start(tool_name, args);
1499 }
1500 AgentEvent::ToolExecutionEnd {
1501 tool_name,
1502 result,
1503 is_error,
1504 ..
1505 } => {
1506 listeners.notify_tool_end(tool_name, result, *is_error);
1507 }
1508 AgentEvent::MessageUpdate {
1509 assistant_message_event,
1510 ..
1511 } => {
1512 if let Some(stream_ev) =
1515 stream_event_from_assistant_message_event(assistant_message_event)
1516 {
1517 listeners.notify_stream_event(&stream_ev);
1518 }
1519 }
1520 _ => {}
1521 }
1522
1523 listeners.notify(&event);
1525
1526 per_prompt(event);
1528 }
1529 }
1530}
1531
1532fn stream_event_from_assistant_message_event(
1537 event: &crate::model::AssistantMessageEvent,
1538) -> Option<StreamEvent> {
1539 use crate::model::AssistantMessageEvent as AME;
1540 match event {
1541 AME::TextStart { content_index, .. } => Some(StreamEvent::TextStart {
1542 content_index: *content_index,
1543 }),
1544 AME::TextDelta {
1545 content_index,
1546 delta,
1547 ..
1548 } => Some(StreamEvent::TextDelta {
1549 content_index: *content_index,
1550 delta: delta.clone(),
1551 }),
1552 AME::TextEnd {
1553 content_index,
1554 content,
1555 ..
1556 } => Some(StreamEvent::TextEnd {
1557 content_index: *content_index,
1558 content: content.clone(),
1559 }),
1560 AME::ThinkingStart { content_index, .. } => Some(StreamEvent::ThinkingStart {
1561 content_index: *content_index,
1562 }),
1563 AME::ThinkingDelta {
1564 content_index,
1565 delta,
1566 ..
1567 } => Some(StreamEvent::ThinkingDelta {
1568 content_index: *content_index,
1569 delta: delta.clone(),
1570 }),
1571 AME::ThinkingEnd {
1572 content_index,
1573 content,
1574 ..
1575 } => Some(StreamEvent::ThinkingEnd {
1576 content_index: *content_index,
1577 content: content.clone(),
1578 }),
1579 AME::ToolCallStart { content_index, .. } => Some(StreamEvent::ToolCallStart {
1580 content_index: *content_index,
1581 }),
1582 AME::ToolCallDelta {
1583 content_index,
1584 delta,
1585 ..
1586 } => Some(StreamEvent::ToolCallDelta {
1587 content_index: *content_index,
1588 delta: delta.clone(),
1589 }),
1590 AME::ToolCallEnd {
1591 content_index,
1592 tool_call,
1593 ..
1594 } => Some(StreamEvent::ToolCallEnd {
1595 content_index: *content_index,
1596 tool_call: tool_call.clone(),
1597 }),
1598 AME::Done { reason, message } => Some(StreamEvent::Done {
1599 reason: *reason,
1600 message: (**message).clone(),
1601 }),
1602 AME::Error { reason, error } => Some(StreamEvent::Error {
1603 reason: *reason,
1604 error: (**error).clone(),
1605 }),
1606 AME::Start { .. } => None,
1607 }
1608}
1609
1610fn resolve_path_for_cwd(path: &Path, cwd: &Path) -> PathBuf {
1611 if path.is_absolute() {
1612 path.to_path_buf()
1613 } else {
1614 cwd.join(path)
1615 }
1616}
1617
1618fn build_stream_options_with_optional_key(
1619 config: &Config,
1620 api_key: Option<String>,
1621 selection: &app::ModelSelection,
1622 session: &Session,
1623) -> StreamOptions {
1624 let mut options = StreamOptions {
1625 api_key,
1626 headers: selection.model_entry.headers.clone(),
1627 session_id: Some(session.header.id.clone()),
1628 thinking_level: Some(selection.thinking_level),
1629 ..Default::default()
1630 };
1631
1632 if let Some(budgets) = &config.thinking_budgets {
1633 let defaults = ThinkingBudgets::default();
1634 options.thinking_budgets = Some(ThinkingBudgets {
1635 minimal: budgets.minimal.unwrap_or(defaults.minimal),
1636 low: budgets.low.unwrap_or(defaults.low),
1637 medium: budgets.medium.unwrap_or(defaults.medium),
1638 high: budgets.high.unwrap_or(defaults.high),
1639 xhigh: budgets.xhigh.unwrap_or(defaults.xhigh),
1640 });
1641 }
1642
1643 options
1644}
1645
1646#[allow(clippy::too_many_lines)]
1651pub async fn create_agent_session(options: SessionOptions) -> Result<AgentSessionHandle> {
1652 let process_cwd =
1653 std::env::current_dir().map_err(|e| Error::config(format!("cwd lookup failed: {e}")))?;
1654 let cwd = options.working_directory.as_deref().map_or_else(
1655 || process_cwd.clone(),
1656 |path| resolve_path_for_cwd(path, &process_cwd),
1657 );
1658 let resolved_session_path = options
1659 .session_path
1660 .as_deref()
1661 .map(|path| resolve_path_for_cwd(path, &cwd));
1662 let resolved_session_dir = options
1663 .session_dir
1664 .as_deref()
1665 .map(|path| resolve_path_for_cwd(path, &cwd));
1666
1667 let mut cli = Cli::try_parse_from(["pi"])
1668 .map_err(|e| Error::validation(format!("CLI init failed: {e}")))?;
1669 cli.no_session = options.no_session;
1670 cli.provider = options.provider.clone();
1671 cli.model = options.model.clone();
1672 cli.api_key = options.api_key.clone();
1673 cli.system_prompt = options.system_prompt.clone();
1674 cli.append_system_prompt = options.append_system_prompt.clone();
1675 cli.hide_cwd_in_prompt = !options.include_cwd_in_prompt;
1676 cli.thinking = options.thinking.map(|t| t.to_string());
1677 cli.session = resolved_session_path
1678 .as_ref()
1679 .map(|p| p.to_string_lossy().to_string());
1680 cli.session_dir = resolved_session_dir
1681 .as_ref()
1682 .map(|p| p.to_string_lossy().to_string());
1683 if let Some(enabled_tools) = &options.enabled_tools {
1684 if enabled_tools.is_empty() {
1685 cli.no_tools = true;
1686 } else {
1687 cli.no_tools = false;
1688 cli.tools = enabled_tools.join(",");
1689 }
1690 }
1691
1692 let config = Config::load()?;
1693
1694 let mut auth = AuthStorage::load_async(Config::auth_path()).await?;
1695 auth.refresh_expired_oauth_tokens().await?;
1696
1697 let global_dir = Config::global_dir();
1698 let package_dir = Config::package_dir();
1699 let models_path = default_models_path(&global_dir);
1700 let model_registry = ModelRegistry::load(&auth, Some(models_path));
1701
1702 let mut session = Session::new(&cli, &config).await?;
1703 if resolved_session_path.is_none() {
1704 session.header.cwd = cwd.display().to_string();
1705 }
1706 let scoped_patterns = if let Some(models_arg) = &cli.models {
1707 app::parse_models_arg(models_arg)
1708 } else {
1709 config.enabled_models.clone().unwrap_or_default()
1710 };
1711 let scoped_models = if scoped_patterns.is_empty() {
1712 Vec::new()
1713 } else {
1714 app::resolve_model_scope(&scoped_patterns, &model_registry, cli.api_key.is_some())
1715 };
1716
1717 let selection = app::select_model_and_thinking(
1718 &cli,
1719 &config,
1720 &session,
1721 &model_registry,
1722 &scoped_models,
1723 &global_dir,
1724 )
1725 .map_err(|err| Error::validation(err.to_string()))?;
1726 app::update_session_for_selection(&mut session, &selection);
1727
1728 let enabled_tools_owned = cli
1729 .enabled_tools()
1730 .into_iter()
1731 .map(str::to_string)
1732 .collect::<Vec<_>>();
1733 let enabled_tools = enabled_tools_owned
1734 .iter()
1735 .map(String::as_str)
1736 .collect::<Vec<_>>();
1737
1738 let system_prompt = app::build_system_prompt(
1739 &cli,
1740 &cwd,
1741 &enabled_tools,
1742 None,
1743 &global_dir,
1744 &package_dir,
1745 std::env::var_os("PI_TEST_MODE").is_some(),
1746 options.include_cwd_in_prompt,
1747 )
1748 .map_err(|err| Error::validation(err.to_string()))?;
1749
1750 let provider = providers::create_provider(&selection.model_entry, None)
1751 .map_err(|e| Error::provider("sdk", e.to_string()))?;
1752
1753 let api_key = app::resolve_api_key(&auth, &cli, &selection.model_entry)
1754 .map_err(|err| Error::validation(err.to_string()))?;
1755
1756 let stream_options =
1757 build_stream_options_with_optional_key(&config, api_key, &selection, &session);
1758
1759 let agent_config = AgentConfig {
1760 system_prompt: Some(system_prompt),
1761 max_tool_iterations: options.max_tool_iterations,
1762 stream_options,
1763 block_images: config.image_block_images(),
1764 fail_closed_hooks: config.fail_closed_hooks(),
1765 tool_approval: None,
1766 };
1767
1768 let tools = options.tool_factory.as_ref().map_or_else(
1769 || ToolRegistry::new(&enabled_tools, &cwd, Some(&config)),
1770 |factory| factory.create_tool_registry(&enabled_tools, &cwd, &config),
1771 );
1772 let session_arc = Arc::new(asupersync::sync::Mutex::new(session));
1773
1774 let context_window_tokens = if selection.model_entry.model.context_window == 0 {
1775 ResolvedCompactionSettings::default().context_window_tokens
1776 } else {
1777 selection.model_entry.model.context_window
1778 };
1779 let compaction_settings = ResolvedCompactionSettings {
1780 enabled: config.compaction_enabled(),
1781 reserve_tokens: config.compaction_reserve_tokens(),
1782 keep_recent_tokens: config.compaction_keep_recent_tokens(),
1783 context_window_tokens,
1784 };
1785
1786 let mut agent_session = AgentSession::new(
1787 Agent::new(provider, tools, agent_config),
1788 Arc::clone(&session_arc),
1789 !cli.no_session,
1790 compaction_settings,
1791 );
1792 agent_session.set_api_key_override(options.api_key.clone());
1793
1794 if !options.extension_paths.is_empty() {
1795 let extension_paths = options
1796 .extension_paths
1797 .iter()
1798 .map(|path| resolve_path_for_cwd(path, &cwd))
1799 .collect::<Vec<_>>();
1800 let resolved_ext_policy =
1801 config.resolve_extension_policy_with_metadata(options.extension_policy.as_deref());
1802 let resolved_repair_policy =
1803 config.resolve_repair_policy_with_metadata(options.repair_policy.as_deref());
1804
1805 agent_session
1806 .enable_extensions_with_policy(
1807 &enabled_tools,
1808 &cwd,
1809 Some(&config),
1810 &extension_paths,
1811 Some(resolved_ext_policy.policy),
1812 Some(resolved_repair_policy.effective_mode),
1813 None,
1814 )
1815 .await?;
1816 }
1817
1818 agent_session.set_model_registry(model_registry.clone());
1819 agent_session.set_auth_storage(auth);
1820
1821 let history = {
1822 let cx = crate::agent_cx::AgentCx::for_request();
1823 let guard = session_arc
1824 .lock(cx.cx())
1825 .await
1826 .map_err(|e| Error::session(e.to_string()))?;
1827 guard.to_messages_for_current_path()
1828 };
1829 if !history.is_empty() {
1830 agent_session.agent.replace_messages(history);
1831 }
1832
1833 let mut listeners = EventListeners::new();
1834 if let Some(on_event) = options.on_event {
1835 listeners.subscribe(on_event);
1836 }
1837 listeners.on_tool_start = options.on_tool_start;
1838 listeners.on_tool_end = options.on_tool_end;
1839 listeners.on_stream_event = options.on_stream_event;
1840
1841 Ok(AgentSessionHandle {
1842 session: agent_session,
1843 listeners,
1844 })
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849 use super::*;
1850 use asupersync::runtime::RuntimeBuilder;
1851 use asupersync::runtime::reactor::create_reactor;
1852 use asupersync::sync::Mutex as AsyncMutex;
1853 use std::env;
1854 use std::sync::{Arc, Mutex, OnceLock};
1855 use tempfile::tempdir;
1856
1857 fn run_async<F>(future: F) -> F::Output
1858 where
1859 F: std::future::Future,
1860 {
1861 let reactor = create_reactor().expect("create reactor");
1862 let runtime = RuntimeBuilder::current_thread()
1863 .with_reactor(reactor)
1864 .build()
1865 .expect("build runtime");
1866 runtime.block_on(future)
1867 }
1868
1869 fn current_dir_lock() -> std::sync::MutexGuard<'static, ()> {
1870 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
1871 LOCK.get_or_init(|| Mutex::new(()))
1872 .lock()
1873 .unwrap_or_else(std::sync::PoisonError::into_inner)
1874 }
1875
1876 struct CurrentDirGuard {
1877 previous: PathBuf,
1878 }
1879
1880 impl CurrentDirGuard {
1881 fn new(path: &Path) -> Self {
1882 let previous = env::current_dir().expect("current dir");
1883 env::set_current_dir(path).expect("set current dir");
1884 Self { previous }
1885 }
1886 }
1887
1888 impl Drop for CurrentDirGuard {
1889 fn drop(&mut self) {
1890 let _ = env::set_current_dir(&self.previous);
1891 }
1892 }
1893
1894 fn hermetic_session_options(working_directory: &Path) -> SessionOptions {
1895 SessionOptions {
1896 provider: Some("openai".to_string()),
1897 model: Some("gpt-4o".to_string()),
1898 api_key: Some("dummy-key".to_string()),
1899 working_directory: Some(working_directory.to_path_buf()),
1900 no_session: true,
1901 ..SessionOptions::default()
1902 }
1903 }
1904
1905 #[test]
1906 fn create_agent_session_with_explicit_test_provider_succeeds() {
1907 let tmp = tempdir().expect("tempdir");
1908 let options = hermetic_session_options(tmp.path());
1909
1910 let handle = run_async(create_agent_session(options)).expect("create session");
1911 let provider = handle.session().agent.provider();
1912 assert!(!provider.name().is_empty());
1913 assert!(!provider.model_id().is_empty());
1914 assert_eq!(handle.model().0, provider.name());
1915 assert_eq!(handle.model().1, provider.model_id());
1916 }
1917
1918 #[test]
1919 fn create_agent_session_respects_provider_model_and_clamps_thinking() {
1920 let tmp = tempdir().expect("tempdir");
1921 let options = SessionOptions {
1922 provider: Some("openai".to_string()),
1923 model: Some("gpt-4o".to_string()),
1924 api_key: Some("dummy-key".to_string()),
1925 thinking: Some(crate::model::ThinkingLevel::Low),
1926 working_directory: Some(tmp.path().to_path_buf()),
1927 no_session: true,
1928 ..SessionOptions::default()
1929 };
1930
1931 let handle = run_async(create_agent_session(options)).expect("create session");
1932 let provider = handle.session().agent.provider();
1933 assert_eq!(provider.name(), "openai");
1934 assert_eq!(provider.model_id(), "gpt-4o");
1935 assert_eq!(
1936 handle.session().agent.stream_options().thinking_level,
1937 Some(crate::model::ThinkingLevel::Off)
1938 );
1939 }
1940
1941 #[test]
1942 fn create_agent_session_no_session_keeps_ephemeral_state() {
1943 let tmp = tempdir().expect("tempdir");
1944 let options = hermetic_session_options(tmp.path());
1945
1946 let handle = run_async(create_agent_session(options)).expect("create session");
1947 assert!(!handle.session().save_enabled());
1948
1949 let path_is_none = run_async(async {
1950 let cx = crate::agent_cx::AgentCx::for_request();
1951 let guard = handle
1952 .session()
1953 .session
1954 .lock(cx.cx())
1955 .await
1956 .expect("lock session");
1957 guard.path.is_none()
1958 });
1959 assert!(path_is_none);
1960 }
1961
1962 #[test]
1963 fn create_agent_session_uses_working_directory_for_new_session_header_and_path() {
1964 let _lock = current_dir_lock();
1965 let process_cwd = tempdir().expect("process cwd");
1966 let sdk_cwd = tempdir().expect("sdk cwd");
1967 let session_root = tempdir().expect("session root");
1968 let _guard = CurrentDirGuard::new(process_cwd.path());
1969
1970 let handle = run_async(create_agent_session(SessionOptions {
1971 provider: Some("openai".to_string()),
1972 model: Some("gpt-4o".to_string()),
1973 api_key: Some("dummy-key".to_string()),
1974 working_directory: Some(sdk_cwd.path().to_path_buf()),
1975 no_session: false,
1976 session_dir: Some(session_root.path().to_path_buf()),
1977 ..SessionOptions::default()
1978 }))
1979 .expect("create session");
1980
1981 let (header_cwd, path) = run_async(async {
1982 let cx = crate::agent_cx::AgentCx::for_request();
1983 let mut guard = handle
1984 .session()
1985 .session
1986 .lock(cx.cx())
1987 .await
1988 .expect("lock session");
1989 guard.save().await.expect("save sdk session");
1990 (
1991 guard.header.cwd.clone(),
1992 guard.path.clone().expect("saved session path"),
1993 )
1994 });
1995
1996 let expected_dir = session_root
1997 .path()
1998 .join(crate::session::encode_cwd(sdk_cwd.path()));
1999 let process_dir = session_root
2000 .path()
2001 .join(crate::session::encode_cwd(process_cwd.path()));
2002
2003 assert_eq!(header_cwd, sdk_cwd.path().display().to_string());
2004 assert_eq!(path.parent(), Some(expected_dir.as_path()));
2005 assert_ne!(path.parent(), Some(process_dir.as_path()));
2006 }
2007
2008 #[test]
2009 fn create_agent_session_resolves_relative_session_dir_against_working_directory() {
2010 let _lock = current_dir_lock();
2011 let process_cwd = tempdir().expect("process cwd");
2012 let sdk_cwd = tempdir().expect("sdk cwd");
2013 let _guard = CurrentDirGuard::new(process_cwd.path());
2014
2015 let handle = run_async(create_agent_session(SessionOptions {
2016 provider: Some("openai".to_string()),
2017 model: Some("gpt-4o".to_string()),
2018 api_key: Some("dummy-key".to_string()),
2019 working_directory: Some(sdk_cwd.path().to_path_buf()),
2020 no_session: false,
2021 session_dir: Some(PathBuf::from("sessions")),
2022 ..SessionOptions::default()
2023 }))
2024 .expect("create session");
2025
2026 let path = run_async(async {
2027 let cx = crate::agent_cx::AgentCx::for_request();
2028 let mut guard = handle
2029 .session()
2030 .session
2031 .lock(cx.cx())
2032 .await
2033 .expect("lock session");
2034 guard.save().await.expect("save sdk session");
2035 guard.path.clone().expect("saved session path")
2036 });
2037
2038 let expected_dir = sdk_cwd
2039 .path()
2040 .join("sessions")
2041 .join(crate::session::encode_cwd(sdk_cwd.path()));
2042 let process_dir = process_cwd
2043 .path()
2044 .join("sessions")
2045 .join(crate::session::encode_cwd(sdk_cwd.path()));
2046
2047 assert_eq!(path.parent(), Some(expected_dir.as_path()));
2048 assert_ne!(path.parent(), Some(process_dir.as_path()));
2049 }
2050
2051 #[test]
2052 fn create_agent_session_resolves_relative_session_path_against_working_directory() {
2053 let _lock = current_dir_lock();
2054 let process_cwd = tempdir().expect("process cwd");
2055 let sdk_cwd = tempdir().expect("sdk cwd");
2056 let _guard = CurrentDirGuard::new(process_cwd.path());
2057
2058 let session_path = sdk_cwd.path().join("relative").join("existing.jsonl");
2059 std::fs::create_dir_all(session_path.parent().expect("session parent"))
2060 .expect("create session parent");
2061 let mut header = crate::session::SessionHeader::new();
2062 header.cwd = sdk_cwd.path().display().to_string();
2063 let header_json = serde_json::to_string(&header).expect("serialize session header");
2064 std::fs::write(&session_path, format!("{header_json}\n")).expect("write session");
2065
2066 let handle = run_async(create_agent_session(SessionOptions {
2067 provider: Some("openai".to_string()),
2068 model: Some("gpt-4o".to_string()),
2069 api_key: Some("dummy-key".to_string()),
2070 working_directory: Some(sdk_cwd.path().to_path_buf()),
2071 no_session: false,
2072 session_path: Some(PathBuf::from("relative/existing.jsonl")),
2073 ..SessionOptions::default()
2074 }))
2075 .expect("create session");
2076
2077 let opened_path = run_async(async {
2078 let cx = crate::agent_cx::AgentCx::for_request();
2079 let guard = handle
2080 .session()
2081 .session
2082 .lock(cx.cx())
2083 .await
2084 .expect("lock session");
2085 guard.path.clone().expect("opened session path")
2086 });
2087
2088 assert_eq!(opened_path, session_path);
2089 }
2090
2091 #[test]
2092 fn from_session_with_listeners_set_model_switches_provider_model() {
2093 let dir = tempdir().expect("tempdir");
2094 let auth_path = dir.path().join("auth.json");
2095 let mut auth = AuthStorage::load(auth_path).expect("load auth");
2096 auth.set(
2097 "anthropic",
2098 crate::auth::AuthCredential::ApiKey {
2099 key: "anthropic-key".to_string(),
2100 },
2101 );
2102 auth.set(
2103 "openai",
2104 crate::auth::AuthCredential::ApiKey {
2105 key: "openai-key".to_string(),
2106 },
2107 );
2108
2109 let registry = ModelRegistry::load(&auth, None);
2110 let entry = registry
2111 .find("anthropic", "claude-sonnet-4-5")
2112 .expect("anthropic model in registry");
2113 let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
2114 let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
2115 let agent = Agent::new(
2116 provider,
2117 tools,
2118 AgentConfig {
2119 system_prompt: None,
2120 max_tool_iterations: 50,
2121 stream_options: StreamOptions::default(),
2122 block_images: false,
2123 fail_closed_hooks: false,
2124 tool_approval: None,
2125 },
2126 );
2127
2128 let mut session = Session::in_memory();
2129 session.header.provider = Some("anthropic".to_string());
2130 session.header.model_id = Some("claude-sonnet-4-5".to_string());
2131
2132 let mut agent_session = AgentSession::new(
2133 agent,
2134 Arc::new(AsyncMutex::new(session)),
2135 false,
2136 ResolvedCompactionSettings::default(),
2137 );
2138 agent_session.set_model_registry(registry);
2139 agent_session.set_auth_storage(auth);
2140
2141 let mut handle =
2142 AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
2143 run_async(handle.set_model("openai", "gpt-4o")).expect("set model");
2144 let provider = handle.session().agent.provider();
2145 assert_eq!(provider.name(), "openai");
2146 assert_eq!(provider.model_id(), "gpt-4o");
2147 }
2148
2149 #[test]
2150 fn create_agent_session_set_thinking_level_clamps_and_dedupes_history() {
2151 let tmp = tempdir().expect("tempdir");
2152 let options = SessionOptions {
2153 provider: Some("openai".to_string()),
2154 model: Some("gpt-4o".to_string()),
2155 api_key: Some("dummy-key".to_string()),
2156 working_directory: Some(tmp.path().to_path_buf()),
2157 no_session: true,
2158 ..SessionOptions::default()
2159 };
2160
2161 let mut handle = run_async(create_agent_session(options)).expect("create session");
2162 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
2163 .expect("set thinking");
2164 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
2165 .expect("reapply thinking");
2166
2167 assert_eq!(
2168 handle.session().agent.stream_options().thinking_level,
2169 Some(crate::model::ThinkingLevel::Off)
2170 );
2171
2172 let thinking_changes = run_async(async {
2173 let cx = crate::agent_cx::AgentCx::for_request();
2174 let guard = handle
2175 .session()
2176 .session
2177 .lock(cx.cx())
2178 .await
2179 .expect("lock session");
2180 assert_eq!(guard.header.thinking_level.as_deref(), Some("off"));
2181 guard
2182 .entries_for_current_path()
2183 .iter()
2184 .filter(|entry| {
2185 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
2186 })
2187 .count()
2188 });
2189 assert_eq!(thinking_changes, 1);
2190 }
2191
2192 #[test]
2193 fn from_session_with_listeners_set_thinking_level_uses_session_header_target() {
2194 let dir = tempdir().expect("tempdir");
2195 let auth_path = dir.path().join("auth.json");
2196 let auth = crate::auth::AuthStorage::load(auth_path).expect("load auth");
2197 let mut registry = ModelRegistry::load(&auth, None);
2198 registry.merge_entries(vec![ModelEntry {
2199 model: Model {
2200 id: "plain-model".to_string(),
2201 name: "Plain Model".to_string(),
2202 api: "openai-completions".to_string(),
2203 provider: "acme".to_string(),
2204 base_url: "https://example.invalid/v1".to_string(),
2205 reasoning: false,
2206 input: vec![InputType::Text],
2207 cost: ModelCost {
2208 input: 0.0,
2209 output: 0.0,
2210 cache_read: 0.0,
2211 cache_write: 0.0,
2212 },
2213 context_window: 128_000,
2214 max_tokens: 8_192,
2215 headers: HashMap::new(),
2216 },
2217 api_key: None,
2218 headers: HashMap::new(),
2219 auth_header: false,
2220 compat: None,
2221 oauth_config: None,
2222 }]);
2223 let entry = registry
2224 .find("anthropic", "claude-sonnet-4-5")
2225 .expect("anthropic model in registry");
2226 let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
2227 let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
2228 let agent = Agent::new(
2229 provider,
2230 tools,
2231 AgentConfig {
2232 system_prompt: None,
2233 max_tool_iterations: 50,
2234 stream_options: StreamOptions::default(),
2235 block_images: false,
2236 fail_closed_hooks: false,
2237 tool_approval: None,
2238 },
2239 );
2240
2241 let mut session = Session::in_memory();
2242 session.header.provider = Some("acme".to_string());
2243 session.header.model_id = Some("plain-model".to_string());
2244
2245 let mut agent_session = AgentSession::new(
2246 agent,
2247 Arc::new(AsyncMutex::new(session)),
2248 false,
2249 ResolvedCompactionSettings::default(),
2250 );
2251 agent_session.set_model_registry(registry);
2252
2253 let mut handle =
2254 AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
2255 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
2256 .expect("set thinking");
2257
2258 assert_eq!(
2259 handle.session().agent.stream_options().thinking_level,
2260 Some(crate::model::ThinkingLevel::Off)
2261 );
2262 assert_eq!(handle.model().0, "anthropic");
2263 assert_eq!(handle.model().1, "claude-sonnet-4-5");
2264 }
2265
2266 #[test]
2267 fn compact_without_history_is_noop() {
2268 let tmp = tempdir().expect("tempdir");
2269 let options = hermetic_session_options(tmp.path());
2270
2271 let mut handle = run_async(create_agent_session(options)).expect("create session");
2272 let events = Arc::new(Mutex::new(Vec::new()));
2273 let events_for_callback = Arc::clone(&events);
2274 run_async(handle.compact(move |event| {
2275 events_for_callback
2276 .lock()
2277 .expect("compact callback lock")
2278 .push(event);
2279 }))
2280 .expect("compact");
2281
2282 assert!(
2283 events
2284 .lock()
2285 .unwrap_or_else(std::sync::PoisonError::into_inner)
2286 .is_empty(),
2287 "expected no compaction lifecycle events for empty session"
2288 );
2289 }
2290
2291 #[test]
2292 fn resolve_path_for_cwd_uses_cwd_for_relative_paths() {
2293 let cwd = Path::new("/tmp/pi-sdk-cwd");
2294 assert_eq!(
2295 resolve_path_for_cwd(Path::new("relative/file.txt"), cwd),
2296 PathBuf::from("/tmp/pi-sdk-cwd/relative/file.txt")
2297 );
2298 assert_eq!(
2299 resolve_path_for_cwd(Path::new("/etc/hosts"), cwd),
2300 PathBuf::from("/etc/hosts")
2301 );
2302 }
2303
2304 #[test]
2309 fn event_listeners_subscribe_and_notify() {
2310 let listeners = EventListeners::new();
2311 let received = Arc::new(Mutex::new(Vec::new()));
2312
2313 let recv_clone = Arc::clone(&received);
2314 let id = listeners.subscribe(Arc::new(move |event| {
2315 recv_clone
2316 .lock()
2317 .unwrap_or_else(std::sync::PoisonError::into_inner)
2318 .push(event);
2319 }));
2320
2321 let event = AgentEvent::AgentStart {
2322 session_id: "test-123".into(),
2323 };
2324 listeners.notify(&event);
2325
2326 let events = received
2327 .lock()
2328 .unwrap_or_else(std::sync::PoisonError::into_inner);
2329 assert_eq!(events.len(), 1);
2330
2331 drop(events);
2333 assert!(listeners.unsubscribe(id));
2334 listeners.notify(&AgentEvent::AgentStart {
2335 session_id: "test-456".into(),
2336 });
2337 assert_eq!(
2338 received
2339 .lock()
2340 .unwrap_or_else(std::sync::PoisonError::into_inner)
2341 .len(),
2342 1
2343 );
2344 }
2345
2346 #[test]
2347 fn event_listeners_unsubscribe_nonexistent_returns_false() {
2348 let listeners = EventListeners::new();
2349 assert!(!listeners.unsubscribe(SubscriptionId(999)));
2350 }
2351
2352 #[test]
2353 fn event_listeners_multiple_subscribers() {
2354 let listeners = EventListeners::new();
2355 let count_a = Arc::new(Mutex::new(0u32));
2356 let count_b = Arc::new(Mutex::new(0u32));
2357
2358 let ca = Arc::clone(&count_a);
2359 listeners.subscribe(Arc::new(move |_| {
2360 *ca.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2361 }));
2362
2363 let cb = Arc::clone(&count_b);
2364 listeners.subscribe(Arc::new(move |_| {
2365 *cb.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2366 }));
2367
2368 listeners.notify(&AgentEvent::AgentStart {
2369 session_id: "s".into(),
2370 });
2371
2372 assert_eq!(
2373 *count_a
2374 .lock()
2375 .unwrap_or_else(std::sync::PoisonError::into_inner),
2376 1
2377 );
2378 assert_eq!(
2379 *count_b
2380 .lock()
2381 .unwrap_or_else(std::sync::PoisonError::into_inner),
2382 1
2383 );
2384 }
2385
2386 #[test]
2387 fn event_listeners_tool_hooks_fire() {
2388 let listeners = EventListeners::new();
2389 let starts = Arc::new(Mutex::new(Vec::new()));
2390 let ends = Arc::new(Mutex::new(Vec::new()));
2391
2392 let s = Arc::clone(&starts);
2393 let mut listeners = listeners;
2394 listeners.on_tool_start = Some(Arc::new(move |name, args| {
2395 s.lock()
2396 .expect("lock")
2397 .push((name.to_string(), args.clone()));
2398 }));
2399
2400 let e = Arc::clone(&ends);
2401 listeners.on_tool_end = Some(Arc::new(move |name, _output, is_error| {
2402 e.lock()
2403 .unwrap_or_else(std::sync::PoisonError::into_inner)
2404 .push((name.to_string(), is_error));
2405 }));
2406
2407 let args = serde_json::json!({"path": "/foo"});
2408 listeners.notify_tool_start("bash", &args);
2409 let output = ToolOutput {
2410 content: vec![ContentBlock::Text(TextContent::new("ok"))],
2411 details: None,
2412 is_error: false,
2413 };
2414 listeners.notify_tool_end("bash", &output, false);
2415
2416 {
2417 let s = starts
2418 .lock()
2419 .unwrap_or_else(std::sync::PoisonError::into_inner);
2420 assert_eq!(s.len(), 1);
2421 assert_eq!(s[0].0, "bash");
2422 drop(s);
2423 }
2424
2425 {
2426 let e = ends
2427 .lock()
2428 .unwrap_or_else(std::sync::PoisonError::into_inner);
2429 assert_eq!(e.len(), 1);
2430 assert_eq!(e[0].0, "bash");
2431 assert!(!e[0].1);
2432 drop(e);
2433 }
2434 }
2435
2436 #[test]
2437 fn event_listeners_stream_event_hook_fires() {
2438 let mut listeners = EventListeners::new();
2439 let received = Arc::new(Mutex::new(Vec::new()));
2440
2441 let r = Arc::clone(&received);
2442 listeners.on_stream_event = Some(Arc::new(move |ev| {
2443 r.lock()
2444 .unwrap_or_else(std::sync::PoisonError::into_inner)
2445 .push(format!("{ev:?}"));
2446 }));
2447
2448 let event = StreamEvent::TextDelta {
2449 content_index: 0,
2450 delta: "hello".to_string(),
2451 };
2452 listeners.notify_stream_event(&event);
2453
2454 assert_eq!(
2455 received
2456 .lock()
2457 .unwrap_or_else(std::sync::PoisonError::into_inner)
2458 .len(),
2459 1
2460 );
2461 }
2462
2463 #[test]
2464 fn session_options_on_event_wired_into_listeners() {
2465 let received = Arc::new(Mutex::new(Vec::new()));
2466 let r = Arc::clone(&received);
2467 let tmp = tempdir().expect("tempdir");
2468
2469 let options = SessionOptions {
2470 on_event: Some(Arc::new(move |event| {
2471 r.lock()
2472 .unwrap_or_else(std::sync::PoisonError::into_inner)
2473 .push(format!("{event:?}"));
2474 })),
2475 ..hermetic_session_options(tmp.path())
2476 };
2477
2478 let handle = run_async(create_agent_session(options)).expect("create session");
2479 let count = handle
2481 .listeners()
2482 .subscribers
2483 .lock()
2484 .unwrap_or_else(std::sync::PoisonError::into_inner)
2485 .len();
2486 assert_eq!(
2487 count, 1,
2488 "on_event from SessionOptions should register one subscriber"
2489 );
2490 }
2491
2492 #[test]
2493 fn subscribe_unsubscribe_on_handle() {
2494 let tmp = tempdir().expect("tempdir");
2495 let options = hermetic_session_options(tmp.path());
2496
2497 let handle = run_async(create_agent_session(options)).expect("create session");
2498 let id = handle.subscribe(|_event| {});
2499 assert_eq!(
2500 handle
2501 .listeners()
2502 .subscribers
2503 .lock()
2504 .unwrap_or_else(std::sync::PoisonError::into_inner)
2505 .len(),
2506 1
2507 );
2508
2509 assert!(handle.unsubscribe(id));
2510 assert_eq!(
2511 handle
2512 .listeners()
2513 .subscribers
2514 .lock()
2515 .unwrap_or_else(std::sync::PoisonError::into_inner)
2516 .len(),
2517 0
2518 );
2519
2520 assert!(!handle.unsubscribe(id));
2522 }
2523
2524 #[test]
2525 fn stream_event_from_assistant_message_event_converts_text_delta() {
2526 use crate::model::AssistantMessageEvent as AME;
2527
2528 let partial = Arc::new(AssistantMessage {
2529 content: Vec::new(),
2530 api: String::new(),
2531 provider: String::new(),
2532 model: String::new(),
2533 usage: Usage::default(),
2534 stop_reason: StopReason::Stop,
2535 error_message: None,
2536 timestamp: 0,
2537 });
2538 let ame = AME::TextDelta {
2539 content_index: 2,
2540 delta: "chunk".to_string(),
2541 partial,
2542 };
2543 let result = stream_event_from_assistant_message_event(&ame);
2544 assert!(result.is_some());
2545 match result.unwrap() {
2546 StreamEvent::TextDelta {
2547 content_index,
2548 delta,
2549 } => {
2550 assert_eq!(content_index, 2);
2551 assert_eq!(delta, "chunk");
2552 }
2553 other => unreachable!("expected TextDelta, got {other:?}"),
2554 }
2555 }
2556
2557 #[test]
2558 fn stream_event_from_assistant_message_event_start_returns_none() {
2559 use crate::model::AssistantMessageEvent as AME;
2560
2561 let partial = Arc::new(AssistantMessage {
2562 content: Vec::new(),
2563 api: String::new(),
2564 provider: String::new(),
2565 model: String::new(),
2566 usage: Usage::default(),
2567 stop_reason: StopReason::Stop,
2568 error_message: None,
2569 timestamp: 0,
2570 });
2571 let ame = AME::Start { partial };
2572 assert!(stream_event_from_assistant_message_event(&ame).is_none());
2573 }
2574
2575 #[test]
2576 fn event_listeners_debug_impl() {
2577 let listeners = EventListeners::new();
2578 let debug = format!("{listeners:?}");
2579 assert!(debug.contains("subscriber_count"));
2580 assert!(debug.contains("has_on_tool_start"));
2581 }
2582
2583 #[test]
2588 fn has_extensions_false_by_default() {
2589 let tmp = tempdir().expect("tempdir");
2590 let options = hermetic_session_options(tmp.path());
2591
2592 let handle = run_async(create_agent_session(options)).expect("create session");
2593 assert!(
2594 !handle.has_extensions(),
2595 "session without extension_paths should have no extensions"
2596 );
2597 assert!(handle.extension_manager().is_none());
2598 assert!(handle.extension_region().is_none());
2599 }
2600
2601 #[test]
2606 fn create_read_tool_has_correct_name() {
2607 let tmp = tempdir().expect("tempdir");
2608 let tool = super::create_read_tool(tmp.path());
2609 assert_eq!(tool.name(), "read");
2610 assert!(!tool.description().is_empty());
2611 let params = tool.parameters();
2612 assert!(params.is_object(), "parameters should be a JSON object");
2613 }
2614
2615 #[test]
2616 fn create_bash_tool_has_correct_name() {
2617 let tmp = tempdir().expect("tempdir");
2618 let tool = super::create_bash_tool(tmp.path());
2619 assert_eq!(tool.name(), "bash");
2620 assert!(!tool.description().is_empty());
2621 }
2622
2623 #[test]
2624 fn create_edit_tool_has_correct_name() {
2625 let tmp = tempdir().expect("tempdir");
2626 let tool = super::create_edit_tool(tmp.path());
2627 assert_eq!(tool.name(), "edit");
2628 }
2629
2630 #[test]
2631 fn create_write_tool_has_correct_name() {
2632 let tmp = tempdir().expect("tempdir");
2633 let tool = super::create_write_tool(tmp.path());
2634 assert_eq!(tool.name(), "write");
2635 }
2636
2637 #[test]
2638 fn create_grep_tool_has_correct_name() {
2639 let tmp = tempdir().expect("tempdir");
2640 let tool = super::create_grep_tool(tmp.path());
2641 assert_eq!(tool.name(), "grep");
2642 }
2643
2644 #[test]
2645 fn create_find_tool_has_correct_name() {
2646 let tmp = tempdir().expect("tempdir");
2647 let tool = super::create_find_tool(tmp.path());
2648 assert_eq!(tool.name(), "find");
2649 }
2650
2651 #[test]
2652 fn create_ls_tool_has_correct_name() {
2653 let tmp = tempdir().expect("tempdir");
2654 let tool = super::create_ls_tool(tmp.path());
2655 assert_eq!(tool.name(), "ls");
2656 }
2657
2658 #[test]
2659 fn create_all_tools_returns_eight() {
2660 let tmp = tempdir().expect("tempdir");
2661 let tools = super::create_all_tools(tmp.path());
2662 assert_eq!(tools.len(), 8, "should create all 8 built-in tools");
2663
2664 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2665 for expected in BUILTIN_TOOL_NAMES {
2666 assert!(names.contains(expected), "missing tool: {expected}");
2667 }
2668 }
2669
2670 #[test]
2671 fn tool_to_definition_preserves_schema() {
2672 let tmp = tempdir().expect("tempdir");
2673 let tool = super::create_read_tool(tmp.path());
2674 let def = super::tool_to_definition(tool.as_ref());
2675 assert_eq!(def.name, "read");
2676 assert!(!def.description.is_empty());
2677 assert!(def.parameters.is_object());
2678 assert!(
2679 def.parameters.get("properties").is_some(),
2680 "schema should have properties"
2681 );
2682 }
2683
2684 #[test]
2685 fn all_tool_definitions_returns_eight_schemas() {
2686 let tmp = tempdir().expect("tempdir");
2687 let defs = super::all_tool_definitions(tmp.path());
2688 assert_eq!(defs.len(), 8);
2689
2690 for def in &defs {
2691 assert!(!def.name.is_empty());
2692 assert!(!def.description.is_empty());
2693 assert!(def.parameters.is_object());
2694 }
2695 }
2696
2697 #[test]
2698 fn builtin_tool_names_matches_create_all() {
2699 let tmp = tempdir().expect("tempdir");
2700 let tools = super::create_all_tools(tmp.path());
2701 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2702 assert_eq!(
2703 names.as_slice(),
2704 BUILTIN_TOOL_NAMES,
2705 "create_all_tools order should match BUILTIN_TOOL_NAMES"
2706 );
2707 }
2708
2709 #[test]
2710 fn tool_registry_from_factory_tools() {
2711 let tmp = tempdir().expect("tempdir");
2712 let tools = super::create_all_tools(tmp.path());
2713 let registry = ToolRegistry::from_tools(tools);
2714 assert!(registry.get("read").is_some());
2715 assert!(registry.get("bash").is_some());
2716 assert!(registry.get("nonexistent").is_none());
2717 }
2718
2719 #[test]
2720 fn set_session_name_records_session_info_entry() {
2721 let tmp = tempdir().expect("tempdir");
2722 let options = hermetic_session_options(tmp.path());
2723
2724 let mut handle = run_async(create_agent_session(options)).expect("create session");
2725 run_async(handle.set_session_name("renamed-by-sdk")).expect("set session name");
2726
2727 let info_entries = run_async(async {
2728 let cx = crate::agent_cx::AgentCx::for_request();
2729 let guard = handle
2730 .session()
2731 .session
2732 .lock(cx.cx())
2733 .await
2734 .expect("lock session");
2735 guard
2736 .entries_for_current_path()
2737 .iter()
2738 .filter_map(|entry| match entry {
2739 crate::session::SessionEntry::SessionInfo(info) => info.name.clone(),
2740 _ => None,
2741 })
2742 .collect::<Vec<_>>()
2743 });
2744 assert_eq!(info_entries, vec!["renamed-by-sdk".to_string()]);
2745 }
2746
2747 #[test]
2748 fn max_tokens_default_is_none_and_set_overrides() {
2749 let tmp = tempdir().expect("tempdir");
2750 let options = hermetic_session_options(tmp.path());
2751
2752 let mut handle = run_async(create_agent_session(options)).expect("create session");
2753 assert_eq!(handle.max_tokens(), None);
2754
2755 handle.set_max_tokens(Some(32_000));
2756 assert_eq!(handle.max_tokens(), Some(32_000));
2757 assert_eq!(
2758 handle.session().agent.stream_options().max_tokens,
2759 Some(32_000)
2760 );
2761
2762 handle.set_max_tokens(None);
2763 assert_eq!(handle.max_tokens(), None);
2764 }
2765}