1pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod hooks;
10pub mod loop_types;
11pub mod message;
12pub mod session;
13pub mod session_event;
14pub mod state;
15pub mod tool;
16pub mod transport;
17pub mod validation;
18
19pub use agent::Agent;
20pub use event::{AgentEvent, AgentEventSink};
21pub use hooks::AgentHooks;
22pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
23pub use message::AgentMessage;
24pub use session_event::AgentSessionEvent;
25pub use state::AgentState;
26pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
27pub use transport::Transport;
28
29pub use opi_ai::message::ToolDef;
31
32use std::collections::{HashMap, VecDeque};
33use std::sync::{Arc, Mutex};
34
35use futures_util::StreamExt;
36use hooks::{
37 AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
38 ShouldStopAfterTurnContext,
39};
40use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
41use opi_ai::provider::{Request, validate_request_capabilities};
42use serde_json::json;
43use tokio_util::sync::CancellationToken;
44
45pub async fn agent_loop(
51 context: AgentLoopContext,
52 config: AgentLoopConfig,
53 hooks: &dyn AgentHooks,
54 events: AgentEventSink,
55 cancel: CancellationToken,
56) -> Result<Vec<AgentMessage>, AgentError> {
57 let tools_map: HashMap<String, &dyn Tool> = context
58 .tools
59 .iter()
60 .map(|t| (t.definition().name.clone(), t.as_ref()))
61 .collect();
62 let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
63
64 let mut messages = context.messages;
65
66 events(AgentEvent::AgentStart);
67
68 let mut has_tools_pending;
69 for turn_idx in 0..config.max_turns {
70 if cancel.is_cancelled() {
71 events(AgentEvent::AgentEnd {
72 messages: messages.clone(),
73 });
74 return Err(AgentError::Cancelled);
75 }
76
77 events(AgentEvent::TurnStart);
78
79 let transformed = hooks
81 .transform_context(messages.clone(), cancel.clone())
82 .await?;
83
84 let llm_messages = hooks.convert_to_llm(&transformed)?;
86
87 let mut assistant_content: Vec<AssistantContent> = Vec::new();
89 has_tools_pending = false;
90 let mut retry_attempt: u32 = 0;
91 let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
92
93 'stream: loop {
94 let request = Request {
95 model: context.model.clone(),
96 system: context.system.clone(),
97 messages: llm_messages.clone(),
98 tools: tool_defs.clone(),
99 max_tokens: config.max_tokens,
100 temperature: config.temperature,
101 thinking: config.thinking.clone().unwrap_or_default(),
102 stop_sequences: vec![],
103 metadata: None,
104 cancel: cancel.clone(),
105 };
106 if let Err(e) = validate_request_capabilities(context.provider.as_ref(), &request) {
107 events(AgentEvent::AgentEnd {
108 messages: messages.clone(),
109 });
110 return Err(AgentError::Provider(e.to_string()));
111 }
112 let mut stream = context.provider.stream(request);
113 assistant_content.clear();
114
115 while let Some(item) = {
116 tokio::select! {
117 biased;
118 _ = cancel.cancelled() => {
119 events(AgentEvent::AgentEnd {
120 messages: messages.clone(),
121 });
122 return Err(AgentError::Cancelled);
123 }
124 item = stream.next() => item,
125 }
126 } {
127 match item {
128 Ok(event) => {
129 if let Some(msg) =
130 process_stream_event(&event, &mut assistant_content, &events)
131 {
132 let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
136
137 events(AgentEvent::MessageEnd {
138 message: agent_msg.clone(),
139 });
140
141 messages.push(agent_msg.clone());
142
143 let content = match &agent_msg {
145 AgentMessage::Llm(Message::Assistant(a)) => &a.content,
146 _ => &Vec::new(),
147 };
148 let tool_calls: Vec<_> = content
149 .iter()
150 .filter_map(|c| match c {
151 AssistantContent::ToolCall { tool_call } => {
152 Some(tool_call.clone())
153 }
154 _ => None,
155 })
156 .collect();
157
158 if !tool_calls.is_empty() {
159 has_tools_pending = true;
160 let mut tool_results = Vec::new();
161 let mut terminate_flags = Vec::new();
162
163 let batch_is_sequential = tool_calls.iter().any(|tc| {
164 tools_map
165 .get(tc.name.as_str())
166 .map(|t| t.execution_mode() == ExecutionMode::Sequential)
167 .unwrap_or(true)
168 });
169
170 if batch_is_sequential {
171 for tc in &tool_calls {
172 let args: serde_json::Value =
173 serde_json::from_str(&tc.arguments)
174 .unwrap_or(json!({}));
175
176 events(AgentEvent::ToolExecutionStart {
177 tool_call_id: tc.id.clone(),
178 tool_name: tc.name.clone(),
179 args: args.clone(),
180 });
181
182 let result = execute_tool(
183 &tc.id,
184 &tc.name,
185 &args,
186 &tools_map,
187 hooks,
188 &messages,
189 cancel.clone(),
190 )
191 .await;
192
193 let is_error = result.is_error;
194 let details = result.details.clone();
195 terminate_flags.push(result.terminate);
196 events(AgentEvent::ToolExecutionEnd {
197 tool_call_id: tc.id.clone(),
198 tool_name: tc.name.clone(),
199 result: serde_json::json!(&result.content),
200 details,
201 is_error,
202 });
203
204 let trm = ToolResultMessage {
205 tool_call_id: tc.id.clone(),
206 tool_name: tc.name.clone(),
207 content: result.content,
208 details: result.details,
209 is_error,
210 timestamp_ms: 0,
211 };
212 tool_results.push(trm.clone());
213 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
214 }
215 } else {
216 let tc_args: Vec<_> = tool_calls
217 .iter()
218 .map(|tc| {
219 let args: serde_json::Value =
220 serde_json::from_str(&tc.arguments)
221 .unwrap_or(json!({}));
222 events(AgentEvent::ToolExecutionStart {
223 tool_call_id: tc.id.clone(),
224 tool_name: tc.name.clone(),
225 args: args.clone(),
226 });
227 (tc.clone(), args)
228 })
229 .collect();
230
231 let futures: Vec<_> = tc_args
232 .iter()
233 .map(|(tc, args)| {
234 let tools_map = &tools_map;
235 let messages = &messages;
236 let cancel = cancel.clone();
237 let tc_id = tc.id.clone();
238 let tc_name = tc.name.clone();
239 let args = args.clone();
240 async move {
241 let result = execute_tool(
242 &tc_id, &tc_name, &args, tools_map, hooks,
243 messages, cancel,
244 )
245 .await;
246 (tc_id, tc_name, result)
247 }
248 })
249 .collect();
250 let results = futures_util::future::join_all(futures).await;
251 for (tc_id, tc_name, result) in results {
252 let is_error = result.is_error;
253 let details = result.details.clone();
254 terminate_flags.push(result.terminate);
255 events(AgentEvent::ToolExecutionEnd {
256 tool_call_id: tc_id.clone(),
257 tool_name: tc_name.clone(),
258 result: serde_json::json!(&result.content),
259 details,
260 is_error,
261 });
262 let trm = ToolResultMessage {
263 tool_call_id: tc_id,
264 tool_name: tc_name,
265 content: result.content,
266 details: result.details,
267 is_error,
268 timestamp_ms: 0,
269 };
270 tool_results.push(trm.clone());
271 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
272 }
273 }
274
275 let all_terminate = !terminate_flags.is_empty()
276 && terminate_flags.iter().all(|t| *t);
277
278 events(AgentEvent::TurnEnd {
279 message: agent_msg,
280 tool_results: tool_results.clone(),
281 });
282
283 if all_terminate {
284 events(AgentEvent::AgentEnd {
285 messages: messages.clone(),
286 });
287 return Ok(messages);
288 }
289
290 let stop_ctx = ShouldStopAfterTurnContext {
291 messages: messages.clone(),
292 tool_results,
293 };
294 if hooks.should_stop_after_turn(stop_ctx).await {
295 events(AgentEvent::AgentEnd {
296 messages: messages.clone(),
297 });
298 return Ok(messages);
299 }
300
301 break 'stream;
302 }
303
304 events(AgentEvent::TurnEnd {
305 message: agent_msg.clone(),
306 tool_results: vec![],
307 });
308
309 let stop_ctx = ShouldStopAfterTurnContext {
310 messages: messages.clone(),
311 tool_results: vec![],
312 };
313 if hooks.should_stop_after_turn(stop_ctx).await {
314 events(AgentEvent::AgentEnd {
315 messages: messages.clone(),
316 });
317 return Ok(messages);
318 }
319 }
320 }
321 Err(e) => {
322 if e.is_retryable()
323 && retry_attempt < max_attempts
324 && let Some(ref rc) = config.retry
325 {
326 let retry_after_ms = match &e {
327 opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
328 *retry_after_ms
329 }
330 _ => None,
331 };
332 let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
333 retry_attempt += 1;
334
335 events(AgentEvent::AutoRetryStart {
336 attempt: retry_attempt,
337 max_attempts: rc.max_attempts,
338 delay_ms,
339 error_message: e.to_string(),
340 });
341
342 tokio::select! {
343 biased;
344 _ = cancel.cancelled() => {
345 events(AgentEvent::AgentEnd {
346 messages: messages.clone(),
347 });
348 return Err(AgentError::Cancelled);
349 }
350 _ = tokio::time::sleep(
351 std::time::Duration::from_millis(delay_ms)
352 ) => {}
353 }
354 continue 'stream;
355 }
356
357 if retry_attempt > 0 {
358 events(AgentEvent::AutoRetryEnd {
359 success: false,
360 attempt: retry_attempt,
361 final_error: Some(e.to_string()),
362 });
363 }
364
365 events(AgentEvent::AgentEnd {
366 messages: messages.clone(),
367 });
368 return Err(match &e {
369 opi_ai::provider::ProviderError::AuthFailed(msg) => {
370 AgentError::AuthFailed(msg.clone())
371 }
372 _ => AgentError::Provider(e.to_string()),
373 });
374 }
375 }
376 }
377
378 if retry_attempt > 0 {
379 events(AgentEvent::AutoRetryEnd {
380 success: true,
381 attempt: retry_attempt,
382 final_error: None,
383 });
384 }
385 break 'stream;
386 }
387
388 let next_turn_ctx = hooks::PrepareNextTurnContext {
392 messages: messages.clone(),
393 turn: turn_idx + 1,
394 };
395 let mut hook_injected = false;
396 if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
397 && !update.extra_messages.is_empty()
398 {
399 hook_injected = true;
400 messages.extend(update.extra_messages);
401 }
402
403 let steering = drain_queue(&context.steering_queue);
405 if !steering.is_empty() {
406 events(AgentEvent::QueueUpdate {
407 steering: steering.clone(),
408 follow_up: vec![],
409 });
410 for msg in steering {
411 messages.push(user_text_message(msg));
412 }
413 continue; }
415
416 if hook_injected {
418 continue;
419 }
420
421 if !has_tools_pending {
423 let follow_up = pop_follow_up(&context.follow_up_queue);
424 if !follow_up.is_empty() {
425 events(AgentEvent::QueueUpdate {
426 steering: vec![],
427 follow_up: follow_up.clone(),
428 });
429 for msg in follow_up {
430 messages.push(user_text_message(msg));
431 }
432 continue; }
434 break; }
436
437 let _ = turn_idx;
439 }
440
441 events(AgentEvent::AgentEnd {
442 messages: messages.clone(),
443 });
444 Ok(messages)
445}
446
447fn process_stream_event(
450 event: &opi_ai::stream::AssistantStreamEvent,
451 content: &mut Vec<AssistantContent>,
452 events: &AgentEventSink,
453) -> Option<opi_ai::message::AssistantMessage> {
454 use opi_ai::stream::AssistantStreamEvent::*;
455
456 match event {
457 Start { partial } => {
458 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
459 events(AgentEvent::MessageStart { message: msg });
460 None
461 }
462 TextDelta { delta, partial, .. } => {
463 match content.last_mut() {
465 Some(AssistantContent::Text { text }) => {
466 text.push_str(delta);
467 }
468 _ => {
469 content.push(AssistantContent::Text {
470 text: delta.clone(),
471 });
472 }
473 }
474 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
475 events(AgentEvent::MessageUpdate {
476 message: msg,
477 assistant_event: Box::new(event.clone()),
478 });
479 None
480 }
481 ToolCallEnd { tool_call, .. } => {
482 content.push(AssistantContent::ToolCall {
483 tool_call: tool_call.clone(),
484 });
485 None
486 }
487 ThinkingStart { partial, .. }
488 | ThinkingDelta { partial, .. }
489 | ThinkingEnd { partial, .. } => {
490 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
491 events(AgentEvent::MessageUpdate {
492 message: msg,
493 assistant_event: Box::new(event.clone()),
494 });
495 None
496 }
497 Done { message, .. } => Some(message.clone()),
498 Error { message, .. } => Some(message.clone()),
499 _ => None,
500 }
501}
502
503async fn execute_tool(
505 call_id: &str,
506 tool_name: &str,
507 args: &serde_json::Value,
508 tools_map: &HashMap<String, &dyn Tool>,
509 hooks: &dyn AgentHooks,
510 messages: &[AgentMessage],
511 cancel: CancellationToken,
512) -> ToolResult {
513 let tool = match tools_map.get(tool_name) {
514 Some(t) => *t,
515 None => {
516 return ToolResult {
517 content: vec![opi_ai::message::OutputContent::Text {
518 text: format!("unknown tool: {tool_name}"),
519 }],
520 details: None,
521 is_error: true,
522 terminate: false,
523 };
524 }
525 };
526
527 let schema = &tool.definition().input_schema;
529 if let Err(err) = validation::validate(schema, args) {
530 return ToolResult::from_validation_error(err);
531 }
532
533 let ctx = BeforeToolCallContext {
535 tool_call_id: call_id.to_owned(),
536 tool_name: tool_name.to_owned(),
537 args: args.clone(),
538 messages: messages.to_vec(),
539 };
540 match hooks.before_tool_call(ctx).await {
541 BeforeToolCallResult::Allow => {}
542 BeforeToolCallResult::Deny { reason } => {
543 return ToolResult {
544 content: vec![opi_ai::message::OutputContent::Text { text: reason }],
545 details: None,
546 is_error: true,
547 terminate: false,
548 };
549 }
550 }
551
552 match tool.execute(call_id, args.clone(), cancel, None).await {
554 Ok(result) => {
555 let ctx = AfterToolCallContext {
556 tool_call_id: call_id.to_owned(),
557 tool_name: tool_name.to_owned(),
558 result: result.clone(),
559 };
560 match hooks.after_tool_call(ctx).await {
561 AfterToolCallResult::Keep => result,
562 AfterToolCallResult::Replace(replacement) => replacement,
563 }
564 }
565 Err(e) => ToolResult {
566 content: vec![opi_ai::message::OutputContent::Text {
567 text: e.to_string(),
568 }],
569 details: None,
570 is_error: true,
571 terminate: false,
572 },
573 }
574}
575
576fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
578 match queue {
579 Some(q) => {
580 let mut q = q.lock().unwrap();
581 q.drain(..).collect()
582 }
583 None => vec![],
584 }
585}
586
587fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
589 match queue {
590 Some(q) => {
591 let mut q = q.lock().unwrap();
592 match q.pop_front() {
593 Some(msg) => vec![msg],
594 None => vec![],
595 }
596 }
597 None => vec![],
598 }
599}
600
601fn user_text_message(text: String) -> AgentMessage {
603 AgentMessage::Llm(Message::User(UserMessage {
604 content: vec![InputContent::Text { text }],
605 timestamp_ms: 0,
606 }))
607}