Skip to main content

opi_agent/
lib.rs

1//! General-purpose agent runtime with tool calling and transport abstraction.
2//!
3//! Provides the foundation for building specialized agents with pluggable
4//! tool systems and communication transports.
5
6pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod hooks;
10pub mod loop_types;
11pub mod message;
12pub mod session;
13pub mod session_event;
14pub mod state;
15pub mod tool;
16pub mod transport;
17pub mod validation;
18
19pub use agent::Agent;
20pub use event::{AgentEvent, AgentEventSink};
21pub use hooks::AgentHooks;
22pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
23pub use message::AgentMessage;
24pub use session_event::AgentSessionEvent;
25pub use state::AgentState;
26pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
27pub use transport::Transport;
28
29// Re-export provider-facing types needed at the agent boundary.
30pub use opi_ai::message::ToolDef;
31
32use std::collections::{HashMap, VecDeque};
33use std::sync::{Arc, Mutex};
34
35use futures_util::StreamExt;
36use hooks::{
37    AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
38    ShouldStopAfterTurnContext,
39};
40use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
41use opi_ai::provider::Request;
42use serde_json::json;
43use tokio_util::sync::CancellationToken;
44
45/// Run the agent loop until completion or cancellation.
46///
47/// The loop iterates: provider request → stream response → detect tool calls
48/// → validate and execute tools → send tool results back → repeat until no
49/// tool calls or stop condition.
50pub async fn agent_loop(
51    context: AgentLoopContext,
52    config: AgentLoopConfig,
53    hooks: &dyn AgentHooks,
54    events: AgentEventSink,
55    cancel: CancellationToken,
56) -> Result<Vec<AgentMessage>, AgentError> {
57    let tools_map: HashMap<String, &dyn Tool> = context
58        .tools
59        .iter()
60        .map(|t| (t.definition().name.clone(), t.as_ref()))
61        .collect();
62    let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
63
64    let mut messages = context.messages;
65
66    events(AgentEvent::AgentStart);
67
68    let mut has_tools_pending;
69    for turn_idx in 0..config.max_turns {
70        if cancel.is_cancelled() {
71            events(AgentEvent::AgentEnd {
72                messages: messages.clone(),
73            });
74            return Err(AgentError::Cancelled);
75        }
76
77        events(AgentEvent::TurnStart);
78
79        // H5: transform context before provider call
80        let transformed = hooks
81            .transform_context(messages.clone(), cancel.clone())
82            .await?;
83
84        // Convert messages for the provider
85        let llm_messages = hooks.convert_to_llm(&transformed)?;
86
87        // Stream the response (with optional retry on retryable errors)
88        let mut assistant_content: Vec<AssistantContent> = Vec::new();
89        has_tools_pending = false;
90        let mut retry_attempt: u32 = 0;
91        let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
92
93        'stream: loop {
94            let request = Request {
95                model: context.model.clone(),
96                system: context.system.clone(),
97                messages: llm_messages.clone(),
98                tools: tool_defs.clone(),
99                max_tokens: config.max_tokens,
100                temperature: config.temperature,
101                thinking: config.thinking.clone().unwrap_or_default(),
102                stop_sequences: vec![],
103                metadata: None,
104                cancel: cancel.clone(),
105            };
106            let mut stream = context.provider.stream(request);
107            assistant_content.clear();
108
109            while let Some(item) = {
110                tokio::select! {
111                    biased;
112                    _ = cancel.cancelled() => {
113                        events(AgentEvent::AgentEnd {
114                            messages: messages.clone(),
115                        });
116                        return Err(AgentError::Cancelled);
117                    }
118                    item = stream.next() => item,
119                }
120            } {
121                match item {
122                    Ok(event) => {
123                        if let Some(msg) =
124                            process_stream_event(&event, &mut assistant_content, &events)
125                        {
126                            // Use the provider's complete content (includes thinking)
127                            // rather than the locally-accumulated assistant_content
128                            // which only tracks text and tool calls.
129                            let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
130
131                            events(AgentEvent::MessageEnd {
132                                message: agent_msg.clone(),
133                            });
134
135                            messages.push(agent_msg.clone());
136
137                            // Extract tool calls from the provider's complete content.
138                            let content = match &agent_msg {
139                                AgentMessage::Llm(Message::Assistant(a)) => &a.content,
140                                _ => &Vec::new(),
141                            };
142                            let tool_calls: Vec<_> = content
143                                .iter()
144                                .filter_map(|c| match c {
145                                    AssistantContent::ToolCall { tool_call } => {
146                                        Some(tool_call.clone())
147                                    }
148                                    _ => None,
149                                })
150                                .collect();
151
152                            if !tool_calls.is_empty() {
153                                has_tools_pending = true;
154                                let mut tool_results = Vec::new();
155                                let mut terminate_flags = Vec::new();
156
157                                let batch_is_sequential = tool_calls.iter().any(|tc| {
158                                    tools_map
159                                        .get(tc.name.as_str())
160                                        .map(|t| t.execution_mode() == ExecutionMode::Sequential)
161                                        .unwrap_or(true)
162                                });
163
164                                if batch_is_sequential {
165                                    for tc in &tool_calls {
166                                        let args: serde_json::Value =
167                                            serde_json::from_str(&tc.arguments)
168                                                .unwrap_or(json!({}));
169
170                                        events(AgentEvent::ToolExecutionStart {
171                                            tool_call_id: tc.id.clone(),
172                                            tool_name: tc.name.clone(),
173                                            args: args.clone(),
174                                        });
175
176                                        let result = execute_tool(
177                                            &tc.id,
178                                            &tc.name,
179                                            &args,
180                                            &tools_map,
181                                            hooks,
182                                            &messages,
183                                            cancel.clone(),
184                                        )
185                                        .await;
186
187                                        let is_error = result.is_error;
188                                        let details = result.details.clone();
189                                        terminate_flags.push(result.terminate);
190                                        events(AgentEvent::ToolExecutionEnd {
191                                            tool_call_id: tc.id.clone(),
192                                            tool_name: tc.name.clone(),
193                                            result: serde_json::json!(&result.content),
194                                            details,
195                                            is_error,
196                                        });
197
198                                        let trm = ToolResultMessage {
199                                            tool_call_id: tc.id.clone(),
200                                            tool_name: tc.name.clone(),
201                                            content: result.content,
202                                            details: result.details,
203                                            is_error,
204                                            timestamp_ms: 0,
205                                        };
206                                        tool_results.push(trm.clone());
207                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
208                                    }
209                                } else {
210                                    let tc_args: Vec<_> = tool_calls
211                                        .iter()
212                                        .map(|tc| {
213                                            let args: serde_json::Value =
214                                                serde_json::from_str(&tc.arguments)
215                                                    .unwrap_or(json!({}));
216                                            events(AgentEvent::ToolExecutionStart {
217                                                tool_call_id: tc.id.clone(),
218                                                tool_name: tc.name.clone(),
219                                                args: args.clone(),
220                                            });
221                                            (tc.clone(), args)
222                                        })
223                                        .collect();
224
225                                    let futures: Vec<_> = tc_args
226                                        .iter()
227                                        .map(|(tc, args)| {
228                                            let tools_map = &tools_map;
229                                            let messages = &messages;
230                                            let cancel = cancel.clone();
231                                            let tc_id = tc.id.clone();
232                                            let tc_name = tc.name.clone();
233                                            let args = args.clone();
234                                            async move {
235                                                let result = execute_tool(
236                                                    &tc_id, &tc_name, &args, tools_map, hooks,
237                                                    messages, cancel,
238                                                )
239                                                .await;
240                                                (tc_id, tc_name, result)
241                                            }
242                                        })
243                                        .collect();
244                                    let results = futures_util::future::join_all(futures).await;
245                                    for (tc_id, tc_name, result) in results {
246                                        let is_error = result.is_error;
247                                        let details = result.details.clone();
248                                        terminate_flags.push(result.terminate);
249                                        events(AgentEvent::ToolExecutionEnd {
250                                            tool_call_id: tc_id.clone(),
251                                            tool_name: tc_name.clone(),
252                                            result: serde_json::json!(&result.content),
253                                            details,
254                                            is_error,
255                                        });
256                                        let trm = ToolResultMessage {
257                                            tool_call_id: tc_id,
258                                            tool_name: tc_name,
259                                            content: result.content,
260                                            details: result.details,
261                                            is_error,
262                                            timestamp_ms: 0,
263                                        };
264                                        tool_results.push(trm.clone());
265                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
266                                    }
267                                }
268
269                                let all_terminate = !terminate_flags.is_empty()
270                                    && terminate_flags.iter().all(|t| *t);
271
272                                events(AgentEvent::TurnEnd {
273                                    message: agent_msg,
274                                    tool_results: tool_results.clone(),
275                                });
276
277                                if all_terminate {
278                                    events(AgentEvent::AgentEnd {
279                                        messages: messages.clone(),
280                                    });
281                                    return Ok(messages);
282                                }
283
284                                let stop_ctx = ShouldStopAfterTurnContext {
285                                    messages: messages.clone(),
286                                    tool_results,
287                                };
288                                if hooks.should_stop_after_turn(stop_ctx).await {
289                                    events(AgentEvent::AgentEnd {
290                                        messages: messages.clone(),
291                                    });
292                                    return Ok(messages);
293                                }
294
295                                break 'stream;
296                            }
297
298                            events(AgentEvent::TurnEnd {
299                                message: agent_msg.clone(),
300                                tool_results: vec![],
301                            });
302
303                            let stop_ctx = ShouldStopAfterTurnContext {
304                                messages: messages.clone(),
305                                tool_results: vec![],
306                            };
307                            if hooks.should_stop_after_turn(stop_ctx).await {
308                                events(AgentEvent::AgentEnd {
309                                    messages: messages.clone(),
310                                });
311                                return Ok(messages);
312                            }
313                        }
314                    }
315                    Err(e) => {
316                        if e.is_retryable()
317                            && retry_attempt < max_attempts
318                            && let Some(ref rc) = config.retry
319                        {
320                            let retry_after_ms = match &e {
321                                opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
322                                    *retry_after_ms
323                                }
324                                _ => None,
325                            };
326                            let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
327                            retry_attempt += 1;
328
329                            events(AgentEvent::AutoRetryStart {
330                                attempt: retry_attempt,
331                                max_attempts: rc.max_attempts,
332                                delay_ms,
333                                error_message: e.to_string(),
334                            });
335
336                            tokio::select! {
337                                biased;
338                                _ = cancel.cancelled() => {
339                                    events(AgentEvent::AgentEnd {
340                                        messages: messages.clone(),
341                                    });
342                                    return Err(AgentError::Cancelled);
343                                }
344                                _ = tokio::time::sleep(
345                                    std::time::Duration::from_millis(delay_ms)
346                                ) => {}
347                            }
348                            continue 'stream;
349                        }
350
351                        if retry_attempt > 0 {
352                            events(AgentEvent::AutoRetryEnd {
353                                success: false,
354                                attempt: retry_attempt,
355                                final_error: Some(e.to_string()),
356                            });
357                        }
358
359                        events(AgentEvent::AgentEnd {
360                            messages: messages.clone(),
361                        });
362                        return Err(match &e {
363                            opi_ai::provider::ProviderError::AuthFailed(msg) => {
364                                AgentError::AuthFailed(msg.clone())
365                            }
366                            _ => AgentError::Provider(e.to_string()),
367                        });
368                    }
369                }
370            }
371
372            if retry_attempt > 0 {
373                events(AgentEvent::AutoRetryEnd {
374                    success: true,
375                    attempt: retry_attempt,
376                    final_error: None,
377                });
378            }
379            break 'stream;
380        }
381
382        // -- Queue polling after turn completes --------------------------------
383
384        // H5: prepare_next_turn hook
385        let next_turn_ctx = hooks::PrepareNextTurnContext {
386            messages: messages.clone(),
387            turn: turn_idx + 1,
388        };
389        let mut hook_injected = false;
390        if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
391            && !update.extra_messages.is_empty()
392        {
393            hook_injected = true;
394            messages.extend(update.extra_messages);
395        }
396
397        // Poll steering queue (drain all)
398        let steering = drain_queue(&context.steering_queue);
399        if !steering.is_empty() {
400            events(AgentEvent::QueueUpdate {
401                steering: steering.clone(),
402                follow_up: vec![],
403            });
404            for msg in steering {
405                messages.push(user_text_message(msg));
406            }
407            continue; // next turn
408        }
409
410        // If hook injected messages, continue so they reach the provider
411        if hook_injected {
412            continue;
413        }
414
415        // If no tools pending (agent would stop), poll follow-up (one at a time)
416        if !has_tools_pending {
417            let follow_up = pop_follow_up(&context.follow_up_queue);
418            if !follow_up.is_empty() {
419                events(AgentEvent::QueueUpdate {
420                    steering: vec![],
421                    follow_up: follow_up.clone(),
422                });
423                for msg in follow_up {
424                    messages.push(user_text_message(msg));
425                }
426                continue; // next turn
427            }
428            break; // no tools, no queues → agent stops
429        }
430
431        // Tools were executed and queues are empty → continue to next turn
432        let _ = turn_idx;
433    }
434
435    events(AgentEvent::AgentEnd {
436        messages: messages.clone(),
437    });
438    Ok(messages)
439}
440
441/// Process a single stream event, updating content and emitting message events.
442/// Returns Some(AssistantMessage) when a terminal event is received.
443fn process_stream_event(
444    event: &opi_ai::stream::AssistantStreamEvent,
445    content: &mut Vec<AssistantContent>,
446    events: &AgentEventSink,
447) -> Option<opi_ai::message::AssistantMessage> {
448    use opi_ai::stream::AssistantStreamEvent::*;
449
450    match event {
451        Start { partial } => {
452            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
453            events(AgentEvent::MessageStart { message: msg });
454            None
455        }
456        TextDelta { delta, partial, .. } => {
457            // Accumulate text into content vector
458            match content.last_mut() {
459                Some(AssistantContent::Text { text }) => {
460                    text.push_str(delta);
461                }
462                _ => {
463                    content.push(AssistantContent::Text {
464                        text: delta.clone(),
465                    });
466                }
467            }
468            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
469            events(AgentEvent::MessageUpdate {
470                message: msg,
471                assistant_event: Box::new(event.clone()),
472            });
473            None
474        }
475        ToolCallEnd { tool_call, .. } => {
476            content.push(AssistantContent::ToolCall {
477                tool_call: tool_call.clone(),
478            });
479            None
480        }
481        ThinkingStart { partial, .. }
482        | ThinkingDelta { partial, .. }
483        | ThinkingEnd { partial, .. } => {
484            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
485            events(AgentEvent::MessageUpdate {
486                message: msg,
487                assistant_event: Box::new(event.clone()),
488            });
489            None
490        }
491        Done { message, .. } => Some(message.clone()),
492        Error { message, .. } => Some(message.clone()),
493        _ => None,
494    }
495}
496
497/// Execute a single tool, with validation and hook integration.
498async fn execute_tool(
499    call_id: &str,
500    tool_name: &str,
501    args: &serde_json::Value,
502    tools_map: &HashMap<String, &dyn Tool>,
503    hooks: &dyn AgentHooks,
504    messages: &[AgentMessage],
505    cancel: CancellationToken,
506) -> ToolResult {
507    let tool = match tools_map.get(tool_name) {
508        Some(t) => *t,
509        None => {
510            return ToolResult {
511                content: vec![opi_ai::message::OutputContent::Text {
512                    text: format!("unknown tool: {tool_name}"),
513                }],
514                details: None,
515                is_error: true,
516                terminate: false,
517            };
518        }
519    };
520
521    // Validate arguments against schema
522    let schema = &tool.definition().input_schema;
523    if let Err(err) = validation::validate(schema, args) {
524        return ToolResult::from_validation_error(err);
525    }
526
527    // Run before_tool_call hook
528    let ctx = BeforeToolCallContext {
529        tool_call_id: call_id.to_owned(),
530        tool_name: tool_name.to_owned(),
531        args: args.clone(),
532        messages: messages.to_vec(),
533    };
534    match hooks.before_tool_call(ctx).await {
535        BeforeToolCallResult::Allow => {}
536        BeforeToolCallResult::Deny { reason } => {
537            return ToolResult {
538                content: vec![opi_ai::message::OutputContent::Text { text: reason }],
539                details: None,
540                is_error: true,
541                terminate: false,
542            };
543        }
544    }
545
546    // Execute the tool
547    match tool.execute(call_id, args.clone(), cancel, None).await {
548        Ok(result) => {
549            let ctx = AfterToolCallContext {
550                tool_call_id: call_id.to_owned(),
551                tool_name: tool_name.to_owned(),
552                result: result.clone(),
553            };
554            match hooks.after_tool_call(ctx).await {
555                AfterToolCallResult::Keep => result,
556                AfterToolCallResult::Replace(replacement) => replacement,
557            }
558        }
559        Err(e) => ToolResult {
560            content: vec![opi_ai::message::OutputContent::Text {
561                text: e.to_string(),
562            }],
563            details: None,
564            is_error: true,
565            terminate: false,
566        },
567    }
568}
569
570/// Drain all messages from a queue (steering mode: All).
571fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
572    match queue {
573        Some(q) => {
574            let mut q = q.lock().unwrap();
575            q.drain(..).collect()
576        }
577        None => vec![],
578    }
579}
580
581/// Pop one message from a queue (follow-up mode: OneAtATime).
582fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
583    match queue {
584        Some(q) => {
585            let mut q = q.lock().unwrap();
586            match q.pop_front() {
587                Some(msg) => vec![msg],
588                None => vec![],
589            }
590        }
591        None => vec![],
592    }
593}
594
595/// Create a user text AgentMessage.
596fn user_text_message(text: String) -> AgentMessage {
597    AgentMessage::Llm(Message::User(UserMessage {
598        content: vec![InputContent::Text { text }],
599        timestamp_ms: 0,
600    }))
601}