Skip to main content

opi_agent/
lib.rs

1//! General-purpose agent runtime with tool calling and transport abstraction.
2//!
3//! Provides the foundation for building specialized agents with pluggable
4//! tool systems and communication transports.
5
6pub mod agent;
7pub mod event;
8pub mod hooks;
9pub mod loop_types;
10pub mod message;
11pub mod state;
12pub mod tool;
13pub mod transport;
14pub mod validation;
15
16pub use agent::Agent;
17pub use event::{AgentEvent, AgentEventSink};
18pub use hooks::AgentHooks;
19pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
20pub use message::AgentMessage;
21pub use state::AgentState;
22pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
23pub use transport::Transport;
24
25// Re-export provider-facing types needed at the agent boundary.
26pub use opi_ai::message::ToolDef;
27
28use std::collections::{HashMap, VecDeque};
29use std::sync::{Arc, Mutex};
30
31use futures_util::StreamExt;
32use hooks::{
33    AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
34    ShouldStopAfterTurnContext,
35};
36use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
37use opi_ai::provider::Request;
38use serde_json::json;
39use tokio_util::sync::CancellationToken;
40
41/// Run the agent loop until completion or cancellation.
42///
43/// The loop iterates: provider request → stream response → detect tool calls
44/// → validate and execute tools → send tool results back → repeat until no
45/// tool calls or stop condition.
46pub async fn agent_loop(
47    context: AgentLoopContext,
48    config: AgentLoopConfig,
49    hooks: &dyn AgentHooks,
50    events: AgentEventSink,
51    cancel: CancellationToken,
52) -> Result<Vec<AgentMessage>, AgentError> {
53    let tools_map: HashMap<String, &dyn Tool> = context
54        .tools
55        .iter()
56        .map(|t| (t.definition().name.clone(), t.as_ref()))
57        .collect();
58    let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
59
60    let mut messages = context.messages;
61
62    events(AgentEvent::AgentStart);
63
64    let mut has_tools_pending;
65    for turn_idx in 0..config.max_turns {
66        if cancel.is_cancelled() {
67            events(AgentEvent::AgentEnd {
68                messages: messages.clone(),
69            });
70            return Err(AgentError::Cancelled);
71        }
72
73        events(AgentEvent::TurnStart);
74
75        // H5: transform context before provider call
76        let transformed = hooks
77            .transform_context(messages.clone(), cancel.clone())
78            .await?;
79
80        // Convert messages for the provider
81        let llm_messages = hooks.convert_to_llm(&transformed)?;
82
83        // Build the provider request
84        let request = Request {
85            model: context.model.clone(),
86            system: context.system.clone(),
87            messages: llm_messages,
88            tools: tool_defs.clone(),
89            max_tokens: config.max_tokens,
90            temperature: config.temperature,
91            thinking: Default::default(),
92            stop_sequences: vec![],
93            metadata: None,
94            cancel: cancel.clone(),
95        };
96
97        // Stream the response
98        let mut stream = context.provider.stream(request);
99        let mut assistant_content: Vec<AssistantContent> = Vec::new();
100        has_tools_pending = false;
101
102        while let Some(item) = {
103            tokio::select! {
104                biased;
105                _ = cancel.cancelled() => {
106                    events(AgentEvent::AgentEnd {
107                        messages: messages.clone(),
108                    });
109                    return Err(AgentError::Cancelled);
110                }
111                item = stream.next() => item,
112            }
113        } {
114            match item {
115                Ok(event) => {
116                    if let Some(msg) = process_stream_event(&event, &mut assistant_content, &events)
117                    {
118                        // Build the assistant message from accumulated content
119                        let mut assistant_msg = msg;
120                        assistant_msg.content = assistant_content.clone();
121                        let agent_msg = AgentMessage::Llm(Message::Assistant(assistant_msg));
122
123                        events(AgentEvent::MessageEnd {
124                            message: agent_msg.clone(),
125                        });
126
127                        messages.push(agent_msg.clone());
128
129                        // Check for tool calls
130                        let tool_calls: Vec<_> = assistant_content
131                            .iter()
132                            .filter_map(|c| match c {
133                                AssistantContent::ToolCall { tool_call } => Some(tool_call.clone()),
134                                _ => None,
135                            })
136                            .collect();
137
138                        if !tool_calls.is_empty() {
139                            has_tools_pending = true;
140                            let mut tool_results = Vec::new();
141                            let mut terminate_flags = Vec::new();
142
143                            // Determine batch execution mode (H3):
144                            // parallel by default; any sequential tool forces serial
145                            let batch_is_sequential = tool_calls.iter().any(|tc| {
146                                tools_map
147                                    .get(tc.name.as_str())
148                                    .map(|t| t.execution_mode() == ExecutionMode::Sequential)
149                                    .unwrap_or(true)
150                            });
151
152                            if batch_is_sequential {
153                                for tc in &tool_calls {
154                                    let args: serde_json::Value =
155                                        serde_json::from_str(&tc.arguments).unwrap_or(json!({}));
156
157                                    events(AgentEvent::ToolExecutionStart {
158                                        tool_call_id: tc.id.clone(),
159                                        tool_name: tc.name.clone(),
160                                        args: args.clone(),
161                                    });
162
163                                    let result = execute_tool(
164                                        &tc.id,
165                                        &tc.name,
166                                        &args,
167                                        &tools_map,
168                                        hooks,
169                                        &messages,
170                                        cancel.clone(),
171                                    )
172                                    .await;
173
174                                    let is_error = result.is_error;
175                                    terminate_flags.push(result.terminate);
176                                    events(AgentEvent::ToolExecutionEnd {
177                                        tool_call_id: tc.id.clone(),
178                                        tool_name: tc.name.clone(),
179                                        result: serde_json::json!(&result.content),
180                                        is_error,
181                                    });
182
183                                    let trm = ToolResultMessage {
184                                        tool_call_id: tc.id.clone(),
185                                        tool_name: tc.name.clone(),
186                                        content: result.content,
187                                        details: result.details,
188                                        is_error,
189                                        timestamp_ms: 0,
190                                    };
191                                    tool_results.push(trm.clone());
192                                    messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
193                                }
194                            } else {
195                                // Parallel execution — emit Start events before spawning
196                                let tc_args: Vec<_> = tool_calls
197                                    .iter()
198                                    .map(|tc| {
199                                        let args: serde_json::Value =
200                                            serde_json::from_str(&tc.arguments)
201                                                .unwrap_or(json!({}));
202                                        events(AgentEvent::ToolExecutionStart {
203                                            tool_call_id: tc.id.clone(),
204                                            tool_name: tc.name.clone(),
205                                            args: args.clone(),
206                                        });
207                                        (tc.clone(), args)
208                                    })
209                                    .collect();
210
211                                let futures: Vec<_> = tc_args
212                                    .iter()
213                                    .map(|(tc, args)| {
214                                        let tools_map = &tools_map;
215                                        let messages = &messages;
216                                        let cancel = cancel.clone();
217                                        let tc_id = tc.id.clone();
218                                        let tc_name = tc.name.clone();
219                                        let args = args.clone();
220                                        async move {
221                                            let result = execute_tool(
222                                                &tc_id, &tc_name, &args, tools_map, hooks,
223                                                messages, cancel,
224                                            )
225                                            .await;
226                                            (tc_id, tc_name, result)
227                                        }
228                                    })
229                                    .collect();
230                                let results = futures_util::future::join_all(futures).await;
231                                for (tc_id, tc_name, result) in results {
232                                    let is_error = result.is_error;
233                                    terminate_flags.push(result.terminate);
234                                    events(AgentEvent::ToolExecutionEnd {
235                                        tool_call_id: tc_id.clone(),
236                                        tool_name: tc_name.clone(),
237                                        result: serde_json::json!(&result.content),
238                                        is_error,
239                                    });
240                                    let trm = ToolResultMessage {
241                                        tool_call_id: tc_id,
242                                        tool_name: tc_name,
243                                        content: result.content,
244                                        details: result.details,
245                                        is_error,
246                                        timestamp_ms: 0,
247                                    };
248                                    tool_results.push(trm.clone());
249                                    messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
250                                }
251                            }
252
253                            // H4: early stop if ALL results have terminate=true
254                            let all_terminate =
255                                !terminate_flags.is_empty() && terminate_flags.iter().all(|t| *t);
256
257                            events(AgentEvent::TurnEnd {
258                                message: agent_msg,
259                                tool_results: tool_results.clone(),
260                            });
261
262                            if all_terminate {
263                                events(AgentEvent::AgentEnd {
264                                    messages: messages.clone(),
265                                });
266                                return Ok(messages);
267                            }
268
269                            // M1: pass only current turn's tool_results
270                            let stop_ctx = ShouldStopAfterTurnContext {
271                                messages: messages.clone(),
272                                tool_results,
273                            };
274                            if hooks.should_stop_after_turn(stop_ctx).await {
275                                events(AgentEvent::AgentEnd {
276                                    messages: messages.clone(),
277                                });
278                                return Ok(messages);
279                            }
280
281                            // Break inner loop; outer for loop continues
282                            break;
283                        }
284
285                        // No tool calls — this turn is done
286                        events(AgentEvent::TurnEnd {
287                            message: agent_msg.clone(),
288                            tool_results: vec![],
289                        });
290
291                        // M1: no tool results for a text-only turn
292                        let stop_ctx = ShouldStopAfterTurnContext {
293                            messages: messages.clone(),
294                            tool_results: vec![],
295                        };
296                        if hooks.should_stop_after_turn(stop_ctx).await {
297                            events(AgentEvent::AgentEnd {
298                                messages: messages.clone(),
299                            });
300                            return Ok(messages);
301                        }
302                    }
303                }
304                Err(e) => {
305                    events(AgentEvent::AgentEnd {
306                        messages: messages.clone(),
307                    });
308                    return Err(match &e {
309                        opi_ai::provider::ProviderError::AuthFailed(msg) => {
310                            AgentError::AuthFailed(msg.clone())
311                        }
312                        _ => AgentError::Provider(e.to_string()),
313                    });
314                }
315            }
316        }
317
318        // -- Queue polling after turn completes --------------------------------
319
320        // H5: prepare_next_turn hook
321        let next_turn_ctx = hooks::PrepareNextTurnContext {
322            messages: messages.clone(),
323            turn: turn_idx + 1,
324        };
325        let mut hook_injected = false;
326        if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
327            && !update.extra_messages.is_empty()
328        {
329            hook_injected = true;
330            messages.extend(update.extra_messages);
331        }
332
333        // Poll steering queue (drain all)
334        let steering = drain_queue(&context.steering_queue);
335        if !steering.is_empty() {
336            events(AgentEvent::QueueUpdate {
337                steering: steering.clone(),
338                follow_up: vec![],
339            });
340            for msg in steering {
341                messages.push(user_text_message(msg));
342            }
343            continue; // next turn
344        }
345
346        // If hook injected messages, continue so they reach the provider
347        if hook_injected {
348            continue;
349        }
350
351        // If no tools pending (agent would stop), poll follow-up (one at a time)
352        if !has_tools_pending {
353            let follow_up = pop_follow_up(&context.follow_up_queue);
354            if !follow_up.is_empty() {
355                events(AgentEvent::QueueUpdate {
356                    steering: vec![],
357                    follow_up: follow_up.clone(),
358                });
359                for msg in follow_up {
360                    messages.push(user_text_message(msg));
361                }
362                continue; // next turn
363            }
364            break; // no tools, no queues → agent stops
365        }
366
367        // Tools were executed and queues are empty → continue to next turn
368        let _ = turn_idx;
369    }
370
371    events(AgentEvent::AgentEnd {
372        messages: messages.clone(),
373    });
374    Ok(messages)
375}
376
377/// Process a single stream event, updating content and emitting message events.
378/// Returns Some(AssistantMessage) when a terminal event is received.
379fn process_stream_event(
380    event: &opi_ai::stream::AssistantStreamEvent,
381    content: &mut Vec<AssistantContent>,
382    events: &AgentEventSink,
383) -> Option<opi_ai::message::AssistantMessage> {
384    use opi_ai::stream::AssistantStreamEvent::*;
385
386    match event {
387        Start { partial } => {
388            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
389            events(AgentEvent::MessageStart { message: msg });
390            None
391        }
392        TextDelta { delta, partial, .. } => {
393            // Accumulate text into content vector
394            match content.last_mut() {
395                Some(AssistantContent::Text { text }) => {
396                    text.push_str(delta);
397                }
398                _ => {
399                    content.push(AssistantContent::Text {
400                        text: delta.clone(),
401                    });
402                }
403            }
404            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
405            events(AgentEvent::MessageUpdate {
406                message: msg,
407                assistant_event: Box::new(event.clone()),
408            });
409            None
410        }
411        ToolCallEnd { tool_call, .. } => {
412            content.push(AssistantContent::ToolCall {
413                tool_call: tool_call.clone(),
414            });
415            None
416        }
417        Done { message, .. } => Some(message.clone()),
418        Error { message, .. } => Some(message.clone()),
419        _ => None,
420    }
421}
422
423/// Execute a single tool, with validation and hook integration.
424async fn execute_tool(
425    call_id: &str,
426    tool_name: &str,
427    args: &serde_json::Value,
428    tools_map: &HashMap<String, &dyn Tool>,
429    hooks: &dyn AgentHooks,
430    messages: &[AgentMessage],
431    cancel: CancellationToken,
432) -> ToolResult {
433    let tool = match tools_map.get(tool_name) {
434        Some(t) => *t,
435        None => {
436            return ToolResult {
437                content: vec![opi_ai::message::OutputContent::Text {
438                    text: format!("unknown tool: {tool_name}"),
439                }],
440                details: None,
441                is_error: true,
442                terminate: false,
443            };
444        }
445    };
446
447    // Validate arguments against schema
448    let schema = &tool.definition().input_schema;
449    if let Err(err) = validation::validate(schema, args) {
450        return ToolResult::from_validation_error(err);
451    }
452
453    // Run before_tool_call hook
454    let ctx = BeforeToolCallContext {
455        tool_call_id: call_id.to_owned(),
456        tool_name: tool_name.to_owned(),
457        args: args.clone(),
458        messages: messages.to_vec(),
459    };
460    match hooks.before_tool_call(ctx).await {
461        BeforeToolCallResult::Allow => {}
462        BeforeToolCallResult::Deny { reason } => {
463            return ToolResult {
464                content: vec![opi_ai::message::OutputContent::Text { text: reason }],
465                details: None,
466                is_error: true,
467                terminate: false,
468            };
469        }
470    }
471
472    // Execute the tool
473    match tool.execute(call_id, args.clone(), cancel, None).await {
474        Ok(result) => {
475            let ctx = AfterToolCallContext {
476                tool_call_id: call_id.to_owned(),
477                tool_name: tool_name.to_owned(),
478                result: result.clone(),
479            };
480            match hooks.after_tool_call(ctx).await {
481                AfterToolCallResult::Keep => result,
482                AfterToolCallResult::Replace(replacement) => replacement,
483            }
484        }
485        Err(e) => ToolResult {
486            content: vec![opi_ai::message::OutputContent::Text {
487                text: e.to_string(),
488            }],
489            details: None,
490            is_error: true,
491            terminate: false,
492        },
493    }
494}
495
496/// Drain all messages from a queue (steering mode: All).
497fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
498    match queue {
499        Some(q) => {
500            let mut q = q.lock().unwrap();
501            q.drain(..).collect()
502        }
503        None => vec![],
504    }
505}
506
507/// Pop one message from a queue (follow-up mode: OneAtATime).
508fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
509    match queue {
510        Some(q) => {
511            let mut q = q.lock().unwrap();
512            match q.pop_front() {
513                Some(msg) => vec![msg],
514                None => vec![],
515            }
516        }
517        None => vec![],
518    }
519}
520
521/// Create a user text AgentMessage.
522fn user_text_message(text: String) -> AgentMessage {
523    AgentMessage::Llm(Message::User(UserMessage {
524        content: vec![InputContent::Text { text }],
525        timestamp_ms: 0,
526    }))
527}