Skip to main content

pi_agent/
agent_loop.rs

1//! Agent loop — Rust port of `packages/agent/src/agent-loop.ts`.
2//!
3//! Streams assistant deltas, executes tool calls, and surfaces permission
4//! decisions. Cancellation is honored via `StreamOptions::cancel`.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use futures::StreamExt;
10use pi_ai::{
11    stream_simple, AssistantMessageEvent, Content, Context, Message, StopReason, ToolResultMessage,
12};
13use serde_json::Value;
14use tokio::sync::mpsc;
15use tracing::instrument;
16
17use crate::error::{AgentError, Result};
18use crate::types::{AgentConfig, AgentEvent, AgentTool, AgentToolResult, PermissionDecision};
19
20pub struct AgentRun {
21    pub messages: Vec<Message>,
22    pub stopped_at_turn_limit: bool,
23}
24
25#[instrument(skip(config, initial_prompt, events), fields(model = %config.model.id))]
26pub async fn run_agent(
27    config: &AgentConfig,
28    initial_prompt: Message,
29    events: Option<mpsc::UnboundedSender<AgentEvent>>,
30) -> Result<AgentRun> {
31    run_agent_with_history(config, vec![initial_prompt], events).await
32}
33
34/// Continue a run with an existing transcript. Use this for `pi --resume`.
35pub async fn run_agent_with_history(
36    config: &AgentConfig,
37    mut messages: Vec<Message>,
38    events: Option<mpsc::UnboundedSender<AgentEvent>>,
39) -> Result<AgentRun> {
40    if let Some(last) = messages.last().cloned() {
41        emit(&events, AgentEvent::UserMessage { message: last });
42    }
43    emit(&events, AgentEvent::AgentStart);
44
45    let tool_index: HashMap<String, Arc<dyn AgentTool>> = config
46        .tools
47        .iter()
48        .map(|t| (t.name().to_string(), t.clone()))
49        .collect();
50    let tool_defs: Vec<pi_ai::Tool> = config
51        .tools
52        .iter()
53        .map(|t| crate::types::tool_def(t.as_ref()))
54        .collect();
55
56    let mut session_allowed: HashSet<String> = HashSet::new();
57    let mut turn: u32 = 0;
58    let mut stopped_at_turn_limit = false;
59
60    'outer: while turn < config.max_turns {
61        turn += 1;
62        emit(&events, AgentEvent::TurnStart);
63
64        let ctx = Context {
65            system_prompt: Some(config.system_prompt.clone()),
66            messages: messages.clone(),
67            tools: tool_defs.clone(),
68        };
69
70        let mut options = config.stream_options.clone();
71        if options.reasoning.is_none() && config.thinking_level != pi_ai::ThinkingLevel::Off {
72            options.reasoning = Some(config.thinking_level);
73        }
74
75        let mut stream = stream_simple(&config.model, &ctx, &options).await?;
76
77        let mut final_message: Option<pi_ai::AssistantMessage> = None;
78        let mut stop = StopReason::Stop;
79
80        while let Some(ev) = stream.next().await {
81            let ev = ev?;
82            match ev {
83                AssistantMessageEvent::Done { reason, message } => {
84                    stop = reason;
85                    final_message = Some(message);
86                    break;
87                }
88                AssistantMessageEvent::Error { reason: _, error } => {
89                    let err_msg = error
90                        .error_message
91                        .clone()
92                        .unwrap_or_else(|| "provider error".into());
93                    return Err(AgentError::Other(err_msg));
94                }
95                AssistantMessageEvent::TextDelta { delta, .. } => {
96                    emit(&events, AgentEvent::TextDelta { delta });
97                }
98                AssistantMessageEvent::ThinkingDelta { delta, .. } => {
99                    emit(&events, AgentEvent::ThinkingDelta { delta });
100                }
101                _ => {}
102            }
103        }
104
105        let Some(msg) = final_message else {
106            return Err(AgentError::Other(
107                "provider stream produced no terminal event".into(),
108            ));
109        };
110
111        let assistant_message = Message::Assistant(msg.clone());
112        messages.push(assistant_message.clone());
113        emit(
114            &events,
115            AgentEvent::AssistantMessage {
116                message: assistant_message,
117            },
118        );
119
120        let tool_calls: Vec<(String, String, Value)> = msg
121            .content
122            .iter()
123            .filter_map(|c| match c {
124                Content::ToolCall {
125                    id,
126                    name,
127                    arguments,
128                } => Some((id.clone(), name.clone(), arguments.clone())),
129                _ => None,
130            })
131            .collect();
132
133        if tool_calls.is_empty() || stop != StopReason::ToolUse {
134            emit(&events, AgentEvent::TurnEnd);
135            break 'outer;
136        }
137
138        let mut any_terminate = !tool_calls.is_empty();
139        for (id, name, args) in tool_calls {
140            // Permission gate (only for tools that require it, and only once
141            // per name per run if the user said "allow session").
142            let tool_obj = tool_index.get(&name);
143            let needs_perm = tool_obj.map(|t| t.requires_permission()).unwrap_or(false)
144                && !session_allowed.contains(&name);
145            if needs_perm {
146                match config.permission.check(&name, &args).await {
147                    PermissionDecision::Allow => {}
148                    PermissionDecision::AllowSession => {
149                        session_allowed.insert(name.clone());
150                    }
151                    PermissionDecision::Deny { reason } => {
152                        emit(
153                            &events,
154                            AgentEvent::PermissionDenied {
155                                tool_name: name.clone(),
156                                reason: reason.clone(),
157                            },
158                        );
159                        let tr = ToolResultMessage {
160                            tool_call_id: id,
161                            tool_name: name,
162                            content: vec![Content::text(format!("permission denied: {reason}"))],
163                            is_error: true,
164                            timestamp: pi_ai::now_ms(),
165                        };
166                        messages.push(Message::ToolResult(tr));
167                        any_terminate = false;
168                        continue;
169                    }
170                }
171            }
172
173            emit(
174                &events,
175                AgentEvent::ToolExecutionStart {
176                    tool_call_id: id.clone(),
177                    tool_name: name.clone(),
178                    args: args.clone(),
179                },
180            );
181            let (content, is_error, terminate) = match tool_obj {
182                Some(tool) => match tool.execute(&id, args).await {
183                    Ok(AgentToolResult {
184                        content,
185                        details: _,
186                        terminate,
187                    }) => (content, false, terminate),
188                    Err(e) => (vec![Content::text(format!("tool error: {e}"))], true, false),
189                },
190                None => (
191                    vec![Content::text(format!("unknown tool: {name}"))],
192                    true,
193                    false,
194                ),
195            };
196            if !terminate {
197                any_terminate = false;
198            }
199            emit(
200                &events,
201                AgentEvent::ToolExecutionEnd {
202                    tool_call_id: id.clone(),
203                    tool_name: name.clone(),
204                    is_error,
205                    content: content.clone(),
206                },
207            );
208            let tr = ToolResultMessage {
209                tool_call_id: id,
210                tool_name: name,
211                content,
212                is_error,
213                timestamp: pi_ai::now_ms(),
214            };
215            messages.push(Message::ToolResult(tr));
216        }
217        emit(&events, AgentEvent::TurnEnd);
218        if any_terminate {
219            break;
220        }
221    }
222
223    if turn >= config.max_turns {
224        stopped_at_turn_limit = true;
225    }
226
227    emit(
228        &events,
229        AgentEvent::AgentEnd {
230            messages: messages.clone(),
231        },
232    );
233    Ok(AgentRun {
234        messages,
235        stopped_at_turn_limit,
236    })
237}
238
239fn emit(sink: &Option<mpsc::UnboundedSender<AgentEvent>>, ev: AgentEvent) {
240    if let Some(s) = sink {
241        let _ = s.send(ev);
242    }
243}