1use crate::app;
23use crate::auth::AuthStorage;
24use crate::cli::Cli;
25use crate::compaction::ResolvedCompactionSettings;
26use crate::models::default_models_path;
27use crate::provider::ThinkingBudgets;
28use crate::providers;
29use clap::Parser;
30use serde::{Deserialize, Serialize, de::DeserializeOwned};
31use serde_json::{Map, Value};
32use std::collections::HashMap;
33use std::io::{BufRead, BufReader, BufWriter, Write};
34use std::path::{Path, PathBuf};
35use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38
39pub use crate::agent::{
40 AbortHandle, AbortSignal, Agent, AgentConfig, AgentEvent, AgentSession, QueueMode,
41};
42pub use crate::config::Config;
43pub use crate::error::{Error, Result};
44pub use crate::extensions::{ExtensionManager, ExtensionPolicy, ExtensionRegion};
45pub use crate::model::{
46 AssistantMessage, ContentBlock, Cost, CustomMessage, ImageContent, Message, StopReason,
47 StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage, UserContent,
48 UserMessage,
49};
50pub use crate::models::{ModelEntry, ModelRegistry};
51pub use crate::provider::{
52 Context as ProviderContext, InputType, Model, ModelCost, Provider, StreamOptions,
53 ThinkingBudgets as ProviderThinkingBudgets, ToolDef,
54};
55pub use crate::session::Session;
56pub use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
57pub use crate::model::ThinkingLevel;
58
59pub type ToolDefinition = ToolDef;
61
62use crate::tools::{
67 BashTool, EditTool, FindTool, GrepTool, HashlineEditTool, LsTool, ReadTool, WriteTool,
68};
69
70pub const BUILTIN_TOOL_NAMES: &[&str] = &[
72 "read",
73 "bash",
74 "edit",
75 "write",
76 "grep",
77 "find",
78 "ls",
79 "hashline_edit",
80];
81
82pub fn create_read_tool(cwd: &Path) -> Box<dyn Tool> {
84 Box::new(ReadTool::new(cwd))
85}
86
87pub fn create_bash_tool(cwd: &Path) -> Box<dyn Tool> {
89 Box::new(BashTool::new(cwd))
90}
91
92pub fn create_edit_tool(cwd: &Path) -> Box<dyn Tool> {
94 Box::new(EditTool::new(cwd))
95}
96
97pub fn create_write_tool(cwd: &Path) -> Box<dyn Tool> {
99 Box::new(WriteTool::new(cwd))
100}
101
102pub fn create_grep_tool(cwd: &Path) -> Box<dyn Tool> {
104 Box::new(GrepTool::new(cwd))
105}
106
107pub fn create_find_tool(cwd: &Path) -> Box<dyn Tool> {
109 Box::new(FindTool::new(cwd))
110}
111
112pub fn create_ls_tool(cwd: &Path) -> Box<dyn Tool> {
114 Box::new(LsTool::new(cwd))
115}
116
117pub fn create_hashline_edit_tool(cwd: &Path) -> Box<dyn Tool> {
119 Box::new(HashlineEditTool::new(cwd))
120}
121
122pub fn create_all_tools(cwd: &Path) -> Vec<Box<dyn Tool>> {
124 vec![
125 create_read_tool(cwd),
126 create_bash_tool(cwd),
127 create_edit_tool(cwd),
128 create_write_tool(cwd),
129 create_grep_tool(cwd),
130 create_find_tool(cwd),
131 create_ls_tool(cwd),
132 create_hashline_edit_tool(cwd),
133 ]
134}
135
136pub fn tool_to_definition(tool: &dyn Tool) -> ToolDefinition {
138 ToolDefinition {
139 name: tool.name().to_string(),
140 description: tool.description().to_string(),
141 parameters: tool.parameters(),
142 }
143}
144
145pub fn all_tool_definitions(cwd: &Path) -> Vec<ToolDefinition> {
147 create_all_tools(cwd)
148 .iter()
149 .map(|t| tool_to_definition(t.as_ref()))
150 .collect()
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub struct SubscriptionId(u64);
163
164pub type OnToolStart = Arc<dyn Fn(&str, &Value) + Send + Sync>;
168
169pub type OnToolEnd = Arc<dyn Fn(&str, &ToolOutput, bool) + Send + Sync>;
173
174pub type OnStreamEvent = Arc<dyn Fn(&StreamEvent) + Send + Sync>;
179
180pub type EventSubscriber = Arc<dyn Fn(AgentEvent) + Send + Sync>;
181type EventSubscribers = HashMap<SubscriptionId, EventSubscriber>;
182
183#[derive(Clone, Default)]
189pub struct EventListeners {
190 next_id: Arc<AtomicU64>,
191 subscribers: Arc<std::sync::Mutex<EventSubscribers>>,
192 pub on_tool_start: Option<OnToolStart>,
193 pub on_tool_end: Option<OnToolEnd>,
194 pub on_stream_event: Option<OnStreamEvent>,
195}
196
197impl EventListeners {
198 fn new() -> Self {
199 Self {
200 next_id: Arc::new(AtomicU64::new(1)),
201 subscribers: Arc::new(std::sync::Mutex::new(HashMap::new())),
202 on_tool_start: None,
203 on_tool_end: None,
204 on_stream_event: None,
205 }
206 }
207
208 pub fn subscribe(&self, listener: EventSubscriber) -> SubscriptionId {
210 let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
211 let mut subs = self
212 .subscribers
213 .lock()
214 .unwrap_or_else(std::sync::PoisonError::into_inner);
215 subs.insert(id, listener);
216 id
217 }
218
219 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
221 let mut subs = self
222 .subscribers
223 .lock()
224 .unwrap_or_else(std::sync::PoisonError::into_inner);
225 subs.remove(&id).is_some()
226 }
227
228 pub fn notify(&self, event: &AgentEvent) {
230 let listeners: Vec<_> = {
231 let subs = self
232 .subscribers
233 .lock()
234 .unwrap_or_else(std::sync::PoisonError::into_inner);
235 subs.values().cloned().collect()
236 };
237 for listener in listeners {
238 listener(event.clone());
239 }
240 }
241
242 pub fn notify_tool_start(&self, tool_name: &str, args: &Value) {
244 if let Some(cb) = &self.on_tool_start {
245 cb(tool_name, args);
246 }
247 }
248
249 pub fn notify_tool_end(&self, tool_name: &str, output: &ToolOutput, is_error: bool) {
251 if let Some(cb) = &self.on_tool_end {
252 cb(tool_name, output, is_error);
253 }
254 }
255
256 pub fn notify_stream_event(&self, event: &StreamEvent) {
258 if let Some(cb) = &self.on_stream_event {
259 cb(event);
260 }
261 }
262}
263
264impl std::fmt::Debug for EventListeners {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 let count = self.subscribers.lock().map_or(0, |s| s.len());
267 let next_id = self.next_id.load(Ordering::Relaxed);
268 f.debug_struct("EventListeners")
269 .field("subscriber_count", &count)
270 .field("next_id", &next_id)
271 .field("has_on_tool_start", &self.on_tool_start.is_some())
272 .field("has_on_tool_end", &self.on_tool_end.is_some())
273 .field("has_on_stream_event", &self.on_stream_event.is_some())
274 .finish()
275 }
276}
277
278#[derive(Clone)]
283pub struct SessionOptions {
284 pub provider: Option<String>,
285 pub model: Option<String>,
286 pub api_key: Option<String>,
287 pub thinking: Option<crate::model::ThinkingLevel>,
288 pub system_prompt: Option<String>,
289 pub append_system_prompt: Option<String>,
290 pub enabled_tools: Option<Vec<String>>,
291 pub working_directory: Option<PathBuf>,
292 pub no_session: bool,
293 pub session_path: Option<PathBuf>,
294 pub session_dir: Option<PathBuf>,
295 pub extension_paths: Vec<PathBuf>,
296 pub extension_policy: Option<String>,
297 pub repair_policy: Option<String>,
298 pub include_cwd_in_prompt: bool,
299 pub max_tool_iterations: usize,
300
301 pub on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
306
307 pub on_tool_start: Option<OnToolStart>,
309
310 pub on_tool_end: Option<OnToolEnd>,
312
313 pub on_stream_event: Option<OnStreamEvent>,
315}
316
317impl Default for SessionOptions {
318 fn default() -> Self {
319 Self {
320 provider: None,
321 model: None,
322 api_key: None,
323 thinking: None,
324 system_prompt: None,
325 append_system_prompt: None,
326 enabled_tools: None,
327 working_directory: None,
328 no_session: true,
329 session_path: None,
330 session_dir: None,
331 extension_paths: Vec::new(),
332 extension_policy: None,
333 repair_policy: None,
334 include_cwd_in_prompt: true,
335 max_tool_iterations: 50,
336 on_event: None,
337 on_tool_start: None,
338 on_tool_end: None,
339 on_stream_event: None,
340 }
341 }
342}
343
344pub struct AgentSessionHandle {
353 session: AgentSession,
354 listeners: EventListeners,
355}
356
357#[derive(Debug, Clone, PartialEq, Eq)]
359pub struct AgentSessionState {
360 pub session_id: Option<String>,
361 pub provider: String,
362 pub model_id: String,
363 pub thinking_level: Option<crate::model::ThinkingLevel>,
364 pub save_enabled: bool,
365 pub message_count: usize,
366}
367
368#[derive(Debug, Clone)]
370pub enum SessionPromptResult {
371 InProcess(AssistantMessage),
372 RpcEvents(Vec<Value>),
373}
374
375#[derive(Debug, Clone)]
377pub enum SessionTransportEvent {
378 InProcess(AgentEvent),
379 Rpc(Value),
380}
381
382#[derive(Debug, Clone, PartialEq)]
384pub enum SessionTransportState {
385 InProcess(AgentSessionState),
386 Rpc(Box<RpcSessionState>),
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
391#[serde(rename_all = "camelCase")]
392pub struct RpcModelInfo {
393 pub id: String,
394 pub name: String,
395 pub api: String,
396 pub provider: String,
397 #[serde(default)]
398 pub base_url: String,
399 #[serde(default)]
400 pub reasoning: bool,
401 #[serde(default)]
402 pub input: Vec<InputType>,
403 #[serde(default)]
404 pub context_window: u32,
405 #[serde(default)]
406 pub max_tokens: u32,
407 #[serde(default)]
408 pub cost: Option<ModelCost>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
413#[serde(rename_all = "camelCase")]
414#[allow(clippy::struct_excessive_bools)]
415pub struct RpcSessionState {
416 #[serde(default)]
417 pub model: Option<RpcModelInfo>,
418 #[serde(default)]
419 pub thinking_level: String,
420 #[serde(default)]
421 pub is_streaming: bool,
422 #[serde(default)]
423 pub is_compacting: bool,
424 #[serde(default)]
425 pub steering_mode: String,
426 #[serde(default)]
427 pub follow_up_mode: String,
428 #[serde(default)]
429 pub session_file: Option<String>,
430 #[serde(default)]
431 pub session_id: String,
432 #[serde(default)]
433 pub session_name: Option<String>,
434 #[serde(default)]
435 pub auto_compaction_enabled: bool,
436 #[serde(default)]
437 pub auto_retry_enabled: bool,
438 #[serde(default)]
439 pub message_count: usize,
440 #[serde(default)]
441 pub pending_message_count: usize,
442 #[serde(default)]
443 pub durability_mode: String,
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
448#[serde(rename_all = "camelCase")]
449pub struct RpcTokenStats {
450 pub input: u64,
451 pub output: u64,
452 pub cache_read: u64,
453 pub cache_write: u64,
454 pub total: u64,
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
459#[serde(rename_all = "camelCase")]
460pub struct RpcSessionStats {
461 #[serde(default)]
462 pub session_file: Option<String>,
463 pub session_id: String,
464 pub user_messages: u64,
465 pub assistant_messages: u64,
466 pub tool_calls: u64,
467 pub tool_results: u64,
468 pub total_messages: u64,
469 pub tokens: RpcTokenStats,
470 pub cost: f64,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
475pub struct RpcCancelledResult {
476 pub cancelled: bool,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
481#[serde(rename_all = "camelCase")]
482pub struct RpcCycleModelResult {
483 pub model: RpcModelInfo,
484 pub thinking_level: crate::model::ThinkingLevel,
485 pub is_scoped: bool,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
490pub struct RpcThinkingLevelResult {
491 pub level: crate::model::ThinkingLevel,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
496#[serde(rename_all = "camelCase")]
497pub struct RpcBashResult {
498 pub output: String,
499 pub exit_code: i32,
500 pub cancelled: bool,
501 pub truncated: bool,
502 pub full_output_path: Option<String>,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
507#[serde(rename_all = "camelCase")]
508pub struct RpcCompactionResult {
509 pub summary: String,
510 pub first_kept_entry_id: String,
511 pub tokens_before: u64,
512 #[serde(default)]
513 pub details: Value,
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
518pub struct RpcForkResult {
519 pub text: String,
520 pub cancelled: bool,
521}
522
523#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
525#[serde(rename_all = "camelCase")]
526pub struct RpcForkMessage {
527 pub entry_id: String,
528 pub text: String,
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
533pub struct RpcCommandInfo {
534 pub name: String,
535 #[serde(default)]
536 pub description: Option<String>,
537 pub source: String,
538 #[serde(default)]
539 pub location: Option<String>,
540 #[serde(default)]
541 pub path: Option<String>,
542}
543
544#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
546pub struct RpcExportHtmlResult {
547 pub path: String,
548}
549
550#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
552pub struct RpcLastAssistantText {
553 pub text: Option<String>,
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
558#[serde(tag = "kind", rename_all = "snake_case")]
559pub enum RpcExtensionUiResponse {
560 Value { value: Value },
561 Confirmed { confirmed: bool },
562 Cancelled,
563}
564
565#[derive(Debug, Clone)]
567pub struct RpcTransportOptions {
568 pub binary_path: PathBuf,
569 pub args: Vec<String>,
570 pub cwd: Option<PathBuf>,
571}
572
573impl Default for RpcTransportOptions {
574 fn default() -> Self {
575 Self {
576 binary_path: PathBuf::from("pi"),
577 args: vec!["--mode".to_string(), "rpc".to_string()],
578 cwd: None,
579 }
580 }
581}
582
583pub struct RpcTransportClient {
585 child: Child,
586 stdin: BufWriter<ChildStdin>,
587 stdout: BufReader<ChildStdout>,
588 next_request_id: u64,
589}
590
591pub enum SessionTransport {
593 InProcess(Box<AgentSessionHandle>),
594 RpcSubprocess(RpcTransportClient),
595}
596
597impl SessionTransport {
598 pub async fn in_process(options: SessionOptions) -> Result<Self> {
599 create_agent_session(options)
600 .await
601 .map(Box::new)
602 .map(Self::InProcess)
603 }
604
605 pub fn rpc_subprocess(options: RpcTransportOptions) -> Result<Self> {
606 RpcTransportClient::connect(options).map(Self::RpcSubprocess)
607 }
608
609 #[allow(clippy::missing_const_for_fn)]
610 pub fn as_in_process_mut(&mut self) -> Option<&mut AgentSessionHandle> {
611 match self {
612 Self::InProcess(handle) => Some(handle.as_mut()),
613 Self::RpcSubprocess(_) => None,
614 }
615 }
616
617 #[allow(clippy::missing_const_for_fn)]
618 pub fn as_rpc_mut(&mut self) -> Option<&mut RpcTransportClient> {
619 match self {
620 Self::InProcess(_) => None,
621 Self::RpcSubprocess(client) => Some(client),
622 }
623 }
624
625 pub async fn prompt(
630 &mut self,
631 input: impl Into<String>,
632 on_event: impl Fn(SessionTransportEvent) + Send + Sync + 'static,
633 ) -> Result<SessionPromptResult> {
634 let input = input.into();
635 let on_event = Arc::new(on_event);
636 match self {
637 Self::InProcess(handle) => {
638 let on_event = Arc::clone(&on_event);
639 let assistant = handle
640 .prompt(input, move |event| {
641 (on_event)(SessionTransportEvent::InProcess(event));
642 })
643 .await?;
644 Ok(SessionPromptResult::InProcess(assistant))
645 }
646 Self::RpcSubprocess(client) => {
647 let events = client.prompt(input).await?;
648 for event in events.iter().cloned() {
649 (on_event)(SessionTransportEvent::Rpc(event));
650 }
651 Ok(SessionPromptResult::RpcEvents(events))
652 }
653 }
654 }
655
656 pub async fn state(&mut self) -> Result<SessionTransportState> {
658 match self {
659 Self::InProcess(handle) => handle.state().await.map(SessionTransportState::InProcess),
660 Self::RpcSubprocess(client) => client
661 .get_state()
662 .await
663 .map(Box::new)
664 .map(SessionTransportState::Rpc),
665 }
666 }
667
668 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
670 match self {
671 Self::InProcess(handle) => handle.set_model(provider, model_id).await,
672 Self::RpcSubprocess(client) => {
673 let _ = client.set_model(provider, model_id).await?;
674 Ok(())
675 }
676 }
677 }
678
679 pub fn shutdown(&mut self) -> Result<()> {
681 match self {
682 Self::InProcess(_) => Ok(()),
683 Self::RpcSubprocess(client) => client.shutdown(),
684 }
685 }
686}
687
688impl RpcTransportClient {
689 pub fn connect(options: RpcTransportOptions) -> Result<Self> {
690 let mut command = Command::new(&options.binary_path);
691 command
692 .args(&options.args)
693 .stdin(Stdio::piped())
694 .stdout(Stdio::piped())
695 .stderr(Stdio::inherit());
696 if let Some(cwd) = options.cwd {
697 command.current_dir(cwd);
698 }
699
700 let mut child = command.spawn().map_err(|err| {
701 Error::config(format!(
702 "Failed to spawn RPC subprocess {}: {err}",
703 options.binary_path.display()
704 ))
705 })?;
706 let stdin = child
707 .stdin
708 .take()
709 .ok_or_else(|| Error::config("RPC subprocess stdin is not piped"))?;
710 let stdout = child
711 .stdout
712 .take()
713 .ok_or_else(|| Error::config("RPC subprocess stdout is not piped"))?;
714
715 Ok(Self {
716 child,
717 stdin: BufWriter::new(stdin),
718 stdout: BufReader::new(stdout),
719 next_request_id: 1,
720 })
721 }
722
723 pub async fn request(&mut self, command: &str, payload: Map<String, Value>) -> Result<Value> {
724 let request_id = self.next_request_id();
725 let mut command_payload = Map::new();
726 command_payload.insert("type".to_string(), Value::String(command.to_string()));
727 command_payload.insert("id".to_string(), Value::String(request_id.clone()));
728 command_payload.extend(payload);
729
730 self.write_json_line(&Value::Object(command_payload))?;
731 self.wait_for_response(&request_id, command)
732 }
733
734 fn parse_response_data<T: DeserializeOwned>(data: Value, command: &str) -> Result<T> {
735 serde_json::from_value(data).map_err(|err| {
736 Error::api(format!(
737 "Failed to decode RPC `{command}` response payload: {err}"
738 ))
739 })
740 }
741
742 async fn request_typed<T: DeserializeOwned>(
743 &mut self,
744 command: &str,
745 payload: Map<String, Value>,
746 ) -> Result<T> {
747 let data = self.request(command, payload).await?;
748 Self::parse_response_data(data, command)
749 }
750
751 async fn request_no_data(&mut self, command: &str, payload: Map<String, Value>) -> Result<()> {
752 let _ = self.request(command, payload).await?;
753 Ok(())
754 }
755
756 pub async fn steer(&mut self, message: impl Into<String>) -> Result<()> {
757 let mut payload = Map::new();
758 payload.insert("message".to_string(), Value::String(message.into()));
759 self.request_no_data("steer", payload).await
760 }
761
762 pub async fn follow_up(&mut self, message: impl Into<String>) -> Result<()> {
763 let mut payload = Map::new();
764 payload.insert("message".to_string(), Value::String(message.into()));
765 self.request_no_data("follow_up", payload).await
766 }
767
768 pub async fn abort(&mut self) -> Result<()> {
769 self.request_no_data("abort", Map::new()).await
770 }
771
772 pub async fn new_session(
773 &mut self,
774 parent_session: Option<&Path>,
775 ) -> Result<RpcCancelledResult> {
776 let mut payload = Map::new();
777 if let Some(parent_session) = parent_session {
778 payload.insert(
779 "parentSession".to_string(),
780 Value::String(parent_session.display().to_string()),
781 );
782 }
783 self.request_typed("new_session", payload).await
784 }
785
786 pub async fn get_state(&mut self) -> Result<RpcSessionState> {
787 self.request_typed("get_state", Map::new()).await
788 }
789
790 pub async fn get_session_stats(&mut self) -> Result<RpcSessionStats> {
791 self.request_typed("get_session_stats", Map::new()).await
792 }
793
794 pub async fn get_messages(&mut self) -> Result<Vec<Value>> {
795 #[derive(Deserialize)]
796 struct MessagesPayload {
797 messages: Vec<Value>,
798 }
799 let payload: MessagesPayload = self.request_typed("get_messages", Map::new()).await?;
800 Ok(payload.messages)
801 }
802
803 pub async fn get_available_models(&mut self) -> Result<Vec<RpcModelInfo>> {
804 #[derive(Deserialize)]
805 struct ModelsPayload {
806 models: Vec<RpcModelInfo>,
807 }
808 let payload: ModelsPayload = self
809 .request_typed("get_available_models", Map::new())
810 .await?;
811 Ok(payload.models)
812 }
813
814 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<RpcModelInfo> {
815 let mut payload = Map::new();
816 payload.insert("provider".to_string(), Value::String(provider.to_string()));
817 payload.insert("modelId".to_string(), Value::String(model_id.to_string()));
818 self.request_typed("set_model", payload).await
819 }
820
821 pub async fn cycle_model(&mut self) -> Result<Option<RpcCycleModelResult>> {
822 self.request_typed("cycle_model", Map::new()).await
823 }
824
825 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
826 let mut payload = Map::new();
827 payload.insert("level".to_string(), Value::String(level.to_string()));
828 self.request_no_data("set_thinking_level", payload).await
829 }
830
831 pub async fn cycle_thinking_level(&mut self) -> Result<Option<RpcThinkingLevelResult>> {
832 self.request_typed("cycle_thinking_level", Map::new()).await
833 }
834
835 pub async fn set_steering_mode(&mut self, mode: &str) -> Result<()> {
836 let mut payload = Map::new();
837 payload.insert("mode".to_string(), Value::String(mode.to_string()));
838 self.request_no_data("set_steering_mode", payload).await
839 }
840
841 pub async fn set_follow_up_mode(&mut self, mode: &str) -> Result<()> {
842 let mut payload = Map::new();
843 payload.insert("mode".to_string(), Value::String(mode.to_string()));
844 self.request_no_data("set_follow_up_mode", payload).await
845 }
846
847 pub async fn set_auto_compaction(&mut self, enabled: bool) -> Result<()> {
848 let mut payload = Map::new();
849 payload.insert("enabled".to_string(), Value::Bool(enabled));
850 self.request_no_data("set_auto_compaction", payload).await
851 }
852
853 pub async fn set_auto_retry(&mut self, enabled: bool) -> Result<()> {
854 let mut payload = Map::new();
855 payload.insert("enabled".to_string(), Value::Bool(enabled));
856 self.request_no_data("set_auto_retry", payload).await
857 }
858
859 pub async fn abort_retry(&mut self) -> Result<()> {
860 self.request_no_data("abort_retry", Map::new()).await
861 }
862
863 pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
864 let mut payload = Map::new();
865 payload.insert("name".to_string(), Value::String(name.into()));
866 self.request_no_data("set_session_name", payload).await
867 }
868
869 pub async fn get_last_assistant_text(&mut self) -> Result<Option<String>> {
870 let payload: RpcLastAssistantText = self
871 .request_typed("get_last_assistant_text", Map::new())
872 .await?;
873 Ok(payload.text)
874 }
875
876 pub async fn export_html(&mut self, output_path: Option<&Path>) -> Result<RpcExportHtmlResult> {
877 let mut payload = Map::new();
878 if let Some(path) = output_path {
879 payload.insert(
880 "outputPath".to_string(),
881 Value::String(path.display().to_string()),
882 );
883 }
884 self.request_typed("export_html", payload).await
885 }
886
887 pub async fn bash(&mut self, command: impl Into<String>) -> Result<RpcBashResult> {
888 let mut payload = Map::new();
889 payload.insert("command".to_string(), Value::String(command.into()));
890 self.request_typed("bash", payload).await
891 }
892
893 pub async fn abort_bash(&mut self) -> Result<()> {
894 self.request_no_data("abort_bash", Map::new()).await
895 }
896
897 pub async fn compact(&mut self) -> Result<RpcCompactionResult> {
898 self.compact_with_instructions(None).await
899 }
900
901 pub async fn compact_with_instructions(
902 &mut self,
903 custom_instructions: Option<&str>,
904 ) -> Result<RpcCompactionResult> {
905 let mut payload = Map::new();
906 if let Some(custom_instructions) = custom_instructions {
907 payload.insert(
908 "customInstructions".to_string(),
909 Value::String(custom_instructions.to_string()),
910 );
911 }
912 self.request_typed("compact", payload).await
913 }
914
915 pub async fn switch_session(&mut self, session_path: &Path) -> Result<RpcCancelledResult> {
916 let mut payload = Map::new();
917 payload.insert(
918 "sessionPath".to_string(),
919 Value::String(session_path.display().to_string()),
920 );
921 self.request_typed("switch_session", payload).await
922 }
923
924 pub async fn fork(&mut self, entry_id: impl Into<String>) -> Result<RpcForkResult> {
925 let mut payload = Map::new();
926 payload.insert("entryId".to_string(), Value::String(entry_id.into()));
927 self.request_typed("fork", payload).await
928 }
929
930 pub async fn get_fork_messages(&mut self) -> Result<Vec<RpcForkMessage>> {
931 #[derive(Deserialize)]
932 struct ForkMessagesPayload {
933 messages: Vec<RpcForkMessage>,
934 }
935 let payload: ForkMessagesPayload =
936 self.request_typed("get_fork_messages", Map::new()).await?;
937 Ok(payload.messages)
938 }
939
940 pub async fn get_commands(&mut self) -> Result<Vec<RpcCommandInfo>> {
941 #[derive(Deserialize)]
942 struct CommandsPayload {
943 commands: Vec<RpcCommandInfo>,
944 }
945 let payload: CommandsPayload = self.request_typed("get_commands", Map::new()).await?;
946 Ok(payload.commands)
947 }
948
949 pub async fn extension_ui_response(
950 &mut self,
951 request_id: &str,
952 response: RpcExtensionUiResponse,
953 ) -> Result<bool> {
954 #[derive(Deserialize)]
955 struct ExtensionUiResolvedPayload {
956 resolved: bool,
957 }
958
959 let mut payload = Map::new();
960 payload.insert(
961 "requestId".to_string(),
962 Value::String(request_id.to_string()),
963 );
964
965 match response {
966 RpcExtensionUiResponse::Value { value } => {
967 payload.insert("value".to_string(), value);
968 }
969 RpcExtensionUiResponse::Confirmed { confirmed } => {
970 payload.insert("confirmed".to_string(), Value::Bool(confirmed));
971 }
972 RpcExtensionUiResponse::Cancelled => {
973 payload.insert("cancelled".to_string(), Value::Bool(true));
974 }
975 }
976
977 let response: Option<ExtensionUiResolvedPayload> =
978 self.request_typed("extension_ui_response", payload).await?;
979 Ok(response.is_none_or(|payload| payload.resolved))
980 }
981
982 pub async fn prompt(&mut self, message: impl Into<String>) -> Result<Vec<Value>> {
983 self.prompt_with_options(message, None, None).await
984 }
985
986 pub async fn prompt_with_options(
987 &mut self,
988 message: impl Into<String>,
989 images: Option<Vec<ImageContent>>,
990 streaming_behavior: Option<&str>,
991 ) -> Result<Vec<Value>> {
992 let request_id = self.next_request_id();
993 let mut payload = Map::new();
994 payload.insert("type".to_string(), Value::String("prompt".to_string()));
995 payload.insert("id".to_string(), Value::String(request_id.clone()));
996 payload.insert("message".to_string(), Value::String(message.into()));
997 if let Some(images) = images {
998 payload.insert(
999 "images".to_string(),
1000 serde_json::to_value(images).map_err(|err| Error::Json(Box::new(err)))?,
1001 );
1002 }
1003 if let Some(streaming_behavior) = streaming_behavior {
1004 payload.insert(
1005 "streamingBehavior".to_string(),
1006 Value::String(streaming_behavior.to_string()),
1007 );
1008 }
1009 let payload = Value::Object(payload);
1010 self.write_json_line(&payload)?;
1011
1012 let mut saw_ack = false;
1013 let mut events = Vec::new();
1014 loop {
1015 let item = self.read_json_line()?;
1016 let item_type = item.get("type").and_then(Value::as_str);
1017 if item_type == Some("response") {
1018 if item.get("id").and_then(Value::as_str) != Some(request_id.as_str()) {
1019 continue;
1020 }
1021 let success = item
1022 .get("success")
1023 .and_then(Value::as_bool)
1024 .unwrap_or(false);
1025 if !success {
1026 return Err(rpc_error_from_response(&item, "prompt"));
1027 }
1028 saw_ack = true;
1029 continue;
1030 }
1031
1032 if saw_ack {
1033 let reached_end = item_type == Some("agent_end");
1034 events.push(item);
1035 if reached_end {
1036 return Ok(events);
1037 }
1038 }
1039 }
1040 }
1041
1042 pub fn shutdown(&mut self) -> Result<()> {
1043 if self
1044 .child
1045 .try_wait()
1046 .map_err(|err| Error::Io(Box::new(err)))?
1047 .is_none()
1048 {
1049 self.child.kill().map_err(|err| Error::Io(Box::new(err)))?;
1050 }
1051 let _ = self.child.wait();
1052 Ok(())
1053 }
1054
1055 fn next_request_id(&mut self) -> String {
1056 let id = format!("rpc-{}", self.next_request_id);
1057 self.next_request_id = self.next_request_id.saturating_add(1);
1058 id
1059 }
1060
1061 fn write_json_line(&mut self, payload: &Value) -> Result<()> {
1062 let encoded = serde_json::to_string(payload).map_err(|err| Error::Json(Box::new(err)))?;
1063 self.stdin
1064 .write_all(encoded.as_bytes())
1065 .map_err(|err| Error::Io(Box::new(err)))?;
1066 self.stdin
1067 .write_all(b"\n")
1068 .map_err(|err| Error::Io(Box::new(err)))?;
1069 self.stdin.flush().map_err(|err| Error::Io(Box::new(err)))?;
1070 Ok(())
1071 }
1072
1073 fn read_json_line(&mut self) -> Result<Value> {
1074 let mut line = String::new();
1075 let read = self
1076 .stdout
1077 .read_line(&mut line)
1078 .map_err(|err| Error::Io(Box::new(err)))?;
1079 if read == 0 {
1080 return Err(Error::api(
1081 "RPC subprocess exited before sending a response",
1082 ));
1083 }
1084 serde_json::from_str(line.trim_end()).map_err(|err| Error::Json(Box::new(err)))
1085 }
1086
1087 fn wait_for_response(&mut self, request_id: &str, command: &str) -> Result<Value> {
1088 loop {
1089 let item = self.read_json_line()?;
1090 let Some(item_type) = item.get("type").and_then(Value::as_str) else {
1091 continue;
1092 };
1093 if item_type != "response" {
1094 continue;
1095 }
1096 if item.get("id").and_then(Value::as_str) != Some(request_id) {
1097 continue;
1098 }
1099 if item.get("command").and_then(Value::as_str) != Some(command) {
1100 continue;
1101 }
1102
1103 let success = item
1104 .get("success")
1105 .and_then(Value::as_bool)
1106 .unwrap_or(false);
1107 if success {
1108 return Ok(item.get("data").cloned().unwrap_or(Value::Null));
1109 }
1110 return Err(rpc_error_from_response(&item, command));
1111 }
1112 }
1113}
1114
1115impl Drop for RpcTransportClient {
1116 fn drop(&mut self) {
1117 let _ = self.shutdown();
1118 }
1119}
1120
1121fn rpc_error_from_response(response: &Value, command: &str) -> Error {
1122 let error = response
1123 .get("error")
1124 .and_then(Value::as_str)
1125 .unwrap_or("RPC command failed");
1126 Error::api(format!("RPC {command} failed: {error}"))
1127}
1128
1129impl AgentSessionHandle {
1130 pub const fn from_session_with_listeners(
1135 session: AgentSession,
1136 listeners: EventListeners,
1137 ) -> Self {
1138 Self { session, listeners }
1139 }
1140
1141 pub async fn prompt(
1147 &mut self,
1148 input: impl Into<String>,
1149 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1150 ) -> Result<AssistantMessage> {
1151 let combined = self.make_combined_callback(on_event);
1152 self.session.run_text(input.into(), combined).await
1153 }
1154
1155 pub async fn prompt_with_abort(
1157 &mut self,
1158 input: impl Into<String>,
1159 abort_signal: AbortSignal,
1160 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1161 ) -> Result<AssistantMessage> {
1162 let combined = self.make_combined_callback(on_event);
1163 self.session
1164 .run_text_with_abort(input.into(), Some(abort_signal), combined)
1165 .await
1166 }
1167
1168 pub async fn continue_turn(
1174 &mut self,
1175 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1176 ) -> Result<AssistantMessage> {
1177 let combined = self.make_combined_callback(on_event);
1178 self.session
1179 .sync_runtime_selection_from_session_header()
1180 .await?;
1181 self.session
1182 .agent
1183 .run_continue_with_abort(None, combined)
1184 .await
1185 }
1186
1187 pub async fn continue_turn_with_abort(
1189 &mut self,
1190 abort_signal: AbortSignal,
1191 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1192 ) -> Result<AssistantMessage> {
1193 let combined = self.make_combined_callback(on_event);
1194 self.session
1195 .sync_runtime_selection_from_session_header()
1196 .await?;
1197 self.session
1198 .agent
1199 .run_continue_with_abort(Some(abort_signal), combined)
1200 .await
1201 }
1202
1203 pub fn new_abort_handle() -> (AbortHandle, AbortSignal) {
1205 AbortHandle::new()
1206 }
1207
1208 pub fn subscribe(
1215 &self,
1216 listener: impl Fn(AgentEvent) + Send + Sync + 'static,
1217 ) -> SubscriptionId {
1218 self.listeners.subscribe(Arc::new(listener))
1219 }
1220
1221 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
1225 self.listeners.unsubscribe(id)
1226 }
1227
1228 pub const fn listeners(&self) -> &EventListeners {
1230 &self.listeners
1231 }
1232
1233 pub const fn listeners_mut(&mut self) -> &mut EventListeners {
1238 &mut self.listeners
1239 }
1240
1241 pub const fn has_extensions(&self) -> bool {
1247 self.session.extensions.is_some()
1248 }
1249
1250 pub fn extension_manager(&self) -> Option<&ExtensionManager> {
1252 self.session
1253 .extensions
1254 .as_ref()
1255 .map(ExtensionRegion::manager)
1256 }
1257
1258 pub const fn extension_region(&self) -> Option<&ExtensionRegion> {
1262 self.session.extensions.as_ref()
1263 }
1264
1265 pub fn model(&self) -> (String, String) {
1271 let provider = self.session.agent.provider();
1272 (provider.name().to_string(), provider.model_id().to_string())
1273 }
1274
1275 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
1277 self.session.set_provider_model(provider, model_id).await
1278 }
1279
1280 pub const fn thinking_level(&self) -> Option<crate::model::ThinkingLevel> {
1282 self.session.agent.stream_options().thinking_level
1283 }
1284
1285 pub const fn thinking(&self) -> Option<crate::model::ThinkingLevel> {
1287 self.thinking_level()
1288 }
1289
1290 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
1292 let cx = crate::agent_cx::AgentCx::for_request();
1293 let (effective_level, changed) = {
1294 let mut guard = self
1295 .session
1296 .session
1297 .lock(cx.cx())
1298 .await
1299 .map_err(|e| Error::session(e.to_string()))?;
1300 let (provider_id, model_id) = guard
1301 .effective_model_for_current_path()
1302 .unwrap_or_else(|| self.model());
1303 let effective_level =
1304 self.session
1305 .clamp_thinking_level_for_model(&provider_id, &model_id, level);
1306 let level_string = effective_level.to_string();
1307 let changed = guard.effective_thinking_level_for_current_path().as_deref()
1308 != Some(level_string.as_str());
1309 guard.set_model_header(None, None, Some(level_string.clone()));
1310 if changed {
1311 guard.append_thinking_level_change(level_string);
1312 }
1313 (effective_level, changed)
1314 };
1315 self.session.agent.stream_options_mut().thinking_level = Some(effective_level);
1316 if changed {
1317 self.session.persist_session().await
1318 } else {
1319 Ok(())
1320 }
1321 }
1322
1323 pub async fn messages(&self) -> Result<Vec<Message>> {
1325 let cx = crate::agent_cx::AgentCx::for_request();
1326 let guard = self
1327 .session
1328 .session
1329 .lock(cx.cx())
1330 .await
1331 .map_err(|e| Error::session(e.to_string()))?;
1332 Ok(guard.to_messages_for_current_path())
1333 }
1334
1335 pub async fn state(&self) -> Result<AgentSessionState> {
1337 let (provider, model_id) = self.model();
1338 let thinking_level = self.thinking_level();
1339 let save_enabled = self.session.save_enabled();
1340 let cx = crate::agent_cx::AgentCx::for_request();
1341 let guard = self
1342 .session
1343 .session
1344 .lock(cx.cx())
1345 .await
1346 .map_err(|e| Error::session(e.to_string()))?;
1347 let session_id = Some(guard.header.id.clone());
1348 let message_count = guard.to_messages_for_current_path().len();
1349
1350 Ok(AgentSessionState {
1351 session_id,
1352 provider,
1353 model_id,
1354 thinking_level,
1355 save_enabled,
1356 message_count,
1357 })
1358 }
1359
1360 pub async fn compact(
1362 &mut self,
1363 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1364 ) -> Result<()> {
1365 self.session.compact_now(on_event).await
1366 }
1367
1368 pub const fn session(&self) -> &AgentSession {
1370 &self.session
1371 }
1372
1373 pub const fn session_mut(&mut self) -> &mut AgentSession {
1375 &mut self.session
1376 }
1377
1378 pub fn into_inner(self) -> AgentSession {
1380 self.session
1381 }
1382
1383 fn make_combined_callback(
1386 &self,
1387 per_prompt: impl Fn(AgentEvent) + Send + Sync + 'static,
1388 ) -> impl Fn(AgentEvent) + Send + Sync + 'static {
1389 let listeners = self.listeners.clone();
1390 move |event: AgentEvent| {
1391 match &event {
1393 AgentEvent::ToolExecutionStart {
1394 tool_name, args, ..
1395 } => {
1396 listeners.notify_tool_start(tool_name, args);
1397 }
1398 AgentEvent::ToolExecutionEnd {
1399 tool_name,
1400 result,
1401 is_error,
1402 ..
1403 } => {
1404 listeners.notify_tool_end(tool_name, result, *is_error);
1405 }
1406 AgentEvent::MessageUpdate {
1407 assistant_message_event,
1408 ..
1409 } => {
1410 if let Some(stream_ev) =
1413 stream_event_from_assistant_message_event(assistant_message_event)
1414 {
1415 listeners.notify_stream_event(&stream_ev);
1416 }
1417 }
1418 _ => {}
1419 }
1420
1421 listeners.notify(&event);
1423
1424 per_prompt(event);
1426 }
1427 }
1428}
1429
1430fn stream_event_from_assistant_message_event(
1435 event: &crate::model::AssistantMessageEvent,
1436) -> Option<StreamEvent> {
1437 use crate::model::AssistantMessageEvent as AME;
1438 match event {
1439 AME::TextStart { content_index, .. } => Some(StreamEvent::TextStart {
1440 content_index: *content_index,
1441 }),
1442 AME::TextDelta {
1443 content_index,
1444 delta,
1445 ..
1446 } => Some(StreamEvent::TextDelta {
1447 content_index: *content_index,
1448 delta: delta.clone(),
1449 }),
1450 AME::TextEnd {
1451 content_index,
1452 content,
1453 ..
1454 } => Some(StreamEvent::TextEnd {
1455 content_index: *content_index,
1456 content: content.clone(),
1457 }),
1458 AME::ThinkingStart { content_index, .. } => Some(StreamEvent::ThinkingStart {
1459 content_index: *content_index,
1460 }),
1461 AME::ThinkingDelta {
1462 content_index,
1463 delta,
1464 ..
1465 } => Some(StreamEvent::ThinkingDelta {
1466 content_index: *content_index,
1467 delta: delta.clone(),
1468 }),
1469 AME::ThinkingEnd {
1470 content_index,
1471 content,
1472 ..
1473 } => Some(StreamEvent::ThinkingEnd {
1474 content_index: *content_index,
1475 content: content.clone(),
1476 }),
1477 AME::ToolCallStart { content_index, .. } => Some(StreamEvent::ToolCallStart {
1478 content_index: *content_index,
1479 }),
1480 AME::ToolCallDelta {
1481 content_index,
1482 delta,
1483 ..
1484 } => Some(StreamEvent::ToolCallDelta {
1485 content_index: *content_index,
1486 delta: delta.clone(),
1487 }),
1488 AME::ToolCallEnd {
1489 content_index,
1490 tool_call,
1491 ..
1492 } => Some(StreamEvent::ToolCallEnd {
1493 content_index: *content_index,
1494 tool_call: tool_call.clone(),
1495 }),
1496 AME::Done { reason, message } => Some(StreamEvent::Done {
1497 reason: *reason,
1498 message: (**message).clone(),
1499 }),
1500 AME::Error { reason, error } => Some(StreamEvent::Error {
1501 reason: *reason,
1502 error: (**error).clone(),
1503 }),
1504 AME::Start { .. } => None,
1505 }
1506}
1507
1508fn resolve_path_for_cwd(path: &Path, cwd: &Path) -> PathBuf {
1509 if path.is_absolute() {
1510 path.to_path_buf()
1511 } else {
1512 cwd.join(path)
1513 }
1514}
1515
1516fn build_stream_options_with_optional_key(
1517 config: &Config,
1518 api_key: Option<String>,
1519 selection: &app::ModelSelection,
1520 session: &Session,
1521) -> StreamOptions {
1522 let mut options = StreamOptions {
1523 api_key,
1524 headers: selection.model_entry.headers.clone(),
1525 session_id: Some(session.header.id.clone()),
1526 thinking_level: Some(selection.thinking_level),
1527 ..Default::default()
1528 };
1529
1530 if let Some(budgets) = &config.thinking_budgets {
1531 let defaults = ThinkingBudgets::default();
1532 options.thinking_budgets = Some(ThinkingBudgets {
1533 minimal: budgets.minimal.unwrap_or(defaults.minimal),
1534 low: budgets.low.unwrap_or(defaults.low),
1535 medium: budgets.medium.unwrap_or(defaults.medium),
1536 high: budgets.high.unwrap_or(defaults.high),
1537 xhigh: budgets.xhigh.unwrap_or(defaults.xhigh),
1538 });
1539 }
1540
1541 options
1542}
1543
1544#[allow(clippy::too_many_lines)]
1549pub async fn create_agent_session(options: SessionOptions) -> Result<AgentSessionHandle> {
1550 let process_cwd =
1551 std::env::current_dir().map_err(|e| Error::config(format!("cwd lookup failed: {e}")))?;
1552 let cwd = options.working_directory.as_deref().map_or_else(
1553 || process_cwd.clone(),
1554 |path| resolve_path_for_cwd(path, &process_cwd),
1555 );
1556
1557 let mut cli = Cli::try_parse_from(["pi"])
1558 .map_err(|e| Error::validation(format!("CLI init failed: {e}")))?;
1559 cli.no_session = options.no_session;
1560 cli.provider = options.provider.clone();
1561 cli.model = options.model.clone();
1562 cli.api_key = options.api_key.clone();
1563 cli.system_prompt = options.system_prompt.clone();
1564 cli.append_system_prompt = options.append_system_prompt.clone();
1565 cli.hide_cwd_in_prompt = !options.include_cwd_in_prompt;
1566 cli.thinking = options.thinking.map(|t| t.to_string());
1567 cli.session = options
1568 .session_path
1569 .as_ref()
1570 .map(|p| p.to_string_lossy().to_string());
1571 cli.session_dir = options
1572 .session_dir
1573 .as_ref()
1574 .map(|p| p.to_string_lossy().to_string());
1575 if let Some(enabled_tools) = &options.enabled_tools {
1576 if enabled_tools.is_empty() {
1577 cli.no_tools = true;
1578 } else {
1579 cli.no_tools = false;
1580 cli.tools = enabled_tools.join(",");
1581 }
1582 }
1583
1584 let config = Config::load()?;
1585
1586 let mut auth = AuthStorage::load_async(Config::auth_path()).await?;
1587 auth.refresh_expired_oauth_tokens().await?;
1588
1589 let global_dir = Config::global_dir();
1590 let package_dir = Config::package_dir();
1591 let models_path = default_models_path(&global_dir);
1592 let model_registry = ModelRegistry::load(&auth, Some(models_path));
1593
1594 let mut session = Session::new(&cli, &config).await?;
1595 let scoped_patterns = if let Some(models_arg) = &cli.models {
1596 app::parse_models_arg(models_arg)
1597 } else {
1598 config.enabled_models.clone().unwrap_or_default()
1599 };
1600 let scoped_models = if scoped_patterns.is_empty() {
1601 Vec::new()
1602 } else {
1603 app::resolve_model_scope(&scoped_patterns, &model_registry, cli.api_key.is_some())
1604 };
1605
1606 let selection = app::select_model_and_thinking(
1607 &cli,
1608 &config,
1609 &session,
1610 &model_registry,
1611 &scoped_models,
1612 &global_dir,
1613 )
1614 .map_err(|err| Error::validation(err.to_string()))?;
1615 app::update_session_for_selection(&mut session, &selection);
1616
1617 let enabled_tools_owned = cli
1618 .enabled_tools()
1619 .into_iter()
1620 .map(str::to_string)
1621 .collect::<Vec<_>>();
1622 let enabled_tools = enabled_tools_owned
1623 .iter()
1624 .map(String::as_str)
1625 .collect::<Vec<_>>();
1626
1627 let system_prompt = app::build_system_prompt(
1628 &cli,
1629 &cwd,
1630 &enabled_tools,
1631 None,
1632 &global_dir,
1633 &package_dir,
1634 std::env::var_os("PI_TEST_MODE").is_some(),
1635 options.include_cwd_in_prompt,
1636 )
1637 .map_err(|err| Error::validation(err.to_string()))?;
1638
1639 let provider = providers::create_provider(&selection.model_entry, None)
1640 .map_err(|e| Error::provider("sdk", e.to_string()))?;
1641
1642 let api_key = app::resolve_api_key(&auth, &cli, &selection.model_entry)
1643 .map_err(|err| Error::validation(err.to_string()))?;
1644
1645 let stream_options =
1646 build_stream_options_with_optional_key(&config, api_key, &selection, &session);
1647
1648 let agent_config = AgentConfig {
1649 system_prompt: Some(system_prompt),
1650 max_tool_iterations: options.max_tool_iterations,
1651 stream_options,
1652 block_images: config.image_block_images(),
1653 fail_closed_hooks: config.fail_closed_hooks(),
1654 };
1655
1656 let tools = ToolRegistry::new(&enabled_tools, &cwd, Some(&config));
1657 let session_arc = Arc::new(asupersync::sync::Mutex::new(session));
1658
1659 let context_window_tokens = if selection.model_entry.model.context_window == 0 {
1660 ResolvedCompactionSettings::default().context_window_tokens
1661 } else {
1662 selection.model_entry.model.context_window
1663 };
1664 let compaction_settings = ResolvedCompactionSettings {
1665 enabled: config.compaction_enabled(),
1666 reserve_tokens: config.compaction_reserve_tokens(),
1667 keep_recent_tokens: config.compaction_keep_recent_tokens(),
1668 context_window_tokens,
1669 };
1670
1671 let mut agent_session = AgentSession::new(
1672 Agent::new(provider, tools, agent_config),
1673 Arc::clone(&session_arc),
1674 !cli.no_session,
1675 compaction_settings,
1676 );
1677
1678 if !options.extension_paths.is_empty() {
1679 let extension_paths = options
1680 .extension_paths
1681 .iter()
1682 .map(|path| resolve_path_for_cwd(path, &cwd))
1683 .collect::<Vec<_>>();
1684 let resolved_ext_policy =
1685 config.resolve_extension_policy_with_metadata(options.extension_policy.as_deref());
1686 let resolved_repair_policy =
1687 config.resolve_repair_policy_with_metadata(options.repair_policy.as_deref());
1688
1689 agent_session
1690 .enable_extensions_with_policy(
1691 &enabled_tools,
1692 &cwd,
1693 Some(&config),
1694 &extension_paths,
1695 Some(resolved_ext_policy.policy),
1696 Some(resolved_repair_policy.effective_mode),
1697 None,
1698 )
1699 .await?;
1700 }
1701
1702 agent_session.set_model_registry(model_registry.clone());
1703 agent_session.set_auth_storage(auth);
1704
1705 let history = {
1706 let cx = crate::agent_cx::AgentCx::for_request();
1707 let guard = session_arc
1708 .lock(cx.cx())
1709 .await
1710 .map_err(|e| Error::session(e.to_string()))?;
1711 guard.to_messages_for_current_path()
1712 };
1713 if !history.is_empty() {
1714 agent_session.agent.replace_messages(history);
1715 }
1716
1717 let mut listeners = EventListeners::new();
1718 if let Some(on_event) = options.on_event {
1719 listeners.subscribe(on_event);
1720 }
1721 listeners.on_tool_start = options.on_tool_start;
1722 listeners.on_tool_end = options.on_tool_end;
1723 listeners.on_stream_event = options.on_stream_event;
1724
1725 Ok(AgentSessionHandle {
1726 session: agent_session,
1727 listeners,
1728 })
1729}
1730
1731#[cfg(test)]
1732mod tests {
1733 use super::*;
1734 use asupersync::runtime::RuntimeBuilder;
1735 use asupersync::runtime::reactor::create_reactor;
1736 use asupersync::sync::Mutex as AsyncMutex;
1737 use std::sync::{Arc, Mutex};
1738 use tempfile::tempdir;
1739
1740 fn run_async<F>(future: F) -> F::Output
1741 where
1742 F: std::future::Future,
1743 {
1744 let reactor = create_reactor().expect("create reactor");
1745 let runtime = RuntimeBuilder::current_thread()
1746 .with_reactor(reactor)
1747 .build()
1748 .expect("build runtime");
1749 runtime.block_on(future)
1750 }
1751
1752 #[test]
1753 fn create_agent_session_default_succeeds() {
1754 let tmp = tempdir().expect("tempdir");
1755 let options = SessionOptions {
1756 working_directory: Some(tmp.path().to_path_buf()),
1757 no_session: true,
1758 ..SessionOptions::default()
1759 };
1760
1761 let handle = run_async(create_agent_session(options)).expect("create session");
1762 let provider = handle.session().agent.provider();
1763 assert!(!provider.name().is_empty());
1764 assert!(!provider.model_id().is_empty());
1765 assert_eq!(handle.model().0, provider.name());
1766 assert_eq!(handle.model().1, provider.model_id());
1767 }
1768
1769 #[test]
1770 fn create_agent_session_respects_provider_model_and_clamps_thinking() {
1771 let tmp = tempdir().expect("tempdir");
1772 let options = SessionOptions {
1773 provider: Some("openai".to_string()),
1774 model: Some("gpt-4o".to_string()),
1775 api_key: Some("dummy-key".to_string()),
1776 thinking: Some(crate::model::ThinkingLevel::Low),
1777 working_directory: Some(tmp.path().to_path_buf()),
1778 no_session: true,
1779 ..SessionOptions::default()
1780 };
1781
1782 let handle = run_async(create_agent_session(options)).expect("create session");
1783 let provider = handle.session().agent.provider();
1784 assert_eq!(provider.name(), "openai");
1785 assert_eq!(provider.model_id(), "gpt-4o");
1786 assert_eq!(
1787 handle.session().agent.stream_options().thinking_level,
1788 Some(crate::model::ThinkingLevel::Off)
1789 );
1790 }
1791
1792 #[test]
1793 fn create_agent_session_no_session_keeps_ephemeral_state() {
1794 let tmp = tempdir().expect("tempdir");
1795 let options = SessionOptions {
1796 working_directory: Some(tmp.path().to_path_buf()),
1797 no_session: true,
1798 ..SessionOptions::default()
1799 };
1800
1801 let handle = run_async(create_agent_session(options)).expect("create session");
1802 assert!(!handle.session().save_enabled());
1803
1804 let path_is_none = run_async(async {
1805 let cx = crate::agent_cx::AgentCx::for_request();
1806 let guard = handle
1807 .session()
1808 .session
1809 .lock(cx.cx())
1810 .await
1811 .expect("lock session");
1812 guard.path.is_none()
1813 });
1814 assert!(path_is_none);
1815 }
1816
1817 #[test]
1818 fn from_session_with_listeners_set_model_switches_provider_model() {
1819 let dir = tempdir().expect("tempdir");
1820 let auth_path = dir.path().join("auth.json");
1821 let mut auth = AuthStorage::load(auth_path).expect("load auth");
1822 auth.set(
1823 "anthropic",
1824 crate::auth::AuthCredential::ApiKey {
1825 key: "anthropic-key".to_string(),
1826 },
1827 );
1828 auth.set(
1829 "openai",
1830 crate::auth::AuthCredential::ApiKey {
1831 key: "openai-key".to_string(),
1832 },
1833 );
1834
1835 let registry = ModelRegistry::load(&auth, None);
1836 let entry = registry
1837 .find("anthropic", "claude-sonnet-4-5")
1838 .expect("anthropic model in registry");
1839 let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
1840 let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
1841 let agent = Agent::new(
1842 provider,
1843 tools,
1844 AgentConfig {
1845 system_prompt: None,
1846 max_tool_iterations: 50,
1847 stream_options: StreamOptions::default(),
1848 block_images: false,
1849 fail_closed_hooks: false,
1850 },
1851 );
1852
1853 let mut session = Session::in_memory();
1854 session.header.provider = Some("anthropic".to_string());
1855 session.header.model_id = Some("claude-sonnet-4-5".to_string());
1856
1857 let mut agent_session = AgentSession::new(
1858 agent,
1859 Arc::new(AsyncMutex::new(session)),
1860 false,
1861 ResolvedCompactionSettings::default(),
1862 );
1863 agent_session.set_model_registry(registry);
1864 agent_session.set_auth_storage(auth);
1865
1866 let mut handle =
1867 AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
1868 run_async(handle.set_model("openai", "gpt-4o")).expect("set model");
1869 let provider = handle.session().agent.provider();
1870 assert_eq!(provider.name(), "openai");
1871 assert_eq!(provider.model_id(), "gpt-4o");
1872 }
1873
1874 #[test]
1875 fn create_agent_session_set_thinking_level_clamps_and_dedupes_history() {
1876 let tmp = tempdir().expect("tempdir");
1877 let options = SessionOptions {
1878 provider: Some("openai".to_string()),
1879 model: Some("gpt-4o".to_string()),
1880 api_key: Some("dummy-key".to_string()),
1881 working_directory: Some(tmp.path().to_path_buf()),
1882 no_session: true,
1883 ..SessionOptions::default()
1884 };
1885
1886 let mut handle = run_async(create_agent_session(options)).expect("create session");
1887 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1888 .expect("set thinking");
1889 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1890 .expect("reapply thinking");
1891
1892 assert_eq!(
1893 handle.session().agent.stream_options().thinking_level,
1894 Some(crate::model::ThinkingLevel::Off)
1895 );
1896
1897 let thinking_changes = run_async(async {
1898 let cx = crate::agent_cx::AgentCx::for_request();
1899 let guard = handle
1900 .session()
1901 .session
1902 .lock(cx.cx())
1903 .await
1904 .expect("lock session");
1905 assert_eq!(guard.header.thinking_level.as_deref(), Some("off"));
1906 guard
1907 .entries_for_current_path()
1908 .iter()
1909 .filter(|entry| {
1910 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
1911 })
1912 .count()
1913 });
1914 assert_eq!(thinking_changes, 1);
1915 }
1916
1917 #[test]
1918 fn from_session_with_listeners_set_thinking_level_uses_session_header_target() {
1919 let dir = tempdir().expect("tempdir");
1920 let auth_path = dir.path().join("auth.json");
1921 let auth = crate::auth::AuthStorage::load(auth_path).expect("load auth");
1922 let mut registry = ModelRegistry::load(&auth, None);
1923 registry.merge_entries(vec![ModelEntry {
1924 model: Model {
1925 id: "plain-model".to_string(),
1926 name: "Plain Model".to_string(),
1927 api: "openai-completions".to_string(),
1928 provider: "acme".to_string(),
1929 base_url: "https://example.invalid/v1".to_string(),
1930 reasoning: false,
1931 input: vec![InputType::Text],
1932 cost: ModelCost {
1933 input: 0.0,
1934 output: 0.0,
1935 cache_read: 0.0,
1936 cache_write: 0.0,
1937 },
1938 context_window: 128_000,
1939 max_tokens: 8_192,
1940 headers: HashMap::new(),
1941 },
1942 api_key: None,
1943 headers: HashMap::new(),
1944 auth_header: false,
1945 compat: None,
1946 oauth_config: None,
1947 }]);
1948 let entry = registry
1949 .find("anthropic", "claude-sonnet-4-5")
1950 .expect("anthropic model in registry");
1951 let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
1952 let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
1953 let agent = Agent::new(
1954 provider,
1955 tools,
1956 AgentConfig {
1957 system_prompt: None,
1958 max_tool_iterations: 50,
1959 stream_options: StreamOptions::default(),
1960 block_images: false,
1961 fail_closed_hooks: false,
1962 },
1963 );
1964
1965 let mut session = Session::in_memory();
1966 session.header.provider = Some("acme".to_string());
1967 session.header.model_id = Some("plain-model".to_string());
1968
1969 let mut agent_session = AgentSession::new(
1970 agent,
1971 Arc::new(AsyncMutex::new(session)),
1972 false,
1973 ResolvedCompactionSettings::default(),
1974 );
1975 agent_session.set_model_registry(registry);
1976
1977 let mut handle =
1978 AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
1979 run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
1980 .expect("set thinking");
1981
1982 assert_eq!(
1983 handle.session().agent.stream_options().thinking_level,
1984 Some(crate::model::ThinkingLevel::Off)
1985 );
1986 assert_eq!(handle.model().0, "anthropic");
1987 assert_eq!(handle.model().1, "claude-sonnet-4-5");
1988 }
1989
1990 #[test]
1991 fn compact_without_history_is_noop() {
1992 let tmp = tempdir().expect("tempdir");
1993 let options = SessionOptions {
1994 working_directory: Some(tmp.path().to_path_buf()),
1995 no_session: true,
1996 ..SessionOptions::default()
1997 };
1998
1999 let mut handle = run_async(create_agent_session(options)).expect("create session");
2000 let events = Arc::new(Mutex::new(Vec::new()));
2001 let events_for_callback = Arc::clone(&events);
2002 run_async(handle.compact(move |event| {
2003 events_for_callback
2004 .lock()
2005 .expect("compact callback lock")
2006 .push(event);
2007 }))
2008 .expect("compact");
2009
2010 assert!(
2011 events
2012 .lock()
2013 .unwrap_or_else(std::sync::PoisonError::into_inner)
2014 .is_empty(),
2015 "expected no compaction lifecycle events for empty session"
2016 );
2017 }
2018
2019 #[test]
2020 fn resolve_path_for_cwd_uses_cwd_for_relative_paths() {
2021 let cwd = Path::new("/tmp/pi-sdk-cwd");
2022 assert_eq!(
2023 resolve_path_for_cwd(Path::new("relative/file.txt"), cwd),
2024 PathBuf::from("/tmp/pi-sdk-cwd/relative/file.txt")
2025 );
2026 assert_eq!(
2027 resolve_path_for_cwd(Path::new("/etc/hosts"), cwd),
2028 PathBuf::from("/etc/hosts")
2029 );
2030 }
2031
2032 #[test]
2037 fn event_listeners_subscribe_and_notify() {
2038 let listeners = EventListeners::new();
2039 let received = Arc::new(Mutex::new(Vec::new()));
2040
2041 let recv_clone = Arc::clone(&received);
2042 let id = listeners.subscribe(Arc::new(move |event| {
2043 recv_clone
2044 .lock()
2045 .unwrap_or_else(std::sync::PoisonError::into_inner)
2046 .push(event);
2047 }));
2048
2049 let event = AgentEvent::AgentStart {
2050 session_id: "test-123".into(),
2051 };
2052 listeners.notify(&event);
2053
2054 let events = received
2055 .lock()
2056 .unwrap_or_else(std::sync::PoisonError::into_inner);
2057 assert_eq!(events.len(), 1);
2058
2059 drop(events);
2061 assert!(listeners.unsubscribe(id));
2062 listeners.notify(&AgentEvent::AgentStart {
2063 session_id: "test-456".into(),
2064 });
2065 assert_eq!(
2066 received
2067 .lock()
2068 .unwrap_or_else(std::sync::PoisonError::into_inner)
2069 .len(),
2070 1
2071 );
2072 }
2073
2074 #[test]
2075 fn event_listeners_unsubscribe_nonexistent_returns_false() {
2076 let listeners = EventListeners::new();
2077 assert!(!listeners.unsubscribe(SubscriptionId(999)));
2078 }
2079
2080 #[test]
2081 fn event_listeners_multiple_subscribers() {
2082 let listeners = EventListeners::new();
2083 let count_a = Arc::new(Mutex::new(0u32));
2084 let count_b = Arc::new(Mutex::new(0u32));
2085
2086 let ca = Arc::clone(&count_a);
2087 listeners.subscribe(Arc::new(move |_| {
2088 *ca.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2089 }));
2090
2091 let cb = Arc::clone(&count_b);
2092 listeners.subscribe(Arc::new(move |_| {
2093 *cb.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
2094 }));
2095
2096 listeners.notify(&AgentEvent::AgentStart {
2097 session_id: "s".into(),
2098 });
2099
2100 assert_eq!(
2101 *count_a
2102 .lock()
2103 .unwrap_or_else(std::sync::PoisonError::into_inner),
2104 1
2105 );
2106 assert_eq!(
2107 *count_b
2108 .lock()
2109 .unwrap_or_else(std::sync::PoisonError::into_inner),
2110 1
2111 );
2112 }
2113
2114 #[test]
2115 fn event_listeners_tool_hooks_fire() {
2116 let listeners = EventListeners::new();
2117 let starts = Arc::new(Mutex::new(Vec::new()));
2118 let ends = Arc::new(Mutex::new(Vec::new()));
2119
2120 let s = Arc::clone(&starts);
2121 let mut listeners = listeners;
2122 listeners.on_tool_start = Some(Arc::new(move |name, args| {
2123 s.lock()
2124 .expect("lock")
2125 .push((name.to_string(), args.clone()));
2126 }));
2127
2128 let e = Arc::clone(&ends);
2129 listeners.on_tool_end = Some(Arc::new(move |name, _output, is_error| {
2130 e.lock()
2131 .unwrap_or_else(std::sync::PoisonError::into_inner)
2132 .push((name.to_string(), is_error));
2133 }));
2134
2135 let args = serde_json::json!({"path": "/foo"});
2136 listeners.notify_tool_start("bash", &args);
2137 let output = ToolOutput {
2138 content: vec![ContentBlock::Text(TextContent::new("ok"))],
2139 details: None,
2140 is_error: false,
2141 };
2142 listeners.notify_tool_end("bash", &output, false);
2143
2144 {
2145 let s = starts
2146 .lock()
2147 .unwrap_or_else(std::sync::PoisonError::into_inner);
2148 assert_eq!(s.len(), 1);
2149 assert_eq!(s[0].0, "bash");
2150 drop(s);
2151 }
2152
2153 {
2154 let e = ends
2155 .lock()
2156 .unwrap_or_else(std::sync::PoisonError::into_inner);
2157 assert_eq!(e.len(), 1);
2158 assert_eq!(e[0].0, "bash");
2159 assert!(!e[0].1);
2160 drop(e);
2161 }
2162 }
2163
2164 #[test]
2165 fn event_listeners_stream_event_hook_fires() {
2166 let mut listeners = EventListeners::new();
2167 let received = Arc::new(Mutex::new(Vec::new()));
2168
2169 let r = Arc::clone(&received);
2170 listeners.on_stream_event = Some(Arc::new(move |ev| {
2171 r.lock()
2172 .unwrap_or_else(std::sync::PoisonError::into_inner)
2173 .push(format!("{ev:?}"));
2174 }));
2175
2176 let event = StreamEvent::TextDelta {
2177 content_index: 0,
2178 delta: "hello".to_string(),
2179 };
2180 listeners.notify_stream_event(&event);
2181
2182 assert_eq!(
2183 received
2184 .lock()
2185 .unwrap_or_else(std::sync::PoisonError::into_inner)
2186 .len(),
2187 1
2188 );
2189 }
2190
2191 #[test]
2192 fn session_options_on_event_wired_into_listeners() {
2193 let received = Arc::new(Mutex::new(Vec::new()));
2194 let r = Arc::clone(&received);
2195 let tmp = tempdir().expect("tempdir");
2196
2197 let options = SessionOptions {
2198 working_directory: Some(tmp.path().to_path_buf()),
2199 no_session: true,
2200 on_event: Some(Arc::new(move |event| {
2201 r.lock()
2202 .unwrap_or_else(std::sync::PoisonError::into_inner)
2203 .push(format!("{event:?}"));
2204 })),
2205 ..SessionOptions::default()
2206 };
2207
2208 let handle = run_async(create_agent_session(options)).expect("create session");
2209 let count = handle
2211 .listeners()
2212 .subscribers
2213 .lock()
2214 .unwrap_or_else(std::sync::PoisonError::into_inner)
2215 .len();
2216 assert_eq!(
2217 count, 1,
2218 "on_event from SessionOptions should register one subscriber"
2219 );
2220 }
2221
2222 #[test]
2223 fn subscribe_unsubscribe_on_handle() {
2224 let tmp = tempdir().expect("tempdir");
2225 let options = SessionOptions {
2226 working_directory: Some(tmp.path().to_path_buf()),
2227 no_session: true,
2228 ..SessionOptions::default()
2229 };
2230
2231 let handle = run_async(create_agent_session(options)).expect("create session");
2232 let id = handle.subscribe(|_event| {});
2233 assert_eq!(
2234 handle
2235 .listeners()
2236 .subscribers
2237 .lock()
2238 .unwrap_or_else(std::sync::PoisonError::into_inner)
2239 .len(),
2240 1
2241 );
2242
2243 assert!(handle.unsubscribe(id));
2244 assert_eq!(
2245 handle
2246 .listeners()
2247 .subscribers
2248 .lock()
2249 .unwrap_or_else(std::sync::PoisonError::into_inner)
2250 .len(),
2251 0
2252 );
2253
2254 assert!(!handle.unsubscribe(id));
2256 }
2257
2258 #[test]
2259 fn stream_event_from_assistant_message_event_converts_text_delta() {
2260 use crate::model::AssistantMessageEvent as AME;
2261
2262 let partial = Arc::new(AssistantMessage {
2263 content: Vec::new(),
2264 api: String::new(),
2265 provider: String::new(),
2266 model: String::new(),
2267 usage: Usage::default(),
2268 stop_reason: StopReason::Stop,
2269 error_message: None,
2270 timestamp: 0,
2271 });
2272 let ame = AME::TextDelta {
2273 content_index: 2,
2274 delta: "chunk".to_string(),
2275 partial,
2276 };
2277 let result = stream_event_from_assistant_message_event(&ame);
2278 assert!(result.is_some());
2279 match result.unwrap() {
2280 StreamEvent::TextDelta {
2281 content_index,
2282 delta,
2283 } => {
2284 assert_eq!(content_index, 2);
2285 assert_eq!(delta, "chunk");
2286 }
2287 other => unreachable!("expected TextDelta, got {other:?}"),
2288 }
2289 }
2290
2291 #[test]
2292 fn stream_event_from_assistant_message_event_start_returns_none() {
2293 use crate::model::AssistantMessageEvent as AME;
2294
2295 let partial = Arc::new(AssistantMessage {
2296 content: Vec::new(),
2297 api: String::new(),
2298 provider: String::new(),
2299 model: String::new(),
2300 usage: Usage::default(),
2301 stop_reason: StopReason::Stop,
2302 error_message: None,
2303 timestamp: 0,
2304 });
2305 let ame = AME::Start { partial };
2306 assert!(stream_event_from_assistant_message_event(&ame).is_none());
2307 }
2308
2309 #[test]
2310 fn event_listeners_debug_impl() {
2311 let listeners = EventListeners::new();
2312 let debug = format!("{listeners:?}");
2313 assert!(debug.contains("subscriber_count"));
2314 assert!(debug.contains("has_on_tool_start"));
2315 }
2316
2317 #[test]
2322 fn has_extensions_false_by_default() {
2323 let tmp = tempdir().expect("tempdir");
2324 let options = SessionOptions {
2325 working_directory: Some(tmp.path().to_path_buf()),
2326 no_session: true,
2327 ..SessionOptions::default()
2328 };
2329
2330 let handle = run_async(create_agent_session(options)).expect("create session");
2331 assert!(
2332 !handle.has_extensions(),
2333 "session without extension_paths should have no extensions"
2334 );
2335 assert!(handle.extension_manager().is_none());
2336 assert!(handle.extension_region().is_none());
2337 }
2338
2339 #[test]
2344 fn create_read_tool_has_correct_name() {
2345 let tmp = tempdir().expect("tempdir");
2346 let tool = super::create_read_tool(tmp.path());
2347 assert_eq!(tool.name(), "read");
2348 assert!(!tool.description().is_empty());
2349 let params = tool.parameters();
2350 assert!(params.is_object(), "parameters should be a JSON object");
2351 }
2352
2353 #[test]
2354 fn create_bash_tool_has_correct_name() {
2355 let tmp = tempdir().expect("tempdir");
2356 let tool = super::create_bash_tool(tmp.path());
2357 assert_eq!(tool.name(), "bash");
2358 assert!(!tool.description().is_empty());
2359 }
2360
2361 #[test]
2362 fn create_edit_tool_has_correct_name() {
2363 let tmp = tempdir().expect("tempdir");
2364 let tool = super::create_edit_tool(tmp.path());
2365 assert_eq!(tool.name(), "edit");
2366 }
2367
2368 #[test]
2369 fn create_write_tool_has_correct_name() {
2370 let tmp = tempdir().expect("tempdir");
2371 let tool = super::create_write_tool(tmp.path());
2372 assert_eq!(tool.name(), "write");
2373 }
2374
2375 #[test]
2376 fn create_grep_tool_has_correct_name() {
2377 let tmp = tempdir().expect("tempdir");
2378 let tool = super::create_grep_tool(tmp.path());
2379 assert_eq!(tool.name(), "grep");
2380 }
2381
2382 #[test]
2383 fn create_find_tool_has_correct_name() {
2384 let tmp = tempdir().expect("tempdir");
2385 let tool = super::create_find_tool(tmp.path());
2386 assert_eq!(tool.name(), "find");
2387 }
2388
2389 #[test]
2390 fn create_ls_tool_has_correct_name() {
2391 let tmp = tempdir().expect("tempdir");
2392 let tool = super::create_ls_tool(tmp.path());
2393 assert_eq!(tool.name(), "ls");
2394 }
2395
2396 #[test]
2397 fn create_all_tools_returns_eight() {
2398 let tmp = tempdir().expect("tempdir");
2399 let tools = super::create_all_tools(tmp.path());
2400 assert_eq!(tools.len(), 8, "should create all 8 built-in tools");
2401
2402 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2403 for expected in BUILTIN_TOOL_NAMES {
2404 assert!(names.contains(expected), "missing tool: {expected}");
2405 }
2406 }
2407
2408 #[test]
2409 fn tool_to_definition_preserves_schema() {
2410 let tmp = tempdir().expect("tempdir");
2411 let tool = super::create_read_tool(tmp.path());
2412 let def = super::tool_to_definition(tool.as_ref());
2413 assert_eq!(def.name, "read");
2414 assert!(!def.description.is_empty());
2415 assert!(def.parameters.is_object());
2416 assert!(
2417 def.parameters.get("properties").is_some(),
2418 "schema should have properties"
2419 );
2420 }
2421
2422 #[test]
2423 fn all_tool_definitions_returns_eight_schemas() {
2424 let tmp = tempdir().expect("tempdir");
2425 let defs = super::all_tool_definitions(tmp.path());
2426 assert_eq!(defs.len(), 8);
2427
2428 for def in &defs {
2429 assert!(!def.name.is_empty());
2430 assert!(!def.description.is_empty());
2431 assert!(def.parameters.is_object());
2432 }
2433 }
2434
2435 #[test]
2436 fn builtin_tool_names_matches_create_all() {
2437 let tmp = tempdir().expect("tempdir");
2438 let tools = super::create_all_tools(tmp.path());
2439 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2440 assert_eq!(
2441 names.as_slice(),
2442 BUILTIN_TOOL_NAMES,
2443 "create_all_tools order should match BUILTIN_TOOL_NAMES"
2444 );
2445 }
2446
2447 #[test]
2448 fn tool_registry_from_factory_tools() {
2449 let tmp = tempdir().expect("tempdir");
2450 let tools = super::create_all_tools(tmp.path());
2451 let registry = ToolRegistry::from_tools(tools);
2452 assert!(registry.get("read").is_some());
2453 assert!(registry.get("bash").is_some());
2454 assert!(registry.get("nonexistent").is_none());
2455 }
2456}