1pub mod agent;
7pub mod compaction;
8pub mod event;
9pub mod hooks;
10pub mod loop_types;
11pub mod message;
12pub mod session;
13pub mod session_event;
14pub mod state;
15pub mod tool;
16pub mod transport;
17pub mod validation;
18
19pub use agent::Agent;
20pub use event::{AgentEvent, AgentEventSink};
21pub use hooks::AgentHooks;
22pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
23pub use message::AgentMessage;
24pub use session_event::AgentSessionEvent;
25pub use state::AgentState;
26pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
27pub use transport::Transport;
28
29pub use opi_ai::message::ToolDef;
31
32use std::collections::{HashMap, VecDeque};
33use std::sync::{Arc, Mutex};
34
35use futures_util::StreamExt;
36use hooks::{
37 AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
38 ShouldStopAfterTurnContext,
39};
40use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
41use opi_ai::provider::Request;
42use serde_json::json;
43use tokio_util::sync::CancellationToken;
44
45pub async fn agent_loop(
51 context: AgentLoopContext,
52 config: AgentLoopConfig,
53 hooks: &dyn AgentHooks,
54 events: AgentEventSink,
55 cancel: CancellationToken,
56) -> Result<Vec<AgentMessage>, AgentError> {
57 let tools_map: HashMap<String, &dyn Tool> = context
58 .tools
59 .iter()
60 .map(|t| (t.definition().name.clone(), t.as_ref()))
61 .collect();
62 let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
63
64 let mut messages = context.messages;
65
66 events(AgentEvent::AgentStart);
67
68 let mut has_tools_pending;
69 for turn_idx in 0..config.max_turns {
70 if cancel.is_cancelled() {
71 events(AgentEvent::AgentEnd {
72 messages: messages.clone(),
73 });
74 return Err(AgentError::Cancelled);
75 }
76
77 events(AgentEvent::TurnStart);
78
79 let transformed = hooks
81 .transform_context(messages.clone(), cancel.clone())
82 .await?;
83
84 let llm_messages = hooks.convert_to_llm(&transformed)?;
86
87 let mut assistant_content: Vec<AssistantContent> = Vec::new();
89 has_tools_pending = false;
90 let mut retry_attempt: u32 = 0;
91 let max_attempts = config.retry.as_ref().map(|r| r.max_attempts).unwrap_or(0);
92
93 'stream: loop {
94 let request = Request {
95 model: context.model.clone(),
96 system: context.system.clone(),
97 messages: llm_messages.clone(),
98 tools: tool_defs.clone(),
99 max_tokens: config.max_tokens,
100 temperature: config.temperature,
101 thinking: config.thinking.clone().unwrap_or_default(),
102 stop_sequences: vec![],
103 metadata: None,
104 cancel: cancel.clone(),
105 };
106 let mut stream = context.provider.stream(request);
107 assistant_content.clear();
108
109 while let Some(item) = {
110 tokio::select! {
111 biased;
112 _ = cancel.cancelled() => {
113 events(AgentEvent::AgentEnd {
114 messages: messages.clone(),
115 });
116 return Err(AgentError::Cancelled);
117 }
118 item = stream.next() => item,
119 }
120 } {
121 match item {
122 Ok(event) => {
123 if let Some(msg) =
124 process_stream_event(&event, &mut assistant_content, &events)
125 {
126 let agent_msg = AgentMessage::Llm(Message::Assistant(msg));
130
131 events(AgentEvent::MessageEnd {
132 message: agent_msg.clone(),
133 });
134
135 messages.push(agent_msg.clone());
136
137 let content = match &agent_msg {
139 AgentMessage::Llm(Message::Assistant(a)) => &a.content,
140 _ => &Vec::new(),
141 };
142 let tool_calls: Vec<_> = content
143 .iter()
144 .filter_map(|c| match c {
145 AssistantContent::ToolCall { tool_call } => {
146 Some(tool_call.clone())
147 }
148 _ => None,
149 })
150 .collect();
151
152 if !tool_calls.is_empty() {
153 has_tools_pending = true;
154 let mut tool_results = Vec::new();
155 let mut terminate_flags = Vec::new();
156
157 let batch_is_sequential = tool_calls.iter().any(|tc| {
158 tools_map
159 .get(tc.name.as_str())
160 .map(|t| t.execution_mode() == ExecutionMode::Sequential)
161 .unwrap_or(true)
162 });
163
164 if batch_is_sequential {
165 for tc in &tool_calls {
166 let args: serde_json::Value =
167 serde_json::from_str(&tc.arguments)
168 .unwrap_or(json!({}));
169
170 events(AgentEvent::ToolExecutionStart {
171 tool_call_id: tc.id.clone(),
172 tool_name: tc.name.clone(),
173 args: args.clone(),
174 });
175
176 let result = execute_tool(
177 &tc.id,
178 &tc.name,
179 &args,
180 &tools_map,
181 hooks,
182 &messages,
183 cancel.clone(),
184 )
185 .await;
186
187 let is_error = result.is_error;
188 let details = result.details.clone();
189 terminate_flags.push(result.terminate);
190 events(AgentEvent::ToolExecutionEnd {
191 tool_call_id: tc.id.clone(),
192 tool_name: tc.name.clone(),
193 result: serde_json::json!(&result.content),
194 details,
195 is_error,
196 });
197
198 let trm = ToolResultMessage {
199 tool_call_id: tc.id.clone(),
200 tool_name: tc.name.clone(),
201 content: result.content,
202 details: result.details,
203 is_error,
204 timestamp_ms: 0,
205 };
206 tool_results.push(trm.clone());
207 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
208 }
209 } else {
210 let tc_args: Vec<_> = tool_calls
211 .iter()
212 .map(|tc| {
213 let args: serde_json::Value =
214 serde_json::from_str(&tc.arguments)
215 .unwrap_or(json!({}));
216 events(AgentEvent::ToolExecutionStart {
217 tool_call_id: tc.id.clone(),
218 tool_name: tc.name.clone(),
219 args: args.clone(),
220 });
221 (tc.clone(), args)
222 })
223 .collect();
224
225 let futures: Vec<_> = tc_args
226 .iter()
227 .map(|(tc, args)| {
228 let tools_map = &tools_map;
229 let messages = &messages;
230 let cancel = cancel.clone();
231 let tc_id = tc.id.clone();
232 let tc_name = tc.name.clone();
233 let args = args.clone();
234 async move {
235 let result = execute_tool(
236 &tc_id, &tc_name, &args, tools_map, hooks,
237 messages, cancel,
238 )
239 .await;
240 (tc_id, tc_name, result)
241 }
242 })
243 .collect();
244 let results = futures_util::future::join_all(futures).await;
245 for (tc_id, tc_name, result) in results {
246 let is_error = result.is_error;
247 let details = result.details.clone();
248 terminate_flags.push(result.terminate);
249 events(AgentEvent::ToolExecutionEnd {
250 tool_call_id: tc_id.clone(),
251 tool_name: tc_name.clone(),
252 result: serde_json::json!(&result.content),
253 details,
254 is_error,
255 });
256 let trm = ToolResultMessage {
257 tool_call_id: tc_id,
258 tool_name: tc_name,
259 content: result.content,
260 details: result.details,
261 is_error,
262 timestamp_ms: 0,
263 };
264 tool_results.push(trm.clone());
265 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
266 }
267 }
268
269 let all_terminate = !terminate_flags.is_empty()
270 && terminate_flags.iter().all(|t| *t);
271
272 events(AgentEvent::TurnEnd {
273 message: agent_msg,
274 tool_results: tool_results.clone(),
275 });
276
277 if all_terminate {
278 events(AgentEvent::AgentEnd {
279 messages: messages.clone(),
280 });
281 return Ok(messages);
282 }
283
284 let stop_ctx = ShouldStopAfterTurnContext {
285 messages: messages.clone(),
286 tool_results,
287 };
288 if hooks.should_stop_after_turn(stop_ctx).await {
289 events(AgentEvent::AgentEnd {
290 messages: messages.clone(),
291 });
292 return Ok(messages);
293 }
294
295 break 'stream;
296 }
297
298 events(AgentEvent::TurnEnd {
299 message: agent_msg.clone(),
300 tool_results: vec![],
301 });
302
303 let stop_ctx = ShouldStopAfterTurnContext {
304 messages: messages.clone(),
305 tool_results: vec![],
306 };
307 if hooks.should_stop_after_turn(stop_ctx).await {
308 events(AgentEvent::AgentEnd {
309 messages: messages.clone(),
310 });
311 return Ok(messages);
312 }
313 }
314 }
315 Err(e) => {
316 if e.is_retryable()
317 && retry_attempt < max_attempts
318 && let Some(ref rc) = config.retry
319 {
320 let retry_after_ms = match &e {
321 opi_ai::provider::ProviderError::RateLimited { retry_after_ms } => {
322 *retry_after_ms
323 }
324 _ => None,
325 };
326 let delay_ms = rc.delay_for_attempt(retry_attempt, retry_after_ms);
327 retry_attempt += 1;
328
329 events(AgentEvent::AutoRetryStart {
330 attempt: retry_attempt,
331 max_attempts: rc.max_attempts,
332 delay_ms,
333 error_message: e.to_string(),
334 });
335
336 tokio::select! {
337 biased;
338 _ = cancel.cancelled() => {
339 events(AgentEvent::AgentEnd {
340 messages: messages.clone(),
341 });
342 return Err(AgentError::Cancelled);
343 }
344 _ = tokio::time::sleep(
345 std::time::Duration::from_millis(delay_ms)
346 ) => {}
347 }
348 continue 'stream;
349 }
350
351 if retry_attempt > 0 {
352 events(AgentEvent::AutoRetryEnd {
353 success: false,
354 attempt: retry_attempt,
355 final_error: Some(e.to_string()),
356 });
357 }
358
359 events(AgentEvent::AgentEnd {
360 messages: messages.clone(),
361 });
362 return Err(match &e {
363 opi_ai::provider::ProviderError::AuthFailed(msg) => {
364 AgentError::AuthFailed(msg.clone())
365 }
366 _ => AgentError::Provider(e.to_string()),
367 });
368 }
369 }
370 }
371
372 if retry_attempt > 0 {
373 events(AgentEvent::AutoRetryEnd {
374 success: true,
375 attempt: retry_attempt,
376 final_error: None,
377 });
378 }
379 break 'stream;
380 }
381
382 let next_turn_ctx = hooks::PrepareNextTurnContext {
386 messages: messages.clone(),
387 turn: turn_idx + 1,
388 };
389 let mut hook_injected = false;
390 if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
391 && !update.extra_messages.is_empty()
392 {
393 hook_injected = true;
394 messages.extend(update.extra_messages);
395 }
396
397 let steering = drain_queue(&context.steering_queue);
399 if !steering.is_empty() {
400 events(AgentEvent::QueueUpdate {
401 steering: steering.clone(),
402 follow_up: vec![],
403 });
404 for msg in steering {
405 messages.push(user_text_message(msg));
406 }
407 continue; }
409
410 if hook_injected {
412 continue;
413 }
414
415 if !has_tools_pending {
417 let follow_up = pop_follow_up(&context.follow_up_queue);
418 if !follow_up.is_empty() {
419 events(AgentEvent::QueueUpdate {
420 steering: vec![],
421 follow_up: follow_up.clone(),
422 });
423 for msg in follow_up {
424 messages.push(user_text_message(msg));
425 }
426 continue; }
428 break; }
430
431 let _ = turn_idx;
433 }
434
435 events(AgentEvent::AgentEnd {
436 messages: messages.clone(),
437 });
438 Ok(messages)
439}
440
441fn process_stream_event(
444 event: &opi_ai::stream::AssistantStreamEvent,
445 content: &mut Vec<AssistantContent>,
446 events: &AgentEventSink,
447) -> Option<opi_ai::message::AssistantMessage> {
448 use opi_ai::stream::AssistantStreamEvent::*;
449
450 match event {
451 Start { partial } => {
452 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
453 events(AgentEvent::MessageStart { message: msg });
454 None
455 }
456 TextDelta { delta, partial, .. } => {
457 match content.last_mut() {
459 Some(AssistantContent::Text { text }) => {
460 text.push_str(delta);
461 }
462 _ => {
463 content.push(AssistantContent::Text {
464 text: delta.clone(),
465 });
466 }
467 }
468 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
469 events(AgentEvent::MessageUpdate {
470 message: msg,
471 assistant_event: Box::new(event.clone()),
472 });
473 None
474 }
475 ToolCallEnd { tool_call, .. } => {
476 content.push(AssistantContent::ToolCall {
477 tool_call: tool_call.clone(),
478 });
479 None
480 }
481 ThinkingStart { partial, .. }
482 | ThinkingDelta { partial, .. }
483 | ThinkingEnd { partial, .. } => {
484 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
485 events(AgentEvent::MessageUpdate {
486 message: msg,
487 assistant_event: Box::new(event.clone()),
488 });
489 None
490 }
491 Done { message, .. } => Some(message.clone()),
492 Error { message, .. } => Some(message.clone()),
493 _ => None,
494 }
495}
496
497async fn execute_tool(
499 call_id: &str,
500 tool_name: &str,
501 args: &serde_json::Value,
502 tools_map: &HashMap<String, &dyn Tool>,
503 hooks: &dyn AgentHooks,
504 messages: &[AgentMessage],
505 cancel: CancellationToken,
506) -> ToolResult {
507 let tool = match tools_map.get(tool_name) {
508 Some(t) => *t,
509 None => {
510 return ToolResult {
511 content: vec![opi_ai::message::OutputContent::Text {
512 text: format!("unknown tool: {tool_name}"),
513 }],
514 details: None,
515 is_error: true,
516 terminate: false,
517 };
518 }
519 };
520
521 let schema = &tool.definition().input_schema;
523 if let Err(err) = validation::validate(schema, args) {
524 return ToolResult::from_validation_error(err);
525 }
526
527 let ctx = BeforeToolCallContext {
529 tool_call_id: call_id.to_owned(),
530 tool_name: tool_name.to_owned(),
531 args: args.clone(),
532 messages: messages.to_vec(),
533 };
534 match hooks.before_tool_call(ctx).await {
535 BeforeToolCallResult::Allow => {}
536 BeforeToolCallResult::Deny { reason } => {
537 return ToolResult {
538 content: vec![opi_ai::message::OutputContent::Text { text: reason }],
539 details: None,
540 is_error: true,
541 terminate: false,
542 };
543 }
544 }
545
546 match tool.execute(call_id, args.clone(), cancel, None).await {
548 Ok(result) => {
549 let ctx = AfterToolCallContext {
550 tool_call_id: call_id.to_owned(),
551 tool_name: tool_name.to_owned(),
552 result: result.clone(),
553 };
554 match hooks.after_tool_call(ctx).await {
555 AfterToolCallResult::Keep => result,
556 AfterToolCallResult::Replace(replacement) => replacement,
557 }
558 }
559 Err(e) => ToolResult {
560 content: vec![opi_ai::message::OutputContent::Text {
561 text: e.to_string(),
562 }],
563 details: None,
564 is_error: true,
565 terminate: false,
566 },
567 }
568}
569
570fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
572 match queue {
573 Some(q) => {
574 let mut q = q.lock().unwrap();
575 q.drain(..).collect()
576 }
577 None => vec![],
578 }
579}
580
581fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
583 match queue {
584 Some(q) => {
585 let mut q = q.lock().unwrap();
586 match q.pop_front() {
587 Some(msg) => vec![msg],
588 None => vec![],
589 }
590 }
591 None => vec![],
592 }
593}
594
595fn user_text_message(text: String) -> AgentMessage {
597 AgentMessage::Llm(Message::User(UserMessage {
598 content: vec![InputContent::Text { text }],
599 timestamp_ms: 0,
600 }))
601}