1use crate::error::CoreError;
2use crate::protocol::{
3 AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, ToolCall,
4 ToolDefinition, ToolResult, ToolResultSummary,
5};
6use crate::state::AppState;
7use std::collections::BTreeMap;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct ProviderRequest {
12 pub run_id: String,
13 pub session_id: String,
14 pub iteration: u32,
15 pub messages: Vec<ChatMessage>,
16 pub tools: Vec<ToolDefinition>,
17 pub state: AppState,
18}
19
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
22 fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
23}
24
25#[derive(Debug, Clone)]
26pub struct ToolContext {
27 pub run_id: String,
28 pub session_id: String,
29 pub iteration: u32,
30}
31
32pub trait Tool: Send + Sync {
33 fn definition(&self) -> ToolDefinition;
34 fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
35}
36
37pub trait Middleware: Send + Sync {
38 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
39 Ok(())
40 }
41
42 fn after_model_call(
43 &self,
44 _request: &ProviderRequest,
45 _response: &ModelTurn,
46 ) -> Result<(), CoreError> {
47 Ok(())
48 }
49
50 fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
51 Ok(())
52 }
53
54 fn post_tool_call(
55 &self,
56 _context: &ToolContext,
57 _result: &ToolResult,
58 ) -> Result<(), CoreError> {
59 Ok(())
60 }
61
62 fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
63 Ok(())
64 }
65}
66
67#[derive(Clone, Default)]
68pub struct ToolRegistry {
69 tools: BTreeMap<String, Arc<dyn Tool>>,
70}
71
72impl ToolRegistry {
73 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
74 self.tools
75 .insert(tool.definition().name.clone(), Arc::new(tool));
76 }
77
78 pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
79 self.tools.get(tool_name).cloned()
80 }
81
82 pub fn definitions(&self) -> Vec<ToolDefinition> {
83 self.tools.values().map(|tool| tool.definition()).collect()
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct OrchestratorConfig {
89 pub max_iterations: u32,
90}
91
92impl Default for OrchestratorConfig {
93 fn default() -> Self {
94 Self { max_iterations: 24 }
95 }
96}
97
98#[derive(Debug, Clone)]
99pub struct RunInput {
100 pub run_id: String,
101 pub session_id: String,
102 pub messages: Vec<ChatMessage>,
103 pub state: AppState,
104}
105
106#[derive(Debug, Clone)]
107pub struct RunOutput {
108 pub run_id: String,
109 pub session_id: String,
110 pub events: Vec<AgentEvent>,
111 pub messages: Vec<ChatMessage>,
112 pub state: AppState,
113 pub reason: RunStopReason,
114 pub final_answer: Option<String>,
115}
116
117pub struct Orchestrator {
118 provider: Arc<dyn Provider>,
119 tools: ToolRegistry,
120 middlewares: Vec<Arc<dyn Middleware>>,
121 config: OrchestratorConfig,
122}
123
124impl Orchestrator {
125 pub fn new(
126 provider: Arc<dyn Provider>,
127 tools: ToolRegistry,
128 middlewares: Vec<Arc<dyn Middleware>>,
129 config: OrchestratorConfig,
130 ) -> Self {
131 Self {
132 provider,
133 tools,
134 middlewares,
135 config,
136 }
137 }
138
139 pub fn run(&self, input: RunInput, mut event_handler: impl FnMut(AgentEvent)) -> RunOutput {
140 let mut events = Vec::new();
141 let mut messages = input.messages;
142 let mut state = input.state;
143 let mut final_answer: Option<String> = None;
144 let mut stop_reason = RunStopReason::BudgetExceeded;
145 let mut total_iterations = 0;
146
147 let start_event = AgentEvent::RunStarted {
148 run_id: input.run_id.clone(),
149 session_id: input.session_id.clone(),
150 provider: self.provider.name().to_string(),
151 max_iterations: self.config.max_iterations,
152 };
153 event_handler(start_event.clone());
154 events.push(start_event);
155
156 for iteration in 1..=self.config.max_iterations {
157 total_iterations = iteration;
158 let iter_event = AgentEvent::IterationStarted {
159 run_id: input.run_id.clone(),
160 session_id: input.session_id.clone(),
161 iteration,
162 };
163 event_handler(iter_event.clone());
164 events.push(iter_event);
165
166 let provider_request = ProviderRequest {
167 run_id: input.run_id.clone(),
168 session_id: input.session_id.clone(),
169 iteration,
170 messages: messages.clone(),
171 tools: self.tools.definitions(),
172 state: state.clone(),
173 };
174
175 if let Err(err) = self.run_before_model(&provider_request) {
176 stop_reason = RunStopReason::BlockedByPolicy;
177 let err_event = AgentEvent::RunErrored {
178 run_id: input.run_id.clone(),
179 session_id: input.session_id.clone(),
180 error: err.to_string(),
181 };
182 event_handler(err_event.clone());
183 events.push(err_event);
184 break;
185 }
186
187 let model_turn = match self.provider.complete(&provider_request) {
188 Ok(turn) => turn,
189 Err(err) => {
190 stop_reason = RunStopReason::Error;
191 let err_event = AgentEvent::RunErrored {
192 run_id: input.run_id.clone(),
193 session_id: input.session_id.clone(),
194 error: err.to_string(),
195 };
196 event_handler(err_event.clone());
197 events.push(err_event);
198 break;
199 }
200 };
201
202 if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
203 stop_reason = RunStopReason::BlockedByPolicy;
204 let err_event = AgentEvent::RunErrored {
205 run_id: input.run_id.clone(),
206 session_id: input.session_id.clone(),
207 error: err.to_string(),
208 };
209 event_handler(err_event.clone());
210 events.push(err_event);
211 break;
212 }
213
214 let output_event = AgentEvent::ModelOutput {
215 run_id: input.run_id.clone(),
216 session_id: input.session_id.clone(),
217 iteration,
218 stop_reason: model_turn.stop_reason,
219 directive_count: model_turn.directives.len(),
220 };
221 event_handler(output_event.clone());
222 events.push(output_event);
223
224 let mut requested_tool = false;
225
226 for directive in model_turn.directives {
227 match directive {
228 ModelDirective::Text { delta } => {
229 let delta_event = AgentEvent::TextDelta {
230 run_id: input.run_id.clone(),
231 session_id: input.session_id.clone(),
232 iteration,
233 delta: delta.clone(),
234 };
235 event_handler(delta_event.clone());
236 events.push(delta_event);
237 messages.push(ChatMessage::assistant(delta));
238 }
239 ModelDirective::ToolCall { call } => {
240 requested_tool = true;
241 let tc_event = AgentEvent::ToolCallRequested {
242 run_id: input.run_id.clone(),
243 session_id: input.session_id.clone(),
244 iteration,
245 call: call.clone(),
246 };
247 event_handler(tc_event.clone());
248 events.push(tc_event);
249
250 let context = ToolContext {
251 run_id: input.run_id.clone(),
252 session_id: input.session_id.clone(),
253 iteration,
254 };
255
256 if let Err(err) = self.run_pre_tool(&context, &call) {
257 stop_reason = RunStopReason::BlockedByPolicy;
258 let err_event = AgentEvent::ToolCallFailed {
259 run_id: input.run_id.clone(),
260 session_id: input.session_id.clone(),
261 iteration,
262 call_id: call.call_id.clone(),
263 tool_name: call.tool_name.clone(),
264 error: err.to_string(),
265 };
266 event_handler(err_event.clone());
267 events.push(err_event);
268 break;
269 }
270
271 let Some(tool) = self.tools.get(&call.tool_name) else {
272 stop_reason = RunStopReason::Error;
273 let err_event = AgentEvent::ToolCallFailed {
274 run_id: input.run_id.clone(),
275 session_id: input.session_id.clone(),
276 iteration,
277 call_id: call.call_id.clone(),
278 tool_name: call.tool_name.clone(),
279 error: format!(
280 "{}",
281 CoreError::ToolNotFound {
282 tool_name: call.tool_name.clone(),
283 }
284 ),
285 };
286 event_handler(err_event.clone());
287 events.push(err_event);
288 break;
289 };
290
291 match tool.execute(&call, &context) {
292 Ok(result) => {
293 if let Some(patch) = &result.state_patch {
294 match state.apply_patch(patch) {
295 Ok(()) => {
296 let patch_event = AgentEvent::StatePatched {
297 run_id: input.run_id.clone(),
298 session_id: input.session_id.clone(),
299 iteration,
300 patch: patch.clone(),
301 revision: state.revision,
302 };
303 event_handler(patch_event.clone());
304 events.push(patch_event);
305 }
306 Err(err) => {
307 stop_reason = RunStopReason::Error;
308 let err_event = AgentEvent::ToolCallFailed {
309 run_id: input.run_id.clone(),
310 session_id: input.session_id.clone(),
311 iteration,
312 call_id: call.call_id.clone(),
313 tool_name: call.tool_name.clone(),
314 error: err.to_string(),
315 };
316 event_handler(err_event.clone());
317 events.push(err_event);
318 break;
319 }
320 }
321 }
322
323 if let Err(err) = self.run_post_tool(&context, &result) {
324 stop_reason = RunStopReason::BlockedByPolicy;
325 let err_event = AgentEvent::ToolCallFailed {
326 run_id: input.run_id.clone(),
327 session_id: input.session_id.clone(),
328 iteration,
329 call_id: call.call_id.clone(),
330 tool_name: call.tool_name.clone(),
331 error: err.to_string(),
332 };
333 event_handler(err_event.clone());
334 events.push(err_event);
335 break;
336 }
337
338 let completed_event = AgentEvent::ToolCallCompleted {
339 run_id: input.run_id.clone(),
340 session_id: input.session_id.clone(),
341 iteration,
342 result: ToolResultSummary::from(&result),
343 };
344 event_handler(completed_event.clone());
345 events.push(completed_event);
346
347 messages.push(ChatMessage::tool_result(
348 &result.call_id,
349 serde_json::to_string(&result.output)
350 .unwrap_or_else(|_| "{}".to_string()),
351 ));
352 }
353 Err(err) => {
354 stop_reason = RunStopReason::Error;
355 let err_event = AgentEvent::ToolCallFailed {
356 run_id: input.run_id.clone(),
357 session_id: input.session_id.clone(),
358 iteration,
359 call_id: call.call_id.clone(),
360 tool_name: call.tool_name.clone(),
361 error: err.to_string(),
362 };
363 event_handler(err_event.clone());
364 events.push(err_event);
365 break;
366 }
367 }
368 }
369 ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
370 Ok(()) => {
371 let patch_event = AgentEvent::StatePatched {
372 run_id: input.run_id.clone(),
373 session_id: input.session_id.clone(),
374 iteration,
375 patch: patch.clone(),
376 revision: state.revision,
377 };
378 event_handler(patch_event.clone());
379 events.push(patch_event);
380 }
381 Err(err) => {
382 stop_reason = RunStopReason::Error;
383 let err_event = AgentEvent::RunErrored {
384 run_id: input.run_id.clone(),
385 session_id: input.session_id.clone(),
386 error: err.to_string(),
387 };
388 event_handler(err_event.clone());
389 events.push(err_event);
390 break;
391 }
392 },
393 ModelDirective::FinalAnswer { text } => {
394 final_answer = Some(text.clone());
395 let delta_event = AgentEvent::TextDelta {
396 run_id: input.run_id.clone(),
397 session_id: input.session_id.clone(),
398 iteration,
399 delta: text.clone(),
400 };
401 event_handler(delta_event.clone());
402 events.push(delta_event);
403 messages.push(ChatMessage::assistant(text));
404 }
405 }
406 }
407
408 if matches!(
409 stop_reason,
410 RunStopReason::Error | RunStopReason::BlockedByPolicy
411 ) {
412 break;
413 }
414
415 match model_turn.stop_reason {
416 ModelStopReason::EndTurn => {
417 stop_reason = RunStopReason::Completed;
418 break;
419 }
420 ModelStopReason::NeedsUser => {
421 stop_reason = RunStopReason::NeedsUser;
422 break;
423 }
424 ModelStopReason::Safety => {
425 stop_reason = RunStopReason::BlockedByPolicy;
426 break;
427 }
428 ModelStopReason::ToolUse => {
429 if !requested_tool {
430 stop_reason = RunStopReason::Error;
431 let err_event = AgentEvent::RunErrored {
432 run_id: input.run_id.clone(),
433 session_id: input.session_id.clone(),
434 error: "model requested tool_use stop reason without tool call"
435 .to_string(),
436 };
437 event_handler(err_event.clone());
438 events.push(err_event);
439 break;
440 }
441 }
442 ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
443 if !requested_tool {
444 stop_reason = RunStopReason::Error;
445 let err_event = AgentEvent::RunErrored {
446 run_id: input.run_id.clone(),
447 session_id: input.session_id.clone(),
448 error: "model returned non-terminal stop reason without tool call"
449 .to_string(),
450 };
451 event_handler(err_event.clone());
452 events.push(err_event);
453 break;
454 }
455 }
456 }
457 }
458
459 if total_iterations == self.config.max_iterations
460 && stop_reason == RunStopReason::BudgetExceeded
461 {
462 let err_event = AgentEvent::RunErrored {
463 run_id: input.run_id.clone(),
464 session_id: input.session_id.clone(),
465 error: "max iteration budget exceeded".to_string(),
466 };
467 event_handler(err_event.clone());
468 events.push(err_event);
469 }
470
471 let finished_event = AgentEvent::RunFinished {
472 run_id: input.run_id.clone(),
473 session_id: input.session_id.clone(),
474 reason: stop_reason,
475 total_iterations,
476 final_answer: final_answer.clone(),
477 };
478 event_handler(finished_event.clone());
479 events.push(finished_event);
480
481 let output = RunOutput {
482 run_id: input.run_id,
483 session_id: input.session_id,
484 events,
485 messages,
486 state,
487 reason: stop_reason,
488 final_answer,
489 };
490
491 let _ = self
492 .middlewares
493 .iter()
494 .try_for_each(|middleware| middleware.on_run_finished(&output));
495
496 output
497 }
498
499 fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
500 self.middlewares
501 .iter()
502 .try_for_each(|middleware| middleware.before_model_call(request))
503 }
504
505 fn run_after_model(
506 &self,
507 request: &ProviderRequest,
508 response: &ModelTurn,
509 ) -> Result<(), CoreError> {
510 self.middlewares
511 .iter()
512 .try_for_each(|middleware| middleware.after_model_call(request, response))
513 }
514
515 fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
516 self.middlewares
517 .iter()
518 .try_for_each(|middleware| middleware.pre_tool_call(context, call))
519 }
520
521 fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
522 self.middlewares
523 .iter()
524 .try_for_each(|middleware| middleware.post_tool_call(context, result))
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::protocol::{
532 ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
533 };
534 use serde_json::json;
535 use std::sync::Mutex;
536
537 struct ScriptedProvider {
538 turns: Vec<ModelTurn>,
539 cursor: Mutex<usize>,
540 }
541
542 impl Provider for ScriptedProvider {
543 fn name(&self) -> &str {
544 "scripted"
545 }
546
547 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
548 let mut cursor = self
549 .cursor
550 .lock()
551 .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
552 let idx = *cursor;
553 let Some(turn) = self.turns.get(idx) else {
554 return Err(CoreError::Provider("no scripted turn left".to_string()));
555 };
556 *cursor += 1;
557 Ok(turn.clone())
558 }
559 }
560
561 struct EchoTool;
562
563 impl Tool for EchoTool {
564 fn definition(&self) -> ToolDefinition {
565 ToolDefinition {
566 name: "echo".to_string(),
567 description: "Echoes the provided value".to_string(),
568 input_schema: json!({
569 "type": "object",
570 "properties": { "value": { "type": "string" } },
571 "required": ["value"]
572 }),
573 title: None,
574 output_schema: None,
575 annotations: None,
576 category: None,
577 tags: Vec::new(),
578 timeout_secs: None,
579 }
580 }
581
582 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
583 let value = call
584 .input
585 .get("value")
586 .cloned()
587 .unwrap_or_else(|| json!(null));
588 Ok(ToolResult {
589 call_id: call.call_id.clone(),
590 tool_name: call.tool_name.clone(),
591 output: json!({ "echo": value.clone() }),
592 content: None,
593 is_error: false,
594 state_patch: Some(StatePatch {
595 format: StatePatchFormat::MergePatch,
596 patch: json!({ "last_echo": value }),
597 source: StatePatchSource::Tool,
598 }),
599 })
600 }
601 }
602
603 #[test]
604 fn orchestrator_runs_tool_then_finishes() {
605 let provider = ScriptedProvider {
606 turns: vec![
607 ModelTurn {
608 directives: vec![ModelDirective::ToolCall {
609 call: ToolCall {
610 call_id: "call-1".to_string(),
611 tool_name: "echo".to_string(),
612 input: json!({ "value": "hello" }),
613 },
614 }],
615 stop_reason: ModelStopReason::ToolUse,
616 },
617 ModelTurn {
618 directives: vec![ModelDirective::FinalAnswer {
619 text: "done".to_string(),
620 }],
621 stop_reason: ModelStopReason::EndTurn,
622 },
623 ],
624 cursor: Mutex::new(0),
625 };
626
627 let mut tools = ToolRegistry::default();
628 tools.register(EchoTool);
629
630 let orchestrator = Orchestrator::new(
631 Arc::new(provider),
632 tools,
633 Vec::new(),
634 OrchestratorConfig { max_iterations: 4 },
635 );
636
637 let output = orchestrator.run(
638 RunInput {
639 run_id: "run-1".to_string(),
640 session_id: "session-1".to_string(),
641 messages: vec![ChatMessage::user("test")],
642 state: AppState::default(),
643 },
644 |_| {},
645 );
646
647 assert_eq!(output.reason, RunStopReason::Completed);
648 assert_eq!(output.final_answer.as_deref(), Some("done"));
649 assert_eq!(output.state.revision, 1);
650 assert_eq!(output.state.data["last_echo"], "hello");
651
652 assert!(output
653 .events
654 .iter()
655 .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. })));
656 assert!(output.events.iter().any(|event| matches!(
657 event,
658 AgentEvent::RunFinished {
659 reason: RunStopReason::Completed,
660 ..
661 }
662 )));
663 }
664
665 #[test]
666 fn provider_error_stops_run() {
667 struct FailProvider;
668 impl Provider for FailProvider {
669 fn name(&self) -> &str {
670 "fail"
671 }
672 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
673 Err(CoreError::Provider("connection refused".to_string()))
674 }
675 }
676
677 let orchestrator = Orchestrator::new(
678 Arc::new(FailProvider),
679 ToolRegistry::default(),
680 Vec::new(),
681 OrchestratorConfig { max_iterations: 4 },
682 );
683
684 let output = orchestrator.run(
685 RunInput {
686 run_id: "run-1".to_string(),
687 session_id: "s1".to_string(),
688 messages: vec![ChatMessage::user("test")],
689 state: AppState::default(),
690 },
691 |_| {},
692 );
693
694 assert_eq!(output.reason, RunStopReason::Error);
695 assert!(output
696 .events
697 .iter()
698 .any(|e| matches!(e, AgentEvent::RunErrored { .. })));
699 }
700
701 #[test]
702 fn tool_not_found_stops_run() {
703 let provider = ScriptedProvider {
704 turns: vec![ModelTurn {
705 directives: vec![ModelDirective::ToolCall {
706 call: ToolCall {
707 call_id: "c1".to_string(),
708 tool_name: "nonexistent".to_string(),
709 input: json!({}),
710 },
711 }],
712 stop_reason: ModelStopReason::ToolUse,
713 }],
714 cursor: Mutex::new(0),
715 };
716
717 let orchestrator = Orchestrator::new(
718 Arc::new(provider),
719 ToolRegistry::default(),
720 Vec::new(),
721 OrchestratorConfig { max_iterations: 4 },
722 );
723
724 let output = orchestrator.run(
725 RunInput {
726 run_id: "run-1".to_string(),
727 session_id: "s1".to_string(),
728 messages: vec![ChatMessage::user("test")],
729 state: AppState::default(),
730 },
731 |_| {},
732 );
733
734 assert_eq!(output.reason, RunStopReason::Error);
735 assert!(output
736 .events
737 .iter()
738 .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. })));
739 }
740
741 #[test]
742 fn middleware_blocks_model_call() {
743 struct BlockMiddleware;
744 impl Middleware for BlockMiddleware {
745 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
746 Err(CoreError::Middleware("blocked by policy".to_string()))
747 }
748 }
749
750 let provider = ScriptedProvider {
751 turns: vec![ModelTurn {
752 directives: vec![ModelDirective::Text {
753 delta: "hi".to_string(),
754 }],
755 stop_reason: ModelStopReason::EndTurn,
756 }],
757 cursor: Mutex::new(0),
758 };
759
760 let orchestrator = Orchestrator::new(
761 Arc::new(provider),
762 ToolRegistry::default(),
763 vec![Arc::new(BlockMiddleware)],
764 OrchestratorConfig { max_iterations: 4 },
765 );
766
767 let output = orchestrator.run(
768 RunInput {
769 run_id: "run-1".to_string(),
770 session_id: "s1".to_string(),
771 messages: vec![ChatMessage::user("test")],
772 state: AppState::default(),
773 },
774 |_| {},
775 );
776
777 assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
778 }
779
780 #[test]
781 fn budget_exceeded_when_iterations_exhausted() {
782 let provider = ScriptedProvider {
786 turns: vec![
787 ModelTurn {
788 directives: vec![ModelDirective::ToolCall {
789 call: ToolCall {
790 call_id: "c1".to_string(),
791 tool_name: "echo".to_string(),
792 input: json!({"value": "1"}),
793 },
794 }],
795 stop_reason: ModelStopReason::ToolUse,
796 },
797 ModelTurn {
798 directives: vec![ModelDirective::ToolCall {
799 call: ToolCall {
800 call_id: "c2".to_string(),
801 tool_name: "echo".to_string(),
802 input: json!({"value": "2"}),
803 },
804 }],
805 stop_reason: ModelStopReason::ToolUse,
806 },
807 ],
810 cursor: Mutex::new(0),
811 };
812
813 let mut tools = ToolRegistry::default();
814 tools.register(EchoTool);
815
816 let orchestrator = Orchestrator::new(
817 Arc::new(provider),
818 tools,
819 Vec::new(),
820 OrchestratorConfig { max_iterations: 2 },
821 );
822
823 let output = orchestrator.run(
824 RunInput {
825 run_id: "run-1".to_string(),
826 session_id: "s1".to_string(),
827 messages: vec![ChatMessage::user("test")],
828 state: AppState::default(),
829 },
830 |_| {},
831 );
832
833 assert_eq!(output.reason, RunStopReason::BudgetExceeded);
834 }
835
836 #[test]
837 fn text_only_response_completes() {
838 let provider = ScriptedProvider {
839 turns: vec![ModelTurn {
840 directives: vec![ModelDirective::Text {
841 delta: "Hello, world!".to_string(),
842 }],
843 stop_reason: ModelStopReason::EndTurn,
844 }],
845 cursor: Mutex::new(0),
846 };
847
848 let orchestrator = Orchestrator::new(
849 Arc::new(provider),
850 ToolRegistry::default(),
851 Vec::new(),
852 OrchestratorConfig { max_iterations: 4 },
853 );
854
855 let output = orchestrator.run(
856 RunInput {
857 run_id: "run-1".to_string(),
858 session_id: "s1".to_string(),
859 messages: vec![ChatMessage::user("hi")],
860 state: AppState::default(),
861 },
862 |_| {},
863 );
864
865 assert_eq!(output.reason, RunStopReason::Completed);
866 assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
867 }
868
869 #[test]
870 fn event_handler_receives_all_events() {
871 let provider = ScriptedProvider {
872 turns: vec![ModelTurn {
873 directives: vec![ModelDirective::FinalAnswer {
874 text: "done".to_string(),
875 }],
876 stop_reason: ModelStopReason::EndTurn,
877 }],
878 cursor: Mutex::new(0),
879 };
880
881 let orchestrator = Orchestrator::new(
882 Arc::new(provider),
883 ToolRegistry::default(),
884 Vec::new(),
885 OrchestratorConfig { max_iterations: 4 },
886 );
887
888 let received = Arc::new(Mutex::new(Vec::new()));
889 let received_clone = received.clone();
890
891 orchestrator.run(
892 RunInput {
893 run_id: "run-1".to_string(),
894 session_id: "s1".to_string(),
895 messages: vec![ChatMessage::user("test")],
896 state: AppState::default(),
897 },
898 move |event| {
899 received_clone.lock().unwrap().push(event);
900 },
901 );
902
903 let events = received.lock().unwrap();
904 assert!(events.len() >= 4); assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
906 assert!(matches!(
907 events.last().unwrap(),
908 AgentEvent::RunFinished { .. }
909 ));
910 }
911
912 #[test]
913 fn tool_result_includes_call_id() {
914 let provider = ScriptedProvider {
915 turns: vec![
916 ModelTurn {
917 directives: vec![ModelDirective::ToolCall {
918 call: ToolCall {
919 call_id: "my-call-id".to_string(),
920 tool_name: "echo".to_string(),
921 input: json!({"value": "test"}),
922 },
923 }],
924 stop_reason: ModelStopReason::ToolUse,
925 },
926 ModelTurn {
927 directives: vec![ModelDirective::FinalAnswer {
928 text: "ok".to_string(),
929 }],
930 stop_reason: ModelStopReason::EndTurn,
931 },
932 ],
933 cursor: Mutex::new(0),
934 };
935
936 let mut tools = ToolRegistry::default();
937 tools.register(EchoTool);
938
939 let orchestrator = Orchestrator::new(
940 Arc::new(provider),
941 tools,
942 Vec::new(),
943 OrchestratorConfig { max_iterations: 4 },
944 );
945
946 let output = orchestrator.run(
947 RunInput {
948 run_id: "run-1".to_string(),
949 session_id: "s1".to_string(),
950 messages: vec![ChatMessage::user("test")],
951 state: AppState::default(),
952 },
953 |_| {},
954 );
955
956 let tool_msg = output
958 .messages
959 .iter()
960 .find(|m| m.role == crate::protocol::Role::Tool)
961 .expect("should have tool message");
962 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
963 }
964}