use crate::app;
use crate::auth::AuthStorage;
use crate::cli::Cli;
use crate::compaction::ResolvedCompactionSettings;
use crate::models::default_models_path;
use crate::provider::ThinkingBudgets;
use crate::providers;
use clap::Parser;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::{Map, Value};
use std::collections::HashMap;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
pub use crate::agent::{
AbortHandle, AbortSignal, Agent, AgentConfig, AgentEvent, AgentSession, QueueMode,
};
pub use crate::config::Config;
pub use crate::error::{Error, Result};
pub use crate::extensions::{ExtensionManager, ExtensionPolicy, ExtensionRegion};
pub use crate::model::{
AssistantMessage, ContentBlock, Cost, CustomMessage, ImageContent, Message, StopReason,
StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage, UserContent,
UserMessage,
};
pub use crate::models::{ModelEntry, ModelRegistry};
pub use crate::provider::{
Context as ProviderContext, InputType, Model, ModelCost, Provider, StreamOptions,
ThinkingBudgets as ProviderThinkingBudgets, ToolDef,
};
pub use crate::session::Session;
pub use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
pub use crate::model::ThinkingLevel;
pub type ToolDefinition = ToolDef;
use crate::tools::{
BashTool, EditTool, FindTool, GrepTool, HashlineEditTool, LsTool, ReadTool, WriteTool,
};
pub const BUILTIN_TOOL_NAMES: &[&str] = &[
"read",
"bash",
"edit",
"write",
"grep",
"find",
"ls",
"hashline_edit",
];
pub fn create_read_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(ReadTool::new(cwd))
}
pub fn create_bash_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(BashTool::new(cwd))
}
pub fn create_edit_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(EditTool::new(cwd))
}
pub fn create_write_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(WriteTool::new(cwd))
}
pub fn create_grep_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(GrepTool::new(cwd))
}
pub fn create_find_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(FindTool::new(cwd))
}
pub fn create_ls_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(LsTool::new(cwd))
}
pub fn create_hashline_edit_tool(cwd: &Path) -> Box<dyn Tool> {
Box::new(HashlineEditTool::new(cwd))
}
pub fn create_all_tools(cwd: &Path) -> Vec<Box<dyn Tool>> {
vec![
create_read_tool(cwd),
create_bash_tool(cwd),
create_edit_tool(cwd),
create_write_tool(cwd),
create_grep_tool(cwd),
create_find_tool(cwd),
create_ls_tool(cwd),
create_hashline_edit_tool(cwd),
]
}
pub fn tool_to_definition(tool: &dyn Tool) -> ToolDefinition {
ToolDefinition {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters(),
}
}
pub fn all_tool_definitions(cwd: &Path) -> Vec<ToolDefinition> {
create_all_tools(cwd)
.iter()
.map(|t| tool_to_definition(t.as_ref()))
.collect()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SubscriptionId(u64);
pub type OnToolStart = Arc<dyn Fn(&str, &Value) + Send + Sync>;
pub type OnToolEnd = Arc<dyn Fn(&str, &ToolOutput, bool) + Send + Sync>;
pub type OnStreamEvent = Arc<dyn Fn(&StreamEvent) + Send + Sync>;
pub type EventSubscriber = Arc<dyn Fn(AgentEvent) + Send + Sync>;
type EventSubscribers = HashMap<SubscriptionId, EventSubscriber>;
#[derive(Clone, Default)]
pub struct EventListeners {
next_id: Arc<AtomicU64>,
subscribers: Arc<std::sync::Mutex<EventSubscribers>>,
pub on_tool_start: Option<OnToolStart>,
pub on_tool_end: Option<OnToolEnd>,
pub on_stream_event: Option<OnStreamEvent>,
}
impl EventListeners {
fn new() -> Self {
Self {
next_id: Arc::new(AtomicU64::new(1)),
subscribers: Arc::new(std::sync::Mutex::new(HashMap::new())),
on_tool_start: None,
on_tool_end: None,
on_stream_event: None,
}
}
pub fn subscribe(&self, listener: EventSubscriber) -> SubscriptionId {
let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
let mut subs = self
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
subs.insert(id, listener);
id
}
pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
let mut subs = self
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
subs.remove(&id).is_some()
}
pub fn notify(&self, event: &AgentEvent) {
let listeners: Vec<_> = {
let subs = self
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
subs.values().cloned().collect()
};
for listener in listeners {
listener(event.clone());
}
}
pub fn notify_tool_start(&self, tool_name: &str, args: &Value) {
if let Some(cb) = &self.on_tool_start {
cb(tool_name, args);
}
}
pub fn notify_tool_end(&self, tool_name: &str, output: &ToolOutput, is_error: bool) {
if let Some(cb) = &self.on_tool_end {
cb(tool_name, output, is_error);
}
}
pub fn notify_stream_event(&self, event: &StreamEvent) {
if let Some(cb) = &self.on_stream_event {
cb(event);
}
}
}
impl std::fmt::Debug for EventListeners {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.subscribers.lock().map_or(0, |s| s.len());
let next_id = self.next_id.load(Ordering::Relaxed);
f.debug_struct("EventListeners")
.field("subscriber_count", &count)
.field("next_id", &next_id)
.field("has_on_tool_start", &self.on_tool_start.is_some())
.field("has_on_tool_end", &self.on_tool_end.is_some())
.field("has_on_stream_event", &self.on_stream_event.is_some())
.finish()
}
}
#[derive(Clone)]
pub struct SessionOptions {
pub provider: Option<String>,
pub model: Option<String>,
pub api_key: Option<String>,
pub thinking: Option<crate::model::ThinkingLevel>,
pub system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub enabled_tools: Option<Vec<String>>,
pub working_directory: Option<PathBuf>,
pub no_session: bool,
pub session_path: Option<PathBuf>,
pub session_dir: Option<PathBuf>,
pub extension_paths: Vec<PathBuf>,
pub extension_policy: Option<String>,
pub repair_policy: Option<String>,
pub include_cwd_in_prompt: bool,
pub max_tool_iterations: usize,
pub on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
pub on_tool_start: Option<OnToolStart>,
pub on_tool_end: Option<OnToolEnd>,
pub on_stream_event: Option<OnStreamEvent>,
}
impl Default for SessionOptions {
fn default() -> Self {
Self {
provider: None,
model: None,
api_key: None,
thinking: None,
system_prompt: None,
append_system_prompt: None,
enabled_tools: None,
working_directory: None,
no_session: true,
session_path: None,
session_dir: None,
extension_paths: Vec::new(),
extension_policy: None,
repair_policy: None,
include_cwd_in_prompt: true,
max_tool_iterations: 50,
on_event: None,
on_tool_start: None,
on_tool_end: None,
on_stream_event: None,
}
}
}
pub struct AgentSessionHandle {
session: AgentSession,
listeners: EventListeners,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentSessionState {
pub session_id: Option<String>,
pub provider: String,
pub model_id: String,
pub thinking_level: Option<crate::model::ThinkingLevel>,
pub save_enabled: bool,
pub message_count: usize,
}
#[derive(Debug, Clone)]
pub enum SessionPromptResult {
InProcess(AssistantMessage),
RpcEvents(Vec<Value>),
}
#[derive(Debug, Clone)]
pub enum SessionTransportEvent {
InProcess(AgentEvent),
Rpc(Value),
}
#[derive(Debug, Clone, PartialEq)]
pub enum SessionTransportState {
InProcess(AgentSessionState),
Rpc(Box<RpcSessionState>),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct RpcModelInfo {
pub id: String,
pub name: String,
pub api: String,
pub provider: String,
#[serde(default)]
pub base_url: String,
#[serde(default)]
pub reasoning: bool,
#[serde(default)]
pub input: Vec<InputType>,
#[serde(default)]
pub context_window: u32,
#[serde(default)]
pub max_tokens: u32,
#[serde(default)]
pub cost: Option<ModelCost>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
#[allow(clippy::struct_excessive_bools)]
pub struct RpcSessionState {
#[serde(default)]
pub model: Option<RpcModelInfo>,
#[serde(default)]
pub thinking_level: String,
#[serde(default)]
pub is_streaming: bool,
#[serde(default)]
pub is_compacting: bool,
#[serde(default)]
pub steering_mode: String,
#[serde(default)]
pub follow_up_mode: String,
#[serde(default)]
pub session_file: Option<String>,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub session_name: Option<String>,
#[serde(default)]
pub auto_compaction_enabled: bool,
#[serde(default)]
pub auto_retry_enabled: bool,
#[serde(default)]
pub message_count: usize,
#[serde(default)]
pub pending_message_count: usize,
#[serde(default)]
pub durability_mode: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct RpcTokenStats {
pub input: u64,
pub output: u64,
pub cache_read: u64,
pub cache_write: u64,
pub total: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct RpcSessionStats {
#[serde(default)]
pub session_file: Option<String>,
pub session_id: String,
pub user_messages: u64,
pub assistant_messages: u64,
pub tool_calls: u64,
pub tool_results: u64,
pub total_messages: u64,
pub tokens: RpcTokenStats,
pub cost: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcCancelledResult {
pub cancelled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct RpcCycleModelResult {
pub model: RpcModelInfo,
pub thinking_level: crate::model::ThinkingLevel,
pub is_scoped: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcThinkingLevelResult {
pub level: crate::model::ThinkingLevel,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct RpcBashResult {
pub output: String,
pub exit_code: i32,
pub cancelled: bool,
pub truncated: bool,
pub full_output_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct RpcCompactionResult {
pub summary: String,
pub first_kept_entry_id: String,
pub tokens_before: u64,
#[serde(default)]
pub details: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcForkResult {
pub text: String,
pub cancelled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct RpcForkMessage {
pub entry_id: String,
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcCommandInfo {
pub name: String,
#[serde(default)]
pub description: Option<String>,
pub source: String,
#[serde(default)]
pub location: Option<String>,
#[serde(default)]
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcExportHtmlResult {
pub path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RpcLastAssistantText {
pub text: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum RpcExtensionUiResponse {
Value { value: Value },
Confirmed { confirmed: bool },
Cancelled,
}
#[derive(Debug, Clone)]
pub struct RpcTransportOptions {
pub binary_path: PathBuf,
pub args: Vec<String>,
pub cwd: Option<PathBuf>,
}
impl Default for RpcTransportOptions {
fn default() -> Self {
Self {
binary_path: PathBuf::from("pi"),
args: vec!["--mode".to_string(), "rpc".to_string()],
cwd: None,
}
}
}
pub struct RpcTransportClient {
child: Child,
stdin: BufWriter<ChildStdin>,
stdout: BufReader<ChildStdout>,
next_request_id: u64,
}
pub enum SessionTransport {
InProcess(Box<AgentSessionHandle>),
RpcSubprocess(RpcTransportClient),
}
impl SessionTransport {
pub async fn in_process(options: SessionOptions) -> Result<Self> {
create_agent_session(options)
.await
.map(Box::new)
.map(Self::InProcess)
}
pub fn rpc_subprocess(options: RpcTransportOptions) -> Result<Self> {
RpcTransportClient::connect(options).map(Self::RpcSubprocess)
}
#[allow(clippy::missing_const_for_fn)]
pub fn as_in_process_mut(&mut self) -> Option<&mut AgentSessionHandle> {
match self {
Self::InProcess(handle) => Some(handle.as_mut()),
Self::RpcSubprocess(_) => None,
}
}
#[allow(clippy::missing_const_for_fn)]
pub fn as_rpc_mut(&mut self) -> Option<&mut RpcTransportClient> {
match self {
Self::InProcess(_) => None,
Self::RpcSubprocess(client) => Some(client),
}
}
pub async fn prompt(
&mut self,
input: impl Into<String>,
on_event: impl Fn(SessionTransportEvent) + Send + Sync + 'static,
) -> Result<SessionPromptResult> {
let input = input.into();
let on_event = Arc::new(on_event);
match self {
Self::InProcess(handle) => {
let on_event = Arc::clone(&on_event);
let assistant = handle
.prompt(input, move |event| {
(on_event)(SessionTransportEvent::InProcess(event));
})
.await?;
Ok(SessionPromptResult::InProcess(assistant))
}
Self::RpcSubprocess(client) => {
let events = client.prompt(input).await?;
for event in events.iter().cloned() {
(on_event)(SessionTransportEvent::Rpc(event));
}
Ok(SessionPromptResult::RpcEvents(events))
}
}
}
pub async fn state(&mut self) -> Result<SessionTransportState> {
match self {
Self::InProcess(handle) => handle.state().await.map(SessionTransportState::InProcess),
Self::RpcSubprocess(client) => client
.get_state()
.await
.map(Box::new)
.map(SessionTransportState::Rpc),
}
}
pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
match self {
Self::InProcess(handle) => handle.set_model(provider, model_id).await,
Self::RpcSubprocess(client) => {
let _ = client.set_model(provider, model_id).await?;
Ok(())
}
}
}
pub fn shutdown(&mut self) -> Result<()> {
match self {
Self::InProcess(_) => Ok(()),
Self::RpcSubprocess(client) => client.shutdown(),
}
}
}
impl RpcTransportClient {
pub fn connect(options: RpcTransportOptions) -> Result<Self> {
let mut command = Command::new(&options.binary_path);
command
.args(&options.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
if let Some(cwd) = options.cwd {
command.current_dir(cwd);
}
let mut child = command.spawn().map_err(|err| {
Error::config(format!(
"Failed to spawn RPC subprocess {}: {err}",
options.binary_path.display()
))
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::config("RPC subprocess stdin is not piped"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::config("RPC subprocess stdout is not piped"))?;
Ok(Self {
child,
stdin: BufWriter::new(stdin),
stdout: BufReader::new(stdout),
next_request_id: 1,
})
}
pub async fn request(&mut self, command: &str, payload: Map<String, Value>) -> Result<Value> {
let request_id = self.next_request_id();
let mut command_payload = Map::new();
command_payload.insert("type".to_string(), Value::String(command.to_string()));
command_payload.insert("id".to_string(), Value::String(request_id.clone()));
command_payload.extend(payload);
self.write_json_line(&Value::Object(command_payload))?;
self.wait_for_response(&request_id, command)
}
fn parse_response_data<T: DeserializeOwned>(data: Value, command: &str) -> Result<T> {
serde_json::from_value(data).map_err(|err| {
Error::api(format!(
"Failed to decode RPC `{command}` response payload: {err}"
))
})
}
async fn request_typed<T: DeserializeOwned>(
&mut self,
command: &str,
payload: Map<String, Value>,
) -> Result<T> {
let data = self.request(command, payload).await?;
Self::parse_response_data(data, command)
}
async fn request_no_data(&mut self, command: &str, payload: Map<String, Value>) -> Result<()> {
let _ = self.request(command, payload).await?;
Ok(())
}
pub async fn steer(&mut self, message: impl Into<String>) -> Result<()> {
let mut payload = Map::new();
payload.insert("message".to_string(), Value::String(message.into()));
self.request_no_data("steer", payload).await
}
pub async fn follow_up(&mut self, message: impl Into<String>) -> Result<()> {
let mut payload = Map::new();
payload.insert("message".to_string(), Value::String(message.into()));
self.request_no_data("follow_up", payload).await
}
pub async fn abort(&mut self) -> Result<()> {
self.request_no_data("abort", Map::new()).await
}
pub async fn new_session(
&mut self,
parent_session: Option<&Path>,
) -> Result<RpcCancelledResult> {
let mut payload = Map::new();
if let Some(parent_session) = parent_session {
payload.insert(
"parentSession".to_string(),
Value::String(parent_session.display().to_string()),
);
}
self.request_typed("new_session", payload).await
}
pub async fn get_state(&mut self) -> Result<RpcSessionState> {
self.request_typed("get_state", Map::new()).await
}
pub async fn get_session_stats(&mut self) -> Result<RpcSessionStats> {
self.request_typed("get_session_stats", Map::new()).await
}
pub async fn get_messages(&mut self) -> Result<Vec<Value>> {
#[derive(Deserialize)]
struct MessagesPayload {
messages: Vec<Value>,
}
let payload: MessagesPayload = self.request_typed("get_messages", Map::new()).await?;
Ok(payload.messages)
}
pub async fn get_available_models(&mut self) -> Result<Vec<RpcModelInfo>> {
#[derive(Deserialize)]
struct ModelsPayload {
models: Vec<RpcModelInfo>,
}
let payload: ModelsPayload = self
.request_typed("get_available_models", Map::new())
.await?;
Ok(payload.models)
}
pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<RpcModelInfo> {
let mut payload = Map::new();
payload.insert("provider".to_string(), Value::String(provider.to_string()));
payload.insert("modelId".to_string(), Value::String(model_id.to_string()));
self.request_typed("set_model", payload).await
}
pub async fn cycle_model(&mut self) -> Result<Option<RpcCycleModelResult>> {
self.request_typed("cycle_model", Map::new()).await
}
pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
let mut payload = Map::new();
payload.insert("level".to_string(), Value::String(level.to_string()));
self.request_no_data("set_thinking_level", payload).await
}
pub async fn cycle_thinking_level(&mut self) -> Result<Option<RpcThinkingLevelResult>> {
self.request_typed("cycle_thinking_level", Map::new()).await
}
pub async fn set_steering_mode(&mut self, mode: &str) -> Result<()> {
let mut payload = Map::new();
payload.insert("mode".to_string(), Value::String(mode.to_string()));
self.request_no_data("set_steering_mode", payload).await
}
pub async fn set_follow_up_mode(&mut self, mode: &str) -> Result<()> {
let mut payload = Map::new();
payload.insert("mode".to_string(), Value::String(mode.to_string()));
self.request_no_data("set_follow_up_mode", payload).await
}
pub async fn set_auto_compaction(&mut self, enabled: bool) -> Result<()> {
let mut payload = Map::new();
payload.insert("enabled".to_string(), Value::Bool(enabled));
self.request_no_data("set_auto_compaction", payload).await
}
pub async fn set_auto_retry(&mut self, enabled: bool) -> Result<()> {
let mut payload = Map::new();
payload.insert("enabled".to_string(), Value::Bool(enabled));
self.request_no_data("set_auto_retry", payload).await
}
pub async fn abort_retry(&mut self) -> Result<()> {
self.request_no_data("abort_retry", Map::new()).await
}
pub async fn set_session_name(&mut self, name: impl Into<String>) -> Result<()> {
let mut payload = Map::new();
payload.insert("name".to_string(), Value::String(name.into()));
self.request_no_data("set_session_name", payload).await
}
pub async fn get_last_assistant_text(&mut self) -> Result<Option<String>> {
let payload: RpcLastAssistantText = self
.request_typed("get_last_assistant_text", Map::new())
.await?;
Ok(payload.text)
}
pub async fn export_html(&mut self, output_path: Option<&Path>) -> Result<RpcExportHtmlResult> {
let mut payload = Map::new();
if let Some(path) = output_path {
payload.insert(
"outputPath".to_string(),
Value::String(path.display().to_string()),
);
}
self.request_typed("export_html", payload).await
}
pub async fn bash(&mut self, command: impl Into<String>) -> Result<RpcBashResult> {
let mut payload = Map::new();
payload.insert("command".to_string(), Value::String(command.into()));
self.request_typed("bash", payload).await
}
pub async fn abort_bash(&mut self) -> Result<()> {
self.request_no_data("abort_bash", Map::new()).await
}
pub async fn compact(&mut self) -> Result<RpcCompactionResult> {
self.compact_with_instructions(None).await
}
pub async fn compact_with_instructions(
&mut self,
custom_instructions: Option<&str>,
) -> Result<RpcCompactionResult> {
let mut payload = Map::new();
if let Some(custom_instructions) = custom_instructions {
payload.insert(
"customInstructions".to_string(),
Value::String(custom_instructions.to_string()),
);
}
self.request_typed("compact", payload).await
}
pub async fn switch_session(&mut self, session_path: &Path) -> Result<RpcCancelledResult> {
let mut payload = Map::new();
payload.insert(
"sessionPath".to_string(),
Value::String(session_path.display().to_string()),
);
self.request_typed("switch_session", payload).await
}
pub async fn fork(&mut self, entry_id: impl Into<String>) -> Result<RpcForkResult> {
let mut payload = Map::new();
payload.insert("entryId".to_string(), Value::String(entry_id.into()));
self.request_typed("fork", payload).await
}
pub async fn get_fork_messages(&mut self) -> Result<Vec<RpcForkMessage>> {
#[derive(Deserialize)]
struct ForkMessagesPayload {
messages: Vec<RpcForkMessage>,
}
let payload: ForkMessagesPayload =
self.request_typed("get_fork_messages", Map::new()).await?;
Ok(payload.messages)
}
pub async fn get_commands(&mut self) -> Result<Vec<RpcCommandInfo>> {
#[derive(Deserialize)]
struct CommandsPayload {
commands: Vec<RpcCommandInfo>,
}
let payload: CommandsPayload = self.request_typed("get_commands", Map::new()).await?;
Ok(payload.commands)
}
pub async fn extension_ui_response(
&mut self,
request_id: &str,
response: RpcExtensionUiResponse,
) -> Result<bool> {
#[derive(Deserialize)]
struct ExtensionUiResolvedPayload {
resolved: bool,
}
let mut payload = Map::new();
payload.insert(
"requestId".to_string(),
Value::String(request_id.to_string()),
);
match response {
RpcExtensionUiResponse::Value { value } => {
payload.insert("value".to_string(), value);
}
RpcExtensionUiResponse::Confirmed { confirmed } => {
payload.insert("confirmed".to_string(), Value::Bool(confirmed));
}
RpcExtensionUiResponse::Cancelled => {
payload.insert("cancelled".to_string(), Value::Bool(true));
}
}
let response: Option<ExtensionUiResolvedPayload> =
self.request_typed("extension_ui_response", payload).await?;
Ok(response.is_none_or(|payload| payload.resolved))
}
pub async fn prompt(&mut self, message: impl Into<String>) -> Result<Vec<Value>> {
self.prompt_with_options(message, None, None).await
}
pub async fn prompt_with_options(
&mut self,
message: impl Into<String>,
images: Option<Vec<ImageContent>>,
streaming_behavior: Option<&str>,
) -> Result<Vec<Value>> {
let request_id = self.next_request_id();
let mut payload = Map::new();
payload.insert("type".to_string(), Value::String("prompt".to_string()));
payload.insert("id".to_string(), Value::String(request_id.clone()));
payload.insert("message".to_string(), Value::String(message.into()));
if let Some(images) = images {
payload.insert(
"images".to_string(),
serde_json::to_value(images).map_err(|err| Error::Json(Box::new(err)))?,
);
}
if let Some(streaming_behavior) = streaming_behavior {
payload.insert(
"streamingBehavior".to_string(),
Value::String(streaming_behavior.to_string()),
);
}
let payload = Value::Object(payload);
self.write_json_line(&payload)?;
let mut saw_ack = false;
let mut events = Vec::new();
loop {
let item = self.read_json_line()?;
let item_type = item.get("type").and_then(Value::as_str);
if item_type == Some("response") {
if item.get("id").and_then(Value::as_str) != Some(request_id.as_str()) {
continue;
}
let success = item
.get("success")
.and_then(Value::as_bool)
.unwrap_or(false);
if !success {
return Err(rpc_error_from_response(&item, "prompt"));
}
saw_ack = true;
continue;
}
if saw_ack {
let reached_end = item_type == Some("agent_end");
events.push(item);
if reached_end {
return Ok(events);
}
}
}
}
pub fn shutdown(&mut self) -> Result<()> {
if self
.child
.try_wait()
.map_err(|err| Error::Io(Box::new(err)))?
.is_none()
{
self.child.kill().map_err(|err| Error::Io(Box::new(err)))?;
}
let _ = self.child.wait();
Ok(())
}
fn next_request_id(&mut self) -> String {
let id = format!("rpc-{}", self.next_request_id);
self.next_request_id = self.next_request_id.saturating_add(1);
id
}
fn write_json_line(&mut self, payload: &Value) -> Result<()> {
let encoded = serde_json::to_string(payload).map_err(|err| Error::Json(Box::new(err)))?;
self.stdin
.write_all(encoded.as_bytes())
.map_err(|err| Error::Io(Box::new(err)))?;
self.stdin
.write_all(b"\n")
.map_err(|err| Error::Io(Box::new(err)))?;
self.stdin.flush().map_err(|err| Error::Io(Box::new(err)))?;
Ok(())
}
fn read_json_line(&mut self) -> Result<Value> {
let mut line = String::new();
let read = self
.stdout
.read_line(&mut line)
.map_err(|err| Error::Io(Box::new(err)))?;
if read == 0 {
return Err(Error::api(
"RPC subprocess exited before sending a response",
));
}
serde_json::from_str(line.trim_end()).map_err(|err| Error::Json(Box::new(err)))
}
fn wait_for_response(&mut self, request_id: &str, command: &str) -> Result<Value> {
loop {
let item = self.read_json_line()?;
let Some(item_type) = item.get("type").and_then(Value::as_str) else {
continue;
};
if item_type != "response" {
continue;
}
if item.get("id").and_then(Value::as_str) != Some(request_id) {
continue;
}
if item.get("command").and_then(Value::as_str) != Some(command) {
continue;
}
let success = item
.get("success")
.and_then(Value::as_bool)
.unwrap_or(false);
if success {
return Ok(item.get("data").cloned().unwrap_or(Value::Null));
}
return Err(rpc_error_from_response(&item, command));
}
}
}
impl Drop for RpcTransportClient {
fn drop(&mut self) {
let _ = self.shutdown();
}
}
fn rpc_error_from_response(response: &Value, command: &str) -> Error {
let error = response
.get("error")
.and_then(Value::as_str)
.unwrap_or("RPC command failed");
Error::api(format!("RPC {command} failed: {error}"))
}
impl AgentSessionHandle {
pub const fn from_session_with_listeners(
session: AgentSession,
listeners: EventListeners,
) -> Self {
Self { session, listeners }
}
pub async fn prompt(
&mut self,
input: impl Into<String>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let combined = self.make_combined_callback(on_event);
self.session.run_text(input.into(), combined).await
}
pub async fn prompt_with_abort(
&mut self,
input: impl Into<String>,
abort_signal: AbortSignal,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let combined = self.make_combined_callback(on_event);
self.session
.run_text_with_abort(input.into(), Some(abort_signal), combined)
.await
}
pub async fn continue_turn(
&mut self,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let combined = self.make_combined_callback(on_event);
self.session
.sync_runtime_selection_from_session_header()
.await?;
self.session
.agent
.run_continue_with_abort(None, combined)
.await
}
pub async fn continue_turn_with_abort(
&mut self,
abort_signal: AbortSignal,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let combined = self.make_combined_callback(on_event);
self.session
.sync_runtime_selection_from_session_header()
.await?;
self.session
.agent
.run_continue_with_abort(Some(abort_signal), combined)
.await
}
pub fn new_abort_handle() -> (AbortHandle, AbortSignal) {
AbortHandle::new()
}
pub fn subscribe(
&self,
listener: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> SubscriptionId {
self.listeners.subscribe(Arc::new(listener))
}
pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
self.listeners.unsubscribe(id)
}
pub const fn listeners(&self) -> &EventListeners {
&self.listeners
}
pub const fn listeners_mut(&mut self) -> &mut EventListeners {
&mut self.listeners
}
pub const fn has_extensions(&self) -> bool {
self.session.extensions.is_some()
}
pub fn extension_manager(&self) -> Option<&ExtensionManager> {
self.session
.extensions
.as_ref()
.map(ExtensionRegion::manager)
}
pub const fn extension_region(&self) -> Option<&ExtensionRegion> {
self.session.extensions.as_ref()
}
pub fn model(&self) -> (String, String) {
let provider = self.session.agent.provider();
(provider.name().to_string(), provider.model_id().to_string())
}
pub async fn set_model(&mut self, provider: &str, model_id: &str) -> Result<()> {
self.session.set_provider_model(provider, model_id).await
}
pub const fn thinking_level(&self) -> Option<crate::model::ThinkingLevel> {
self.session.agent.stream_options().thinking_level
}
pub const fn thinking(&self) -> Option<crate::model::ThinkingLevel> {
self.thinking_level()
}
pub async fn set_thinking_level(&mut self, level: crate::model::ThinkingLevel) -> Result<()> {
let cx = crate::agent_cx::AgentCx::for_request();
let (effective_level, changed) = {
let mut guard = self
.session
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
let (provider_id, model_id) = guard
.effective_model_for_current_path()
.unwrap_or_else(|| self.model());
let effective_level =
self.session
.clamp_thinking_level_for_model(&provider_id, &model_id, level);
let level_string = effective_level.to_string();
let changed = guard.effective_thinking_level_for_current_path().as_deref()
!= Some(level_string.as_str());
guard.set_model_header(None, None, Some(level_string.clone()));
if changed {
guard.append_thinking_level_change(level_string);
}
(effective_level, changed)
};
self.session.agent.stream_options_mut().thinking_level = Some(effective_level);
if changed {
self.session.persist_session().await
} else {
Ok(())
}
}
pub async fn messages(&self) -> Result<Vec<Message>> {
let cx = crate::agent_cx::AgentCx::for_request();
let guard = self
.session
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
Ok(guard.to_messages_for_current_path())
}
pub async fn state(&self) -> Result<AgentSessionState> {
let (provider, model_id) = self.model();
let thinking_level = self.thinking_level();
let save_enabled = self.session.save_enabled();
let cx = crate::agent_cx::AgentCx::for_request();
let guard = self
.session
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
let session_id = Some(guard.header.id.clone());
let message_count = guard.to_messages_for_current_path().len();
Ok(AgentSessionState {
session_id,
provider,
model_id,
thinking_level,
save_enabled,
message_count,
})
}
pub async fn compact(
&mut self,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<()> {
self.session.compact_now(on_event).await
}
pub const fn session(&self) -> &AgentSession {
&self.session
}
pub const fn session_mut(&mut self) -> &mut AgentSession {
&mut self.session
}
pub fn into_inner(self) -> AgentSession {
self.session
}
fn make_combined_callback(
&self,
per_prompt: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> impl Fn(AgentEvent) + Send + Sync + 'static {
let listeners = self.listeners.clone();
move |event: AgentEvent| {
match &event {
AgentEvent::ToolExecutionStart {
tool_name, args, ..
} => {
listeners.notify_tool_start(tool_name, args);
}
AgentEvent::ToolExecutionEnd {
tool_name,
result,
is_error,
..
} => {
listeners.notify_tool_end(tool_name, result, *is_error);
}
AgentEvent::MessageUpdate {
assistant_message_event,
..
} => {
if let Some(stream_ev) =
stream_event_from_assistant_message_event(assistant_message_event)
{
listeners.notify_stream_event(&stream_ev);
}
}
_ => {}
}
listeners.notify(&event);
per_prompt(event);
}
}
}
fn stream_event_from_assistant_message_event(
event: &crate::model::AssistantMessageEvent,
) -> Option<StreamEvent> {
use crate::model::AssistantMessageEvent as AME;
match event {
AME::TextStart { content_index, .. } => Some(StreamEvent::TextStart {
content_index: *content_index,
}),
AME::TextDelta {
content_index,
delta,
..
} => Some(StreamEvent::TextDelta {
content_index: *content_index,
delta: delta.clone(),
}),
AME::TextEnd {
content_index,
content,
..
} => Some(StreamEvent::TextEnd {
content_index: *content_index,
content: content.clone(),
}),
AME::ThinkingStart { content_index, .. } => Some(StreamEvent::ThinkingStart {
content_index: *content_index,
}),
AME::ThinkingDelta {
content_index,
delta,
..
} => Some(StreamEvent::ThinkingDelta {
content_index: *content_index,
delta: delta.clone(),
}),
AME::ThinkingEnd {
content_index,
content,
..
} => Some(StreamEvent::ThinkingEnd {
content_index: *content_index,
content: content.clone(),
}),
AME::ToolCallStart { content_index, .. } => Some(StreamEvent::ToolCallStart {
content_index: *content_index,
}),
AME::ToolCallDelta {
content_index,
delta,
..
} => Some(StreamEvent::ToolCallDelta {
content_index: *content_index,
delta: delta.clone(),
}),
AME::ToolCallEnd {
content_index,
tool_call,
..
} => Some(StreamEvent::ToolCallEnd {
content_index: *content_index,
tool_call: tool_call.clone(),
}),
AME::Done { reason, message } => Some(StreamEvent::Done {
reason: *reason,
message: (**message).clone(),
}),
AME::Error { reason, error } => Some(StreamEvent::Error {
reason: *reason,
error: (**error).clone(),
}),
AME::Start { .. } => None,
}
}
fn resolve_path_for_cwd(path: &Path, cwd: &Path) -> PathBuf {
if path.is_absolute() {
path.to_path_buf()
} else {
cwd.join(path)
}
}
fn build_stream_options_with_optional_key(
config: &Config,
api_key: Option<String>,
selection: &app::ModelSelection,
session: &Session,
) -> StreamOptions {
let mut options = StreamOptions {
api_key,
headers: selection.model_entry.headers.clone(),
session_id: Some(session.header.id.clone()),
thinking_level: Some(selection.thinking_level),
..Default::default()
};
if let Some(budgets) = &config.thinking_budgets {
let defaults = ThinkingBudgets::default();
options.thinking_budgets = Some(ThinkingBudgets {
minimal: budgets.minimal.unwrap_or(defaults.minimal),
low: budgets.low.unwrap_or(defaults.low),
medium: budgets.medium.unwrap_or(defaults.medium),
high: budgets.high.unwrap_or(defaults.high),
xhigh: budgets.xhigh.unwrap_or(defaults.xhigh),
});
}
options
}
#[allow(clippy::too_many_lines)]
pub async fn create_agent_session(options: SessionOptions) -> Result<AgentSessionHandle> {
let process_cwd =
std::env::current_dir().map_err(|e| Error::config(format!("cwd lookup failed: {e}")))?;
let cwd = options.working_directory.as_deref().map_or_else(
|| process_cwd.clone(),
|path| resolve_path_for_cwd(path, &process_cwd),
);
let mut cli = Cli::try_parse_from(["pi"])
.map_err(|e| Error::validation(format!("CLI init failed: {e}")))?;
cli.no_session = options.no_session;
cli.provider = options.provider.clone();
cli.model = options.model.clone();
cli.api_key = options.api_key.clone();
cli.system_prompt = options.system_prompt.clone();
cli.append_system_prompt = options.append_system_prompt.clone();
cli.hide_cwd_in_prompt = !options.include_cwd_in_prompt;
cli.thinking = options.thinking.map(|t| t.to_string());
cli.session = options
.session_path
.as_ref()
.map(|p| p.to_string_lossy().to_string());
cli.session_dir = options
.session_dir
.as_ref()
.map(|p| p.to_string_lossy().to_string());
if let Some(enabled_tools) = &options.enabled_tools {
if enabled_tools.is_empty() {
cli.no_tools = true;
} else {
cli.no_tools = false;
cli.tools = enabled_tools.join(",");
}
}
let config = Config::load()?;
let mut auth = AuthStorage::load_async(Config::auth_path()).await?;
auth.refresh_expired_oauth_tokens().await?;
let global_dir = Config::global_dir();
let package_dir = Config::package_dir();
let models_path = default_models_path(&global_dir);
let model_registry = ModelRegistry::load(&auth, Some(models_path));
let mut session = Session::new(&cli, &config).await?;
let scoped_patterns = if let Some(models_arg) = &cli.models {
app::parse_models_arg(models_arg)
} else {
config.enabled_models.clone().unwrap_or_default()
};
let scoped_models = if scoped_patterns.is_empty() {
Vec::new()
} else {
app::resolve_model_scope(&scoped_patterns, &model_registry, cli.api_key.is_some())
};
let selection = app::select_model_and_thinking(
&cli,
&config,
&session,
&model_registry,
&scoped_models,
&global_dir,
)
.map_err(|err| Error::validation(err.to_string()))?;
app::update_session_for_selection(&mut session, &selection);
let enabled_tools_owned = cli
.enabled_tools()
.into_iter()
.map(str::to_string)
.collect::<Vec<_>>();
let enabled_tools = enabled_tools_owned
.iter()
.map(String::as_str)
.collect::<Vec<_>>();
let system_prompt = app::build_system_prompt(
&cli,
&cwd,
&enabled_tools,
None,
&global_dir,
&package_dir,
std::env::var_os("PI_TEST_MODE").is_some(),
options.include_cwd_in_prompt,
)
.map_err(|err| Error::validation(err.to_string()))?;
let provider = providers::create_provider(&selection.model_entry, None)
.map_err(|e| Error::provider("sdk", e.to_string()))?;
let api_key = app::resolve_api_key(&auth, &cli, &selection.model_entry)
.map_err(|err| Error::validation(err.to_string()))?;
let stream_options =
build_stream_options_with_optional_key(&config, api_key, &selection, &session);
let agent_config = AgentConfig {
system_prompt: Some(system_prompt),
max_tool_iterations: options.max_tool_iterations,
stream_options,
block_images: config.image_block_images(),
fail_closed_hooks: config.fail_closed_hooks(),
};
let tools = ToolRegistry::new(&enabled_tools, &cwd, Some(&config));
let session_arc = Arc::new(asupersync::sync::Mutex::new(session));
let context_window_tokens = if selection.model_entry.model.context_window == 0 {
ResolvedCompactionSettings::default().context_window_tokens
} else {
selection.model_entry.model.context_window
};
let compaction_settings = ResolvedCompactionSettings {
enabled: config.compaction_enabled(),
reserve_tokens: config.compaction_reserve_tokens(),
keep_recent_tokens: config.compaction_keep_recent_tokens(),
context_window_tokens,
};
let mut agent_session = AgentSession::new(
Agent::new(provider, tools, agent_config),
Arc::clone(&session_arc),
!cli.no_session,
compaction_settings,
);
if !options.extension_paths.is_empty() {
let extension_paths = options
.extension_paths
.iter()
.map(|path| resolve_path_for_cwd(path, &cwd))
.collect::<Vec<_>>();
let resolved_ext_policy =
config.resolve_extension_policy_with_metadata(options.extension_policy.as_deref());
let resolved_repair_policy =
config.resolve_repair_policy_with_metadata(options.repair_policy.as_deref());
agent_session
.enable_extensions_with_policy(
&enabled_tools,
&cwd,
Some(&config),
&extension_paths,
Some(resolved_ext_policy.policy),
Some(resolved_repair_policy.effective_mode),
None,
)
.await?;
}
agent_session.set_model_registry(model_registry.clone());
agent_session.set_auth_storage(auth);
let history = {
let cx = crate::agent_cx::AgentCx::for_request();
let guard = session_arc
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
guard.to_messages_for_current_path()
};
if !history.is_empty() {
agent_session.agent.replace_messages(history);
}
let mut listeners = EventListeners::new();
if let Some(on_event) = options.on_event {
listeners.subscribe(on_event);
}
listeners.on_tool_start = options.on_tool_start;
listeners.on_tool_end = options.on_tool_end;
listeners.on_stream_event = options.on_stream_event;
Ok(AgentSessionHandle {
session: agent_session,
listeners,
})
}
#[cfg(test)]
mod tests {
use super::*;
use asupersync::runtime::RuntimeBuilder;
use asupersync::runtime::reactor::create_reactor;
use asupersync::sync::Mutex as AsyncMutex;
use std::sync::{Arc, Mutex};
use tempfile::tempdir;
fn run_async<F>(future: F) -> F::Output
where
F: std::future::Future,
{
let reactor = create_reactor().expect("create reactor");
let runtime = RuntimeBuilder::current_thread()
.with_reactor(reactor)
.build()
.expect("build runtime");
runtime.block_on(future)
}
#[test]
fn create_agent_session_default_succeeds() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
let provider = handle.session().agent.provider();
assert!(!provider.name().is_empty());
assert!(!provider.model_id().is_empty());
assert_eq!(handle.model().0, provider.name());
assert_eq!(handle.model().1, provider.model_id());
}
#[test]
fn create_agent_session_respects_provider_model_and_clamps_thinking() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
provider: Some("openai".to_string()),
model: Some("gpt-4o".to_string()),
api_key: Some("dummy-key".to_string()),
thinking: Some(crate::model::ThinkingLevel::Low),
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
let provider = handle.session().agent.provider();
assert_eq!(provider.name(), "openai");
assert_eq!(provider.model_id(), "gpt-4o");
assert_eq!(
handle.session().agent.stream_options().thinking_level,
Some(crate::model::ThinkingLevel::Off)
);
}
#[test]
fn create_agent_session_no_session_keeps_ephemeral_state() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
assert!(!handle.session().save_enabled());
let path_is_none = run_async(async {
let cx = crate::agent_cx::AgentCx::for_request();
let guard = handle
.session()
.session
.lock(cx.cx())
.await
.expect("lock session");
guard.path.is_none()
});
assert!(path_is_none);
}
#[test]
fn from_session_with_listeners_set_model_switches_provider_model() {
let dir = tempdir().expect("tempdir");
let auth_path = dir.path().join("auth.json");
let mut auth = AuthStorage::load(auth_path).expect("load auth");
auth.set(
"anthropic",
crate::auth::AuthCredential::ApiKey {
key: "anthropic-key".to_string(),
},
);
auth.set(
"openai",
crate::auth::AuthCredential::ApiKey {
key: "openai-key".to_string(),
},
);
let registry = ModelRegistry::load(&auth, None);
let entry = registry
.find("anthropic", "claude-sonnet-4-5")
.expect("anthropic model in registry");
let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
let agent = Agent::new(
provider,
tools,
AgentConfig {
system_prompt: None,
max_tool_iterations: 50,
stream_options: StreamOptions::default(),
block_images: false,
fail_closed_hooks: false,
},
);
let mut session = Session::in_memory();
session.header.provider = Some("anthropic".to_string());
session.header.model_id = Some("claude-sonnet-4-5".to_string());
let mut agent_session = AgentSession::new(
agent,
Arc::new(AsyncMutex::new(session)),
false,
ResolvedCompactionSettings::default(),
);
agent_session.set_model_registry(registry);
agent_session.set_auth_storage(auth);
let mut handle =
AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
run_async(handle.set_model("openai", "gpt-4o")).expect("set model");
let provider = handle.session().agent.provider();
assert_eq!(provider.name(), "openai");
assert_eq!(provider.model_id(), "gpt-4o");
}
#[test]
fn create_agent_session_set_thinking_level_clamps_and_dedupes_history() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
provider: Some("openai".to_string()),
model: Some("gpt-4o".to_string()),
api_key: Some("dummy-key".to_string()),
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let mut handle = run_async(create_agent_session(options)).expect("create session");
run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
.expect("set thinking");
run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
.expect("reapply thinking");
assert_eq!(
handle.session().agent.stream_options().thinking_level,
Some(crate::model::ThinkingLevel::Off)
);
let thinking_changes = run_async(async {
let cx = crate::agent_cx::AgentCx::for_request();
let guard = handle
.session()
.session
.lock(cx.cx())
.await
.expect("lock session");
assert_eq!(guard.header.thinking_level.as_deref(), Some("off"));
guard
.entries_for_current_path()
.iter()
.filter(|entry| {
matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
})
.count()
});
assert_eq!(thinking_changes, 1);
}
#[test]
fn from_session_with_listeners_set_thinking_level_uses_session_header_target() {
let dir = tempdir().expect("tempdir");
let auth_path = dir.path().join("auth.json");
let auth = crate::auth::AuthStorage::load(auth_path).expect("load auth");
let mut registry = ModelRegistry::load(&auth, None);
registry.merge_entries(vec![ModelEntry {
model: Model {
id: "plain-model".to_string(),
name: "Plain Model".to_string(),
api: "openai-completions".to_string(),
provider: "acme".to_string(),
base_url: "https://example.invalid/v1".to_string(),
reasoning: false,
input: vec![InputType::Text],
cost: ModelCost {
input: 0.0,
output: 0.0,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 8_192,
headers: HashMap::new(),
},
api_key: None,
headers: HashMap::new(),
auth_header: false,
compat: None,
oauth_config: None,
}]);
let entry = registry
.find("anthropic", "claude-sonnet-4-5")
.expect("anthropic model in registry");
let provider = providers::create_provider(&entry, None).expect("create anthropic provider");
let tools = crate::tools::ToolRegistry::new(&[], std::path::Path::new("."), None);
let agent = Agent::new(
provider,
tools,
AgentConfig {
system_prompt: None,
max_tool_iterations: 50,
stream_options: StreamOptions::default(),
block_images: false,
fail_closed_hooks: false,
},
);
let mut session = Session::in_memory();
session.header.provider = Some("acme".to_string());
session.header.model_id = Some("plain-model".to_string());
let mut agent_session = AgentSession::new(
agent,
Arc::new(AsyncMutex::new(session)),
false,
ResolvedCompactionSettings::default(),
);
agent_session.set_model_registry(registry);
let mut handle =
AgentSessionHandle::from_session_with_listeners(agent_session, EventListeners::new());
run_async(handle.set_thinking_level(crate::model::ThinkingLevel::High))
.expect("set thinking");
assert_eq!(
handle.session().agent.stream_options().thinking_level,
Some(crate::model::ThinkingLevel::Off)
);
assert_eq!(handle.model().0, "anthropic");
assert_eq!(handle.model().1, "claude-sonnet-4-5");
}
#[test]
fn compact_without_history_is_noop() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let mut handle = run_async(create_agent_session(options)).expect("create session");
let events = Arc::new(Mutex::new(Vec::new()));
let events_for_callback = Arc::clone(&events);
run_async(handle.compact(move |event| {
events_for_callback
.lock()
.expect("compact callback lock")
.push(event);
}))
.expect("compact");
assert!(
events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_empty(),
"expected no compaction lifecycle events for empty session"
);
}
#[test]
fn resolve_path_for_cwd_uses_cwd_for_relative_paths() {
let cwd = Path::new("/tmp/pi-sdk-cwd");
assert_eq!(
resolve_path_for_cwd(Path::new("relative/file.txt"), cwd),
PathBuf::from("/tmp/pi-sdk-cwd/relative/file.txt")
);
assert_eq!(
resolve_path_for_cwd(Path::new("/etc/hosts"), cwd),
PathBuf::from("/etc/hosts")
);
}
#[test]
fn event_listeners_subscribe_and_notify() {
let listeners = EventListeners::new();
let received = Arc::new(Mutex::new(Vec::new()));
let recv_clone = Arc::clone(&received);
let id = listeners.subscribe(Arc::new(move |event| {
recv_clone
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(event);
}));
let event = AgentEvent::AgentStart {
session_id: "test-123".into(),
};
listeners.notify(&event);
let events = received
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
assert_eq!(events.len(), 1);
drop(events);
assert!(listeners.unsubscribe(id));
listeners.notify(&AgentEvent::AgentStart {
session_id: "test-456".into(),
});
assert_eq!(
received
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len(),
1
);
}
#[test]
fn event_listeners_unsubscribe_nonexistent_returns_false() {
let listeners = EventListeners::new();
assert!(!listeners.unsubscribe(SubscriptionId(999)));
}
#[test]
fn event_listeners_multiple_subscribers() {
let listeners = EventListeners::new();
let count_a = Arc::new(Mutex::new(0u32));
let count_b = Arc::new(Mutex::new(0u32));
let ca = Arc::clone(&count_a);
listeners.subscribe(Arc::new(move |_| {
*ca.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
}));
let cb = Arc::clone(&count_b);
listeners.subscribe(Arc::new(move |_| {
*cb.lock().unwrap_or_else(std::sync::PoisonError::into_inner) += 1;
}));
listeners.notify(&AgentEvent::AgentStart {
session_id: "s".into(),
});
assert_eq!(
*count_a
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner),
1
);
assert_eq!(
*count_b
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner),
1
);
}
#[test]
fn event_listeners_tool_hooks_fire() {
let listeners = EventListeners::new();
let starts = Arc::new(Mutex::new(Vec::new()));
let ends = Arc::new(Mutex::new(Vec::new()));
let s = Arc::clone(&starts);
let mut listeners = listeners;
listeners.on_tool_start = Some(Arc::new(move |name, args| {
s.lock()
.expect("lock")
.push((name.to_string(), args.clone()));
}));
let e = Arc::clone(&ends);
listeners.on_tool_end = Some(Arc::new(move |name, _output, is_error| {
e.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push((name.to_string(), is_error));
}));
let args = serde_json::json!({"path": "/foo"});
listeners.notify_tool_start("bash", &args);
let output = ToolOutput {
content: vec![ContentBlock::Text(TextContent::new("ok"))],
details: None,
is_error: false,
};
listeners.notify_tool_end("bash", &output, false);
{
let s = starts
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
assert_eq!(s.len(), 1);
assert_eq!(s[0].0, "bash");
drop(s);
}
{
let e = ends
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
assert_eq!(e.len(), 1);
assert_eq!(e[0].0, "bash");
assert!(!e[0].1);
drop(e);
}
}
#[test]
fn event_listeners_stream_event_hook_fires() {
let mut listeners = EventListeners::new();
let received = Arc::new(Mutex::new(Vec::new()));
let r = Arc::clone(&received);
listeners.on_stream_event = Some(Arc::new(move |ev| {
r.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(format!("{ev:?}"));
}));
let event = StreamEvent::TextDelta {
content_index: 0,
delta: "hello".to_string(),
};
listeners.notify_stream_event(&event);
assert_eq!(
received
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len(),
1
);
}
#[test]
fn session_options_on_event_wired_into_listeners() {
let received = Arc::new(Mutex::new(Vec::new()));
let r = Arc::clone(&received);
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
on_event: Some(Arc::new(move |event| {
r.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(format!("{event:?}"));
})),
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
let count = handle
.listeners()
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len();
assert_eq!(
count, 1,
"on_event from SessionOptions should register one subscriber"
);
}
#[test]
fn subscribe_unsubscribe_on_handle() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
let id = handle.subscribe(|_event| {});
assert_eq!(
handle
.listeners()
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len(),
1
);
assert!(handle.unsubscribe(id));
assert_eq!(
handle
.listeners()
.subscribers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len(),
0
);
assert!(!handle.unsubscribe(id));
}
#[test]
fn stream_event_from_assistant_message_event_converts_text_delta() {
use crate::model::AssistantMessageEvent as AME;
let partial = Arc::new(AssistantMessage {
content: Vec::new(),
api: String::new(),
provider: String::new(),
model: String::new(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
});
let ame = AME::TextDelta {
content_index: 2,
delta: "chunk".to_string(),
partial,
};
let result = stream_event_from_assistant_message_event(&ame);
assert!(result.is_some());
match result.unwrap() {
StreamEvent::TextDelta {
content_index,
delta,
} => {
assert_eq!(content_index, 2);
assert_eq!(delta, "chunk");
}
other => unreachable!("expected TextDelta, got {other:?}"),
}
}
#[test]
fn stream_event_from_assistant_message_event_start_returns_none() {
use crate::model::AssistantMessageEvent as AME;
let partial = Arc::new(AssistantMessage {
content: Vec::new(),
api: String::new(),
provider: String::new(),
model: String::new(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
});
let ame = AME::Start { partial };
assert!(stream_event_from_assistant_message_event(&ame).is_none());
}
#[test]
fn event_listeners_debug_impl() {
let listeners = EventListeners::new();
let debug = format!("{listeners:?}");
assert!(debug.contains("subscriber_count"));
assert!(debug.contains("has_on_tool_start"));
}
#[test]
fn has_extensions_false_by_default() {
let tmp = tempdir().expect("tempdir");
let options = SessionOptions {
working_directory: Some(tmp.path().to_path_buf()),
no_session: true,
..SessionOptions::default()
};
let handle = run_async(create_agent_session(options)).expect("create session");
assert!(
!handle.has_extensions(),
"session without extension_paths should have no extensions"
);
assert!(handle.extension_manager().is_none());
assert!(handle.extension_region().is_none());
}
#[test]
fn create_read_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_read_tool(tmp.path());
assert_eq!(tool.name(), "read");
assert!(!tool.description().is_empty());
let params = tool.parameters();
assert!(params.is_object(), "parameters should be a JSON object");
}
#[test]
fn create_bash_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_bash_tool(tmp.path());
assert_eq!(tool.name(), "bash");
assert!(!tool.description().is_empty());
}
#[test]
fn create_edit_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_edit_tool(tmp.path());
assert_eq!(tool.name(), "edit");
}
#[test]
fn create_write_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_write_tool(tmp.path());
assert_eq!(tool.name(), "write");
}
#[test]
fn create_grep_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_grep_tool(tmp.path());
assert_eq!(tool.name(), "grep");
}
#[test]
fn create_find_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_find_tool(tmp.path());
assert_eq!(tool.name(), "find");
}
#[test]
fn create_ls_tool_has_correct_name() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_ls_tool(tmp.path());
assert_eq!(tool.name(), "ls");
}
#[test]
fn create_all_tools_returns_eight() {
let tmp = tempdir().expect("tempdir");
let tools = super::create_all_tools(tmp.path());
assert_eq!(tools.len(), 8, "should create all 8 built-in tools");
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
for expected in BUILTIN_TOOL_NAMES {
assert!(names.contains(expected), "missing tool: {expected}");
}
}
#[test]
fn tool_to_definition_preserves_schema() {
let tmp = tempdir().expect("tempdir");
let tool = super::create_read_tool(tmp.path());
let def = super::tool_to_definition(tool.as_ref());
assert_eq!(def.name, "read");
assert!(!def.description.is_empty());
assert!(def.parameters.is_object());
assert!(
def.parameters.get("properties").is_some(),
"schema should have properties"
);
}
#[test]
fn all_tool_definitions_returns_eight_schemas() {
let tmp = tempdir().expect("tempdir");
let defs = super::all_tool_definitions(tmp.path());
assert_eq!(defs.len(), 8);
for def in &defs {
assert!(!def.name.is_empty());
assert!(!def.description.is_empty());
assert!(def.parameters.is_object());
}
}
#[test]
fn builtin_tool_names_matches_create_all() {
let tmp = tempdir().expect("tempdir");
let tools = super::create_all_tools(tmp.path());
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert_eq!(
names.as_slice(),
BUILTIN_TOOL_NAMES,
"create_all_tools order should match BUILTIN_TOOL_NAMES"
);
}
#[test]
fn tool_registry_from_factory_tools() {
let tmp = tempdir().expect("tempdir");
let tools = super::create_all_tools(tmp.path());
let registry = ToolRegistry::from_tools(tools);
assert!(registry.get("read").is_some());
assert!(registry.get("bash").is_some());
assert!(registry.get("nonexistent").is_none());
}
}