Skip to main content

opi_agent/
lib.rs

1//! General-purpose agent runtime with tool calling and session management.
2//!
3//! Provides the foundation for building specialized agents with pluggable
4//! tool systems, hooks, queues, and session persistence.
5
6pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod extension;
10pub mod hooks;
11pub mod loop_types;
12pub mod message;
13pub mod sdk;
14pub mod session;
15pub mod session_branch;
16pub mod session_event;
17pub mod state;
18pub mod streaming_proxy;
19pub mod tool;
20pub mod validation;
21
22pub use agent::Agent;
23pub use event::{AgentEvent, AgentEventSink};
24pub use extension::{
25    Extension, ExtensionCommand, ExtensionError, ExtensionHookResult, ExtensionRegistry,
26};
27pub use hooks::AgentHooks;
28pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
29pub use message::AgentMessage;
30pub use sdk::{SDK_SCHEMA_VERSION, SdkCommand, SdkResponse};
31pub use session_event::AgentSessionEvent;
32pub use state::AgentState;
33pub use streaming_proxy::{
34    ProxyConfig, ProxyEvent, ProxyHandler, SecretRedactor, StreamingProxy, StreamingProxyError,
35};
36pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
37
38// Re-export provider-facing types needed at the agent boundary.
39pub use opi_ai::message::ToolDef;
40
41use std::collections::{HashMap, VecDeque};
42use std::sync::{Arc, Mutex};
43
44use futures_util::StreamExt;
45use hooks::{
46    AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
47    ShouldStopAfterTurnContext,
48};
49use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
50use opi_ai::provider::{Request, validate_request_capabilities};
51use serde_json::json;
52use tokio_util::sync::CancellationToken;
53
54/// Run the agent loop until completion or cancellation.
55///
56/// The loop iterates: provider request → stream response → detect tool calls
57/// → validate and execute tools → send tool results back → repeat until no
58/// tool calls or stop condition.
59pub async fn agent_loop(
60    context: AgentLoopContext,
61    config: AgentLoopConfig,
62    hooks: &dyn AgentHooks,
63    events: AgentEventSink,
64    cancel: CancellationToken,
65) -> Result<Vec<AgentMessage>, AgentError> {
66    let tools_map: HashMap<String, &dyn Tool> = context
67        .tools
68        .iter()
69        .map(|t| (t.definition().name.clone(), t.as_ref()))
70        .collect();
71    let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
72
73    let mut messages = context.messages;
74
75    events(AgentEvent::AgentStart);
76
77    let mut has_tools_pending;
78    for turn_idx in 0..config.max_turns {
79        if cancel.is_cancelled() {
80            events(AgentEvent::AgentEnd {
81                messages: messages.clone(),
82            });
83            return Err(AgentError::Cancelled);
84        }
85
86        events(AgentEvent::TurnStart);
87
88        // H5: transform context before provider call
89        let transformed = hooks
90            .transform_context(messages.clone(), cancel.clone())
91            .await?;
92
93        // Convert messages for the provider
94        let llm_messages = hooks.convert_to_llm(&transformed)?;
95
96        // Stream the response (with optional retry on retryable errors)
97        let mut assistant_content: Vec<AssistantContent> = Vec::new();
98        has_tools_pending = false;
99        let mut retry_attempt: u32 = 0;
100        let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
101
102        'stream: loop {
103            let request = Request {
104                model: context.model.clone(),
105                system: context.system.clone(),
106                messages: llm_messages.clone(),
107                tools: tool_defs.clone(),
108                max_tokens: config.max_tokens,
109                temperature: config.temperature,
110                thinking: config.thinking.clone().unwrap_or_default(),
111                stop_sequences: vec![],
112                metadata: None,
113                cancel: cancel.clone(),
114            };
115            if let Err(e) = validate_request_capabilities(context.provider.as_ref(), &request) {
116                events(AgentEvent::AgentEnd {
117                    messages: messages.clone(),
118                });
119                return Err(AgentError::Provider(e.to_string()));
120            }
121            let mut stream = context.provider.stream(request);
122            assistant_content.clear();
123
124            while let Some(item) = {
125                tokio::select! {
126                    biased;
127                    _ = cancel.cancelled() => {
128                        events(AgentEvent::AgentEnd {
129                            messages: messages.clone(),
130                        });
131                        return Err(AgentError::Cancelled);
132                    }
133                    item = stream.next() => item,
134                }
135            } {
136                match item {
137                    Ok(event) => {
138                        if let Some(msg) =
139                            process_stream_event(&event, &mut assistant_content, &events)
140                        {
141                            // Use the provider's complete content (includes thinking)
142                            // rather than the locally-accumulated assistant_content
143                            // which only tracks text and tool calls.
144                            let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
145
146                            events(AgentEvent::MessageEnd {
147                                message: agent_msg.clone(),
148                            });
149
150                            messages.push(agent_msg.clone());
151
152                            // Extract tool calls from the provider's complete content.
153                            let content = match &agent_msg {
154                                AgentMessage::Llm(Message::Assistant(a)) => &a.content,
155                                _ => &Vec::new(),
156                            };
157                            let tool_calls: Vec<_> = content
158                                .iter()
159                                .filter_map(|c| match c {
160                                    AssistantContent::ToolCall { tool_call } => {
161                                        Some(tool_call.clone())
162                                    }
163                                    _ => None,
164                                })
165                                .collect();
166
167                            if !tool_calls.is_empty() {
168                                has_tools_pending = true;
169                                let mut tool_results = Vec::new();
170                                let mut terminate_flags = Vec::new();
171
172                                let batch_is_sequential = tool_calls.iter().any(|tc| {
173                                    tools_map
174                                        .get(tc.name.as_str())
175                                        .map(|t| t.execution_mode() == ExecutionMode::Sequential)
176                                        .unwrap_or(true)
177                                });
178
179                                if batch_is_sequential {
180                                    for tc in &tool_calls {
181                                        let args: serde_json::Value =
182                                            serde_json::from_str(&tc.arguments)
183                                                .unwrap_or(json!({}));
184
185                                        events(AgentEvent::ToolExecutionStart {
186                                            tool_call_id: tc.id.clone(),
187                                            tool_name: tc.name.clone(),
188                                            args: args.clone(),
189                                        });
190
191                                        let result = execute_tool(
192                                            &tc.id,
193                                            &tc.name,
194                                            &args,
195                                            &tools_map,
196                                            hooks,
197                                            &messages,
198                                            cancel.clone(),
199                                        )
200                                        .await;
201
202                                        let is_error = result.is_error;
203                                        let details = result.details.clone();
204                                        terminate_flags.push(result.terminate);
205                                        events(AgentEvent::ToolExecutionEnd {
206                                            tool_call_id: tc.id.clone(),
207                                            tool_name: tc.name.clone(),
208                                            result: serde_json::json!(&result.content),
209                                            details,
210                                            is_error,
211                                        });
212
213                                        let trm = ToolResultMessage {
214                                            tool_call_id: tc.id.clone(),
215                                            tool_name: tc.name.clone(),
216                                            content: result.content,
217                                            details: result.details,
218                                            is_error,
219                                            timestamp_ms: 0,
220                                        };
221                                        tool_results.push(trm.clone());
222                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
223                                    }
224                                } else {
225                                    let tc_args: Vec<_> = tool_calls
226                                        .iter()
227                                        .map(|tc| {
228                                            let args: serde_json::Value =
229                                                serde_json::from_str(&tc.arguments)
230                                                    .unwrap_or(json!({}));
231                                            events(AgentEvent::ToolExecutionStart {
232                                                tool_call_id: tc.id.clone(),
233                                                tool_name: tc.name.clone(),
234                                                args: args.clone(),
235                                            });
236                                            (tc.clone(), args)
237                                        })
238                                        .collect();
239
240                                    let futures: Vec<_> = tc_args
241                                        .iter()
242                                        .map(|(tc, args)| {
243                                            let tools_map = &tools_map;
244                                            let messages = &messages;
245                                            let cancel = cancel.clone();
246                                            let tc_id = tc.id.clone();
247                                            let tc_name = tc.name.clone();
248                                            let args = args.clone();
249                                            async move {
250                                                let result = execute_tool(
251                                                    &tc_id, &tc_name, &args, tools_map, hooks,
252                                                    messages, cancel,
253                                                )
254                                                .await;
255                                                (tc_id, tc_name, result)
256                                            }
257                                        })
258                                        .collect();
259                                    let results = futures_util::future::join_all(futures).await;
260                                    for (tc_id, tc_name, result) in results {
261                                        let is_error = result.is_error;
262                                        let details = result.details.clone();
263                                        terminate_flags.push(result.terminate);
264                                        events(AgentEvent::ToolExecutionEnd {
265                                            tool_call_id: tc_id.clone(),
266                                            tool_name: tc_name.clone(),
267                                            result: serde_json::json!(&result.content),
268                                            details,
269                                            is_error,
270                                        });
271                                        let trm = ToolResultMessage {
272                                            tool_call_id: tc_id,
273                                            tool_name: tc_name,
274                                            content: result.content,
275                                            details: result.details,
276                                            is_error,
277                                            timestamp_ms: 0,
278                                        };
279                                        tool_results.push(trm.clone());
280                                        messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
281                                    }
282                                }
283
284                                let all_terminate = !terminate_flags.is_empty()
285                                    && terminate_flags.iter().all(|t| *t);
286
287                                events(AgentEvent::TurnEnd {
288                                    message: agent_msg,
289                                    tool_results: tool_results.clone(),
290                                });
291
292                                if all_terminate {
293                                    events(AgentEvent::AgentEnd {
294                                        messages: messages.clone(),
295                                    });
296                                    return Ok(messages);
297                                }
298
299                                let stop_ctx = ShouldStopAfterTurnContext {
300                                    messages: messages.clone(),
301                                    tool_results,
302                                };
303                                if hooks.should_stop_after_turn(stop_ctx).await {
304                                    events(AgentEvent::AgentEnd {
305                                        messages: messages.clone(),
306                                    });
307                                    return Ok(messages);
308                                }
309
310                                break 'stream;
311                            }
312
313                            events(AgentEvent::TurnEnd {
314                                message: agent_msg.clone(),
315                                tool_results: vec![],
316                            });
317
318                            let stop_ctx = ShouldStopAfterTurnContext {
319                                messages: messages.clone(),
320                                tool_results: vec![],
321                            };
322                            if hooks.should_stop_after_turn(stop_ctx).await {
323                                events(AgentEvent::AgentEnd {
324                                    messages: messages.clone(),
325                                });
326                                return Ok(messages);
327                            }
328                        }
329                    }
330                    Err(e) => {
331                        if e.is_retryable()
332                            && retry_attempt < max_attempts
333                            && let Some(ref rc) = config.retry
334                        {
335                            let retry_after_ms = match &e {
336                                opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
337                                    *retry_after_ms
338                                }
339                                _ => None,
340                            };
341                            let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
342                            retry_attempt += 1;
343
344                            events(AgentEvent::AutoRetryStart {
345                                attempt: retry_attempt,
346                                max_attempts: rc.max_attempts,
347                                delay_ms,
348                                error_message: e.to_string(),
349                            });
350
351                            tokio::select! {
352                                biased;
353                                _ = cancel.cancelled() => {
354                                    events(AgentEvent::AgentEnd {
355                                        messages: messages.clone(),
356                                    });
357                                    return Err(AgentError::Cancelled);
358                                }
359                                _ = tokio::time::sleep(
360                                    std::time::Duration::from_millis(delay_ms)
361                                ) => {}
362                            }
363                            continue 'stream;
364                        }
365
366                        if retry_attempt > 0 {
367                            events(AgentEvent::AutoRetryEnd {
368                                success: false,
369                                attempt: retry_attempt,
370                                final_error: Some(e.to_string()),
371                            });
372                        }
373
374                        events(AgentEvent::AgentEnd {
375                            messages: messages.clone(),
376                        });
377                        return Err(match &e {
378                            opi_ai::provider::ProviderError::AuthFailed(msg) => {
379                                AgentError::AuthFailed(msg.clone())
380                            }
381                            _ => AgentError::Provider(e.to_string()),
382                        });
383                    }
384                }
385            }
386
387            if retry_attempt > 0 {
388                events(AgentEvent::AutoRetryEnd {
389                    success: true,
390                    attempt: retry_attempt,
391                    final_error: None,
392                });
393            }
394            break 'stream;
395        }
396
397        // -- Queue polling after turn completes --------------------------------
398
399        // H5: prepare_next_turn hook
400        let next_turn_ctx = hooks::PrepareNextTurnContext {
401            messages: messages.clone(),
402            turn: turn_idx + 1,
403        };
404        let mut hook_injected = false;
405        if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
406            && !update.extra_messages.is_empty()
407        {
408            hook_injected = true;
409            messages.extend(update.extra_messages);
410        }
411
412        // Poll steering queue (drain all)
413        let steering = drain_queue(&context.steering_queue);
414        if !steering.is_empty() {
415            events(AgentEvent::QueueUpdate {
416                steering: steering.clone(),
417                follow_up: vec![],
418            });
419            for msg in steering {
420                messages.push(user_text_message(msg));
421            }
422            continue; // next turn
423        }
424
425        // If hook injected messages, continue so they reach the provider
426        if hook_injected {
427            continue;
428        }
429
430        // If no tools pending (agent would stop), poll follow-up (one at a time)
431        if !has_tools_pending {
432            let follow_up = pop_follow_up(&context.follow_up_queue);
433            if !follow_up.is_empty() {
434                events(AgentEvent::QueueUpdate {
435                    steering: vec![],
436                    follow_up: follow_up.clone(),
437                });
438                for msg in follow_up {
439                    messages.push(user_text_message(msg));
440                }
441                continue; // next turn
442            }
443            break; // no tools, no queues → agent stops
444        }
445
446        // Tools were executed and queues are empty → continue to next turn
447        let _ = turn_idx;
448    }
449
450    events(AgentEvent::AgentEnd {
451        messages: messages.clone(),
452    });
453    Ok(messages)
454}
455
456/// Process a single stream event, updating content and emitting message events.
457/// Returns Some(AssistantMessage) when a terminal event is received.
458fn process_stream_event(
459    event: &opi_ai::stream::AssistantStreamEvent,
460    content: &mut Vec<AssistantContent>,
461    events: &AgentEventSink,
462) -> Option<opi_ai::message::AssistantMessage> {
463    use opi_ai::stream::AssistantStreamEvent::*;
464
465    match event {
466        Start { partial } => {
467            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
468            events(AgentEvent::MessageStart { message: msg });
469            None
470        }
471        TextDelta { delta, partial, .. } => {
472            // Accumulate text into content vector
473            match content.last_mut() {
474                Some(AssistantContent::Text { text }) => {
475                    text.push_str(delta);
476                }
477                _ => {
478                    content.push(AssistantContent::Text {
479                        text: delta.clone(),
480                    });
481                }
482            }
483            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
484            events(AgentEvent::MessageUpdate {
485                message: msg,
486                assistant_event: Box::new(event.clone()),
487            });
488            None
489        }
490        ToolCallEnd { tool_call, .. } => {
491            content.push(AssistantContent::ToolCall {
492                tool_call: tool_call.clone(),
493            });
494            None
495        }
496        ThinkingStart { partial, .. }
497        | ThinkingDelta { partial, .. }
498        | ThinkingEnd { partial, .. } => {
499            let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
500            events(AgentEvent::MessageUpdate {
501                message: msg,
502                assistant_event: Box::new(event.clone()),
503            });
504            None
505        }
506        Done { message, .. } => Some(message.clone()),
507        Error { message, .. } => Some(message.clone()),
508        _ => None,
509    }
510}
511
512/// Execute a single tool, with validation and hook integration.
513async fn execute_tool(
514    call_id: &str,
515    tool_name: &str,
516    args: &serde_json::Value,
517    tools_map: &HashMap<String, &dyn Tool>,
518    hooks: &dyn AgentHooks,
519    messages: &[AgentMessage],
520    cancel: CancellationToken,
521) -> ToolResult {
522    let tool = match tools_map.get(tool_name) {
523        Some(t) => *t,
524        None => {
525            return ToolResult {
526                content: vec![opi_ai::message::OutputContent::Text {
527                    text: format!("unknown tool: {tool_name}"),
528                }],
529                details: None,
530                is_error: true,
531                terminate: false,
532            };
533        }
534    };
535
536    // Validate arguments against schema
537    let schema = &tool.definition().input_schema;
538    if let Err(err) = validation::validate(schema, args) {
539        return ToolResult::from_validation_error(err);
540    }
541
542    // Run before_tool_call hook
543    let ctx = BeforeToolCallContext {
544        tool_call_id: call_id.to_owned(),
545        tool_name: tool_name.to_owned(),
546        args: args.clone(),
547        messages: messages.to_vec(),
548    };
549    match hooks.before_tool_call(ctx).await {
550        BeforeToolCallResult::Allow => {}
551        BeforeToolCallResult::Deny { reason } => {
552            return ToolResult {
553                content: vec![opi_ai::message::OutputContent::Text { text: reason }],
554                details: None,
555                is_error: true,
556                terminate: false,
557            };
558        }
559    }
560
561    // Execute the tool
562    match tool.execute(call_id, args.clone(), cancel, None).await {
563        Ok(result) => {
564            let ctx = AfterToolCallContext {
565                tool_call_id: call_id.to_owned(),
566                tool_name: tool_name.to_owned(),
567                result: result.clone(),
568            };
569            match hooks.after_tool_call(ctx).await {
570                AfterToolCallResult::Keep => result,
571                AfterToolCallResult::Replace(replacement) => replacement,
572            }
573        }
574        Err(e) => ToolResult {
575            content: vec![opi_ai::message::OutputContent::Text {
576                text: e.to_string(),
577            }],
578            details: None,
579            is_error: true,
580            terminate: false,
581        },
582    }
583}
584
585/// Drain all messages from a queue (steering mode: All).
586fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
587    match queue {
588        Some(q) => {
589            let mut q = q.lock().unwrap();
590            q.drain(..).collect()
591        }
592        None => vec![],
593    }
594}
595
596/// Pop one message from a queue (follow-up mode: OneAtATime).
597fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
598    match queue {
599        Some(q) => {
600            let mut q = q.lock().unwrap();
601            match q.pop_front() {
602                Some(msg) => vec![msg],
603                None => vec![],
604            }
605        }
606        None => vec![],
607    }
608}
609
610/// Create a user text AgentMessage.
611fn user_text_message(text: String) -> AgentMessage {
612    AgentMessage::Llm(Message::User(UserMessage {
613        content: vec![InputContent::Text { text }],
614        timestamp_ms: 0,
615    }))
616}