1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use futures::StreamExt;
10use pi_ai::{
11 stream_simple, AssistantMessageEvent, Content, Context, Message, StopReason, ToolResultMessage,
12};
13use serde_json::Value;
14use tokio::sync::mpsc;
15use tracing::instrument;
16
17use crate::error::{AgentError, Result};
18use crate::types::{AgentConfig, AgentEvent, AgentTool, AgentToolResult, PermissionDecision};
19
20pub struct AgentRun {
21 pub messages: Vec<Message>,
22 pub stopped_at_turn_limit: bool,
23}
24
25#[instrument(skip(config, initial_prompt, events), fields(model = %config.model.id))]
26pub async fn run_agent(
27 config: &AgentConfig,
28 initial_prompt: Message,
29 events: Option<mpsc::UnboundedSender<AgentEvent>>,
30) -> Result<AgentRun> {
31 run_agent_with_history(config, vec![initial_prompt], events).await
32}
33
34pub async fn run_agent_with_history(
36 config: &AgentConfig,
37 mut messages: Vec<Message>,
38 events: Option<mpsc::UnboundedSender<AgentEvent>>,
39) -> Result<AgentRun> {
40 if let Some(last) = messages.last().cloned() {
41 emit(&events, AgentEvent::UserMessage { message: last });
42 }
43 emit(&events, AgentEvent::AgentStart);
44
45 let tool_index: HashMap<String, Arc<dyn AgentTool>> = config
46 .tools
47 .iter()
48 .map(|t| (t.name().to_string(), t.clone()))
49 .collect();
50 let tool_defs: Vec<pi_ai::Tool> = config
51 .tools
52 .iter()
53 .map(|t| crate::types::tool_def(t.as_ref()))
54 .collect();
55
56 let mut session_allowed: HashSet<String> = HashSet::new();
57 let mut turn: u32 = 0;
58 let mut stopped_at_turn_limit = false;
59
60 'outer: while turn < config.max_turns {
61 turn += 1;
62 emit(&events, AgentEvent::TurnStart);
63
64 let ctx = Context {
65 system_prompt: Some(config.system_prompt.clone()),
66 messages: messages.clone(),
67 tools: tool_defs.clone(),
68 };
69
70 let mut options = config.stream_options.clone();
71 if options.reasoning.is_none() && config.thinking_level != pi_ai::ThinkingLevel::Off {
72 options.reasoning = Some(config.thinking_level);
73 }
74
75 let mut stream = stream_simple(&config.model, &ctx, &options).await?;
76
77 let mut final_message: Option<pi_ai::AssistantMessage> = None;
78 let mut stop = StopReason::Stop;
79
80 while let Some(ev) = stream.next().await {
81 let ev = ev?;
82 match ev {
83 AssistantMessageEvent::Done { reason, message } => {
84 stop = reason;
85 final_message = Some(message);
86 break;
87 }
88 AssistantMessageEvent::Error { reason: _, error } => {
89 let err_msg = error
90 .error_message
91 .clone()
92 .unwrap_or_else(|| "provider error".into());
93 return Err(AgentError::Other(err_msg));
94 }
95 AssistantMessageEvent::TextDelta { delta, .. } => {
96 emit(&events, AgentEvent::TextDelta { delta });
97 }
98 AssistantMessageEvent::ThinkingDelta { delta, .. } => {
99 emit(&events, AgentEvent::ThinkingDelta { delta });
100 }
101 _ => {}
102 }
103 }
104
105 let Some(msg) = final_message else {
106 return Err(AgentError::Other(
107 "provider stream produced no terminal event".into(),
108 ));
109 };
110
111 let assistant_message = Message::Assistant(msg.clone());
112 messages.push(assistant_message.clone());
113 emit(
114 &events,
115 AgentEvent::AssistantMessage {
116 message: assistant_message,
117 },
118 );
119
120 let tool_calls: Vec<(String, String, Value)> = msg
121 .content
122 .iter()
123 .filter_map(|c| match c {
124 Content::ToolCall {
125 id,
126 name,
127 arguments,
128 } => Some((id.clone(), name.clone(), arguments.clone())),
129 _ => None,
130 })
131 .collect();
132
133 if tool_calls.is_empty() || stop != StopReason::ToolUse {
134 emit(&events, AgentEvent::TurnEnd);
135 break 'outer;
136 }
137
138 let mut any_terminate = !tool_calls.is_empty();
139 for (id, name, args) in tool_calls {
140 let tool_obj = tool_index.get(&name);
143 let needs_perm = tool_obj.map(|t| t.requires_permission()).unwrap_or(false)
144 && !session_allowed.contains(&name);
145 if needs_perm {
146 match config.permission.check(&name, &args).await {
147 PermissionDecision::Allow => {}
148 PermissionDecision::AllowSession => {
149 session_allowed.insert(name.clone());
150 }
151 PermissionDecision::Deny { reason } => {
152 emit(
153 &events,
154 AgentEvent::PermissionDenied {
155 tool_name: name.clone(),
156 reason: reason.clone(),
157 },
158 );
159 let tr = ToolResultMessage {
160 tool_call_id: id,
161 tool_name: name,
162 content: vec![Content::text(format!("permission denied: {reason}"))],
163 is_error: true,
164 timestamp: pi_ai::now_ms(),
165 };
166 messages.push(Message::ToolResult(tr));
167 any_terminate = false;
168 continue;
169 }
170 }
171 }
172
173 emit(
174 &events,
175 AgentEvent::ToolExecutionStart {
176 tool_call_id: id.clone(),
177 tool_name: name.clone(),
178 args: args.clone(),
179 },
180 );
181 let (content, is_error, terminate) = match tool_obj {
182 Some(tool) => match tool.execute(&id, args).await {
183 Ok(AgentToolResult {
184 content,
185 details: _,
186 terminate,
187 }) => (content, false, terminate),
188 Err(e) => (vec![Content::text(format!("tool error: {e}"))], true, false),
189 },
190 None => (
191 vec![Content::text(format!("unknown tool: {name}"))],
192 true,
193 false,
194 ),
195 };
196 if !terminate {
197 any_terminate = false;
198 }
199 emit(
200 &events,
201 AgentEvent::ToolExecutionEnd {
202 tool_call_id: id.clone(),
203 tool_name: name.clone(),
204 is_error,
205 content: content.clone(),
206 },
207 );
208 let tr = ToolResultMessage {
209 tool_call_id: id,
210 tool_name: name,
211 content,
212 is_error,
213 timestamp: pi_ai::now_ms(),
214 };
215 messages.push(Message::ToolResult(tr));
216 }
217 emit(&events, AgentEvent::TurnEnd);
218 if any_terminate {
219 break;
220 }
221 }
222
223 if turn >= config.max_turns {
224 stopped_at_turn_limit = true;
225 }
226
227 emit(
228 &events,
229 AgentEvent::AgentEnd {
230 messages: messages.clone(),
231 },
232 );
233 Ok(AgentRun {
234 messages,
235 stopped_at_turn_limit,
236 })
237}
238
239fn emit(sink: &Option<mpsc::UnboundedSender<AgentEvent>>, ev: AgentEvent) {
240 if let Some(s) = sink {
241 let _ = s.send(ev);
242 }
243}