1use crate::app;
23use crate::auth::AuthStorage;
24use crate::cli::Cli;
25use crate::compaction::ResolvedCompactionSettings;
26use crate::models::default_models_path;
27use crate::provider::ThinkingBudgets;
28use crate::providers;
29use clap::Parser;
30use serde::{Deserialize, Serialize, de::DeserializeOwned};
31use serde_json::{Map, Value};
32use std::collections::HashMap;
33use std::io::{BufRead, BufReader, BufWriter, Write};
34use std::path::{Path, PathBuf};
35use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38
39pub use crate::agent::{
40 AbortHandle, AbortSignal, Agent, AgentConfig, AgentEvent, AgentSession, QueueMode,
41};
42pub use crate::config::Config;
43pub use crate::error::{Error, Result};
44pub use crate::extensions::{ExtensionManager, ExtensionPolicy, ExtensionRegion};
45pub use crate::model::{
46 AssistantMessage, ContentBlock, Cost, CustomMessage, ImageContent, Message, StopReason,
47 StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage, UserContent,
48 UserMessage,
49};
50pub use crate::models::{ModelEntry, ModelRegistry};
51pub use crate::provider::{
52 Context as ProviderContext, InputType, Model, ModelCost, Provider, StreamOptions,
53 ThinkingBudgets as ProviderThinkingBudgets, ToolDef,
54};
55pub use crate::session::Session;
56pub use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
57
58pub type ToolDefinition = ToolDef;
60
61use crate::tools::{BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WriteTool};
66
67pub const BUILTIN_TOOL_NAMES: &[&str] = &["read", "bash", "edit", "write", "grep", "find", "ls"];
69
70pub fn create_read_tool(cwd: &Path) -> Box<dyn Tool> {
72 Box::new(ReadTool::new(cwd))
73}
74
75pub fn create_bash_tool(cwd: &Path) -> Box<dyn Tool> {
77 Box::new(BashTool::new(cwd))
78}
79
80pub fn create_edit_tool(cwd: &Path) -> Box<dyn Tool> {
82 Box::new(EditTool::new(cwd))
83}
84
85pub fn create_write_tool(cwd: &Path) -> Box<dyn Tool> {
87 Box::new(WriteTool::new(cwd))
88}
89
90pub fn create_grep_tool(cwd: &Path) -> Box<dyn Tool> {
92 Box::new(GrepTool::new(cwd))
93}
94
95pub fn create_find_tool(cwd: &Path) -> Box<dyn Tool> {
97 Box::new(FindTool::new(cwd))
98}
99
100pub fn create_ls_tool(cwd: &Path) -> Box<dyn Tool> {
102 Box::new(LsTool::new(cwd))
103}
104
105pub fn create_all_tools(cwd: &Path) -> Vec<Box<dyn Tool>> {
107 vec![
108 create_read_tool(cwd),
109 create_bash_tool(cwd),
110 create_edit_tool(cwd),
111 create_write_tool(cwd),
112 create_grep_tool(cwd),
113 create_find_tool(cwd),
114 create_ls_tool(cwd),
115 ]
116}
117
118pub fn tool_to_definition(tool: &dyn Tool) -> ToolDefinition {
120 ToolDefinition {
121 name: tool.name().to_string(),
122 description: tool.description().to_string(),
123 parameters: tool.parameters(),
124 }
125}
126
127pub fn all_tool_definitions(cwd: &Path) -> Vec<ToolDefinition> {
129 create_all_tools(cwd)
130 .iter()
131 .map(|t| tool_to_definition(t.as_ref()))
132 .collect()
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144pub struct SubscriptionId(u64);
145
146pub type OnToolStart = Arc<dyn Fn(&str, &Value) + Send + Sync>;
150
151pub type OnToolEnd = Arc<dyn Fn(&str, &ToolOutput, bool) + Send + Sync>;
155
156pub type OnStreamEvent = Arc<dyn Fn(&StreamEvent) + Send + Sync>;
161
162pub type EventSubscriber = Arc<dyn Fn(AgentEvent) + Send + Sync>;
163type EventSubscribers = HashMap<SubscriptionId, EventSubscriber>;
164
165#[derive(Clone, Default)]
171pub struct EventListeners {
172 next_id: Arc<AtomicU64>,
173 subscribers: Arc<std::sync::Mutex<EventSubscribers>>,
174 pub on_tool_start: Option<OnToolStart>,
175 pub on_tool_end: Option<OnToolEnd>,
176 pub on_stream_event: Option<OnStreamEvent>,
177}
178
179impl EventListeners {
180 fn new() -> Self {
181 Self {
182 next_id: Arc::new(AtomicU64::new(1)),
183 subscribers: Arc::new(std::sync::Mutex::new(HashMap::new())),
184 on_tool_start: None,
185 on_tool_end: None,
186 on_stream_event: None,
187 }
188 }
189
190 pub fn subscribe(&self, listener: EventSubscriber) -> SubscriptionId {
192 let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
193 self.subscribers
194 .lock()
195 .expect("EventListeners lock poisoned")
196 .insert(id, listener);
197 id
198 }
199
200 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
202 self.subscribers
203 .lock()
204 .expect("EventListeners lock poisoned")
205 .remove(&id)
206 .is_some()
207 }
208
209 pub fn notify(&self, event: &AgentEvent) {
211 let subs = self
212 .subscribers
213 .lock()
214 .expect("EventListeners lock poisoned");
215 for listener in subs.values() {
216 listener(event.clone());
217 }
218 }
219
220 pub fn notify_tool_start(&self, tool_name: &str, args: &Value) {
222 if let Some(cb) = &self.on_tool_start {
223 cb(tool_name, args);
224 }
225 }
226
227 pub fn notify_tool_end(&self, tool_name: &str, output: &ToolOutput, is_error: bool) {
229 if let Some(cb) = &self.on_tool_end {
230 cb(tool_name, output, is_error);
231 }
232 }
233
234 pub fn notify_stream_event(&self, event: &StreamEvent) {
236 if let Some(cb) = &self.on_stream_event {
237 cb(event);
238 }
239 }
240}
241
242impl std::fmt::Debug for EventListeners {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 let count = self.subscribers.lock().map_or(0, |s| s.len());
245 let next_id = self.next_id.load(Ordering::Relaxed);
246 f.debug_struct("EventListeners")
247 .field("subscriber_count", &count)
248 .field("next_id", &next_id)
249 .field("has_on_tool_start", &self.on_tool_start.is_some())
250 .field("has_on_tool_end", &self.on_tool_end.is_some())
251 .field("has_on_stream_event", &self.on_stream_event.is_some())
252 .finish()
253 }
254}
255
256#[derive(Clone)]
261pub struct SessionOptions {
262 pub provider: Option<String>,
263 pub model: Option<String>,
264 pub api_key: Option<String>,
265 pub thinking: Option<crate::model::ThinkingLevel>,
266 pub system_prompt: Option<String>,
267 pub append_system_prompt: Option<String>,
268 pub enabled_tools: Option<Vec<String>>,
269 pub working_directory: Option<PathBuf>,
270 pub no_session: bool,
271 pub session_path: Option<PathBuf>,
272 pub session_dir: Option<PathBuf>,
273 pub extension_paths: Vec<PathBuf>,
274 pub extension_policy: Option<String>,
275 pub repair_policy: Option<String>,
276 pub max_tool_iterations: usize,
277
278 pub on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
283
284 pub on_tool_start: Option<OnToolStart>,
286
287 pub on_tool_end: Option<OnToolEnd>,
289
290 pub on_stream_event: Option<OnStreamEvent>,
292}
293
294impl Default for SessionOptions {
295 fn default() -> Self {
296 Self {
297 provider: None,
298 model: None,
299 api_key: None,
300 thinking: None,
301 system_prompt: None,
302 append_system_prompt: None,
303 enabled_tools: None,
304 working_directory: None,
305 no_session: true,
306 session_path: None,
307 session_dir: None,
308 extension_paths: Vec::new(),
309 extension_policy: None,
310 repair_policy: None,
311 max_tool_iterations: 50,
312 on_event: None,
313 on_tool_start: None,
314 on_tool_end: None,
315 on_stream_event: None,
316 }
317 }
318}
319
320pub struct AgentSessionHandle {
329 session: AgentSession,
330 listeners: EventListeners,
331}
332
333#[derive(Debug, Clone, PartialEq, Eq)]
335pub struct AgentSessionState {
336 pub session_id: Option<String>,
337 pub provider: String,
338 pub model_id: String,
339 pub thinking_level: Option<crate::model::ThinkingLevel>,
340 pub save_enabled: bool,
341 pub message_count: usize,
342}
343
344#[derive(Debug, Clone)]
346pub enum SessionPromptResult {
347 InProcess(AssistantMessage),
348 RpcEvents(Vec<Value>),
349}
350
351#[derive(Debug, Clone)]
353pub enum SessionTransportEvent {
354 InProcess(AgentEvent),
355 Rpc(Value),
356}
357
358#[derive(Debug, Clone, PartialEq)]
360pub enum SessionTransportState {
361 InProcess(AgentSessionState),
362 Rpc(Box<RpcSessionState>),
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
367#[serde(rename_all = "camelCase")]
368pub struct RpcModelInfo {
369 pub id: String,
370 pub name: String,
371 pub api: String,
372 pub provider: String,
373 #[serde(default)]
374 pub base_url: String,
375 #[serde(default)]
376 pub reasoning: bool,
377 #[serde(default)]
378 pub input: Vec<InputType>,
379 #[serde(default)]
380 pub context_window: u32,
381 #[serde(default)]
382 pub max_tokens: u32,
383 #[serde(default)]
384 pub cost: Option<ModelCost>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
389#[serde(rename_all = "camelCase")]
390pub struct RpcSessionState {
391 #[serde(default)]
392 pub model: Option<RpcModelInfo>,
393 #[serde(default)]
394 pub thinking_level: String,
395 #[serde(default)]
396 pub is_streaming: bool,
397 #[serde(default)]
398 pub is_compacting: bool,
399 #[serde(default)]
400 pub steering_mode: String,
401 #[serde(default)]
402 pub follow_up_mode: String,
403 #[serde(default)]
404 pub session_file: Option<String>,
405 #[serde(default)]
406 pub session_id: String,
407 #[serde(default)]
408 pub session_name: Option<String>,
409 #[serde(default)]
410 pub auto_compaction_enabled: bool,
411 #[serde(default)]
412 pub message_count: usize,
413 #[serde(default)]
414 pub pending_message_count: usize,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
419#[serde(rename_all = "camelCase")]
420pub struct RpcTokenStats {
421 pub input: u64,
422 pub output: u64,
423 pub cache_read: u64,
424 pub cache_write: u64,
425 pub total: u64,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
430#[serde(rename_all = "camelCase")]
431pub struct RpcSessionStats {
432 #[serde(default)]
433 pub session_file: Option<String>,
434 pub session_id: String,
435 pub user_messages: u64,
436 pub assistant_messages: u64,
437 pub tool_calls: u64,
438 pub tool_results: u64,
439 pub total_messages: u64,
440 pub tokens: RpcTokenStats,
441 pub cost: f64,
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
446pub struct RpcCancelledResult {
447 pub cancelled: bool,
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
452#[serde(rename_all = "camelCase")]
453pub struct RpcCycleModelResult {
454 pub model: RpcModelInfo,
455 pub thinking_level: crate::model::ThinkingLevel,
456 pub is_scoped: bool,
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
461pub struct RpcThinkingLevelResult {
462 pub level: crate::model::ThinkingLevel,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
467#[serde(rename_all = "camelCase")]
468pub struct RpcBashResult {
469 pub output: String,
470 pub exit_code: i32,
471 pub cancelled: bool,
472 pub truncated: bool,
473 pub full_output_path: Option<String>,
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
478#[serde(rename_all = "camelCase")]
479pub struct RpcCompactionResult {
480 pub summary: String,
481 pub first_kept_entry_id: String,
482 pub tokens_before: u64,
483 #[serde(default)]
484 pub details: Value,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
489pub struct RpcForkResult {
490 pub text: String,
491 pub cancelled: bool,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
496#[serde(rename_all = "camelCase")]
497pub struct RpcForkMessage {
498 pub entry_id: String,
499 pub text: String,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
504pub struct RpcCommandInfo {
505 pub name: String,
506 #[serde(default)]
507 pub description: Option<String>,
508 pub source: String,
509 #[serde(default)]
510 pub location: Option<String>,
511 #[serde(default)]
512 pub path: Option<String>,
513}
514
515#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
517pub struct RpcExportHtmlResult {
518 pub path: String,
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
523pub struct RpcLastAssistantText {
524 pub text: Option<String>,
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
529#[serde(tag = "kind", rename_all = "snake_case")]
530pub enum RpcExtensionUiResponse {
531 Value { value: Value },
532 Confirmed { confirmed: bool },
533 Cancelled,
534}
535
536#[derive(Debug, Clone)]
538pub struct RpcTransportOptions {
539 pub binary_path: PathBuf,
540 pub args: Vec<String>,
541 pub cwd: Option<PathBuf>,
542}
543
544impl Default for RpcTransportOptions {
545 fn default() -> Self {
546 Self {
547 binary_path: PathBuf::from("pi"),
548 args: vec!["--mode".to_string(), "rpc".to_string()],
549 cwd: None,
550 }
551 }
552}
553
554pub struct RpcTransportClient {
556 child: Child,
557 stdin: BufWriter<ChildStdin>,
558 stdout: BufReader<ChildStdout>,
559 next_request_id: u64,
560}
561
562pub enum SessionTransport {
564 InProcess(Box<AgentSessionHandle>),
565 RpcSubprocess(RpcTransportClient),
566}
567
568impl SessionTransport {
569 pub async fn in_process(options: SessionOptions) -> Result<Self> {
570 create_agent_session(options)
571 .await
572 .map(Box::new)
573 .map(Self::InProcess)
574 }
575
576 pub fn rpc_subprocess(options: RpcTransportOptions) -> Result<Self> {
577 RpcTransportClient::connect(options).map(Self::RpcSubprocess)
578 }
579
580 #[allow(clippy::missing_const_for_fn)]
581 pub fn as_in_process_mut(&mut self) -> Option<&mut AgentSessionHandle> {
582 match self {
583 Self::InProcess(handle) => Some(handle.as_mut()),
584 Self::RpcSubprocess(_) => None,
585 }
586 }
587
588 #[allow(clippy::missing_const_for_fn)]
589 pub fn as_rpc_mut(&mut self) -> Option<&mut RpcTransportClient> {
590 match self {
591 Self::InProcess(_) => None,
592 Self::RpcSubprocess(client) => Some(client),
593 }
594 }
595
596 pub async fn prompt(
601 &mut self,
602 input: impl Into<String>,
603 on_event: impl Fn(SessionTransportEvent) + Send + Sync + 'static,
604 ) -> Result<SessionPromptResult> {
605 let input = input.into();
606 let on_event = Arc::new(on_event);
607 match self {
608 Self::InProcess(handle) => {
609 let on_event = Arc::clone(&on_event);
610 let assistant = handle
611 .prompt(input, move |event| {
612 (on_event)(SessionTransportEvent::InProcess(event));
613 })
614 .await?;
615 Ok(SessionPromptResult::InProcess(assistant))
616 }
617 Self::RpcSubprocess(client) => {
618 let events = client.prompt(input).await?;
619 for event in events.iter().cloned() {
620 (on_event)(SessionTransportEvent::Rpc(event));
621 }
622 Ok(SessionPromptResult::RpcEvents(events))
623 }
624 }
625 }
626
627 pub async fn state(&mut self) -> Result<SessionTransportState> {
629 match self {
630 Self::InProcess(handle) => handle.state().await.map(SessionTransportState::InProcess),
631 Self::RpcSubprocess(client) => client
632 .get_state()
633 .await
634 .map(Box::new)
635 .map(SessionTransportState::Rpc),
636 }
637 }
638
639 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
641 match self {
642 Self::InProcess(handle) => handle.set_model(provider, model_id).await,
643 Self::RpcSubprocess(client) => {
644 let _ = client.set_model(provider, model_id).await?;
645 Ok(())
646 }
647 }
648 }
649
650 pub fn shutdown(&mut self) -> Result<()> {
652 match self {
653 Self::InProcess(_) => Ok(()),
654 Self::RpcSubprocess(client) => client.shutdown(),
655 }
656 }
657}
658
659impl RpcTransportClient {
660 pub fn connect(options: RpcTransportOptions) -> Result<Self> {
661 let mut command = Command::new(&options.binary_path);
662 command
663 .args(&options.args)
664 .stdin(Stdio::piped())
665 .stdout(Stdio::piped())
666 .stderr(Stdio::inherit());
667 if let Some(cwd) = options.cwd {
668 command.current_dir(cwd);
669 }
670
671 let mut child = command.spawn().map_err(|err| {
672 Error::config(format!(
673 "Failed to spawn RPC subprocess {}: {err}",
674 options.binary_path.display()
675 ))
676 })?;
677 let stdin = child
678 .stdin
679 .take()
680 .ok_or_else(|| Error::config("RPC subprocess stdin is not piped"))?;
681 let stdout = child
682 .stdout
683 .take()
684 .ok_or_else(|| Error::config("RPC subprocess stdout is not piped"))?;
685
686 Ok(Self {
687 child,
688 stdin: BufWriter::new(stdin),
689 stdout: BufReader::new(stdout),
690 next_request_id: 1,
691 })
692 }
693
694 pub async fn request(&mut self, command: &str, payload: Map<String, Value>) -> Result<Value> {
695 let request_id = self.next_request_id();
696 let mut command_payload = Map::new();
697 command_payload.insert("type".to_string(), Value::String(command.to_string()));
698 command_payload.insert("id".to_string(), Value::String(request_id.clone()));
699 command_payload.extend(payload);
700
701 self.write_json_line(&Value::Object(command_payload))?;
702 self.wait_for_response(&request_id, command)
703 }
704
705 fn parse_response_data<T: DeserializeOwned>(data: Value, command: &str) -> Result<T> {
706 serde_json::from_value(data).map_err(|err| {
707 Error::api(format!(
708 "Failed to decode RPC `{command}` response payload: {err}"
709 ))
710 })
711 }
712
713 async fn request_typed<T: DeserializeOwned>(
714 &mut self,
715 command: &str,
716 payload: Map<String, Value>,
717 ) -> Result<T> {
718 let data = self.request(command, payload).await?;
719 Self::parse_response_data(data, command)
720 }
721
722 async fn request_no_data(&mut self, command: &str, payload: Map<String, Value>) -> Result<()> {
723 let _ = self.request(command, payload).await?;
724 Ok(())
725 }
726
727 pub async fn steer(&mut self, message: impl Into<String>) -> Result<()> {
728 let mut payload = Map::new();
729 payload.insert("message".to_string(), Value::String(message.into()));
730 self.request_no_data("steer", payload).await
731 }
732
733 pub async fn follow_up(&mut self, message: impl Into<String>) -> Result<()> {
734 let mut payload = Map::new();
735 payload.insert("message".to_string(), Value::String(message.into()));
736 self.request_no_data("follow_up", payload).await
737 }
738
739 pub async fn abort(&mut self) -> Result<()> {
740 self.request_no_data("abort", Map::new()).await
741 }
742
743 pub async fn new_session(
744 &mut self,
745 parent_session: Option<&Path>,
746 ) -> Result<RpcCancelledResult> {
747 let mut payload = Map::new();
748 if let Some(parent_session) = parent_session {
749 payload.insert(
750 "parentSession".to_string(),
751 Value::String(parent_session.display().to_string()),
752 );
753 }
754 self.request_typed("new_session", payload).await
755 }
756
757 pub async fn get_state(&mut self) -> Result<RpcSessionState> {
758 self.request_typed("get_state", Map::new()).await
759 }
760
761 pub async fn get_session_stats(&mut self) -> Result<RpcSessionStats> {
762 self.request_typed("get_session_stats", Map::new()).await
763 }
764
765 pub async fn get_messages(&mut self) -> Result<Vec<Value>> {
766 #[derive(Deserialize)]
767 struct MessagesPayload {
768 messages: Vec<Value>,
769 }
770 let payload: MessagesPayload = self.request_typed("get_messages", Map::new()).await?;
771 Ok(payload.messages)
772 }
773
774 pub async fn get_available_models(&mut self) -> Result<Vec<RpcModelInfo>> {
775 #[derive(Deserialize)]
776 struct ModelsPayload {
777 models: Vec<RpcModelInfo>,
778 }
779 let payload: ModelsPayload = self
780 .request_typed("get_available_models", Map::new())
781 .await?;
782 Ok(payload.models)
783 }
784
785 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<RpcModelInfo> {
786 let mut payload = Map::new();
787 payload.insert("provider".to_string(), Value::String(provider.to_string()));
788 payload.insert("modelId".to_string(), Value::String(model_id.to_string()));
789 self.request_typed("set_model", payload).await
790 }
791
792 pub async fn cycle_model(&mut self) -> Result<Option<RpcCycleModelResult>> {
793 self.request_typed("cycle_model", Map::new()).await
794 }
795
796 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
797 let mut payload = Map::new();
798 payload.insert("level".to_string(), Value::String(level.to_string()));
799 self.request_no_data("set_thinking_level", payload).await
800 }
801
802 pub async fn cycle_thinking_level(&mut self) -> Result<Option<RpcThinkingLevelResult>> {
803 self.request_typed("cycle_thinking_level", Map::new()).await
804 }
805
806 pub async fn set_steering_mode(&mut self, mode: &str) -> Result<()> {
807 let mut payload = Map::new();
808 payload.insert("mode".to_string(), Value::String(mode.to_string()));
809 self.request_no_data("set_steering_mode", payload).await
810 }
811
812 pub async fn set_follow_up_mode(&mut self, mode: &str) -> Result<()> {
813 let mut payload = Map::new();
814 payload.insert("mode".to_string(), Value::String(mode.to_string()));
815 self.request_no_data("set_follow_up_mode", payload).await
816 }
817
818 pub async fn set_auto_compaction(&mut self, enabled: bool) -> Result<()> {
819 let mut payload = Map::new();
820 payload.insert("enabled".to_string(), Value::Bool(enabled));
821 self.request_no_data("set_auto_compaction", payload).await
822 }
823
824 pub async fn set_auto_retry(&mut self, enabled: bool) -> Result<()> {
825 let mut payload = Map::new();
826 payload.insert("enabled".to_string(), Value::Bool(enabled));
827 self.request_no_data("set_auto_retry", payload).await
828 }
829
830 pub async fn abort_retry(&mut self) -> Result<()> {
831 self.request_no_data("abort_retry", Map::new()).await
832 }
833
834 pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
835 let mut payload = Map::new();
836 payload.insert("name".to_string(), Value::String(name.into()));
837 self.request_no_data("set_session_name", payload).await
838 }
839
840 pub async fn get_last_assistant_text(&mut self) -> Result<Option<String>> {
841 let payload: RpcLastAssistantText = self
842 .request_typed("get_last_assistant_text", Map::new())
843 .await?;
844 Ok(payload.text)
845 }
846
847 pub async fn export_html(&mut self, output_path: Option<&Path>) -> Result<RpcExportHtmlResult> {
848 let mut payload = Map::new();
849 if let Some(path) = output_path {
850 payload.insert(
851 "outputPath".to_string(),
852 Value::String(path.display().to_string()),
853 );
854 }
855 self.request_typed("export_html", payload).await
856 }
857
858 pub async fn bash(&mut self, command: impl Into<String>) -> Result<RpcBashResult> {
859 let mut payload = Map::new();
860 payload.insert("command".to_string(), Value::String(command.into()));
861 self.request_typed("bash", payload).await
862 }
863
864 pub async fn abort_bash(&mut self) -> Result<()> {
865 self.request_no_data("abort_bash", Map::new()).await
866 }
867
868 pub async fn compact(&mut self) -> Result<RpcCompactionResult> {
869 self.compact_with_instructions(None).await
870 }
871
872 pub async fn compact_with_instructions(
873 &mut self,
874 custom_instructions: Option<&str>,
875 ) -> Result<RpcCompactionResult> {
876 let mut payload = Map::new();
877 if let Some(custom_instructions) = custom_instructions {
878 payload.insert(
879 "customInstructions".to_string(),
880 Value::String(custom_instructions.to_string()),
881 );
882 }
883 self.request_typed("compact", payload).await
884 }
885
886 pub async fn switch_session(&mut self, session_path: &Path) -> Result<RpcCancelledResult> {
887 let mut payload = Map::new();
888 payload.insert(
889 "sessionPath".to_string(),
890 Value::String(session_path.display().to_string()),
891 );
892 self.request_typed("switch_session", payload).await
893 }
894
895 pub async fn fork(&mut self, entry_id: impl Into<String>) -> Result<RpcForkResult> {
896 let mut payload = Map::new();
897 payload.insert("entryId".to_string(), Value::String(entry_id.into()));
898 self.request_typed("fork", payload).await
899 }
900
901 pub async fn get_fork_messages(&mut self) -> Result<Vec<RpcForkMessage>> {
902 #[derive(Deserialize)]
903 struct ForkMessagesPayload {
904 messages: Vec<RpcForkMessage>,
905 }
906 let payload: ForkMessagesPayload =
907 self.request_typed("get_fork_messages", Map::new()).await?;
908 Ok(payload.messages)
909 }
910
911 pub async fn get_commands(&mut self) -> Result<Vec<RpcCommandInfo>> {
912 #[derive(Deserialize)]
913 struct CommandsPayload {
914 commands: Vec<RpcCommandInfo>,
915 }
916 let payload: CommandsPayload = self.request_typed("get_commands", Map::new()).await?;
917 Ok(payload.commands)
918 }
919
920 pub async fn extension_ui_response(
921 &mut self,
922 request_id: &str,
923 response: RpcExtensionUiResponse,
924 ) -> Result<bool> {
925 #[derive(Deserialize)]
926 struct ExtensionUiResolvedPayload {
927 resolved: bool,
928 }
929
930 let mut payload = Map::new();
931 payload.insert(
932 "requestId".to_string(),
933 Value::String(request_id.to_string()),
934 );
935
936 match response {
937 RpcExtensionUiResponse::Value { value } => {
938 payload.insert("value".to_string(), value);
939 }
940 RpcExtensionUiResponse::Confirmed { confirmed } => {
941 payload.insert("confirmed".to_string(), Value::Bool(confirmed));
942 }
943 RpcExtensionUiResponse::Cancelled => {
944 payload.insert("cancelled".to_string(), Value::Bool(true));
945 }
946 }
947
948 let response: Option<ExtensionUiResolvedPayload> =
949 self.request_typed("extension_ui_response", payload).await?;
950 Ok(response.is_none_or(|payload| payload.resolved))
951 }
952
953 pub async fn prompt(&mut self, message: impl Into<String>) -> Result<Vec<Value>> {
954 self.prompt_with_options(message, None, None).await
955 }
956
957 pub async fn prompt_with_options(
958 &mut self,
959 message: impl Into<String>,
960 images: Option<Vec<ImageContent>>,
961 streaming_behavior: Option<&str>,
962 ) -> Result<Vec<Value>> {
963 let request_id = self.next_request_id();
964 let mut payload = Map::new();
965 payload.insert("type".to_string(), Value::String("prompt".to_string()));
966 payload.insert("id".to_string(), Value::String(request_id.clone()));
967 payload.insert("message".to_string(), Value::String(message.into()));
968 if let Some(images) = images {
969 payload.insert(
970 "images".to_string(),
971 serde_json::to_value(images).map_err(|err| Error::Json(Box::new(err)))?,
972 );
973 }
974 if let Some(streaming_behavior) = streaming_behavior {
975 payload.insert(
976 "streamingBehavior".to_string(),
977 Value::String(streaming_behavior.to_string()),
978 );
979 }
980 let payload = Value::Object(payload);
981 self.write_json_line(&payload)?;
982
983 let mut saw_ack = false;
984 let mut events = Vec::new();
985 loop {
986 let item = self.read_json_line()?;
987 let item_type = item.get("type").and_then(Value::as_str);
988 if item_type == Some("response") {
989 if item.get("id").and_then(Value::as_str) != Some(request_id.as_str()) {
990 continue;
991 }
992 let success = item
993 .get("success")
994 .and_then(Value::as_bool)
995 .unwrap_or(false);
996 if !success {
997 return Err(rpc_error_from_response(&item, "prompt"));
998 }
999 saw_ack = true;
1000 continue;
1001 }
1002
1003 if saw_ack {
1004 let reached_end = item_type == Some("agent_end");
1005 events.push(item);
1006 if reached_end {
1007 return Ok(events);
1008 }
1009 }
1010 }
1011 }
1012
1013 pub fn shutdown(&mut self) -> Result<()> {
1014 if self
1015 .child
1016 .try_wait()
1017 .map_err(|err| Error::Io(Box::new(err)))?
1018 .is_none()
1019 {
1020 self.child.kill().map_err(|err| Error::Io(Box::new(err)))?;
1021 }
1022 let _ = self.child.wait();
1023 Ok(())
1024 }
1025
1026 fn next_request_id(&mut self) -> String {
1027 let id = format!("rpc-{}", self.next_request_id);
1028 self.next_request_id = self.next_request_id.saturating_add(1);
1029 id
1030 }
1031
1032 fn write_json_line(&mut self, payload: &Value) -> Result<()> {
1033 let encoded = serde_json::to_string(payload).map_err(|err| Error::Json(Box::new(err)))?;
1034 self.stdin
1035 .write_all(encoded.as_bytes())
1036 .map_err(|err| Error::Io(Box::new(err)))?;
1037 self.stdin
1038 .write_all(b"\n")
1039 .map_err(|err| Error::Io(Box::new(err)))?;
1040 self.stdin.flush().map_err(|err| Error::Io(Box::new(err)))?;
1041 Ok(())
1042 }
1043
1044 fn read_json_line(&mut self) -> Result<Value> {
1045 let mut line = String::new();
1046 let read = self
1047 .stdout
1048 .read_line(&mut line)
1049 .map_err(|err| Error::Io(Box::new(err)))?;
1050 if read == 0 {
1051 return Err(Error::api(
1052 "RPC subprocess exited before sending a response",
1053 ));
1054 }
1055 serde_json::from_str(line.trim_end()).map_err(|err| Error::Json(Box::new(err)))
1056 }
1057
1058 fn wait_for_response(&mut self, request_id: &str, command: &str) -> Result<Value> {
1059 loop {
1060 let item = self.read_json_line()?;
1061 let Some(item_type) = item.get("type").and_then(Value::as_str) else {
1062 continue;
1063 };
1064 if item_type != "response" {
1065 continue;
1066 }
1067 if item.get("id").and_then(Value::as_str) != Some(request_id) {
1068 continue;
1069 }
1070 if item.get("command").and_then(Value::as_str) != Some(command) {
1071 continue;
1072 }
1073
1074 let success = item
1075 .get("success")
1076 .and_then(Value::as_bool)
1077 .unwrap_or(false);
1078 if success {
1079 return Ok(item.get("data").cloned().unwrap_or(Value::Null));
1080 }
1081 return Err(rpc_error_from_response(&item, command));
1082 }
1083 }
1084}
1085
1086impl Drop for RpcTransportClient {
1087 fn drop(&mut self) {
1088 let _ = self.shutdown();
1089 }
1090}
1091
1092fn rpc_error_from_response(response: &Value, command: &str) -> Error {
1093 let error = response
1094 .get("error")
1095 .and_then(Value::as_str)
1096 .unwrap_or("RPC command failed");
1097 Error::api(format!("RPC {command} failed: {error}"))
1098}
1099
1100impl AgentSessionHandle {
1101 pub const fn from_session_with_listeners(
1106 session: AgentSession,
1107 listeners: EventListeners,
1108 ) -> Self {
1109 Self { session, listeners }
1110 }
1111
1112 pub async fn prompt(
1118 &mut self,
1119 input: impl Into<String>,
1120 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1121 ) -> Result<AssistantMessage> {
1122 let combined = self.make_combined_callback(on_event);
1123 self.session.run_text(input.into(), combined).await
1124 }
1125
1126 pub async fn prompt_with_abort(
1128 &mut self,
1129 input: impl Into<String>,
1130 abort_signal: AbortSignal,
1131 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1132 ) -> Result<AssistantMessage> {
1133 let combined = self.make_combined_callback(on_event);
1134 self.session
1135 .run_text_with_abort(input.into(), Some(abort_signal), combined)
1136 .await
1137 }
1138
1139 pub fn new_abort_handle() -> (AbortHandle, AbortSignal) {
1141 AbortHandle::new()
1142 }
1143
1144 pub fn subscribe(
1151 &self,
1152 listener: impl Fn(AgentEvent) + Send + Sync + 'static,
1153 ) -> SubscriptionId {
1154 self.listeners.subscribe(Arc::new(listener))
1155 }
1156
1157 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
1161 self.listeners.unsubscribe(id)
1162 }
1163
1164 pub const fn listeners(&self) -> &EventListeners {
1166 &self.listeners
1167 }
1168
1169 pub const fn listeners_mut(&mut self) -> &mut EventListeners {
1174 &mut self.listeners
1175 }
1176
1177 pub const fn has_extensions(&self) -> bool {
1183 self.session.extensions.is_some()
1184 }
1185
1186 pub fn extension_manager(&self) -> Option<&ExtensionManager> {
1188 self.session
1189 .extensions
1190 .as_ref()
1191 .map(ExtensionRegion::manager)
1192 }
1193
1194 pub const fn extension_region(&self) -> Option<&ExtensionRegion> {
1198 self.session.extensions.as_ref()
1199 }
1200
1201 pub fn model(&self) -> (String, String) {
1207 let provider = self.session.agent.provider();
1208 (provider.name().to_string(), provider.model_id().to_string())
1209 }
1210
1211 pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
1213 self.session.set_provider_model(provider, model_id).await
1214 }
1215
1216 pub const fn thinking_level(&self) -> Option<crate::model::ThinkingLevel> {
1218 self.session.agent.stream_options().thinking_level
1219 }
1220
1221 pub const fn thinking(&self) -> Option<crate::model::ThinkingLevel> {
1223 self.thinking_level()
1224 }
1225
1226 pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
1228 let level_string = level.to_string();
1229 let cx = crate::agent_cx::AgentCx::for_request();
1230 {
1231 let mut guard = self
1232 .session
1233 .session
1234 .lock(cx.cx())
1235 .await
1236 .map_err(|e| Error::session(e.to_string()))?;
1237 guard.set_model_header(None, None, Some(level_string.clone()));
1238 guard.append_thinking_level_change(level_string);
1239 }
1240 self.session.agent.stream_options_mut().thinking_level = Some(level);
1241 self.session.persist_session().await
1242 }
1243
1244 pub async fn messages(&self) -> Result<Vec<Message>> {
1246 let cx = crate::agent_cx::AgentCx::for_request();
1247 let guard = self
1248 .session
1249 .session
1250 .lock(cx.cx())
1251 .await
1252 .map_err(|e| Error::session(e.to_string()))?;
1253 Ok(guard.to_messages_for_current_path())
1254 }
1255
1256 pub async fn state(&self) -> Result<AgentSessionState> {
1258 let (provider, model_id) = self.model();
1259 let thinking_level = self.thinking_level();
1260 let save_enabled = self.session.save_enabled();
1261 let cx = crate::agent_cx::AgentCx::for_request();
1262 let guard = self
1263 .session
1264 .session
1265 .lock(cx.cx())
1266 .await
1267 .map_err(|e| Error::session(e.to_string()))?;
1268 let session_id = Some(guard.header.id.clone());
1269 let message_count = guard.to_messages_for_current_path().len();
1270
1271 Ok(AgentSessionState {
1272 session_id,
1273 provider,
1274 model_id,
1275 thinking_level,
1276 save_enabled,
1277 message_count,
1278 })
1279 }
1280
1281 pub async fn compact(
1283 &mut self,
1284 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1285 ) -> Result<()> {
1286 self.session.compact_now(on_event).await
1287 }
1288
1289 pub const fn session(&self) -> &AgentSession {
1291 &self.session
1292 }
1293
1294 pub const fn session_mut(&mut self) -> &mut AgentSession {
1296 &mut self.session
1297 }
1298
1299 pub fn into_inner(self) -> AgentSession {
1301 self.session
1302 }
1303
1304 fn make_combined_callback(
1307 &self,
1308 per_prompt: impl Fn(AgentEvent) + Send + Sync + 'static,
1309 ) -> impl Fn(AgentEvent) + Send + Sync + 'static {
1310 let listeners = self.listeners.clone();
1311 move |event: AgentEvent| {
1312 match &event {
1314 AgentEvent::ToolExecutionStart {
1315 tool_name, args, ..
1316 } => {
1317 listeners.notify_tool_start(tool_name, args);
1318 }
1319 AgentEvent::ToolExecutionEnd {
1320 tool_name,
1321 result,
1322 is_error,
1323 ..
1324 } => {
1325 listeners.notify_tool_end(tool_name, result, *is_error);
1326 }
1327 AgentEvent::MessageUpdate {
1328 assistant_message_event,
1329 ..
1330 } => {
1331 if let Some(stream_ev) =
1334 stream_event_from_assistant_message_event(assistant_message_event)
1335 {
1336 listeners.notify_stream_event(&stream_ev);
1337 }
1338 }
1339 _ => {}
1340 }
1341
1342 listeners.notify(&event);
1344
1345 per_prompt(event);
1347 }
1348 }
1349}
1350
1351fn stream_event_from_assistant_message_event(
1356 event: &crate::model::AssistantMessageEvent,
1357) -> Option<StreamEvent> {
1358 use crate::model::AssistantMessageEvent as AME;
1359 match event {
1360 AME::TextStart { content_index, .. } => Some(StreamEvent::TextStart {
1361 content_index: *content_index,
1362 }),
1363 AME::TextDelta {
1364 content_index,
1365 delta,
1366 ..
1367 } => Some(StreamEvent::TextDelta {
1368 content_index: *content_index,
1369 delta: delta.clone(),
1370 }),
1371 AME::TextEnd {
1372 content_index,
1373 content,
1374 ..
1375 } => Some(StreamEvent::TextEnd {
1376 content_index: *content_index,
1377 content: content.clone(),
1378 }),
1379 AME::ThinkingStart { content_index, .. } => Some(StreamEvent::ThinkingStart {
1380 content_index: *content_index,
1381 }),
1382 AME::ThinkingDelta {
1383 content_index,
1384 delta,
1385 ..
1386 } => Some(StreamEvent::ThinkingDelta {
1387 content_index: *content_index,
1388 delta: delta.clone(),
1389 }),
1390 AME::ThinkingEnd {
1391 content_index,
1392 content,
1393 ..
1394 } => Some(StreamEvent::ThinkingEnd {
1395 content_index: *content_index,
1396 content: content.clone(),
1397 }),
1398 AME::ToolCallStart { content_index, .. } => Some(StreamEvent::ToolCallStart {
1399 content_index: *content_index,
1400 }),
1401 AME::ToolCallDelta {
1402 content_index,
1403 delta,
1404 ..
1405 } => Some(StreamEvent::ToolCallDelta {
1406 content_index: *content_index,
1407 delta: delta.clone(),
1408 }),
1409 AME::ToolCallEnd {
1410 content_index,
1411 tool_call,
1412 ..
1413 } => Some(StreamEvent::ToolCallEnd {
1414 content_index: *content_index,
1415 tool_call: tool_call.clone(),
1416 }),
1417 AME::Done { reason, message } => Some(StreamEvent::Done {
1418 reason: *reason,
1419 message: (**message).clone(),
1420 }),
1421 AME::Error { reason, error } => Some(StreamEvent::Error {
1422 reason: *reason,
1423 error: (**error).clone(),
1424 }),
1425 AME::Start { .. } => None,
1426 }
1427}
1428
1429fn resolve_path_for_cwd(path: &Path, cwd: &Path) -> PathBuf {
1430 if path.is_absolute() {
1431 path.to_path_buf()
1432 } else {
1433 cwd.join(path)
1434 }
1435}
1436
1437fn build_stream_options_with_optional_key(
1438 config: &Config,
1439 api_key: Option<String>,
1440 selection: &app::ModelSelection,
1441 session: &Session,
1442) -> StreamOptions {
1443 let mut options = StreamOptions {
1444 api_key,
1445 headers: selection.model_entry.headers.clone(),
1446 session_id: Some(session.header.id.clone()),
1447 thinking_level: Some(selection.thinking_level),
1448 ..Default::default()
1449 };
1450
1451 if let Some(budgets) = &config.thinking_budgets {
1452 let defaults = ThinkingBudgets::default();
1453 options.thinking_budgets = Some(ThinkingBudgets {
1454 minimal: budgets.minimal.unwrap_or(defaults.minimal),
1455 low: budgets.low.unwrap_or(defaults.low),
1456 medium: budgets.medium.unwrap_or(defaults.medium),
1457 high: budgets.high.unwrap_or(defaults.high),
1458 xhigh: budgets.xhigh.unwrap_or(defaults.xhigh),
1459 });
1460 }
1461
1462 options
1463}
1464
1465#[allow(clippy::too_many_lines)]
1470pub async fn create_agent_session(options: SessionOptions) -> Result<AgentSessionHandle> {
1471 let process_cwd =
1472 std::env::current_dir().map_err(|e| Error::config(format!("cwd lookup failed: {e}")))?;
1473 let cwd = options.working_directory.as_deref().map_or_else(
1474 || process_cwd.clone(),
1475 |path| resolve_path_for_cwd(path, &process_cwd),
1476 );
1477
1478 let mut cli = Cli::try_parse_from(["pi"])
1479 .map_err(|e| Error::validation(format!("CLI init failed: {e}")))?;
1480 cli.no_session = options.no_session;
1481 cli.provider = options.provider.clone();
1482 cli.model = options.model.clone();
1483 cli.api_key = options.api_key.clone();
1484 cli.system_prompt = options.system_prompt.clone();
1485 cli.append_system_prompt = options.append_system_prompt.clone();
1486 cli.thinking = options.thinking.map(|t| t.to_string());
1487 cli.session = options
1488 .session_path
1489 .as_ref()
1490 .map(|p| p.to_string_lossy().to_string());
1491 cli.session_dir = options
1492 .session_dir
1493 .as_ref()
1494 .map(|p| p.to_string_lossy().to_string());
1495 if let Some(enabled_tools) = &options.enabled_tools {
1496 if enabled_tools.is_empty() {
1497 cli.no_tools = true;
1498 } else {
1499 cli.no_tools = false;
1500 cli.tools = enabled_tools.join(",");
1501 }
1502 }
1503
1504 let config = Config::load()?;
1505
1506 let mut auth = AuthStorage::load_async(Config::auth_path()).await?;
1507 auth.refresh_expired_oauth_tokens().await?;
1508
1509 let global_dir = Config::global_dir();
1510 let package_dir = Config::package_dir();
1511 let models_path = default_models_path(&global_dir);
1512 let model_registry = ModelRegistry::load(&auth, Some(models_path));
1513
1514 let mut session = Session::new(&cli, &config).await?;
1515 let scoped_patterns = if let Some(models_arg) = &cli.models {
1516 app::parse_models_arg(models_arg)
1517 } else {
1518 config.enabled_models.clone().unwrap_or_default()
1519 };
1520 let scoped_models = if scoped_patterns.is_empty() {
1521 Vec::new()
1522 } else {
1523 app::resolve_model_scope(&scoped_patterns, &model_registry, cli.api_key.is_some())
1524 };
1525
1526 let selection = app::select_model_and_thinking(
1527 &cli,
1528 &config,
1529 &session,
1530 &model_registry,
1531 &scoped_models,
1532 &global_dir,
1533 )
1534 .map_err(|err| Error::validation(err.to_string()))?;
1535 app::update_session_for_selection(&mut session, &selection);
1536
1537 let enabled_tools_owned = cli
1538 .enabled_tools()
1539 .into_iter()
1540 .map(str::to_string)
1541 .collect::<Vec<_>>();
1542 let enabled_tools = enabled_tools_owned
1543 .iter()
1544 .map(String::as_str)
1545 .collect::<Vec<_>>();
1546
1547 let system_prompt = app::build_system_prompt(
1548 &cli,
1549 &cwd,
1550 &enabled_tools,
1551 None,
1552 &global_dir,
1553 &package_dir,
1554 std::env::var_os("PI_TEST_MODE").is_some(),
1555 );
1556
1557 let provider = providers::create_provider(&selection.model_entry, None)
1558 .map_err(|e| Error::provider("sdk", e.to_string()))?;
1559
1560 let api_key = auth
1561 .resolve_api_key(
1562 &selection.model_entry.model.provider,
1563 cli.api_key.as_deref(),
1564 )
1565 .or_else(|| selection.model_entry.api_key.clone());
1566
1567 let stream_options =
1568 build_stream_options_with_optional_key(&config, api_key, &selection, &session);
1569
1570 let agent_config = AgentConfig {
1571 system_prompt: Some(system_prompt),
1572 max_tool_iterations: options.max_tool_iterations,
1573 stream_options,
1574 block_images: config.image_block_images(),
1575 };
1576
1577 let tools = ToolRegistry::new(&enabled_tools, &cwd, Some(&config));
1578 let session_arc = Arc::new(asupersync::sync::Mutex::new(session));
1579
1580 let context_window_tokens = if selection.model_entry.model.context_window == 0 {
1581 ResolvedCompactionSettings::default().context_window_tokens
1582 } else {
1583 selection.model_entry.model.context_window
1584 };
1585 let compaction_settings = ResolvedCompactionSettings {
1586 enabled: config.compaction_enabled(),
1587 reserve_tokens: config.compaction_reserve_tokens(),
1588 keep_recent_tokens: config.compaction_keep_recent_tokens(),
1589 context_window_tokens,
1590 };
1591
1592 let mut agent_session = AgentSession::new(
1593 Agent::new(provider, tools, agent_config),
1594 Arc::clone(&session_arc),
1595 !cli.no_session,
1596 compaction_settings,
1597 );
1598
1599 if !options.extension_paths.is_empty() {
1600 let extension_paths = options
1601 .extension_paths
1602 .iter()
1603 .map(|path| resolve_path_for_cwd(path, &cwd))
1604 .collect::<Vec<_>>();
1605 let resolved_ext_policy =
1606 config.resolve_extension_policy_with_metadata(options.extension_policy.as_deref());
1607 let resolved_repair_policy =
1608 config.resolve_repair_policy_with_metadata(options.repair_policy.as_deref());
1609
1610 agent_session
1611 .enable_extensions_with_policy(
1612 &enabled_tools,
1613 &cwd,
1614 Some(&config),
1615 &extension_paths,
1616 Some(resolved_ext_policy.policy),
1617 Some(resolved_repair_policy.effective_mode),
1618 None,
1619 )
1620 .await?;
1621 }
1622
1623 agent_session.set_model_registry(model_registry);
1624 agent_session.set_auth_storage(auth);
1625
1626 let history = {
1627 let cx = crate::agent_cx::AgentCx::for_request();
1628 let guard = session_arc
1629 .lock(cx.cx())
1630 .await
1631 .map_err(|e| Error::session(e.to_string()))?;
1632 guard.to_messages_for_current_path()
1633 };
1634 if !history.is_empty() {
1635 agent_session.agent.replace_messages(history);
1636 }
1637
1638 let mut listeners = EventListeners::new();
1639 if let Some(on_event) = options.on_event {
1640 listeners.subscribe(on_event);
1641 }
1642 listeners.on_tool_start = options.on_tool_start;
1643 listeners.on_tool_end = options.on_tool_end;
1644 listeners.on_stream_event = options.on_stream_event;
1645
1646 Ok(AgentSessionHandle {
1647 session: agent_session,
1648 listeners,
1649 })
1650}
1651
1652#[cfg(test)]
1653mod tests {
1654 use super::*;
1655 use asupersync::runtime::RuntimeBuilder;
1656 use asupersync::runtime::reactor::create_reactor;
1657 use std::sync::{Arc, Mutex};
1658 use tempfile::tempdir;
1659
1660 fn run_async<F>(future: F) -> F::Output
1661 where
1662 F: std::future::Future,
1663 {
1664 let reactor = create_reactor().expect("create reactor");
1665 let runtime = RuntimeBuilder::current_thread()
1666 .with_reactor(reactor)
1667 .build()
1668 .expect("build runtime");
1669 runtime.block_on(future)
1670 }
1671
1672 #[test]
1673 fn create_agent_session_default_succeeds() {
1674 let tmp = tempdir().expect("tempdir");
1675 let options = SessionOptions {
1676 working_directory: Some(tmp.path().to_path_buf()),
1677 no_session: true,
1678 ..SessionOptions::default()
1679 };
1680
1681 let handle = run_async(create_agent_session(options)).expect("create session");
1682 let provider = handle.session().agent.provider();
1683 assert_eq!(provider.name(), "openai-codex");
1684 assert_eq!(provider.model_id(), "gpt-5.3-codex");
1685 }
1686
1687 #[test]
1688 fn create_agent_session_respects_provider_model_and_clamps_thinking() {
1689 let tmp = tempdir().expect("tempdir");
1690 let options = SessionOptions {
1691 provider: Some("openai".to_string()),
1692 model: Some("gpt-4o".to_string()),
1693 thinking: Some(crate::model::ThinkingLevel::Low),
1694 working_directory: Some(tmp.path().to_path_buf()),
1695 no_session: true,
1696 ..SessionOptions::default()
1697 };
1698
1699 let handle = run_async(create_agent_session(options)).expect("create session");
1700 let provider = handle.session().agent.provider();
1701 assert_eq!(provider.name(), "openai");
1702 assert_eq!(provider.model_id(), "gpt-4o");
1703 assert_eq!(
1704 handle.session().agent.stream_options().thinking_level,
1705 Some(crate::model::ThinkingLevel::Off)
1706 );
1707 }
1708
1709 #[test]
1710 fn create_agent_session_no_session_keeps_ephemeral_state() {
1711 let tmp = tempdir().expect("tempdir");
1712 let options = SessionOptions {
1713 working_directory: Some(tmp.path().to_path_buf()),
1714 no_session: true,
1715 ..SessionOptions::default()
1716 };
1717
1718 let handle = run_async(create_agent_session(options)).expect("create session");
1719 assert!(!handle.session().save_enabled());
1720
1721 let path_is_none = run_async(async {
1722 let cx = crate::agent_cx::AgentCx::for_request();
1723 let guard = handle
1724 .session()
1725 .session
1726 .lock(cx.cx())
1727 .await
1728 .expect("lock session");
1729 guard.path.is_none()
1730 });
1731 assert!(path_is_none);
1732 }
1733
1734 #[test]
1735 fn create_agent_session_set_model_switches_provider_model() {
1736 let tmp = tempdir().expect("tempdir");
1737 let options = SessionOptions {
1738 working_directory: Some(tmp.path().to_path_buf()),
1739 no_session: true,
1740 ..SessionOptions::default()
1741 };
1742
1743 let mut handle = run_async(create_agent_session(options)).expect("create session");
1744 run_async(handle.set_model("openai", "gpt-4o")).expect("set model");
1745 let provider = handle.session().agent.provider();
1746 assert_eq!(provider.name(), "openai");
1747 assert_eq!(provider.model_id(), "gpt-4o");
1748 }
1749
1750 #[test]
1751 fn compact_without_history_is_noop() {
1752 let tmp = tempdir().expect("tempdir");
1753 let options = SessionOptions {
1754 working_directory: Some(tmp.path().to_path_buf()),
1755 no_session: true,
1756 ..SessionOptions::default()
1757 };
1758
1759 let mut handle = run_async(create_agent_session(options)).expect("create session");
1760 let events = Arc::new(Mutex::new(Vec::new()));
1761 let events_for_callback = Arc::clone(&events);
1762 run_async(handle.compact(move |event| {
1763 events_for_callback
1764 .lock()
1765 .expect("compact callback lock")
1766 .push(event);
1767 }))
1768 .expect("compact");
1769
1770 assert!(
1771 events.lock().expect("events lock").is_empty(),
1772 "expected no compaction lifecycle events for empty session"
1773 );
1774 }
1775
1776 #[test]
1777 fn resolve_path_for_cwd_uses_cwd_for_relative_paths() {
1778 let cwd = Path::new("/tmp/pi-sdk-cwd");
1779 assert_eq!(
1780 resolve_path_for_cwd(Path::new("relative/file.txt"), cwd),
1781 PathBuf::from("/tmp/pi-sdk-cwd/relative/file.txt")
1782 );
1783 assert_eq!(
1784 resolve_path_for_cwd(Path::new("/etc/hosts"), cwd),
1785 PathBuf::from("/etc/hosts")
1786 );
1787 }
1788
1789 #[test]
1794 fn event_listeners_subscribe_and_notify() {
1795 let listeners = EventListeners::new();
1796 let received = Arc::new(Mutex::new(Vec::new()));
1797
1798 let recv_clone = Arc::clone(&received);
1799 let id = listeners.subscribe(Arc::new(move |event| {
1800 recv_clone.lock().expect("lock").push(event);
1801 }));
1802
1803 let event = AgentEvent::AgentStart {
1804 session_id: "test-123".into(),
1805 };
1806 listeners.notify(&event);
1807
1808 let events = received.lock().expect("lock");
1809 assert_eq!(events.len(), 1);
1810
1811 drop(events);
1813 assert!(listeners.unsubscribe(id));
1814 listeners.notify(&AgentEvent::AgentStart {
1815 session_id: "test-456".into(),
1816 });
1817 assert_eq!(received.lock().expect("lock").len(), 1);
1818 }
1819
1820 #[test]
1821 fn event_listeners_unsubscribe_nonexistent_returns_false() {
1822 let listeners = EventListeners::new();
1823 assert!(!listeners.unsubscribe(SubscriptionId(999)));
1824 }
1825
1826 #[test]
1827 fn event_listeners_multiple_subscribers() {
1828 let listeners = EventListeners::new();
1829 let count_a = Arc::new(Mutex::new(0u32));
1830 let count_b = Arc::new(Mutex::new(0u32));
1831
1832 let ca = Arc::clone(&count_a);
1833 listeners.subscribe(Arc::new(move |_| {
1834 *ca.lock().expect("lock") += 1;
1835 }));
1836
1837 let cb = Arc::clone(&count_b);
1838 listeners.subscribe(Arc::new(move |_| {
1839 *cb.lock().expect("lock") += 1;
1840 }));
1841
1842 listeners.notify(&AgentEvent::AgentStart {
1843 session_id: "s".into(),
1844 });
1845
1846 assert_eq!(*count_a.lock().expect("lock"), 1);
1847 assert_eq!(*count_b.lock().expect("lock"), 1);
1848 }
1849
1850 #[test]
1851 fn event_listeners_tool_hooks_fire() {
1852 let listeners = EventListeners::new();
1853 let starts = Arc::new(Mutex::new(Vec::new()));
1854 let ends = Arc::new(Mutex::new(Vec::new()));
1855
1856 let s = Arc::clone(&starts);
1857 let mut listeners = listeners;
1858 listeners.on_tool_start = Some(Arc::new(move |name, args| {
1859 s.lock()
1860 .expect("lock")
1861 .push((name.to_string(), args.clone()));
1862 }));
1863
1864 let e = Arc::clone(&ends);
1865 listeners.on_tool_end = Some(Arc::new(move |name, _output, is_error| {
1866 e.lock().expect("lock").push((name.to_string(), is_error));
1867 }));
1868
1869 let args = serde_json::json!({"path": "/foo"});
1870 listeners.notify_tool_start("bash", &args);
1871 let output = ToolOutput {
1872 content: vec![ContentBlock::Text(TextContent::new("ok"))],
1873 details: None,
1874 is_error: false,
1875 };
1876 listeners.notify_tool_end("bash", &output, false);
1877
1878 {
1879 let s = starts.lock().expect("lock");
1880 assert_eq!(s.len(), 1);
1881 assert_eq!(s[0].0, "bash");
1882 drop(s);
1883 }
1884
1885 {
1886 let e = ends.lock().expect("lock");
1887 assert_eq!(e.len(), 1);
1888 assert_eq!(e[0].0, "bash");
1889 assert!(!e[0].1);
1890 drop(e);
1891 }
1892 }
1893
1894 #[test]
1895 fn event_listeners_stream_event_hook_fires() {
1896 let mut listeners = EventListeners::new();
1897 let received = Arc::new(Mutex::new(Vec::new()));
1898
1899 let r = Arc::clone(&received);
1900 listeners.on_stream_event = Some(Arc::new(move |ev| {
1901 r.lock().expect("lock").push(format!("{ev:?}"));
1902 }));
1903
1904 let event = StreamEvent::TextDelta {
1905 content_index: 0,
1906 delta: "hello".to_string(),
1907 };
1908 listeners.notify_stream_event(&event);
1909
1910 assert_eq!(received.lock().expect("lock").len(), 1);
1911 }
1912
1913 #[test]
1914 fn session_options_on_event_wired_into_listeners() {
1915 let received = Arc::new(Mutex::new(Vec::new()));
1916 let r = Arc::clone(&received);
1917 let tmp = tempdir().expect("tempdir");
1918
1919 let options = SessionOptions {
1920 working_directory: Some(tmp.path().to_path_buf()),
1921 no_session: true,
1922 on_event: Some(Arc::new(move |event| {
1923 r.lock().expect("lock").push(format!("{event:?}"));
1924 })),
1925 ..SessionOptions::default()
1926 };
1927
1928 let handle = run_async(create_agent_session(options)).expect("create session");
1929 let count = handle.listeners().subscribers.lock().expect("lock").len();
1931 assert_eq!(
1932 count, 1,
1933 "on_event from SessionOptions should register one subscriber"
1934 );
1935 }
1936
1937 #[test]
1938 fn subscribe_unsubscribe_on_handle() {
1939 let tmp = tempdir().expect("tempdir");
1940 let options = SessionOptions {
1941 working_directory: Some(tmp.path().to_path_buf()),
1942 no_session: true,
1943 ..SessionOptions::default()
1944 };
1945
1946 let handle = run_async(create_agent_session(options)).expect("create session");
1947 let id = handle.subscribe(|_event| {});
1948 assert_eq!(
1949 handle.listeners().subscribers.lock().expect("lock").len(),
1950 1
1951 );
1952
1953 assert!(handle.unsubscribe(id));
1954 assert_eq!(
1955 handle.listeners().subscribers.lock().expect("lock").len(),
1956 0
1957 );
1958
1959 assert!(!handle.unsubscribe(id));
1961 }
1962
1963 #[test]
1964 fn stream_event_from_assistant_message_event_converts_text_delta() {
1965 use crate::model::AssistantMessageEvent as AME;
1966
1967 let partial = Arc::new(AssistantMessage {
1968 content: Vec::new(),
1969 api: String::new(),
1970 provider: String::new(),
1971 model: String::new(),
1972 usage: Usage::default(),
1973 stop_reason: StopReason::Stop,
1974 error_message: None,
1975 timestamp: 0,
1976 });
1977 let ame = AME::TextDelta {
1978 content_index: 2,
1979 delta: "chunk".to_string(),
1980 partial,
1981 };
1982 let result = stream_event_from_assistant_message_event(&ame);
1983 assert!(result.is_some());
1984 match result.unwrap() {
1985 StreamEvent::TextDelta {
1986 content_index,
1987 delta,
1988 } => {
1989 assert_eq!(content_index, 2);
1990 assert_eq!(delta, "chunk");
1991 }
1992 other => panic!("unexpected variant: {other:?}"),
1993 }
1994 }
1995
1996 #[test]
1997 fn stream_event_from_assistant_message_event_start_returns_none() {
1998 use crate::model::AssistantMessageEvent as AME;
1999
2000 let partial = Arc::new(AssistantMessage {
2001 content: Vec::new(),
2002 api: String::new(),
2003 provider: String::new(),
2004 model: String::new(),
2005 usage: Usage::default(),
2006 stop_reason: StopReason::Stop,
2007 error_message: None,
2008 timestamp: 0,
2009 });
2010 let ame = AME::Start { partial };
2011 assert!(stream_event_from_assistant_message_event(&ame).is_none());
2012 }
2013
2014 #[test]
2015 fn event_listeners_debug_impl() {
2016 let listeners = EventListeners::new();
2017 let debug = format!("{listeners:?}");
2018 assert!(debug.contains("subscriber_count"));
2019 assert!(debug.contains("has_on_tool_start"));
2020 }
2021
2022 #[test]
2027 fn has_extensions_false_by_default() {
2028 let tmp = tempdir().expect("tempdir");
2029 let options = SessionOptions {
2030 working_directory: Some(tmp.path().to_path_buf()),
2031 no_session: true,
2032 ..SessionOptions::default()
2033 };
2034
2035 let handle = run_async(create_agent_session(options)).expect("create session");
2036 assert!(
2037 !handle.has_extensions(),
2038 "session without extension_paths should have no extensions"
2039 );
2040 assert!(handle.extension_manager().is_none());
2041 assert!(handle.extension_region().is_none());
2042 }
2043
2044 #[test]
2049 fn create_read_tool_has_correct_name() {
2050 let tmp = tempdir().expect("tempdir");
2051 let tool = super::create_read_tool(tmp.path());
2052 assert_eq!(tool.name(), "read");
2053 assert!(!tool.description().is_empty());
2054 let params = tool.parameters();
2055 assert!(params.is_object(), "parameters should be a JSON object");
2056 }
2057
2058 #[test]
2059 fn create_bash_tool_has_correct_name() {
2060 let tmp = tempdir().expect("tempdir");
2061 let tool = super::create_bash_tool(tmp.path());
2062 assert_eq!(tool.name(), "bash");
2063 assert!(!tool.description().is_empty());
2064 }
2065
2066 #[test]
2067 fn create_edit_tool_has_correct_name() {
2068 let tmp = tempdir().expect("tempdir");
2069 let tool = super::create_edit_tool(tmp.path());
2070 assert_eq!(tool.name(), "edit");
2071 }
2072
2073 #[test]
2074 fn create_write_tool_has_correct_name() {
2075 let tmp = tempdir().expect("tempdir");
2076 let tool = super::create_write_tool(tmp.path());
2077 assert_eq!(tool.name(), "write");
2078 }
2079
2080 #[test]
2081 fn create_grep_tool_has_correct_name() {
2082 let tmp = tempdir().expect("tempdir");
2083 let tool = super::create_grep_tool(tmp.path());
2084 assert_eq!(tool.name(), "grep");
2085 }
2086
2087 #[test]
2088 fn create_find_tool_has_correct_name() {
2089 let tmp = tempdir().expect("tempdir");
2090 let tool = super::create_find_tool(tmp.path());
2091 assert_eq!(tool.name(), "find");
2092 }
2093
2094 #[test]
2095 fn create_ls_tool_has_correct_name() {
2096 let tmp = tempdir().expect("tempdir");
2097 let tool = super::create_ls_tool(tmp.path());
2098 assert_eq!(tool.name(), "ls");
2099 }
2100
2101 #[test]
2102 fn create_all_tools_returns_seven() {
2103 let tmp = tempdir().expect("tempdir");
2104 let tools = super::create_all_tools(tmp.path());
2105 assert_eq!(tools.len(), 7, "should create all 7 built-in tools");
2106
2107 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2108 for expected in BUILTIN_TOOL_NAMES {
2109 assert!(names.contains(expected), "missing tool: {expected}");
2110 }
2111 }
2112
2113 #[test]
2114 fn tool_to_definition_preserves_schema() {
2115 let tmp = tempdir().expect("tempdir");
2116 let tool = super::create_read_tool(tmp.path());
2117 let def = super::tool_to_definition(tool.as_ref());
2118 assert_eq!(def.name, "read");
2119 assert!(!def.description.is_empty());
2120 assert!(def.parameters.is_object());
2121 assert!(
2122 def.parameters.get("properties").is_some(),
2123 "schema should have properties"
2124 );
2125 }
2126
2127 #[test]
2128 fn all_tool_definitions_returns_seven_schemas() {
2129 let tmp = tempdir().expect("tempdir");
2130 let defs = super::all_tool_definitions(tmp.path());
2131 assert_eq!(defs.len(), 7);
2132
2133 for def in &defs {
2134 assert!(!def.name.is_empty());
2135 assert!(!def.description.is_empty());
2136 assert!(def.parameters.is_object());
2137 }
2138 }
2139
2140 #[test]
2141 fn builtin_tool_names_matches_create_all() {
2142 let tmp = tempdir().expect("tempdir");
2143 let tools = super::create_all_tools(tmp.path());
2144 let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
2145 assert_eq!(
2146 names.as_slice(),
2147 BUILTIN_TOOL_NAMES,
2148 "create_all_tools order should match BUILTIN_TOOL_NAMES"
2149 );
2150 }
2151
2152 #[test]
2153 fn tool_registry_from_factory_tools() {
2154 let tmp = tempdir().expect("tempdir");
2155 let tools = super::create_all_tools(tmp.path());
2156 let registry = ToolRegistry::from_tools(tools);
2157 assert!(registry.get("read").is_some());
2158 assert!(registry.get("bash").is_some());
2159 assert!(registry.get("nonexistent").is_none());
2160 }
2161}