1pub mod agent;
7pub mod event;
8pub mod hooks;
9pub mod loop_types;
10pub mod message;
11pub mod state;
12pub mod tool;
13pub mod transport;
14pub mod validation;
15
16pub use agent::Agent;
17pub use event::{AgentEvent, AgentEventSink};
18pub use hooks::AgentHooks;
19pub use loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
20pub use message::AgentMessage;
21pub use state::AgentState;
22pub use tool::{ExecutionMode, Tool, ToolError, ToolResult};
23pub use transport::Transport;
24
25pub use opi_ai::message::ToolDef;
27
28use std::collections::{HashMap, VecDeque};
29use std::sync::{Arc, Mutex};
30
31use futures_util::StreamExt;
32use hooks::{
33 AfterToolCallContext, AfterToolCallResult, BeforeToolCallContext, BeforeToolCallResult,
34 ShouldStopAfterTurnContext,
35};
36use opi_ai::message::{AssistantContent, InputContent, Message, ToolResultMessage, UserMessage};
37use opi_ai::provider::Request;
38use serde_json::json;
39use tokio_util::sync::CancellationToken;
40
41pub async fn agent_loop(
47 context: AgentLoopContext,
48 config: AgentLoopConfig,
49 hooks: &dyn AgentHooks,
50 events: AgentEventSink,
51 cancel: CancellationToken,
52) -> Result<Vec<AgentMessage>, AgentError> {
53 let tools_map: HashMap<String, &dyn Tool> = context
54 .tools
55 .iter()
56 .map(|t| (t.definition().name.clone(), t.as_ref()))
57 .collect();
58 let tool_defs: Vec<_> = context.tools.iter().map(|t| t.definition()).collect();
59
60 let mut messages = context.messages;
61
62 events(AgentEvent::AgentStart);
63
64 let mut has_tools_pending;
65 for turn_idx in 0..config.max_turns {
66 if cancel.is_cancelled() {
67 events(AgentEvent::AgentEnd {
68 messages: messages.clone(),
69 });
70 return Err(AgentError::Cancelled);
71 }
72
73 events(AgentEvent::TurnStart);
74
75 let transformed = hooks
77 .transform_context(messages.clone(), cancel.clone())
78 .await?;
79
80 let llm_messages = hooks.convert_to_llm(&transformed)?;
82
83 let request = Request {
85 model: context.model.clone(),
86 system: context.system.clone(),
87 messages: llm_messages,
88 tools: tool_defs.clone(),
89 max_tokens: config.max_tokens,
90 temperature: config.temperature,
91 thinking: Default::default(),
92 stop_sequences: vec![],
93 metadata: None,
94 cancel: cancel.clone(),
95 };
96
97 let mut stream = context.provider.stream(request);
99 let mut assistant_content: Vec<AssistantContent> = Vec::new();
100 has_tools_pending = false;
101
102 while let Some(item) = {
103 tokio::select! {
104 biased;
105 _ = cancel.cancelled() => {
106 events(AgentEvent::AgentEnd {
107 messages: messages.clone(),
108 });
109 return Err(AgentError::Cancelled);
110 }
111 item = stream.next() => item,
112 }
113 } {
114 match item {
115 Ok(event) => {
116 if let Some(msg) = process_stream_event(&event, &mut assistant_content, &events)
117 {
118 let mut assistant_msg = msg;
120 assistant_msg.content = assistant_content.clone();
121 let agent_msg = AgentMessage::Llm(Message::Assistant(assistant_msg));
122
123 events(AgentEvent::MessageEnd {
124 message: agent_msg.clone(),
125 });
126
127 messages.push(agent_msg.clone());
128
129 let tool_calls: Vec<_> = assistant_content
131 .iter()
132 .filter_map(|c| match c {
133 AssistantContent::ToolCall { tool_call } => Some(tool_call.clone()),
134 _ => None,
135 })
136 .collect();
137
138 if !tool_calls.is_empty() {
139 has_tools_pending = true;
140 let mut tool_results = Vec::new();
141 let mut terminate_flags = Vec::new();
142
143 let batch_is_sequential = tool_calls.iter().any(|tc| {
146 tools_map
147 .get(tc.name.as_str())
148 .map(|t| t.execution_mode() == ExecutionMode::Sequential)
149 .unwrap_or(true)
150 });
151
152 if batch_is_sequential {
153 for tc in &tool_calls {
154 let args: serde_json::Value =
155 serde_json::from_str(&tc.arguments).unwrap_or(json!({}));
156
157 events(AgentEvent::ToolExecutionStart {
158 tool_call_id: tc.id.clone(),
159 tool_name: tc.name.clone(),
160 args: args.clone(),
161 });
162
163 let result = execute_tool(
164 &tc.id,
165 &tc.name,
166 &args,
167 &tools_map,
168 hooks,
169 &messages,
170 cancel.clone(),
171 )
172 .await;
173
174 let is_error = result.is_error;
175 terminate_flags.push(result.terminate);
176 events(AgentEvent::ToolExecutionEnd {
177 tool_call_id: tc.id.clone(),
178 tool_name: tc.name.clone(),
179 result: serde_json::json!(&result.content),
180 is_error,
181 });
182
183 let trm = ToolResultMessage {
184 tool_call_id: tc.id.clone(),
185 tool_name: tc.name.clone(),
186 content: result.content,
187 details: result.details,
188 is_error,
189 timestamp_ms: 0,
190 };
191 tool_results.push(trm.clone());
192 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
193 }
194 } else {
195 let tc_args: Vec<_> = tool_calls
197 .iter()
198 .map(|tc| {
199 let args: serde_json::Value =
200 serde_json::from_str(&tc.arguments)
201 .unwrap_or(json!({}));
202 events(AgentEvent::ToolExecutionStart {
203 tool_call_id: tc.id.clone(),
204 tool_name: tc.name.clone(),
205 args: args.clone(),
206 });
207 (tc.clone(), args)
208 })
209 .collect();
210
211 let futures: Vec<_> = tc_args
212 .iter()
213 .map(|(tc, args)| {
214 let tools_map = &tools_map;
215 let messages = &messages;
216 let cancel = cancel.clone();
217 let tc_id = tc.id.clone();
218 let tc_name = tc.name.clone();
219 let args = args.clone();
220 async move {
221 let result = execute_tool(
222 &tc_id, &tc_name, &args, tools_map, hooks,
223 messages, cancel,
224 )
225 .await;
226 (tc_id, tc_name, result)
227 }
228 })
229 .collect();
230 let results = futures_util::future::join_all(futures).await;
231 for (tc_id, tc_name, result) in results {
232 let is_error = result.is_error;
233 terminate_flags.push(result.terminate);
234 events(AgentEvent::ToolExecutionEnd {
235 tool_call_id: tc_id.clone(),
236 tool_name: tc_name.clone(),
237 result: serde_json::json!(&result.content),
238 is_error,
239 });
240 let trm = ToolResultMessage {
241 tool_call_id: tc_id,
242 tool_name: tc_name,
243 content: result.content,
244 details: result.details,
245 is_error,
246 timestamp_ms: 0,
247 };
248 tool_results.push(trm.clone());
249 messages.push(AgentMessage::Llm(Message::ToolResult(trm)));
250 }
251 }
252
253 let all_terminate =
255 !terminate_flags.is_empty() && terminate_flags.iter().all(|t| *t);
256
257 events(AgentEvent::TurnEnd {
258 message: agent_msg,
259 tool_results: tool_results.clone(),
260 });
261
262 if all_terminate {
263 events(AgentEvent::AgentEnd {
264 messages: messages.clone(),
265 });
266 return Ok(messages);
267 }
268
269 let stop_ctx = ShouldStopAfterTurnContext {
271 messages: messages.clone(),
272 tool_results,
273 };
274 if hooks.should_stop_after_turn(stop_ctx).await {
275 events(AgentEvent::AgentEnd {
276 messages: messages.clone(),
277 });
278 return Ok(messages);
279 }
280
281 break;
283 }
284
285 events(AgentEvent::TurnEnd {
287 message: agent_msg.clone(),
288 tool_results: vec![],
289 });
290
291 let stop_ctx = ShouldStopAfterTurnContext {
293 messages: messages.clone(),
294 tool_results: vec![],
295 };
296 if hooks.should_stop_after_turn(stop_ctx).await {
297 events(AgentEvent::AgentEnd {
298 messages: messages.clone(),
299 });
300 return Ok(messages);
301 }
302 }
303 }
304 Err(e) => {
305 events(AgentEvent::AgentEnd {
306 messages: messages.clone(),
307 });
308 return Err(match &e {
309 opi_ai::provider::ProviderError::AuthFailed(msg) => {
310 AgentError::AuthFailed(msg.clone())
311 }
312 _ => AgentError::Provider(e.to_string()),
313 });
314 }
315 }
316 }
317
318 let next_turn_ctx = hooks::PrepareNextTurnContext {
322 messages: messages.clone(),
323 turn: turn_idx + 1,
324 };
325 let mut hook_injected = false;
326 if let Some(update) = hooks.prepare_next_turn(next_turn_ctx).await
327 && !update.extra_messages.is_empty()
328 {
329 hook_injected = true;
330 messages.extend(update.extra_messages);
331 }
332
333 let steering = drain_queue(&context.steering_queue);
335 if !steering.is_empty() {
336 events(AgentEvent::QueueUpdate {
337 steering: steering.clone(),
338 follow_up: vec![],
339 });
340 for msg in steering {
341 messages.push(user_text_message(msg));
342 }
343 continue; }
345
346 if hook_injected {
348 continue;
349 }
350
351 if !has_tools_pending {
353 let follow_up = pop_follow_up(&context.follow_up_queue);
354 if !follow_up.is_empty() {
355 events(AgentEvent::QueueUpdate {
356 steering: vec![],
357 follow_up: follow_up.clone(),
358 });
359 for msg in follow_up {
360 messages.push(user_text_message(msg));
361 }
362 continue; }
364 break; }
366
367 let _ = turn_idx;
369 }
370
371 events(AgentEvent::AgentEnd {
372 messages: messages.clone(),
373 });
374 Ok(messages)
375}
376
377fn process_stream_event(
380 event: &opi_ai::stream::AssistantStreamEvent,
381 content: &mut Vec<AssistantContent>,
382 events: &AgentEventSink,
383) -> Option<opi_ai::message::AssistantMessage> {
384 use opi_ai::stream::AssistantStreamEvent::*;
385
386 match event {
387 Start { partial } => {
388 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
389 events(AgentEvent::MessageStart { message: msg });
390 None
391 }
392 TextDelta { delta, partial, .. } => {
393 match content.last_mut() {
395 Some(AssistantContent::Text { text }) => {
396 text.push_str(delta);
397 }
398 _ => {
399 content.push(AssistantContent::Text {
400 text: delta.clone(),
401 });
402 }
403 }
404 let msg = AgentMessage::Llm(Message::Assistant(partial.clone()));
405 events(AgentEvent::MessageUpdate {
406 message: msg,
407 assistant_event: Box::new(event.clone()),
408 });
409 None
410 }
411 ToolCallEnd { tool_call, .. } => {
412 content.push(AssistantContent::ToolCall {
413 tool_call: tool_call.clone(),
414 });
415 None
416 }
417 Done { message, .. } => Some(message.clone()),
418 Error { message, .. } => Some(message.clone()),
419 _ => None,
420 }
421}
422
423async fn execute_tool(
425 call_id: &str,
426 tool_name: &str,
427 args: &serde_json::Value,
428 tools_map: &HashMap<String, &dyn Tool>,
429 hooks: &dyn AgentHooks,
430 messages: &[AgentMessage],
431 cancel: CancellationToken,
432) -> ToolResult {
433 let tool = match tools_map.get(tool_name) {
434 Some(t) => *t,
435 None => {
436 return ToolResult {
437 content: vec![opi_ai::message::OutputContent::Text {
438 text: format!("unknown tool: {tool_name}"),
439 }],
440 details: None,
441 is_error: true,
442 terminate: false,
443 };
444 }
445 };
446
447 let schema = &tool.definition().input_schema;
449 if let Err(err) = validation::validate(schema, args) {
450 return ToolResult::from_validation_error(err);
451 }
452
453 let ctx = BeforeToolCallContext {
455 tool_call_id: call_id.to_owned(),
456 tool_name: tool_name.to_owned(),
457 args: args.clone(),
458 messages: messages.to_vec(),
459 };
460 match hooks.before_tool_call(ctx).await {
461 BeforeToolCallResult::Allow => {}
462 BeforeToolCallResult::Deny { reason } => {
463 return ToolResult {
464 content: vec![opi_ai::message::OutputContent::Text { text: reason }],
465 details: None,
466 is_error: true,
467 terminate: false,
468 };
469 }
470 }
471
472 match tool.execute(call_id, args.clone(), cancel, None).await {
474 Ok(result) => {
475 let ctx = AfterToolCallContext {
476 tool_call_id: call_id.to_owned(),
477 tool_name: tool_name.to_owned(),
478 result: result.clone(),
479 };
480 match hooks.after_tool_call(ctx).await {
481 AfterToolCallResult::Keep => result,
482 AfterToolCallResult::Replace(replacement) => replacement,
483 }
484 }
485 Err(e) => ToolResult {
486 content: vec![opi_ai::message::OutputContent::Text {
487 text: e.to_string(),
488 }],
489 details: None,
490 is_error: true,
491 terminate: false,
492 },
493 }
494}
495
496fn drain_queue(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
498 match queue {
499 Some(q) => {
500 let mut q = q.lock().unwrap();
501 q.drain(..).collect()
502 }
503 None => vec![],
504 }
505}
506
507fn pop_follow_up(queue: &Option<Arc<Mutex<VecDeque<String>>>>) -> Vec<String> {
509 match queue {
510 Some(q) => {
511 let mut q = q.lock().unwrap();
512 match q.pop_front() {
513 Some(msg) => vec![msg],
514 None => vec![],
515 }
516 }
517 None => vec![],
518 }
519}
520
521fn user_text_message(text: String) -> AgentMessage {
523 AgentMessage::Llm(Message::User(UserMessage {
524 content: vec![InputContent::Text { text }],
525 timestamp_ms: 0,
526 }))
527}