Skip to main content

hh_cli/core/agent/
mod.rs

1pub mod state;
2pub mod subagent_manager;
3
4pub use super::{AgentEvents, NoopEvents};
5
6use crate::core::{
7    ApprovalDecision, ApprovalPolicy, Message, Provider, ProviderRequest, ProviderStreamEvent,
8    QuestionAnswers, QuestionPrompt, Role, SessionReader, SessionSink, ToolCall, ToolExecutor,
9};
10use crate::safety::sanitize_tool_output;
11use crate::session::{SessionEvent, event_id};
12use crate::tool::ToolResult;
13use futures::stream::{FuturesUnordered, StreamExt};
14use serde::Serialize;
15use state::AgentState;
16use std::future::Future;
17
18pub struct AgentLoop<P, E, T, A, S>
19where
20    P: Provider,
21    E: AgentEvents,
22    T: ToolExecutor,
23    A: ApprovalPolicy,
24    S: SessionSink + SessionReader,
25{
26    pub provider: P,
27    pub tools: T,
28    pub approvals: A,
29    pub max_steps: usize,
30    pub model: String,
31    pub system_prompt: String,
32    pub session: S,
33    pub events: E,
34}
35
36impl<P, E, T, A, S> AgentLoop<P, E, T, A, S>
37where
38    P: Provider,
39    E: AgentEvents,
40    T: ToolExecutor,
41    A: ApprovalPolicy,
42    S: SessionSink + SessionReader,
43{
44    pub async fn run<F>(&self, prompt: Message, mut approve: F) -> anyhow::Result<String>
45    where
46        F: FnMut(&str) -> anyhow::Result<bool>,
47    {
48        self.run_with_question_tool(prompt, &mut approve, |_questions| async {
49            anyhow::bail!("question tool is unavailable in this mode; provide a question handler")
50        })
51        .await
52    }
53
54    pub async fn run_with_question_tool<F, Q, QFut>(
55        &self,
56        prompt: Message,
57        mut approve: F,
58        mut ask_question: Q,
59    ) -> anyhow::Result<String>
60    where
61        F: FnMut(&str) -> anyhow::Result<bool>,
62        Q: FnMut(Vec<QuestionPrompt>) -> QFut,
63        QFut: Future<Output = anyhow::Result<QuestionAnswers>> + Send,
64    {
65        let replayed_events = self.session.replay_events()?;
66        let mut state = AgentState {
67            messages: self.session.replay_messages()?,
68            todo_items: Vec::new(),
69            step: 0,
70        };
71
72        let mut tool_name_by_call_id = std::collections::HashMap::new();
73        for event in replayed_events {
74            match event {
75                SessionEvent::ToolCall { call } => {
76                    tool_name_by_call_id.insert(call.id, call.name);
77                }
78                SessionEvent::ToolResult { id, result, .. } => {
79                    if let (Some(name), Some(tool_result)) =
80                        (tool_name_by_call_id.get(&id), result.as_ref())
81                    {
82                        state.apply_tool_result(name, tool_result);
83                    }
84                }
85                _ => {}
86            }
87        }
88
89        if state
90            .messages
91            .iter()
92            .all(|message| message.role != Role::System)
93            && !self.system_prompt.trim().is_empty()
94        {
95            self.append_message(
96                &mut state,
97                Message {
98                    role: Role::System,
99                    content: self.system_prompt.clone(),
100                    attachments: Vec::new(),
101                    tool_call_id: None,
102                },
103            )?;
104        }
105
106        self.append_message(&mut state, prompt)?;
107
108        loop {
109            if self.max_steps > 0 && state.step >= self.max_steps {
110                anyhow::bail!("Reached max steps without final answer")
111            }
112
113            let mut request_messages = state.messages.clone();
114            if let Some(state_message) = state.state_for_llm() {
115                request_messages.push(state_message);
116            }
117
118            let req = ProviderRequest {
119                model: self.model.clone(),
120                messages: request_messages,
121                tools: self.tools.schemas(),
122            };
123
124            let mut assistant_content = String::new();
125            let mut thinking_content = String::new();
126            let response = self
127                .provider
128                .complete_stream(req, |event| match event {
129                    ProviderStreamEvent::AssistantDelta(delta) => {
130                        assistant_content.push_str(&delta);
131                        self.events.on_assistant_delta(&delta);
132                    }
133                    ProviderStreamEvent::ThinkingDelta(delta) => {
134                        thinking_content.push_str(&delta);
135                        self.events.on_thinking(&delta);
136                    }
137                })
138                .await?;
139
140            if let Some(tokens) = response.context_tokens {
141                self.events.on_context_usage(tokens);
142            }
143
144            if assistant_content.is_empty() {
145                assistant_content = response.assistant_message.content.clone();
146                if !assistant_content.is_empty() {
147                    self.events.on_assistant_delta(&assistant_content);
148                }
149            }
150
151            if thinking_content.is_empty()
152                && let Some(t) = &response.thinking
153            {
154                thinking_content = t.clone();
155            }
156
157            if !thinking_content.is_empty() {
158                self.session.append(&SessionEvent::Thinking {
159                    id: event_id(),
160                    content: thinking_content,
161                })?;
162            }
163
164            let assistant = Message {
165                role: Role::Assistant,
166                content: assistant_content.clone(),
167                attachments: Vec::new(),
168                tool_call_id: None,
169            };
170
171            self.append_message(&mut state, assistant.clone())?;
172
173            if response.done {
174                self.events.on_assistant_done();
175                return Ok(assistant_content);
176            }
177
178            let mut pending_non_blocking = FuturesUnordered::new();
179
180            for call in response.tool_calls {
181                self.session
182                    .append(&SessionEvent::ToolCall { call: call.clone() })?;
183
184                match self.approvals.decision_for_tool(&call.name) {
185                    ApprovalDecision::Deny => {
186                        let output = format!("tool denied: {}", call.name);
187                        self.record_tool_error(&call, output, &mut state)?;
188                        continue;
189                    }
190                    ApprovalDecision::Ask => {
191                        self.events.on_tool_start(&call.name, &call.arguments);
192                        let approved = approve(&call.name)?;
193                        self.session.append(&SessionEvent::Approval {
194                            id: event_id(),
195                            tool_name: call.name.clone(),
196                            approved,
197                        })?;
198                        if !approved {
199                            self.record_tool_error(
200                                &call,
201                                format!("tool approval denied: {}", call.name),
202                                &mut state,
203                            )?;
204                            continue;
205                        }
206                    }
207                    ApprovalDecision::Allow => {}
208                }
209
210                if call.name == "question" {
211                    self.events.on_tool_start(&call.name, &call.arguments);
212                    let result = self
213                        .execute_question_tool_call(&call, &mut ask_question)
214                        .await;
215                    self.events.on_tool_end(&call.name, &result);
216                    self.record_tool_result(&call, result, &mut state)?;
217                    continue;
218                }
219
220                if self.tools.is_non_blocking(&call.name) {
221                    let event_args = decorate_tool_start_args(&call.name, &call.arguments);
222                    self.events.on_tool_start(&call.name, &event_args);
223                    pending_non_blocking.push(async {
224                        let mut result =
225                            self.tools.execute(&call.name, call.arguments.clone()).await;
226                        result.output = sanitize_tool_output(&result.output);
227                        (call, result)
228                    });
229                    continue;
230                }
231
232                self.execute_tool_call(&call, &mut state).await?;
233            }
234
235            while let Some((call, result)) = pending_non_blocking.next().await {
236                self.events.on_tool_end(&call.name, &result);
237                self.record_tool_result(&call, result, &mut state)?;
238            }
239
240            state.step += 1;
241        }
242    }
243
244    async fn execute_question_tool_call<Q, QFut>(
245        &self,
246        call: &ToolCall,
247        ask_question: &mut Q,
248    ) -> ToolResult
249    where
250        Q: FnMut(Vec<QuestionPrompt>) -> QFut,
251        QFut: Future<Output = anyhow::Result<QuestionAnswers>> + Send,
252    {
253        let parsed = match crate::tool::question::parse_question_args(call.arguments.clone()) {
254            Ok(parsed) => parsed,
255            Err(err) => return ToolResult::err_text("invalid_question_args", err.to_string()),
256        };
257
258        match ask_question(parsed.questions.clone()).await {
259            Ok(answers) => crate::tool::question::question_result(&parsed.questions, answers),
260            Err(err) => ToolResult::err_text("question_dismissed", err.to_string()),
261        }
262    }
263
264    async fn execute_tool_call(
265        &self,
266        call: &ToolCall,
267        state: &mut AgentState,
268    ) -> anyhow::Result<()> {
269        let event_args = decorate_tool_start_args(&call.name, &call.arguments);
270        self.events.on_tool_start(&call.name, &event_args);
271        let mut result = if call.name == "todo_read" {
272            todo_snapshot_result(&state.todo_items)
273        } else {
274            self.tools.execute(&call.name, call.arguments.clone()).await
275        };
276        result.output = sanitize_tool_output(&result.output);
277        self.events.on_tool_end(&call.name, &result);
278        self.record_tool_result(call, result, state)
279    }
280
281    fn record_tool_error(
282        &self,
283        call: &ToolCall,
284        output: String,
285        state: &mut AgentState,
286    ) -> anyhow::Result<()> {
287        self.events.on_tool_start(&call.name, &call.arguments);
288        let result = ToolResult::err_text("denied", sanitize_tool_output(&output));
289        self.events.on_tool_end(&call.name, &result);
290        self.record_tool_result(call, result, state)
291    }
292
293    fn record_tool_result(
294        &self,
295        call: &ToolCall,
296        result: ToolResult,
297        state: &mut AgentState,
298    ) -> anyhow::Result<()> {
299        let call_id = call.id.clone();
300        state.push(Message {
301            role: Role::Tool,
302            content: result.output.clone(),
303            attachments: Vec::new(),
304            tool_call_id: Some(call_id.clone()),
305        });
306        if state.apply_tool_result(&call.name, &result) {
307            self.events.on_todo_items_changed(&state.todo_items);
308        }
309        self.session.append(&SessionEvent::ToolResult {
310            id: call_id,
311            is_error: result.is_error,
312            output: result.output.clone(),
313            result: Some(result),
314        })?;
315        Ok(())
316    }
317
318    fn append_message(&self, state: &mut AgentState, message: Message) -> anyhow::Result<()> {
319        state.push(message.clone());
320        self.session.append(&SessionEvent::Message {
321            id: event_id(),
322            message,
323        })
324    }
325}
326
327fn decorate_tool_start_args(name: &str, args: &serde_json::Value) -> serde_json::Value {
328    if name != "task" {
329        return args.clone();
330    }
331    let mut obj = args.as_object().cloned().unwrap_or_default();
332    let now = std::time::SystemTime::now()
333        .duration_since(std::time::UNIX_EPOCH)
334        .map_or(0, |d| d.as_secs());
335    obj.insert("__started_at".to_string(), serde_json::Value::from(now));
336    serde_json::Value::Object(obj)
337}
338
339#[derive(Debug, Serialize)]
340struct TodoSnapshotCounts {
341    total: usize,
342    pending: usize,
343    in_progress: usize,
344    completed: usize,
345    cancelled: usize,
346}
347
348#[derive(Debug, Serialize)]
349struct TodoSnapshotOutput {
350    todos: Vec<crate::core::TodoItem>,
351    counts: TodoSnapshotCounts,
352}
353
354fn todo_snapshot_result(items: &[crate::core::TodoItem]) -> ToolResult {
355    let mut counts = TodoSnapshotCounts {
356        total: items.len(),
357        pending: 0,
358        in_progress: 0,
359        completed: 0,
360        cancelled: 0,
361    };
362
363    for item in items {
364        match item.status {
365            crate::core::TodoStatus::Pending => counts.pending += 1,
366            crate::core::TodoStatus::InProgress => counts.in_progress += 1,
367            crate::core::TodoStatus::Completed => counts.completed += 1,
368            crate::core::TodoStatus::Cancelled => counts.cancelled += 1,
369        }
370    }
371
372    let output = TodoSnapshotOutput {
373        todos: items.to_vec(),
374        counts,
375    };
376
377    ToolResult::ok_json_typed_serializable(
378        "todo list snapshot",
379        "application/vnd.hh.todo+json",
380        &output,
381    )
382}