1pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod extension;
10pub mod hooks;
11pub mod loop_types;
12pub mod message;
13pub mod sdk;
14pub mod session;
15pub mod session_branch;
16pub mod session_event;
17pub mod state;
18pub mod streaming_proxy;
19pub mod tool;
20pub mod validation;
21
22pub use agent::Agent;
23pub use event::{AgentEvent, AgentEventSink};
24pub use extension::{
25 Extension, ExtensionCommand, ExtensionError, ExtensionHookResult, ExtensionRegistry,
26};
27pub use hooks::AgentHooks;
28pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
29pub use message::AgentMessage;
30pub use sdk::{SDK_SCHEMA_VERSION, SdkCommand, SdkResponse};
31pub use session_event::AgentSessionEvent;
32pub use state::AgentState;
33pub use streaming_proxy::{
34 ProxyConfig, ProxyEvent, ProxyHandler, SecretRedactor, StreamingProxy, StreamingProxyError,
35};
36pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
37
38pub use opi_ai::message::ToolDef;
40
41use std::collections::{HashMap, VecDeque};
42use std::sync::{Arc, Mutex};
43
44use futures_util::StreamExt;
45use hooks::{
46 AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
47 ShouldStopAfterTurnContext,
48};
49use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
50use opi_ai::provider::{Request, validate_request_capabilities};
51use serde_json::json;
52use tokio_util::sync::CancellationToken;
53
54pub async fn agent_loop(
60 context: AgentLoopContext,
61 config: AgentLoopConfig,
62 hooks: &dyn AgentHooks,
63 events: AgentEventSink,
64 cancel: CancellationToken,
65) -> Result<Vec<AgentMessage>, AgentError> {
66 let tools_map: HashMap<String, &dyn Tool> = context
67 .tools
68 .iter()
69 .map(|t| (t.definition().name.clone(), t.as_ref()))
70 .collect();
71 let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
72
73 let mut messages = context.messages;
74
75 events(AgentEvent::AgentStart);
76
77 let mut has_tools_pending;
78 for turn_idx in 0..config.max_turns {
79 if cancel.is_cancelled() {
80 events(AgentEvent::AgentEnd {
81 messages: messages.clone(),
82 });
83 return Err(AgentError::Cancelled);
84 }
85
86 events(AgentEvent::TurnStart);
87
88 let transformed = hooks
90 .transform_context(messages.clone(), cancel.clone())
91 .await?;
92
93 let llm_messages = hooks.convert_to_llm(&transformed)?;
95
96 let mut assistant_content: Vec<AssistantContent> = Vec::new();
98 has_tools_pending = false;
99 let mut retry_attempt: u32 = 0;
100 let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
101
102 'stream: loop {
103 let request = Request {
104 model: context.model.clone(),
105 system: context.system.clone(),
106 messages: llm_messages.clone(),
107 tools: tool_defs.clone(),
108 max_tokens: config.max_tokens,
109 temperature: config.temperature,
110 thinking: config.thinking.clone().unwrap_or_default(),
111 stop_sequences: vec![],
112 metadata: None,
113 cancel: cancel.clone(),
114 };
115 if let Err(e) = validate_request_capabilities(context.provider.as_ref(), &request) {
116 events(AgentEvent::AgentEnd {
117 messages: messages.clone(),
118 });
119 return Err(AgentError::Provider(e.to_string()));
120 }
121 let mut stream = context.provider.stream(request);
122 assistant_content.clear();
123
124 while let Some(item) = {
125 tokio::select! {
126 biased;
127 _ = cancel.cancelled() => {
128 events(AgentEvent::AgentEnd {
129 messages: messages.clone(),
130 });
131 return Err(AgentError::Cancelled);
132 }
133 item = stream.next() => item,
134 }
135 } {
136 match item {
137 Ok(event) => {
138 if let Some(msg) =
139 process_stream_event(&event, &mut assistant_content, &events)
140 {
141 let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
145
146 events(AgentEvent::MessageEnd {
147 message: agent_msg.clone(),
148 });
149
150 messages.push(agent_msg.clone());
151
152 let content = match &agent_msg {
154 AgentMessage::Llm(Message::Assistant(a)) => &a.content,
155 _ => &Vec::new(),
156 };
157 let tool_calls: Vec<_> = content
158 .iter()
159 .filter_map(|c| match c {
160 AssistantContent::ToolCall { tool_call } => {
161 Some(tool_call.clone())
162 }
163 _ => None,
164 })
165 .collect();
166
167 if !tool_calls.is_empty() {
168 has_tools_pending = true;
169 let mut tool_results = Vec::new();
170 let mut terminate_flags = Vec::new();
171
172 let batch_is_sequential = tool_calls.iter().any(|tc| {
173 tools_map
174 .get(tc.name.as_str())
175 .map(|t| t.execution_mode() == ExecutionMode::Sequential)
176 .unwrap_or(true)
177 });
178
179 if batch_is_sequential {
180 for tc in &tool_calls {
181 let args: serde_json::Value =
182 serde_json::from_str(&tc.arguments)
183 .unwrap_or(json!({}));
184
185 events(AgentEvent::ToolExecutionStart {
186 tool_call_id: tc.id.clone(),
187 tool_name: tc.name.clone(),
188 args: args.clone(),
189 });
190
191 let result = execute_tool(
192 &tc.id,
193 &tc.name,
194 &args,
195 &tools_map,
196 hooks,
197 &messages,
198 cancel.clone(),
199 )
200 .await;
201
202 let is_error = result.is_error;
203 let details = result.details.clone();
204 terminate_flags.push(result.terminate);
205 events(AgentEvent::ToolExecutionEnd {
206 tool_call_id: tc.id.clone(),
207 tool_name: tc.name.clone(),
208 result: serde_json::json!(&result.content),
209 details,
210 is_error,
211 });
212
213 let trm = ToolResultMessage {
214 tool_call_id: tc.id.clone(),
215 tool_name: tc.name.clone(),
216 content: result.content,
217 details: result.details,
218 is_error,
219 timestamp_ms: 0,
220 };
221 tool_results.push(trm.clone());
222 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
223 }
224 } else {
225 let tc_args: Vec<_> = tool_calls
226 .iter()
227 .map(|tc| {
228 let args: serde_json::Value =
229 serde_json::from_str(&tc.arguments)
230 .unwrap_or(json!({}));
231 events(AgentEvent::ToolExecutionStart {
232 tool_call_id: tc.id.clone(),
233 tool_name: tc.name.clone(),
234 args: args.clone(),
235 });
236 (tc.clone(), args)
237 })
238 .collect();
239
240 let futures: Vec<_> = tc_args
241 .iter()
242 .map(|(tc, args)| {
243 let tools_map = &tools_map;
244 let messages = &messages;
245 let cancel = cancel.clone();
246 let tc_id = tc.id.clone();
247 let tc_name = tc.name.clone();
248 let args = args.clone();
249 async move {
250 let result = execute_tool(
251 &tc_id, &tc_name, &args, tools_map, hooks,
252 messages, cancel,
253 )
254 .await;
255 (tc_id, tc_name, result)
256 }
257 })
258 .collect();
259 let results = futures_util::future::join_all(futures).await;
260 for (tc_id, tc_name, result) in results {
261 let is_error = result.is_error;
262 let details = result.details.clone();
263 terminate_flags.push(result.terminate);
264 events(AgentEvent::ToolExecutionEnd {
265 tool_call_id: tc_id.clone(),
266 tool_name: tc_name.clone(),
267 result: serde_json::json!(&result.content),
268 details,
269 is_error,
270 });
271 let trm = ToolResultMessage {
272 tool_call_id: tc_id,
273 tool_name: tc_name,
274 content: result.content,
275 details: result.details,
276 is_error,
277 timestamp_ms: 0,
278 };
279 tool_results.push(trm.clone());
280 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
281 }
282 }
283
284 let all_terminate = !terminate_flags.is_empty()
285 && terminate_flags.iter().all(|t| *t);
286
287 events(AgentEvent::TurnEnd {
288 message: agent_msg,
289 tool_results: tool_results.clone(),
290 });
291
292 if all_terminate {
293 events(AgentEvent::AgentEnd {
294 messages: messages.clone(),
295 });
296 return Ok(messages);
297 }
298
299 let stop_ctx = ShouldStopAfterTurnContext {
300 messages: messages.clone(),
301 tool_results,
302 };
303 if hooks.should_stop_after_turn(stop_ctx).await {
304 events(AgentEvent::AgentEnd {
305 messages: messages.clone(),
306 });
307 return Ok(messages);
308 }
309
310 break 'stream;
311 }
312
313 events(AgentEvent::TurnEnd {
314 message: agent_msg.clone(),
315 tool_results: vec![],
316 });
317
318 let stop_ctx = ShouldStopAfterTurnContext {
319 messages: messages.clone(),
320 tool_results: vec![],
321 };
322 if hooks.should_stop_after_turn(stop_ctx).await {
323 events(AgentEvent::AgentEnd {
324 messages: messages.clone(),
325 });
326 return Ok(messages);
327 }
328 }
329 }
330 Err(e) => {
331 if e.is_retryable()
332 && retry_attempt < max_attempts
333 && let Some(ref rc) = config.retry
334 {
335 let retry_after_ms = match &e {
336 opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
337 *retry_after_ms
338 }
339 _ => None,
340 };
341 let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
342 retry_attempt += 1;
343
344 events(AgentEvent::AutoRetryStart {
345 attempt: retry_attempt,
346 max_attempts: rc.max_attempts,
347 delay_ms,
348 error_message: e.to_string(),
349 });
350
351 tokio::select! {
352 biased;
353 _ = cancel.cancelled() => {
354 events(AgentEvent::AgentEnd {
355 messages: messages.clone(),
356 });
357 return Err(AgentError::Cancelled);
358 }
359 _ = tokio::time::sleep(
360 std::time::Duration::from_millis(delay_ms)
361 ) => {}
362 }
363 continue 'stream;
364 }
365
366 if retry_attempt > 0 {
367 events(AgentEvent::AutoRetryEnd {
368 success: false,
369 attempt: retry_attempt,
370 final_error: Some(e.to_string()),
371 });
372 }
373
374 events(AgentEvent::AgentEnd {
375 messages: messages.clone(),
376 });
377 return Err(match &e {
378 opi_ai::provider::ProviderError::AuthFailed(msg) => {
379 AgentError::AuthFailed(msg.clone())
380 }
381 _ => AgentError::Provider(e.to_string()),
382 });
383 }
384 }
385 }
386
387 if retry_attempt > 0 {
388 events(AgentEvent::AutoRetryEnd {
389 success: true,
390 attempt: retry_attempt,
391 final_error: None,
392 });
393 }
394 break 'stream;
395 }
396
397 let next_turn_ctx = hooks::PrepareNextTurnContext {
401 messages: messages.clone(),
402 turn: turn_idx + 1,
403 };
404 let mut hook_injected = false;
405 if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
406 && !update.extra_messages.is_empty()
407 {
408 hook_injected = true;
409 messages.extend(update.extra_messages);
410 }
411
412 let steering = drain_queue(&context.steering_queue);
414 if !steering.is_empty() {
415 events(AgentEvent::QueueUpdate {
416 steering: steering.clone(),
417 follow_up: vec![],
418 });
419 for msg in steering {
420 messages.push(user_text_message(msg));
421 }
422 continue; }
424
425 if hook_injected {
427 continue;
428 }
429
430 if !has_tools_pending {
432 let follow_up = pop_follow_up(&context.follow_up_queue);
433 if !follow_up.is_empty() {
434 events(AgentEvent::QueueUpdate {
435 steering: vec![],
436 follow_up: follow_up.clone(),
437 });
438 for msg in follow_up {
439 messages.push(user_text_message(msg));
440 }
441 continue; }
443 break; }
445
446 let _ = turn_idx;
448 }
449
450 events(AgentEvent::AgentEnd {
451 messages: messages.clone(),
452 });
453 Ok(messages)
454}
455
456fn process_stream_event(
459 event: &opi_ai::stream::AssistantStreamEvent,
460 content: &mut Vec<AssistantContent>,
461 events: &AgentEventSink,
462) -> Option<opi_ai::message::AssistantMessage> {
463 use opi_ai::stream::AssistantStreamEvent::*;
464
465 match event {
466 Start { partial } => {
467 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
468 events(AgentEvent::MessageStart { message: msg });
469 None
470 }
471 TextDelta { delta, partial, .. } => {
472 match content.last_mut() {
474 Some(AssistantContent::Text { text }) => {
475 text.push_str(delta);
476 }
477 _ => {
478 content.push(AssistantContent::Text {
479 text: delta.clone(),
480 });
481 }
482 }
483 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
484 events(AgentEvent::MessageUpdate {
485 message: msg,
486 assistant_event: Box::new(event.clone()),
487 });
488 None
489 }
490 ToolCallEnd { tool_call, .. } => {
491 content.push(AssistantContent::ToolCall {
492 tool_call: tool_call.clone(),
493 });
494 None
495 }
496 ThinkingStart { partial, .. }
497 | ThinkingDelta { partial, .. }
498 | ThinkingEnd { partial, .. } => {
499 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
500 events(AgentEvent::MessageUpdate {
501 message: msg,
502 assistant_event: Box::new(event.clone()),
503 });
504 None
505 }
506 Done { message, .. } => Some(message.clone()),
507 Error { message, .. } => Some(message.clone()),
508 _ => None,
509 }
510}
511
512async fn execute_tool(
514 call_id: &str,
515 tool_name: &str,
516 args: &serde_json::Value,
517 tools_map: &HashMap<String, &dyn Tool>,
518 hooks: &dyn AgentHooks,
519 messages: &[AgentMessage],
520 cancel: CancellationToken,
521) -> ToolResult {
522 let tool = match tools_map.get(tool_name) {
523 Some(t) => *t,
524 None => {
525 return ToolResult {
526 content: vec![opi_ai::message::OutputContent::Text {
527 text: format!("unknown tool: {tool_name}"),
528 }],
529 details: None,
530 is_error: true,
531 terminate: false,
532 };
533 }
534 };
535
536 let schema = &tool.definition().input_schema;
538 if let Err(err) = validation::validate(schema, args) {
539 return ToolResult::from_validation_error(err);
540 }
541
542 let ctx = BeforeToolCallContext {
544 tool_call_id: call_id.to_owned(),
545 tool_name: tool_name.to_owned(),
546 args: args.clone(),
547 messages: messages.to_vec(),
548 };
549 match hooks.before_tool_call(ctx).await {
550 BeforeToolCallResult::Allow => {}
551 BeforeToolCallResult::Deny { reason } => {
552 return ToolResult {
553 content: vec![opi_ai::message::OutputContent::Text { text: reason }],
554 details: None,
555 is_error: true,
556 terminate: false,
557 };
558 }
559 }
560
561 match tool.execute(call_id, args.clone(), cancel, None).await {
563 Ok(result) => {
564 let ctx = AfterToolCallContext {
565 tool_call_id: call_id.to_owned(),
566 tool_name: tool_name.to_owned(),
567 result: result.clone(),
568 };
569 match hooks.after_tool_call(ctx).await {
570 AfterToolCallResult::Keep => result,
571 AfterToolCallResult::Replace(replacement) => replacement,
572 }
573 }
574 Err(e) => ToolResult {
575 content: vec![opi_ai::message::OutputContent::Text {
576 text: e.to_string(),
577 }],
578 details: None,
579 is_error: true,
580 terminate: false,
581 },
582 }
583}
584
585fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
587 match queue {
588 Some(q) => {
589 let mut q = q.lock().unwrap();
590 q.drain(..).collect()
591 }
592 None => vec![],
593 }
594}
595
596fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
598 match queue {
599 Some(q) => {
600 let mut q = q.lock().unwrap();
601 match q.pop_front() {
602 Some(msg) => vec![msg],
603 None => vec![],
604 }
605 }
606 None => vec![],
607 }
608}
609
610fn user_text_message(text: String) -> AgentMessage {
612 AgentMessage::Llm(Message::User(UserMessage {
613 content: vec![InputContent::Text { text }],
614 timestamp_ms: 0,
615 }))
616}