1use crate::error::CoreError;
2use crate::protocol::{
3 AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
4 ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
5};
6use crate::state::AppState;
7use std::collections::BTreeMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
12pub struct ProviderRequest {
13 pub run_id: String,
14 pub session_id: String,
15 pub iteration: u32,
16 pub messages: Vec<ChatMessage>,
17 pub tools: Vec<ToolDefinition>,
18 pub state: AppState,
19}
20
21pub trait Provider: Send + Sync {
22 fn name(&self) -> &str;
23 fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
24}
25
26#[derive(Debug, Clone)]
27pub struct ToolContext {
28 pub run_id: String,
29 pub session_id: String,
30 pub iteration: u32,
31}
32
33pub trait Tool: Send + Sync {
34 fn definition(&self) -> ToolDefinition;
35 fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
36}
37
38pub trait Middleware: Send + Sync {
39 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
40 Ok(())
41 }
42
43 fn after_model_call(
44 &self,
45 _request: &ProviderRequest,
46 _response: &ModelTurn,
47 ) -> Result<(), CoreError> {
48 Ok(())
49 }
50
51 fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
52 Ok(())
53 }
54
55 fn post_tool_call(
56 &self,
57 _context: &ToolContext,
58 _result: &ToolResult,
59 ) -> Result<(), CoreError> {
60 Ok(())
61 }
62
63 fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
64 Ok(())
65 }
66}
67
68#[derive(Clone, Default)]
69pub struct ToolRegistry {
70 tools: BTreeMap<String, Arc<dyn Tool>>,
71}
72
73impl ToolRegistry {
74 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
75 self.tools
76 .insert(tool.definition().name.clone(), Arc::new(tool));
77 }
78
79 pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
80 self.tools.get(tool_name).cloned()
81 }
82
83 pub fn definitions(&self) -> Vec<ToolDefinition> {
84 self.tools.values().map(|tool| tool.definition()).collect()
85 }
86}
87
88#[derive(Debug, Clone)]
89pub struct OrchestratorConfig {
90 pub max_iterations: u32,
91}
92
93impl Default for OrchestratorConfig {
94 fn default() -> Self {
95 Self { max_iterations: 24 }
96 }
97}
98
99#[derive(Debug, Clone)]
100pub struct RunInput {
101 pub run_id: String,
102 pub session_id: String,
103 pub messages: Vec<ChatMessage>,
104 pub state: AppState,
105}
106
107#[derive(Debug, Clone)]
108pub struct RunOutput {
109 pub run_id: String,
110 pub session_id: String,
111 pub events: Vec<AgentEvent>,
112 pub messages: Vec<ChatMessage>,
113 pub state: AppState,
114 pub reason: RunStopReason,
115 pub final_answer: Option<String>,
116 pub total_usage: TokenUsage,
118}
119
120pub struct Orchestrator {
121 provider: Arc<dyn Provider>,
122 tools: ToolRegistry,
123 middlewares: Vec<Arc<dyn Middleware>>,
124 config: OrchestratorConfig,
125}
126
127impl Orchestrator {
128 pub fn new(
129 provider: Arc<dyn Provider>,
130 tools: ToolRegistry,
131 middlewares: Vec<Arc<dyn Middleware>>,
132 config: OrchestratorConfig,
133 ) -> Self {
134 Self {
135 provider,
136 tools,
137 middlewares,
138 config,
139 }
140 }
141
142 pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
143 self.run_cancellable(input, None, event_handler)
144 }
145
146 pub fn run_cancellable(
151 &self,
152 input: RunInput,
153 cancel: Option<Arc<AtomicBool>>,
154 mut event_handler: impl FnMut(AgentEvent),
155 ) -> RunOutput {
156 let mut events = Vec::new();
157 let mut messages = input.messages;
158 let mut state = input.state;
159 let mut final_answer: Option<String> = None;
160 let mut stop_reason = RunStopReason::BudgetExceeded;
161 let mut total_iterations = 0;
162 let mut total_usage = TokenUsage::default();
163
164 let start_event = AgentEvent::RunStarted {
165 run_id: input.run_id.clone(),
166 session_id: input.session_id.clone(),
167 provider: self.provider.name().to_string(),
168 max_iterations: self.config.max_iterations,
169 };
170 event_handler(start_event.clone());
171 events.push(start_event);
172
173 for iteration in 1..=self.config.max_iterations {
174 if let Some(ref flag) = cancel {
176 if flag.load(Ordering::Relaxed) {
177 stop_reason = RunStopReason::Cancelled;
178 let err_event = AgentEvent::RunErrored {
179 run_id: input.run_id.clone(),
180 session_id: input.session_id.clone(),
181 error: "run cancelled".to_string(),
182 };
183 event_handler(err_event.clone());
184 events.push(err_event);
185 break;
186 }
187 }
188
189 total_iterations = iteration;
190 let iter_event = AgentEvent::IterationStarted {
191 run_id: input.run_id.clone(),
192 session_id: input.session_id.clone(),
193 iteration,
194 };
195 event_handler(iter_event.clone());
196 events.push(iter_event);
197
198 let provider_request = ProviderRequest {
199 run_id: input.run_id.clone(),
200 session_id: input.session_id.clone(),
201 iteration,
202 messages: messages.clone(),
203 tools: self.tools.definitions(),
204 state: state.clone(),
205 };
206
207 if let Err(err) = self.run_before_model(&provider_request) {
208 stop_reason = RunStopReason::BlockedByPolicy;
209 let err_event = AgentEvent::RunErrored {
210 run_id: input.run_id.clone(),
211 session_id: input.session_id.clone(),
212 error: err.to_string(),
213 };
214 event_handler(err_event.clone());
215 events.push(err_event);
216 break;
217 }
218
219 let model_turn = match self.provider.complete(&provider_request) {
220 Ok(turn) => turn,
221 Err(err) => {
222 stop_reason = RunStopReason::Error;
223 let err_event = AgentEvent::RunErrored {
224 run_id: input.run_id.clone(),
225 session_id: input.session_id.clone(),
226 error: err.to_string(),
227 };
228 event_handler(err_event.clone());
229 events.push(err_event);
230 break;
231 }
232 };
233
234 if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
235 stop_reason = RunStopReason::BlockedByPolicy;
236 let err_event = AgentEvent::RunErrored {
237 run_id: input.run_id.clone(),
238 session_id: input.session_id.clone(),
239 error: err.to_string(),
240 };
241 event_handler(err_event.clone());
242 events.push(err_event);
243 break;
244 }
245
246 if let Some(ref usage) = model_turn.usage {
248 total_usage.accumulate(usage);
249 }
250
251 let output_event = AgentEvent::ModelOutput {
252 run_id: input.run_id.clone(),
253 session_id: input.session_id.clone(),
254 iteration,
255 stop_reason: model_turn.stop_reason,
256 directive_count: model_turn.directives.len(),
257 usage: model_turn.usage,
258 };
259 event_handler(output_event.clone());
260 events.push(output_event);
261
262 let mut requested_tool = false;
263
264 for directive in model_turn.directives {
265 match directive {
266 ModelDirective::Text { delta } => {
267 let delta_event = AgentEvent::TextDelta {
268 run_id: input.run_id.clone(),
269 session_id: input.session_id.clone(),
270 iteration,
271 delta: delta.clone(),
272 };
273 event_handler(delta_event.clone());
274 events.push(delta_event);
275 messages.push(ChatMessage::assistant(delta));
276 }
277 ModelDirective::ToolCall { call } => {
278 requested_tool = true;
279 let tc_event = AgentEvent::ToolCallRequested {
280 run_id: input.run_id.clone(),
281 session_id: input.session_id.clone(),
282 iteration,
283 call: call.clone(),
284 };
285 event_handler(tc_event.clone());
286 events.push(tc_event);
287
288 let context = ToolContext {
289 run_id: input.run_id.clone(),
290 session_id: input.session_id.clone(),
291 iteration,
292 };
293
294 if let Err(err) = self.run_pre_tool(&context, &call) {
295 stop_reason = RunStopReason::BlockedByPolicy;
296 let err_event = AgentEvent::ToolCallFailed {
297 run_id: input.run_id.clone(),
298 session_id: input.session_id.clone(),
299 iteration,
300 call_id: call.call_id.clone(),
301 tool_name: call.tool_name.clone(),
302 error: err.to_string(),
303 };
304 event_handler(err_event.clone());
305 events.push(err_event);
306 break;
307 }
308
309 let Some(tool) = self.tools.get(&call.tool_name) else {
310 stop_reason = RunStopReason::Error;
311 let err_event = AgentEvent::ToolCallFailed {
312 run_id: input.run_id.clone(),
313 session_id: input.session_id.clone(),
314 iteration,
315 call_id: call.call_id.clone(),
316 tool_name: call.tool_name.clone(),
317 error: format!(
318 "{}",
319 CoreError::ToolNotFound {
320 tool_name: call.tool_name.clone(),
321 }
322 ),
323 };
324 event_handler(err_event.clone());
325 events.push(err_event);
326 break;
327 };
328
329 match tool.execute(&call, &context) {
330 Ok(result) => {
331 if let Some(patch) = &result.state_patch {
332 match state.apply_patch(patch) {
333 Ok(()) => {
334 let patch_event = AgentEvent::StatePatched {
335 run_id: input.run_id.clone(),
336 session_id: input.session_id.clone(),
337 iteration,
338 patch: patch.clone(),
339 revision: state.revision,
340 };
341 event_handler(patch_event.clone());
342 events.push(patch_event);
343 }
344 Err(err) => {
345 stop_reason = RunStopReason::Error;
346 let err_event = AgentEvent::ToolCallFailed {
347 run_id: input.run_id.clone(),
348 session_id: input.session_id.clone(),
349 iteration,
350 call_id: call.call_id.clone(),
351 tool_name: call.tool_name.clone(),
352 error: err.to_string(),
353 };
354 event_handler(err_event.clone());
355 events.push(err_event);
356 break;
357 }
358 }
359 }
360
361 if let Err(err) = self.run_post_tool(&context, &result) {
362 stop_reason = RunStopReason::BlockedByPolicy;
363 let err_event = AgentEvent::ToolCallFailed {
364 run_id: input.run_id.clone(),
365 session_id: input.session_id.clone(),
366 iteration,
367 call_id: call.call_id.clone(),
368 tool_name: call.tool_name.clone(),
369 error: err.to_string(),
370 };
371 event_handler(err_event.clone());
372 events.push(err_event);
373 break;
374 }
375
376 let completed_event = AgentEvent::ToolCallCompleted {
377 run_id: input.run_id.clone(),
378 session_id: input.session_id.clone(),
379 iteration,
380 result: ToolResultSummary::from(&result),
381 };
382 event_handler(completed_event.clone());
383 events.push(completed_event);
384
385 messages.push(ChatMessage::tool_result(
386 &result.call_id,
387 serde_json::to_string(&result.output)
388 .unwrap_or_else(|_| "{}".to_string()),
389 ));
390 }
391 Err(err) => {
392 stop_reason = RunStopReason::Error;
393 let err_event = AgentEvent::ToolCallFailed {
394 run_id: input.run_id.clone(),
395 session_id: input.session_id.clone(),
396 iteration,
397 call_id: call.call_id.clone(),
398 tool_name: call.tool_name.clone(),
399 error: err.to_string(),
400 };
401 event_handler(err_event.clone());
402 events.push(err_event);
403 break;
404 }
405 }
406 }
407 ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
408 Ok(()) => {
409 let patch_event = AgentEvent::StatePatched {
410 run_id: input.run_id.clone(),
411 session_id: input.session_id.clone(),
412 iteration,
413 patch: patch.clone(),
414 revision: state.revision,
415 };
416 event_handler(patch_event.clone());
417 events.push(patch_event);
418 }
419 Err(err) => {
420 stop_reason = RunStopReason::Error;
421 let err_event = AgentEvent::RunErrored {
422 run_id: input.run_id.clone(),
423 session_id: input.session_id.clone(),
424 error: err.to_string(),
425 };
426 event_handler(err_event.clone());
427 events.push(err_event);
428 break;
429 }
430 },
431 ModelDirective::FinalAnswer { text } => {
432 final_answer = Some(text.clone());
433 let delta_event = AgentEvent::TextDelta {
434 run_id: input.run_id.clone(),
435 session_id: input.session_id.clone(),
436 iteration,
437 delta: text.clone(),
438 };
439 event_handler(delta_event.clone());
440 events.push(delta_event);
441 messages.push(ChatMessage::assistant(text));
442 }
443 }
444 }
445
446 if matches!(
447 stop_reason,
448 RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
449 ) {
450 break;
451 }
452
453 match model_turn.stop_reason {
454 ModelStopReason::EndTurn => {
455 stop_reason = RunStopReason::Completed;
456 break;
457 }
458 ModelStopReason::NeedsUser => {
459 stop_reason = RunStopReason::NeedsUser;
460 break;
461 }
462 ModelStopReason::Safety => {
463 stop_reason = RunStopReason::BlockedByPolicy;
464 break;
465 }
466 ModelStopReason::ToolUse => {
467 if !requested_tool {
468 stop_reason = RunStopReason::Error;
469 let err_event = AgentEvent::RunErrored {
470 run_id: input.run_id.clone(),
471 session_id: input.session_id.clone(),
472 error: "model requested tool_use stop reason without tool call"
473 .to_string(),
474 };
475 event_handler(err_event.clone());
476 events.push(err_event);
477 break;
478 }
479 }
480 ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
481 if !requested_tool {
482 stop_reason = RunStopReason::Error;
483 let err_event = AgentEvent::RunErrored {
484 run_id: input.run_id.clone(),
485 session_id: input.session_id.clone(),
486 error: "model returned non-terminal stop reason without tool call"
487 .to_string(),
488 };
489 event_handler(err_event.clone());
490 events.push(err_event);
491 break;
492 }
493 }
494 }
495 }
496
497 if total_iterations == self.config.max_iterations
498 && stop_reason == RunStopReason::BudgetExceeded
499 {
500 let err_event = AgentEvent::RunErrored {
501 run_id: input.run_id.clone(),
502 session_id: input.session_id.clone(),
503 error: "max iteration budget exceeded".to_string(),
504 };
505 event_handler(err_event.clone());
506 events.push(err_event);
507 }
508
509 let finished_event = AgentEvent::RunFinished {
510 run_id: input.run_id.clone(),
511 session_id: input.session_id.clone(),
512 reason: stop_reason,
513 total_iterations,
514 final_answer: final_answer.clone(),
515 };
516 event_handler(finished_event.clone());
517 events.push(finished_event);
518
519 let output = RunOutput {
520 run_id: input.run_id,
521 session_id: input.session_id,
522 events,
523 messages,
524 state,
525 reason: stop_reason,
526 final_answer,
527 total_usage,
528 };
529
530 let _ = self
531 .middlewares
532 .iter()
533 .try_for_each(|middleware| middleware.on_run_finished(&output));
534
535 output
536 }
537
538 fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
539 self.middlewares
540 .iter()
541 .try_for_each(|middleware| middleware.before_model_call(request))
542 }
543
544 fn run_after_model(
545 &self,
546 request: &ProviderRequest,
547 response: &ModelTurn,
548 ) -> Result<(), CoreError> {
549 self.middlewares
550 .iter()
551 .try_for_each(|middleware| middleware.after_model_call(request, response))
552 }
553
554 fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
555 self.middlewares
556 .iter()
557 .try_for_each(|middleware| middleware.pre_tool_call(context, call))
558 }
559
560 fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
561 self.middlewares
562 .iter()
563 .try_for_each(|middleware| middleware.post_tool_call(context, result))
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::protocol::{
571 ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
572 };
573 use serde_json::json;
574 use std::sync::Mutex;
575
576 struct ScriptedProvider {
577 turns: Vec<ModelTurn>,
578 cursor: Mutex<usize>,
579 }
580
581 impl Provider for ScriptedProvider {
582 fn name(&self) -> &str {
583 "scripted"
584 }
585
586 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
587 let mut cursor = self
588 .cursor
589 .lock()
590 .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
591 let idx = *cursor;
592 let Some(turn) = self.turns.get(idx) else {
593 return Err(CoreError::Provider("no scripted turn left".to_string()));
594 };
595 *cursor += 1;
596 Ok(turn.clone())
597 }
598 }
599
600 struct EchoTool;
601
602 impl Tool for EchoTool {
603 fn definition(&self) -> ToolDefinition {
604 ToolDefinition {
605 name: "echo".to_string(),
606 description: "Echoes the provided value".to_string(),
607 input_schema: json!({
608 "type": "object",
609 "properties": { "value": { "type": "string" } },
610 "required": ["value"]
611 }),
612 title: None,
613 output_schema: None,
614 annotations: None,
615 category: None,
616 tags: Vec::new(),
617 timeout_secs: None,
618 }
619 }
620
621 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
622 let value = call
623 .input
624 .get("value")
625 .cloned()
626 .unwrap_or_else(|| json!(null));
627 Ok(ToolResult {
628 call_id: call.call_id.clone(),
629 tool_name: call.tool_name.clone(),
630 output: json!({ "echo": value.clone() }),
631 content: None,
632 is_error: false,
633 state_patch: Some(StatePatch {
634 format: StatePatchFormat::MergePatch,
635 patch: json!({ "last_echo": value }),
636 source: StatePatchSource::Tool,
637 }),
638 })
639 }
640 }
641
642 #[test]
643 fn orchestrator_runs_tool_then_finishes() {
644 let provider = ScriptedProvider {
645 turns: vec![
646 ModelTurn {
647 directives: vec![ModelDirective::ToolCall {
648 call: ToolCall {
649 call_id: "call-1".to_string(),
650 tool_name: "echo".to_string(),
651 input: json!({ "value": "hello" }),
652 },
653 }],
654 stop_reason: ModelStopReason::ToolUse,
655 usage: None,
656 },
657 ModelTurn {
658 directives: vec![ModelDirective::FinalAnswer {
659 text: "done".to_string(),
660 }],
661 stop_reason: ModelStopReason::EndTurn,
662 usage: None,
663 },
664 ],
665 cursor: Mutex::new(0),
666 };
667
668 let mut tools = ToolRegistry::default();
669 tools.register(EchoTool);
670
671 let orchestrator = Orchestrator::new(
672 Arc::new(provider),
673 tools,
674 Vec::new(),
675 OrchestratorConfig { max_iterations: 4 },
676 );
677
678 let output = orchestrator.run(
679 RunInput {
680 run_id: "run-1".to_string(),
681 session_id: "session-1".to_string(),
682 messages: vec![ChatMessage::user("test")],
683 state: AppState::default(),
684 },
685 |_| {},
686 );
687
688 assert_eq!(output.reason, RunStopReason::Completed);
689 assert_eq!(output.final_answer.as_deref(), Some("done"));
690 assert_eq!(output.state.revision, 1);
691 assert_eq!(output.state.data["last_echo"], "hello");
692
693 assert!(output
694 .events
695 .iter()
696 .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. })));
697 assert!(output.events.iter().any(|event| matches!(
698 event,
699 AgentEvent::RunFinished {
700 reason: RunStopReason::Completed,
701 ..
702 }
703 )));
704 }
705
706 #[test]
707 fn provider_error_stops_run() {
708 struct FailProvider;
709 impl Provider for FailProvider {
710 fn name(&self) -> &str {
711 "fail"
712 }
713 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
714 Err(CoreError::Provider("connection refused".to_string()))
715 }
716 }
717
718 let orchestrator = Orchestrator::new(
719 Arc::new(FailProvider),
720 ToolRegistry::default(),
721 Vec::new(),
722 OrchestratorConfig { max_iterations: 4 },
723 );
724
725 let output = orchestrator.run(
726 RunInput {
727 run_id: "run-1".to_string(),
728 session_id: "s1".to_string(),
729 messages: vec![ChatMessage::user("test")],
730 state: AppState::default(),
731 },
732 |_| {},
733 );
734
735 assert_eq!(output.reason, RunStopReason::Error);
736 assert!(output
737 .events
738 .iter()
739 .any(|e| matches!(e, AgentEvent::RunErrored { .. })));
740 }
741
742 #[test]
743 fn tool_not_found_stops_run() {
744 let provider = ScriptedProvider {
745 turns: vec![ModelTurn {
746 directives: vec![ModelDirective::ToolCall {
747 call: ToolCall {
748 call_id: "c1".to_string(),
749 tool_name: "nonexistent".to_string(),
750 input: json!({}),
751 },
752 }],
753 stop_reason: ModelStopReason::ToolUse,
754 usage: None,
755 }],
756 cursor: Mutex::new(0),
757 };
758
759 let orchestrator = Orchestrator::new(
760 Arc::new(provider),
761 ToolRegistry::default(),
762 Vec::new(),
763 OrchestratorConfig { max_iterations: 4 },
764 );
765
766 let output = orchestrator.run(
767 RunInput {
768 run_id: "run-1".to_string(),
769 session_id: "s1".to_string(),
770 messages: vec![ChatMessage::user("test")],
771 state: AppState::default(),
772 },
773 |_| {},
774 );
775
776 assert_eq!(output.reason, RunStopReason::Error);
777 assert!(output
778 .events
779 .iter()
780 .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. })));
781 }
782
783 #[test]
784 fn middleware_blocks_model_call() {
785 struct BlockMiddleware;
786 impl Middleware for BlockMiddleware {
787 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
788 Err(CoreError::Middleware("blocked by policy".to_string()))
789 }
790 }
791
792 let provider = ScriptedProvider {
793 turns: vec![ModelTurn {
794 directives: vec![ModelDirective::Text {
795 delta: "hi".to_string(),
796 }],
797 stop_reason: ModelStopReason::EndTurn,
798 usage: None,
799 }],
800 cursor: Mutex::new(0),
801 };
802
803 let orchestrator = Orchestrator::new(
804 Arc::new(provider),
805 ToolRegistry::default(),
806 vec![Arc::new(BlockMiddleware)],
807 OrchestratorConfig { max_iterations: 4 },
808 );
809
810 let output = orchestrator.run(
811 RunInput {
812 run_id: "run-1".to_string(),
813 session_id: "s1".to_string(),
814 messages: vec![ChatMessage::user("test")],
815 state: AppState::default(),
816 },
817 |_| {},
818 );
819
820 assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
821 }
822
823 #[test]
824 fn budget_exceeded_when_iterations_exhausted() {
825 let provider = ScriptedProvider {
829 turns: vec![
830 ModelTurn {
831 directives: vec![ModelDirective::ToolCall {
832 call: ToolCall {
833 call_id: "c1".to_string(),
834 tool_name: "echo".to_string(),
835 input: json!({"value": "1"}),
836 },
837 }],
838 stop_reason: ModelStopReason::ToolUse,
839 usage: None,
840 },
841 ModelTurn {
842 directives: vec![ModelDirective::ToolCall {
843 call: ToolCall {
844 call_id: "c2".to_string(),
845 tool_name: "echo".to_string(),
846 input: json!({"value": "2"}),
847 },
848 }],
849 stop_reason: ModelStopReason::ToolUse,
850 usage: None,
851 },
852 ],
855 cursor: Mutex::new(0),
856 };
857
858 let mut tools = ToolRegistry::default();
859 tools.register(EchoTool);
860
861 let orchestrator = Orchestrator::new(
862 Arc::new(provider),
863 tools,
864 Vec::new(),
865 OrchestratorConfig { max_iterations: 2 },
866 );
867
868 let output = orchestrator.run(
869 RunInput {
870 run_id: "run-1".to_string(),
871 session_id: "s1".to_string(),
872 messages: vec![ChatMessage::user("test")],
873 state: AppState::default(),
874 },
875 |_| {},
876 );
877
878 assert_eq!(output.reason, RunStopReason::BudgetExceeded);
879 }
880
881 #[test]
882 fn text_only_response_completes() {
883 let provider = ScriptedProvider {
884 turns: vec![ModelTurn {
885 directives: vec![ModelDirective::Text {
886 delta: "Hello, world!".to_string(),
887 }],
888 stop_reason: ModelStopReason::EndTurn,
889 usage: None,
890 }],
891 cursor: Mutex::new(0),
892 };
893
894 let orchestrator = Orchestrator::new(
895 Arc::new(provider),
896 ToolRegistry::default(),
897 Vec::new(),
898 OrchestratorConfig { max_iterations: 4 },
899 );
900
901 let output = orchestrator.run(
902 RunInput {
903 run_id: "run-1".to_string(),
904 session_id: "s1".to_string(),
905 messages: vec![ChatMessage::user("hi")],
906 state: AppState::default(),
907 },
908 |_| {},
909 );
910
911 assert_eq!(output.reason, RunStopReason::Completed);
912 assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
913 }
914
915 #[test]
916 fn event_handler_receives_all_events() {
917 let provider = ScriptedProvider {
918 turns: vec![ModelTurn {
919 directives: vec![ModelDirective::FinalAnswer {
920 text: "done".to_string(),
921 }],
922 stop_reason: ModelStopReason::EndTurn,
923 usage: None,
924 }],
925 cursor: Mutex::new(0),
926 };
927
928 let orchestrator = Orchestrator::new(
929 Arc::new(provider),
930 ToolRegistry::default(),
931 Vec::new(),
932 OrchestratorConfig { max_iterations: 4 },
933 );
934
935 let received = Arc::new(Mutex::new(Vec::new()));
936 let received_clone = received.clone();
937
938 orchestrator.run(
939 RunInput {
940 run_id: "run-1".to_string(),
941 session_id: "s1".to_string(),
942 messages: vec![ChatMessage::user("test")],
943 state: AppState::default(),
944 },
945 move |event| {
946 received_clone.lock().unwrap().push(event);
947 },
948 );
949
950 let events = received.lock().unwrap();
951 assert!(events.len() >= 4); assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
953 assert!(matches!(
954 events.last().unwrap(),
955 AgentEvent::RunFinished { .. }
956 ));
957 }
958
959 #[test]
960 fn tool_result_includes_call_id() {
961 let provider = ScriptedProvider {
962 turns: vec![
963 ModelTurn {
964 directives: vec![ModelDirective::ToolCall {
965 call: ToolCall {
966 call_id: "my-call-id".to_string(),
967 tool_name: "echo".to_string(),
968 input: json!({"value": "test"}),
969 },
970 }],
971 stop_reason: ModelStopReason::ToolUse,
972 usage: None,
973 },
974 ModelTurn {
975 directives: vec![ModelDirective::FinalAnswer {
976 text: "ok".to_string(),
977 }],
978 stop_reason: ModelStopReason::EndTurn,
979 usage: None,
980 },
981 ],
982 cursor: Mutex::new(0),
983 };
984
985 let mut tools = ToolRegistry::default();
986 tools.register(EchoTool);
987
988 let orchestrator = Orchestrator::new(
989 Arc::new(provider),
990 tools,
991 Vec::new(),
992 OrchestratorConfig { max_iterations: 4 },
993 );
994
995 let output = orchestrator.run(
996 RunInput {
997 run_id: "run-1".to_string(),
998 session_id: "s1".to_string(),
999 messages: vec![ChatMessage::user("test")],
1000 state: AppState::default(),
1001 },
1002 |_| {},
1003 );
1004
1005 let tool_msg = output
1007 .messages
1008 .iter()
1009 .find(|m| m.role == crate::protocol::Role::Tool)
1010 .expect("should have tool message");
1011 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1012 }
1013
1014 #[test]
1015 fn cancellation_stops_run() {
1016 let provider = ScriptedProvider {
1017 turns: vec![
1018 ModelTurn {
1019 directives: vec![ModelDirective::ToolCall {
1020 call: ToolCall {
1021 call_id: "c1".to_string(),
1022 tool_name: "echo".to_string(),
1023 input: json!({"value": "1"}),
1024 },
1025 }],
1026 stop_reason: ModelStopReason::ToolUse,
1027 usage: None,
1028 },
1029 ModelTurn {
1030 directives: vec![ModelDirective::FinalAnswer {
1031 text: "should not reach".to_string(),
1032 }],
1033 stop_reason: ModelStopReason::EndTurn,
1034 usage: None,
1035 },
1036 ],
1037 cursor: Mutex::new(0),
1038 };
1039
1040 let mut tools = ToolRegistry::default();
1041 tools.register(EchoTool);
1042
1043 let orchestrator = Orchestrator::new(
1044 Arc::new(provider),
1045 tools,
1046 Vec::new(),
1047 OrchestratorConfig { max_iterations: 10 },
1048 );
1049
1050 let cancel = Arc::new(AtomicBool::new(false));
1052 let cancel_clone = cancel.clone();
1053 let call_count = Arc::new(Mutex::new(0u32));
1054 let call_count_clone = call_count.clone();
1055
1056 let output = orchestrator.run_cancellable(
1057 RunInput {
1058 run_id: "run-1".to_string(),
1059 session_id: "s1".to_string(),
1060 messages: vec![ChatMessage::user("test")],
1061 state: AppState::default(),
1062 },
1063 Some(cancel_clone),
1064 move |event| {
1065 if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1067 let mut count = call_count_clone.lock().unwrap();
1068 *count += 1;
1069 if *count >= 1 {
1070 cancel.store(true, Ordering::Relaxed);
1071 }
1072 }
1073 },
1074 );
1075
1076 assert_eq!(output.reason, RunStopReason::Cancelled);
1077 assert!(output.final_answer.is_none());
1079 }
1080
1081 #[test]
1082 fn token_usage_accumulated() {
1083 let provider = ScriptedProvider {
1084 turns: vec![
1085 ModelTurn {
1086 directives: vec![ModelDirective::ToolCall {
1087 call: ToolCall {
1088 call_id: "c1".to_string(),
1089 tool_name: "echo".to_string(),
1090 input: json!({"value": "hi"}),
1091 },
1092 }],
1093 stop_reason: ModelStopReason::ToolUse,
1094 usage: Some(TokenUsage {
1095 input_tokens: 100,
1096 output_tokens: 50,
1097 cache_read_tokens: 0,
1098 cache_creation_tokens: 0,
1099 }),
1100 },
1101 ModelTurn {
1102 directives: vec![ModelDirective::FinalAnswer {
1103 text: "done".to_string(),
1104 }],
1105 stop_reason: ModelStopReason::EndTurn,
1106 usage: Some(TokenUsage {
1107 input_tokens: 200,
1108 output_tokens: 30,
1109 cache_read_tokens: 0,
1110 cache_creation_tokens: 0,
1111 }),
1112 },
1113 ],
1114 cursor: Mutex::new(0),
1115 };
1116
1117 let mut tools = ToolRegistry::default();
1118 tools.register(EchoTool);
1119
1120 let orchestrator = Orchestrator::new(
1121 Arc::new(provider),
1122 tools,
1123 Vec::new(),
1124 OrchestratorConfig { max_iterations: 4 },
1125 );
1126
1127 let output = orchestrator.run(
1128 RunInput {
1129 run_id: "run-1".to_string(),
1130 session_id: "s1".to_string(),
1131 messages: vec![ChatMessage::user("test")],
1132 state: AppState::default(),
1133 },
1134 |_| {},
1135 );
1136
1137 assert_eq!(output.reason, RunStopReason::Completed);
1138 assert_eq!(output.total_usage.input_tokens, 300);
1139 assert_eq!(output.total_usage.output_tokens, 80);
1140 assert_eq!(output.total_usage.total(), 380);
1141 }
1142}