Skip to main content

opi_agent/
lib.rs

1//! General-purpose agent runtime with tool calling and transport abstraction.
2//!
3//! Provides the foundation for building specialized agents with pluggable
4//! tool systems and communication transports.
5
6pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod hooks;
10pub mod loop_types;
11pub mod message;
12pub mod session;
13pub mod session_event;
14pub mod state;
15pub mod tool;
16pub mod transport;
17pub mod validation;
18
19pub use agent::Agent;
20pub use event::{AgentEvent, AgentEventSink};
21pub use hooks::AgentHooks;
22pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
23pub use message::AgentMessage;
24pub use session_event::AgentSessionEvent;
25pub use state::AgentState;
26pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
27pub use transport::Transport;
28
29// Re-export provider-facing types needed at the agent boundary.
30pub use opi_ai::message::ToolDef;
31
32use std::collections::{HashMap, VecDeque};
33use std::sync::{Arc, Mutex};
34
35use futures_util::StreamExt;
36use hooks::{
37    AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
38    ShouldStopAfterTurnContext,
39};
40use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
41use opi_ai::provider::{Request, validate_request_capabilities};
42use serde_json::json;
43use tokio_util::sync::CancellationToken;
44
45/// Run the agent loop until completion or cancellation.
46///
47/// The loop iterates: provider request → stream response → detect tool calls
48/// → validate and execute tools → send tool results back → repeat until no
49/// tool calls or stop condition.
50pub async fn agent_loop(
51    context: AgentLoopContext,
52    config: AgentLoopConfig,
53    hooks: &dyn AgentHooks,
54    events: AgentEventSink,
55    cancel: CancellationToken,
56) -> Result<Vec<AgentMessage>, AgentError> {
57    let tools_map: HashMap<String, &dyn Tool> = context
58        .tools
59        .iter()
60        .map(|t| (t.definition().name.clone(), t.as_ref()))
61        .collect();
62    let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
63
64    let mut messages = context.messages;
65
66    events(AgentEvent::AgentStart);
67
68    let mut has_tools_pending;
69    for turn_idx in 0..config.max_turns {
70        if cancel.is_cancelled() {
71            events(AgentEvent::AgentEnd {
72                messages: messages.clone(),
73            });
74            return Err(AgentError::Cancelled);
75        }
76
77        events(AgentEvent::TurnStart);
78
79        // H5: transform context before provider call
80        let transformed = hooks
81            .transform_context(messages.clone(), cancel.clone())
82            .await?;
83
84        // Convert messages for the provider
85        let llm_messages = hooks.convert_to_llm(&transformed)?;
86
87        // Stream the response (with optional retry on retryable errors)
88        let mut assistant_content: Vec<AssistantContent> = Vec::new();
89        has_tools_pending = false;
90        let mut retry_attempt: u32 = 0;
91        let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
92
93        'stream: loop {
94            let request = Request {
95                model: context.model.clone(),
96                system: context.system.clone(),
97                messages: llm_messages.clone(),
98                tools: tool_defs.clone(),
99                max_tokens: config.max_tokens,
100                temperature: config.temperature,
101                thinking: config.thinking.clone().unwrap_or_default(),
102                stop_sequences: vec![],
103                metadata: None,
104                cancel: cancel.clone(),
105            };
106            if let Err(e) = validate_request_capabilities(context.provider.as_ref(), &request) {
107                events(AgentEvent::AgentEnd {
108                    messages: messages.clone(),
109                });
110                return Err(AgentError::Provider(e.to_string()));
111            }
112            let mut stream = context.provider.stream(request);
113            assistant_content.clear();
114
115            while let Some(item) = {
116                tokio::select! {
117                    biased;
118                    _ = cancel.cancelled() => {
119                        events(AgentEvent::AgentEnd {
120                            messages: messages.clone(),
121                        });
122                        return Err(AgentError::Cancelled);
123                    }
124                    item = stream.next() => item,
125                }
126            } {
127                match item {
128                    Ok(event) => {
129                        if let Some(msg) =
130                            process_stream_event(&event, &mut assistant_content, &events)
131                        {
132                            // Use the provider's complete content (includes thinking)
133                            // rather than the locally-accumulated assistant_content
134                            // which only tracks text and tool calls.
135                            let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
136
137                            events(AgentEvent::MessageEnd {
138                                message: agent_msg.clone(),
139                            });
140
141                            messages.push(agent_msg.clone());
142
143                            // Extract tool calls from the provider's complete content.
144                            let content = match &agent_msg {
145                                AgentMessage::Llm(Message::Assistant(a)) => &a.content,
146                                _ => &Vec::new(),
147                            };
148                            let tool_calls: Vec<_> = content
149                                .iter()
150                                .filter_map(|c| match c {
151                                    AssistantContent::ToolCall { tool_call } => {
152                                        Some(tool_call.clone())
153                                    }
154                                    _ => None,
155                                })
156                                .collect();
157
158                            if !tool_calls.is_empty() {
159                                has_tools_pending = true;
160                                let mut tool_results = Vec::new();
161                                let mut terminate_flags = Vec::new();
162
163                                let batch_is_sequential = tool_calls.iter().any(|tc| {
164                                    tools_map
165                                        .get(tc.name.as_str())
166                                        .map(|t| t.execution_mode() == ExecutionMode::Sequential)
167                                        .unwrap_or(true)
168                                });
169
170                                if batch_is_sequential {
171                                    for tc in &tool_calls {
172                                        let args: serde_json::Value =
173                                            serde_json::from_str(&tc.arguments)
174                                                .unwrap_or(json!({}));
175
176                                        events(AgentEvent::ToolExecutionStart {
177                                            tool_call_id: tc.id.clone(),
178                                            tool_name: tc.name.clone(),
179                                            args: args.clone(),
180                                        });
181
182                                        let result = execute_tool(
183                                            &tc.id,
184                                            &tc.name,
185                                            &args,
186                                            &tools_map,
187                                            hooks,
188                                            &messages,
189                                            cancel.clone(),
190                                        )
191                                        .await;
192
193                                        let is_error = result.is_error;
194                                        let details = result.details.clone();
195                                        terminate_flags.push(result.terminate);
196                                        events(AgentEvent::ToolExecutionEnd {
197                                            tool_call_id: tc.id.clone(),
198                                            tool_name: tc.name.clone(),
199                                            result: serde_json::json!(&result.content),
200                                            details,
201                                            is_error,
202                                        });
203
204                                        let trm = ToolResultMessage {
205                                            tool_call_id: tc.id.clone(),
206                                            tool_name: tc.name.clone(),
207                                            content: result.content,
208                                            details: result.details,
209                                            is_error,
210                                            timestamp_ms: 0,
211                                        };
212                                        tool_results.push(trm.clone());
213                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
214                                    }
215                                } else {
216                                    let tc_args: Vec<_> = tool_calls
217                                        .iter()
218                                        .map(|tc| {
219                                            let args: serde_json::Value =
220                                                serde_json::from_str(&tc.arguments)
221                                                    .unwrap_or(json!({}));
222                                            events(AgentEvent::ToolExecutionStart {
223                                                tool_call_id: tc.id.clone(),
224                                                tool_name: tc.name.clone(),
225                                                args: args.clone(),
226                                            });
227                                            (tc.clone(), args)
228                                        })
229                                        .collect();
230
231                                    let futures: Vec<_> = tc_args
232                                        .iter()
233                                        .map(|(tc, args)| {
234                                            let tools_map = &tools_map;
235                                            let messages = &messages;
236                                            let cancel = cancel.clone();
237                                            let tc_id = tc.id.clone();
238                                            let tc_name = tc.name.clone();
239                                            let args = args.clone();
240                                            async move {
241                                                let result = execute_tool(
242                                                    &tc_id, &tc_name, &args, tools_map, hooks,
243                                                    messages, cancel,
244                                                )
245                                                .await;
246                                                (tc_id, tc_name, result)
247                                            }
248                                        })
249                                        .collect();
250                                    let results = futures_util::future::join_all(futures).await;
251                                    for (tc_id, tc_name, result) in results {
252                                        let is_error = result.is_error;
253                                        let details = result.details.clone();
254                                        terminate_flags.push(result.terminate);
255                                        events(AgentEvent::ToolExecutionEnd {
256                                            tool_call_id: tc_id.clone(),
257                                            tool_name: tc_name.clone(),
258                                            result: serde_json::json!(&result.content),
259                                            details,
260                                            is_error,
261                                        });
262                                        let trm = ToolResultMessage {
263                                            tool_call_id: tc_id,
264                                            tool_name: tc_name,
265                                            content: result.content,
266                                            details: result.details,
267                                            is_error,
268                                            timestamp_ms: 0,
269                                        };
270                                        tool_results.push(trm.clone());
271                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
272                                    }
273                                }
274
275                                let all_terminate = !terminate_flags.is_empty()
276                                    && terminate_flags.iter().all(|t| *t);
277
278                                events(AgentEvent::TurnEnd {
279                                    message: agent_msg,
280                                    tool_results: tool_results.clone(),
281                                });
282
283                                if all_terminate {
284                                    events(AgentEvent::AgentEnd {
285                                        messages: messages.clone(),
286                                    });
287                                    return Ok(messages);
288                                }
289
290                                let stop_ctx = ShouldStopAfterTurnContext {
291                                    messages: messages.clone(),
292                                    tool_results,
293                                };
294                                if hooks.should_stop_after_turn(stop_ctx).await {
295                                    events(AgentEvent::AgentEnd {
296                                        messages: messages.clone(),
297                                    });
298                                    return Ok(messages);
299                                }
300
301                                break 'stream;
302                            }
303
304                            events(AgentEvent::TurnEnd {
305                                message: agent_msg.clone(),
306                                tool_results: vec![],
307                            });
308
309                            let stop_ctx = ShouldStopAfterTurnContext {
310                                messages: messages.clone(),
311                                tool_results: vec![],
312                            };
313                            if hooks.should_stop_after_turn(stop_ctx).await {
314                                events(AgentEvent::AgentEnd {
315                                    messages: messages.clone(),
316                                });
317                                return Ok(messages);
318                            }
319                        }
320                    }
321                    Err(e) => {
322                        if e.is_retryable()
323                            && retry_attempt < max_attempts
324                            && let Some(ref rc) = config.retry
325                        {
326                            let retry_after_ms = match &e {
327                                opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
328                                    *retry_after_ms
329                                }
330                                _ => None,
331                            };
332                            let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
333                            retry_attempt += 1;
334
335                            events(AgentEvent::AutoRetryStart {
336                                attempt: retry_attempt,
337                                max_attempts: rc.max_attempts,
338                                delay_ms,
339                                error_message: e.to_string(),
340                            });
341
342                            tokio::select! {
343                                biased;
344                                _ = cancel.cancelled() => {
345                                    events(AgentEvent::AgentEnd {
346                                        messages: messages.clone(),
347                                    });
348                                    return Err(AgentError::Cancelled);
349                                }
350                                _ = tokio::time::sleep(
351                                    std::time::Duration::from_millis(delay_ms)
352                                ) => {}
353                            }
354                            continue 'stream;
355                        }
356
357                        if retry_attempt > 0 {
358                            events(AgentEvent::AutoRetryEnd {
359                                success: false,
360                                attempt: retry_attempt,
361                                final_error: Some(e.to_string()),
362                            });
363                        }
364
365                        events(AgentEvent::AgentEnd {
366                            messages: messages.clone(),
367                        });
368                        return Err(match &e {
369                            opi_ai::provider::ProviderError::AuthFailed(msg) => {
370                                AgentError::AuthFailed(msg.clone())
371                            }
372                            _ => AgentError::Provider(e.to_string()),
373                        });
374                    }
375                }
376            }
377
378            if retry_attempt > 0 {
379                events(AgentEvent::AutoRetryEnd {
380                    success: true,
381                    attempt: retry_attempt,
382                    final_error: None,
383                });
384            }
385            break 'stream;
386        }
387
388        // -- Queue polling after turn completes --------------------------------
389
390        // H5: prepare_next_turn hook
391        let next_turn_ctx = hooks::PrepareNextTurnContext {
392            messages: messages.clone(),
393            turn: turn_idx + 1,
394        };
395        let mut hook_injected = false;
396        if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
397            && !update.extra_messages.is_empty()
398        {
399            hook_injected = true;
400            messages.extend(update.extra_messages);
401        }
402
403        // Poll steering queue (drain all)
404        let steering = drain_queue(&context.steering_queue);
405        if !steering.is_empty() {
406            events(AgentEvent::QueueUpdate {
407                steering: steering.clone(),
408                follow_up: vec![],
409            });
410            for msg in steering {
411                messages.push(user_text_message(msg));
412            }
413            continue; // next turn
414        }
415
416        // If hook injected messages, continue so they reach the provider
417        if hook_injected {
418            continue;
419        }
420
421        // If no tools pending (agent would stop), poll follow-up (one at a time)
422        if !has_tools_pending {
423            let follow_up = pop_follow_up(&context.follow_up_queue);
424            if !follow_up.is_empty() {
425                events(AgentEvent::QueueUpdate {
426                    steering: vec![],
427                    follow_up: follow_up.clone(),
428                });
429                for msg in follow_up {
430                    messages.push(user_text_message(msg));
431                }
432                continue; // next turn
433            }
434            break; // no tools, no queues → agent stops
435        }
436
437        // Tools were executed and queues are empty → continue to next turn
438        let _ = turn_idx;
439    }
440
441    events(AgentEvent::AgentEnd {
442        messages: messages.clone(),
443    });
444    Ok(messages)
445}
446
447/// Process a single stream event, updating content and emitting message events.
448/// Returns Some(AssistantMessage) when a terminal event is received.
449fn process_stream_event(
450    event: &opi_ai::stream::AssistantStreamEvent,
451    content: &mut Vec<AssistantContent>,
452    events: &AgentEventSink,
453) -> Option<opi_ai::message::AssistantMessage> {
454    use opi_ai::stream::AssistantStreamEvent::*;
455
456    match event {
457        Start { partial } => {
458            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
459            events(AgentEvent::MessageStart { message: msg });
460            None
461        }
462        TextDelta { delta, partial, .. } => {
463            // Accumulate text into content vector
464            match content.last_mut() {
465                Some(AssistantContent::Text { text }) => {
466                    text.push_str(delta);
467                }
468                _ => {
469                    content.push(AssistantContent::Text {
470                        text: delta.clone(),
471                    });
472                }
473            }
474            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
475            events(AgentEvent::MessageUpdate {
476                message: msg,
477                assistant_event: Box::new(event.clone()),
478            });
479            None
480        }
481        ToolCallEnd { tool_call, .. } => {
482            content.push(AssistantContent::ToolCall {
483                tool_call: tool_call.clone(),
484            });
485            None
486        }
487        ThinkingStart { partial, .. }
488        | ThinkingDelta { partial, .. }
489        | ThinkingEnd { partial, .. } => {
490            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
491            events(AgentEvent::MessageUpdate {
492                message: msg,
493                assistant_event: Box::new(event.clone()),
494            });
495            None
496        }
497        Done { message, .. } => Some(message.clone()),
498        Error { message, .. } => Some(message.clone()),
499        _ => None,
500    }
501}
502
503/// Execute a single tool, with validation and hook integration.
504async fn execute_tool(
505    call_id: &str,
506    tool_name: &str,
507    args: &serde_json::Value,
508    tools_map: &HashMap<String, &dyn Tool>,
509    hooks: &dyn AgentHooks,
510    messages: &[AgentMessage],
511    cancel: CancellationToken,
512) -> ToolResult {
513    let tool = match tools_map.get(tool_name) {
514        Some(t) => *t,
515        None => {
516            return ToolResult {
517                content: vec![opi_ai::message::OutputContent::Text {
518                    text: format!("unknown tool: {tool_name}"),
519                }],
520                details: None,
521                is_error: true,
522                terminate: false,
523            };
524        }
525    };
526
527    // Validate arguments against schema
528    let schema = &tool.definition().input_schema;
529    if let Err(err) = validation::validate(schema, args) {
530        return ToolResult::from_validation_error(err);
531    }
532
533    // Run before_tool_call hook
534    let ctx = BeforeToolCallContext {
535        tool_call_id: call_id.to_owned(),
536        tool_name: tool_name.to_owned(),
537        args: args.clone(),
538        messages: messages.to_vec(),
539    };
540    match hooks.before_tool_call(ctx).await {
541        BeforeToolCallResult::Allow => {}
542        BeforeToolCallResult::Deny { reason } => {
543            return ToolResult {
544                content: vec![opi_ai::message::OutputContent::Text { text: reason }],
545                details: None,
546                is_error: true,
547                terminate: false,
548            };
549        }
550    }
551
552    // Execute the tool
553    match tool.execute(call_id, args.clone(), cancel, None).await {
554        Ok(result) => {
555            let ctx = AfterToolCallContext {
556                tool_call_id: call_id.to_owned(),
557                tool_name: tool_name.to_owned(),
558                result: result.clone(),
559            };
560            match hooks.after_tool_call(ctx).await {
561                AfterToolCallResult::Keep => result,
562                AfterToolCallResult::Replace(replacement) => replacement,
563            }
564        }
565        Err(e) => ToolResult {
566            content: vec![opi_ai::message::OutputContent::Text {
567                text: e.to_string(),
568            }],
569            details: None,
570            is_error: true,
571            terminate: false,
572        },
573    }
574}
575
576/// Drain all messages from a queue (steering mode: All).
577fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
578    match queue {
579        Some(q) => {
580            let mut q = q.lock().unwrap();
581            q.drain(..).collect()
582        }
583        None => vec![],
584    }
585}
586
587/// Pop one message from a queue (follow-up mode: OneAtATime).
588fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
589    match queue {
590        Some(q) => {
591            let mut q = q.lock().unwrap();
592            match q.pop_front() {
593                Some(msg) => vec![msg],
594                None => vec![],
595            }
596        }
597        None => vec![],
598    }
599}
600
601/// Create a user text AgentMessage.
602fn user_text_message(text: String) -> AgentMessage {
603    AgentMessage::Llm(Message::User(UserMessage {
604        content: vec![InputContent::Text { text }],
605        timestamp_ms: 0,
606    }))
607}