1use crate::context::{ContextConfig, compact_messages};
2use crate::error::CoreError;
3use crate::protocol::{
4 AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
5 ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
6};
7use crate::state::AppState;
8use std::collections::BTreeMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12pub trait ApprovalGateHook: Send + Sync {
15 fn set_event_handler(&self, handler: Arc<dyn Fn(AgentEvent) + Send + Sync>);
16 fn clear_event_handler(&self);
17}
18
19pub trait ApprovalResolver: Send + Sync {
21 fn resolve_approval(&self, approval_id: &str, decision: &str, reason: Option<String>) -> bool;
22 fn pending_approval_ids(&self) -> Vec<String>;
23}
24
25#[derive(Debug, Clone)]
26pub struct ProviderRequest {
27 pub run_id: String,
28 pub session_id: String,
29 pub iteration: u32,
30 pub messages: Vec<ChatMessage>,
31 pub tools: Vec<ToolDefinition>,
32 pub state: AppState,
33}
34
35pub trait Provider: Send + Sync {
36 fn name(&self) -> &str;
37 fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
38
39 fn supports_streaming(&self) -> bool {
41 false
42 }
43
44 fn complete_streaming(
47 &self,
48 request: &ProviderRequest,
49 _on_text: &dyn Fn(&str),
50 ) -> Result<ModelTurn, CoreError> {
51 self.complete(request)
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct ToolContext {
57 pub run_id: String,
58 pub session_id: String,
59 pub iteration: u32,
60}
61
62pub trait Tool: Send + Sync {
63 fn definition(&self) -> ToolDefinition;
64 fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
65}
66
67pub trait Middleware: Send + Sync {
68 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
69 Ok(())
70 }
71
72 fn after_model_call(
73 &self,
74 _request: &ProviderRequest,
75 _response: &ModelTurn,
76 ) -> Result<(), CoreError> {
77 Ok(())
78 }
79
80 fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
81 Ok(())
82 }
83
84 fn post_tool_call(
85 &self,
86 _context: &ToolContext,
87 _result: &ToolResult,
88 ) -> Result<(), CoreError> {
89 Ok(())
90 }
91
92 fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
93 Ok(())
94 }
95}
96
97#[derive(Clone, Default)]
98pub struct ToolRegistry {
99 tools: BTreeMap<String, Arc<dyn Tool>>,
100}
101
102impl ToolRegistry {
103 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
104 self.tools
105 .insert(tool.definition().name.clone(), Arc::new(tool));
106 }
107
108 pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
109 self.tools.get(tool_name).cloned()
110 }
111
112 pub fn definitions(&self) -> Vec<ToolDefinition> {
113 self.tools.values().map(|tool| tool.definition()).collect()
114 }
115}
116
117#[derive(Debug, Clone)]
118pub struct OrchestratorConfig {
119 pub max_iterations: u32,
120 pub context: Option<ContextConfig>,
123 pub context_compiler: Option<crate::context_compiler::ContextCompilerConfig>,
126}
127
128impl Default for OrchestratorConfig {
129 fn default() -> Self {
130 Self {
131 max_iterations: 24,
132 context: Some(ContextConfig::default()),
133 context_compiler: None,
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
139pub struct RunInput {
140 pub run_id: String,
141 pub session_id: String,
142 pub branch_id: String,
143 pub messages: Vec<ChatMessage>,
144 pub state: AppState,
145}
146
147#[derive(Debug, Clone)]
148pub struct RunOutput {
149 pub run_id: String,
150 pub session_id: String,
151 pub branch_id: String,
152 pub events: Vec<AgentEvent>,
153 pub messages: Vec<ChatMessage>,
154 pub state: AppState,
155 pub reason: RunStopReason,
156 pub final_answer: Option<String>,
157 pub total_usage: TokenUsage,
159}
160
161pub struct Orchestrator {
162 provider: Arc<std::sync::RwLock<Arc<dyn Provider>>>,
163 tools: ToolRegistry,
164 middlewares: Vec<Arc<dyn Middleware>>,
165 config: OrchestratorConfig,
166}
167
168impl Orchestrator {
169 pub fn new(
170 provider: Arc<dyn Provider>,
171 tools: ToolRegistry,
172 middlewares: Vec<Arc<dyn Middleware>>,
173 config: OrchestratorConfig,
174 ) -> Self {
175 Self {
176 provider: Arc::new(std::sync::RwLock::new(provider)),
177 tools,
178 middlewares,
179 config,
180 }
181 }
182
183 pub fn swap_provider(&self, new_provider: Arc<dyn Provider>) -> String {
185 let name = new_provider.name().to_string();
186 let mut guard = self.provider.write().expect("provider lock poisoned");
187 *guard = new_provider;
188 name
189 }
190
191 pub fn provider_name(&self) -> String {
193 let guard = self.provider.read().expect("provider lock poisoned");
194 guard.name().to_string()
195 }
196
197 pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
198 self.run_cancellable(input, None, event_handler)
199 }
200
201 pub fn run_cancellable(
206 &self,
207 input: RunInput,
208 cancel: Option<&Arc<AtomicBool>>,
209 mut event_handler: impl FnMut(AgentEvent),
210 ) -> RunOutput {
211 let mut events = Vec::new();
212 let mut messages = input.messages;
213 let mut state = input.state;
214 let mut final_answer: Option<String> = None;
215 let mut stop_reason = RunStopReason::BudgetExceeded;
216 let mut total_iterations = 0;
217 let mut total_usage = TokenUsage::default();
218
219 let provider = self
221 .provider
222 .read()
223 .expect("provider lock poisoned")
224 .clone();
225
226 let start_event = AgentEvent::RunStarted {
227 run_id: input.run_id.clone(),
228 session_id: input.session_id.clone(),
229 provider: provider.name().to_string(),
230 max_iterations: self.config.max_iterations,
231 };
232 event_handler(start_event.clone());
233 events.push(start_event);
234
235 for iteration in 1..=self.config.max_iterations {
236 if let Some(flag) = cancel {
238 if flag.load(Ordering::Relaxed) {
239 stop_reason = RunStopReason::Cancelled;
240 let err_event = AgentEvent::RunErrored {
241 run_id: input.run_id.clone(),
242 session_id: input.session_id.clone(),
243 error: "run cancelled".to_string(),
244 };
245 event_handler(err_event.clone());
246 events.push(err_event);
247 break;
248 }
249 }
250
251 total_iterations = iteration;
252 let iter_event = AgentEvent::IterationStarted {
253 run_id: input.run_id.clone(),
254 session_id: input.session_id.clone(),
255 iteration,
256 };
257 event_handler(iter_event.clone());
258 events.push(iter_event);
259
260 if let Some(ref ctx_config) = self.config.context {
262 if let Some(result) = compact_messages(&messages, ctx_config) {
263 let compact_event = AgentEvent::ContextCompacted {
264 run_id: input.run_id.clone(),
265 session_id: input.session_id.clone(),
266 iteration,
267 dropped_count: result.dropped_count,
268 tokens_before: result.tokens_before,
269 tokens_after: result.tokens_after,
270 };
271 event_handler(compact_event.clone());
272 events.push(compact_event);
273 messages = result.messages;
274 }
275 }
276
277 let provider_request = ProviderRequest {
278 run_id: input.run_id.clone(),
279 session_id: input.session_id.clone(),
280 iteration,
281 messages: messages.clone(),
282 tools: self.tools.definitions(),
283 state: state.clone(),
284 };
285
286 if let Err(err) = self.run_before_model(&provider_request) {
287 stop_reason = RunStopReason::BlockedByPolicy;
288 let err_event = AgentEvent::RunErrored {
289 run_id: input.run_id.clone(),
290 session_id: input.session_id.clone(),
291 error: err.to_string(),
292 };
293 event_handler(err_event.clone());
294 events.push(err_event);
295 break;
296 }
297
298 let model_turn = match provider.complete(&provider_request) {
299 Ok(turn) => turn,
300 Err(err) => {
301 stop_reason = RunStopReason::Error;
302 let err_event = AgentEvent::RunErrored {
303 run_id: input.run_id.clone(),
304 session_id: input.session_id.clone(),
305 error: err.to_string(),
306 };
307 event_handler(err_event.clone());
308 events.push(err_event);
309 break;
310 }
311 };
312
313 if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
314 stop_reason = RunStopReason::BlockedByPolicy;
315 let err_event = AgentEvent::RunErrored {
316 run_id: input.run_id.clone(),
317 session_id: input.session_id.clone(),
318 error: err.to_string(),
319 };
320 event_handler(err_event.clone());
321 events.push(err_event);
322 break;
323 }
324
325 if let Some(ref usage) = model_turn.usage {
327 total_usage.accumulate(usage);
328 }
329
330 let output_event = AgentEvent::ModelOutput {
331 run_id: input.run_id.clone(),
332 session_id: input.session_id.clone(),
333 iteration,
334 stop_reason: model_turn.stop_reason,
335 directive_count: model_turn.directives.len(),
336 usage: model_turn.usage,
337 };
338 event_handler(output_event.clone());
339 events.push(output_event);
340
341 let mut requested_tool = false;
342
343 for directive in model_turn.directives {
344 match directive {
345 ModelDirective::Text { delta } => {
346 let delta_event = AgentEvent::TextDelta {
347 run_id: input.run_id.clone(),
348 session_id: input.session_id.clone(),
349 iteration,
350 delta: delta.clone(),
351 };
352 event_handler(delta_event.clone());
353 events.push(delta_event);
354 messages.push(ChatMessage::assistant(delta));
355 }
356 ModelDirective::ToolCall { call } => {
357 requested_tool = true;
358 let tc_event = AgentEvent::ToolCallRequested {
359 run_id: input.run_id.clone(),
360 session_id: input.session_id.clone(),
361 iteration,
362 call: call.clone(),
363 };
364 event_handler(tc_event.clone());
365 events.push(tc_event);
366
367 let context = ToolContext {
368 run_id: input.run_id.clone(),
369 session_id: input.session_id.clone(),
370 iteration,
371 };
372
373 if let Err(err) = self.run_pre_tool(&context, &call) {
374 stop_reason = RunStopReason::BlockedByPolicy;
375 let err_event = AgentEvent::ToolCallFailed {
376 run_id: input.run_id.clone(),
377 session_id: input.session_id.clone(),
378 iteration,
379 call_id: call.call_id.clone(),
380 tool_name: call.tool_name.clone(),
381 error: err.to_string(),
382 };
383 event_handler(err_event.clone());
384 events.push(err_event);
385 break;
386 }
387
388 let Some(tool) = self.tools.get(&call.tool_name) else {
389 stop_reason = RunStopReason::Error;
390 let err_event = AgentEvent::ToolCallFailed {
391 run_id: input.run_id.clone(),
392 session_id: input.session_id.clone(),
393 iteration,
394 call_id: call.call_id.clone(),
395 tool_name: call.tool_name.clone(),
396 error: format!(
397 "{}",
398 CoreError::ToolNotFound {
399 tool_name: call.tool_name.clone(),
400 }
401 ),
402 };
403 event_handler(err_event.clone());
404 events.push(err_event);
405 break;
406 };
407
408 match tool.execute(&call, &context) {
409 Ok(result) => {
410 if let Some(patch) = &result.state_patch {
411 match state.apply_patch(patch) {
412 Ok(()) => {
413 let patch_event = AgentEvent::StatePatched {
414 run_id: input.run_id.clone(),
415 session_id: input.session_id.clone(),
416 iteration,
417 patch: patch.clone(),
418 revision: state.revision,
419 };
420 event_handler(patch_event.clone());
421 events.push(patch_event);
422 }
423 Err(err) => {
424 stop_reason = RunStopReason::Error;
425 let err_event = AgentEvent::ToolCallFailed {
426 run_id: input.run_id.clone(),
427 session_id: input.session_id.clone(),
428 iteration,
429 call_id: call.call_id.clone(),
430 tool_name: call.tool_name.clone(),
431 error: err.to_string(),
432 };
433 event_handler(err_event.clone());
434 events.push(err_event);
435 break;
436 }
437 }
438 }
439
440 if let Err(err) = self.run_post_tool(&context, &result) {
441 stop_reason = RunStopReason::BlockedByPolicy;
442 let err_event = AgentEvent::ToolCallFailed {
443 run_id: input.run_id.clone(),
444 session_id: input.session_id.clone(),
445 iteration,
446 call_id: call.call_id.clone(),
447 tool_name: call.tool_name.clone(),
448 error: err.to_string(),
449 };
450 event_handler(err_event.clone());
451 events.push(err_event);
452 break;
453 }
454
455 let completed_event = AgentEvent::ToolCallCompleted {
456 run_id: input.run_id.clone(),
457 session_id: input.session_id.clone(),
458 iteration,
459 result: ToolResultSummary::from(&result),
460 };
461 event_handler(completed_event.clone());
462 events.push(completed_event);
463
464 messages.push(ChatMessage::tool_result(
465 &result.call_id,
466 serde_json::to_string(&result.output)
467 .unwrap_or_else(|_| "{}".to_string()),
468 ));
469 }
470 Err(err) => {
471 stop_reason = RunStopReason::Error;
472 let err_event = AgentEvent::ToolCallFailed {
473 run_id: input.run_id.clone(),
474 session_id: input.session_id.clone(),
475 iteration,
476 call_id: call.call_id.clone(),
477 tool_name: call.tool_name.clone(),
478 error: err.to_string(),
479 };
480 event_handler(err_event.clone());
481 events.push(err_event);
482 break;
483 }
484 }
485 }
486 ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
487 Ok(()) => {
488 let patch_event = AgentEvent::StatePatched {
489 run_id: input.run_id.clone(),
490 session_id: input.session_id.clone(),
491 iteration,
492 patch: patch.clone(),
493 revision: state.revision,
494 };
495 event_handler(patch_event.clone());
496 events.push(patch_event);
497 }
498 Err(err) => {
499 stop_reason = RunStopReason::Error;
500 let err_event = AgentEvent::RunErrored {
501 run_id: input.run_id.clone(),
502 session_id: input.session_id.clone(),
503 error: err.to_string(),
504 };
505 event_handler(err_event.clone());
506 events.push(err_event);
507 break;
508 }
509 },
510 ModelDirective::FinalAnswer { text } => {
511 final_answer = Some(text.clone());
512 let delta_event = AgentEvent::TextDelta {
513 run_id: input.run_id.clone(),
514 session_id: input.session_id.clone(),
515 iteration,
516 delta: text.clone(),
517 };
518 event_handler(delta_event.clone());
519 events.push(delta_event);
520 messages.push(ChatMessage::assistant(text));
521 }
522 }
523 }
524
525 if matches!(
526 stop_reason,
527 RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
528 ) {
529 break;
530 }
531
532 match model_turn.stop_reason {
533 ModelStopReason::EndTurn => {
534 stop_reason = RunStopReason::Completed;
535 break;
536 }
537 ModelStopReason::NeedsUser => {
538 stop_reason = RunStopReason::NeedsUser;
539 break;
540 }
541 ModelStopReason::Safety => {
542 stop_reason = RunStopReason::BlockedByPolicy;
543 break;
544 }
545 ModelStopReason::ToolUse => {
546 if !requested_tool {
547 stop_reason = RunStopReason::Error;
548 let err_event = AgentEvent::RunErrored {
549 run_id: input.run_id.clone(),
550 session_id: input.session_id.clone(),
551 error: "model requested tool_use stop reason without tool call"
552 .to_string(),
553 };
554 event_handler(err_event.clone());
555 events.push(err_event);
556 break;
557 }
558 }
559 ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
560 if !requested_tool {
561 stop_reason = RunStopReason::Error;
562 let err_event = AgentEvent::RunErrored {
563 run_id: input.run_id.clone(),
564 session_id: input.session_id.clone(),
565 error: "model returned non-terminal stop reason without tool call"
566 .to_string(),
567 };
568 event_handler(err_event.clone());
569 events.push(err_event);
570 break;
571 }
572 }
573 }
574 }
575
576 if total_iterations == self.config.max_iterations
577 && stop_reason == RunStopReason::BudgetExceeded
578 {
579 let err_event = AgentEvent::RunErrored {
580 run_id: input.run_id.clone(),
581 session_id: input.session_id.clone(),
582 error: "max iteration budget exceeded".to_string(),
583 };
584 event_handler(err_event.clone());
585 events.push(err_event);
586 }
587
588 let finished_event = AgentEvent::RunFinished {
589 run_id: input.run_id.clone(),
590 session_id: input.session_id.clone(),
591 reason: stop_reason,
592 total_iterations,
593 final_answer: final_answer.clone(),
594 };
595 event_handler(finished_event.clone());
596 events.push(finished_event);
597
598 let output = RunOutput {
599 run_id: input.run_id,
600 session_id: input.session_id,
601 branch_id: input.branch_id,
602 events,
603 messages,
604 state,
605 reason: stop_reason,
606 final_answer,
607 total_usage,
608 };
609
610 let _ = self
611 .middlewares
612 .iter()
613 .try_for_each(|middleware| middleware.on_run_finished(&output));
614
615 output
616 }
617
618 fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
619 self.middlewares
620 .iter()
621 .try_for_each(|middleware| middleware.before_model_call(request))
622 }
623
624 fn run_after_model(
625 &self,
626 request: &ProviderRequest,
627 response: &ModelTurn,
628 ) -> Result<(), CoreError> {
629 self.middlewares
630 .iter()
631 .try_for_each(|middleware| middleware.after_model_call(request, response))
632 }
633
634 fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
635 self.middlewares
636 .iter()
637 .try_for_each(|middleware| middleware.pre_tool_call(context, call))
638 }
639
640 fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
641 self.middlewares
642 .iter()
643 .try_for_each(|middleware| middleware.post_tool_call(context, result))
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::protocol::{
651 ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
652 };
653 use serde_json::json;
654 use std::sync::Mutex;
655
656 struct ScriptedProvider {
657 turns: Vec<ModelTurn>,
658 cursor: Mutex<usize>,
659 }
660
661 impl Provider for ScriptedProvider {
662 fn name(&self) -> &str {
663 "scripted"
664 }
665
666 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
667 let mut cursor = self
668 .cursor
669 .lock()
670 .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
671 let idx = *cursor;
672 let Some(turn) = self.turns.get(idx) else {
673 return Err(CoreError::Provider("no scripted turn left".to_string()));
674 };
675 *cursor += 1;
676 Ok(turn.clone())
677 }
678 }
679
680 struct EchoTool;
681
682 impl Tool for EchoTool {
683 fn definition(&self) -> ToolDefinition {
684 ToolDefinition {
685 name: "echo".to_string(),
686 description: "Echoes the provided value".to_string(),
687 input_schema: json!({
688 "type": "object",
689 "properties": { "value": { "type": "string" } },
690 "required": ["value"]
691 }),
692 title: None,
693 output_schema: None,
694 annotations: None,
695 category: None,
696 tags: Vec::new(),
697 timeout_secs: None,
698 }
699 }
700
701 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
702 let value = call.input.get("value").cloned().unwrap_or(json!(null));
703 Ok(ToolResult {
704 call_id: call.call_id.clone(),
705 tool_name: call.tool_name.clone(),
706 output: json!({ "echo": value.clone() }),
707 content: None,
708 is_error: false,
709 state_patch: Some(StatePatch {
710 format: StatePatchFormat::MergePatch,
711 patch: json!({ "last_echo": value }),
712 source: StatePatchSource::Tool,
713 }),
714 })
715 }
716 }
717
718 #[test]
719 fn orchestrator_runs_tool_then_finishes() {
720 let provider = ScriptedProvider {
721 turns: vec![
722 ModelTurn {
723 directives: vec![ModelDirective::ToolCall {
724 call: ToolCall {
725 call_id: "call-1".to_string(),
726 tool_name: "echo".to_string(),
727 input: json!({ "value": "hello" }),
728 },
729 }],
730 stop_reason: ModelStopReason::ToolUse,
731 usage: None,
732 },
733 ModelTurn {
734 directives: vec![ModelDirective::FinalAnswer {
735 text: "done".to_string(),
736 }],
737 stop_reason: ModelStopReason::EndTurn,
738 usage: None,
739 },
740 ],
741 cursor: Mutex::new(0),
742 };
743
744 let mut tools = ToolRegistry::default();
745 tools.register(EchoTool);
746
747 let orchestrator = Orchestrator::new(
748 Arc::new(provider),
749 tools,
750 Vec::new(),
751 OrchestratorConfig {
752 max_iterations: 4,
753 context: None,
754 context_compiler: None,
755 },
756 );
757
758 let output = orchestrator.run(
759 RunInput {
760 run_id: "run-1".to_string(),
761 session_id: "session-1".to_string(),
762 branch_id: "main".to_string(),
763 messages: vec![ChatMessage::user("test")],
764 state: AppState::default(),
765 },
766 |_| {},
767 );
768
769 assert_eq!(output.reason, RunStopReason::Completed);
770 assert_eq!(output.final_answer.as_deref(), Some("done"));
771 assert_eq!(output.state.revision, 1);
772 assert_eq!(output.state.data["last_echo"], "hello");
773
774 assert!(
775 output
776 .events
777 .iter()
778 .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. }))
779 );
780 assert!(output.events.iter().any(|event| matches!(
781 event,
782 AgentEvent::RunFinished {
783 reason: RunStopReason::Completed,
784 ..
785 }
786 )));
787 }
788
789 #[test]
790 fn provider_error_stops_run() {
791 struct FailProvider;
792 impl Provider for FailProvider {
793 fn name(&self) -> &str {
794 "fail"
795 }
796 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
797 Err(CoreError::Provider("connection refused".to_string()))
798 }
799 }
800
801 let orchestrator = Orchestrator::new(
802 Arc::new(FailProvider),
803 ToolRegistry::default(),
804 Vec::new(),
805 OrchestratorConfig {
806 max_iterations: 4,
807 context: None,
808 context_compiler: None,
809 },
810 );
811
812 let output = orchestrator.run(
813 RunInput {
814 run_id: "run-1".to_string(),
815 session_id: "s1".to_string(),
816 branch_id: "main".to_string(),
817 messages: vec![ChatMessage::user("test")],
818 state: AppState::default(),
819 },
820 |_| {},
821 );
822
823 assert_eq!(output.reason, RunStopReason::Error);
824 assert!(
825 output
826 .events
827 .iter()
828 .any(|e| matches!(e, AgentEvent::RunErrored { .. }))
829 );
830 }
831
832 #[test]
833 fn tool_not_found_stops_run() {
834 let provider = ScriptedProvider {
835 turns: vec![ModelTurn {
836 directives: vec![ModelDirective::ToolCall {
837 call: ToolCall {
838 call_id: "c1".to_string(),
839 tool_name: "nonexistent".to_string(),
840 input: json!({}),
841 },
842 }],
843 stop_reason: ModelStopReason::ToolUse,
844 usage: None,
845 }],
846 cursor: Mutex::new(0),
847 };
848
849 let orchestrator = Orchestrator::new(
850 Arc::new(provider),
851 ToolRegistry::default(),
852 Vec::new(),
853 OrchestratorConfig {
854 max_iterations: 4,
855 context: None,
856 context_compiler: None,
857 },
858 );
859
860 let output = orchestrator.run(
861 RunInput {
862 run_id: "run-1".to_string(),
863 session_id: "s1".to_string(),
864 branch_id: "main".to_string(),
865 messages: vec![ChatMessage::user("test")],
866 state: AppState::default(),
867 },
868 |_| {},
869 );
870
871 assert_eq!(output.reason, RunStopReason::Error);
872 assert!(
873 output
874 .events
875 .iter()
876 .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. }))
877 );
878 }
879
880 #[test]
881 fn middleware_blocks_model_call() {
882 struct BlockMiddleware;
883 impl Middleware for BlockMiddleware {
884 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
885 Err(CoreError::Middleware("blocked by policy".to_string()))
886 }
887 }
888
889 let provider = ScriptedProvider {
890 turns: vec![ModelTurn {
891 directives: vec![ModelDirective::Text {
892 delta: "hi".to_string(),
893 }],
894 stop_reason: ModelStopReason::EndTurn,
895 usage: None,
896 }],
897 cursor: Mutex::new(0),
898 };
899
900 let orchestrator = Orchestrator::new(
901 Arc::new(provider),
902 ToolRegistry::default(),
903 vec![Arc::new(BlockMiddleware)],
904 OrchestratorConfig {
905 max_iterations: 4,
906 context: None,
907 context_compiler: None,
908 },
909 );
910
911 let output = orchestrator.run(
912 RunInput {
913 run_id: "run-1".to_string(),
914 session_id: "s1".to_string(),
915 branch_id: "main".to_string(),
916 messages: vec![ChatMessage::user("test")],
917 state: AppState::default(),
918 },
919 |_| {},
920 );
921
922 assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
923 }
924
925 #[test]
926 fn budget_exceeded_when_iterations_exhausted() {
927 let provider = ScriptedProvider {
931 turns: vec![
932 ModelTurn {
933 directives: vec![ModelDirective::ToolCall {
934 call: ToolCall {
935 call_id: "c1".to_string(),
936 tool_name: "echo".to_string(),
937 input: json!({"value": "1"}),
938 },
939 }],
940 stop_reason: ModelStopReason::ToolUse,
941 usage: None,
942 },
943 ModelTurn {
944 directives: vec![ModelDirective::ToolCall {
945 call: ToolCall {
946 call_id: "c2".to_string(),
947 tool_name: "echo".to_string(),
948 input: json!({"value": "2"}),
949 },
950 }],
951 stop_reason: ModelStopReason::ToolUse,
952 usage: None,
953 },
954 ],
957 cursor: Mutex::new(0),
958 };
959
960 let mut tools = ToolRegistry::default();
961 tools.register(EchoTool);
962
963 let orchestrator = Orchestrator::new(
964 Arc::new(provider),
965 tools,
966 Vec::new(),
967 OrchestratorConfig {
968 max_iterations: 2,
969 context: None,
970 context_compiler: None,
971 },
972 );
973
974 let output = orchestrator.run(
975 RunInput {
976 run_id: "run-1".to_string(),
977 session_id: "s1".to_string(),
978 branch_id: "main".to_string(),
979 messages: vec![ChatMessage::user("test")],
980 state: AppState::default(),
981 },
982 |_| {},
983 );
984
985 assert_eq!(output.reason, RunStopReason::BudgetExceeded);
986 }
987
988 #[test]
989 fn text_only_response_completes() {
990 let provider = ScriptedProvider {
991 turns: vec![ModelTurn {
992 directives: vec![ModelDirective::Text {
993 delta: "Hello, world!".to_string(),
994 }],
995 stop_reason: ModelStopReason::EndTurn,
996 usage: None,
997 }],
998 cursor: Mutex::new(0),
999 };
1000
1001 let orchestrator = Orchestrator::new(
1002 Arc::new(provider),
1003 ToolRegistry::default(),
1004 Vec::new(),
1005 OrchestratorConfig {
1006 max_iterations: 4,
1007 context: None,
1008 context_compiler: None,
1009 },
1010 );
1011
1012 let output = orchestrator.run(
1013 RunInput {
1014 run_id: "run-1".to_string(),
1015 session_id: "s1".to_string(),
1016 branch_id: "main".to_string(),
1017 messages: vec![ChatMessage::user("hi")],
1018 state: AppState::default(),
1019 },
1020 |_| {},
1021 );
1022
1023 assert_eq!(output.reason, RunStopReason::Completed);
1024 assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
1025 }
1026
1027 #[test]
1028 fn event_handler_receives_all_events() {
1029 let provider = ScriptedProvider {
1030 turns: vec![ModelTurn {
1031 directives: vec![ModelDirective::FinalAnswer {
1032 text: "done".to_string(),
1033 }],
1034 stop_reason: ModelStopReason::EndTurn,
1035 usage: None,
1036 }],
1037 cursor: Mutex::new(0),
1038 };
1039
1040 let orchestrator = Orchestrator::new(
1041 Arc::new(provider),
1042 ToolRegistry::default(),
1043 Vec::new(),
1044 OrchestratorConfig {
1045 max_iterations: 4,
1046 context: None,
1047 context_compiler: None,
1048 },
1049 );
1050
1051 let received = Arc::new(Mutex::new(Vec::new()));
1052 let received_clone = received.clone();
1053
1054 orchestrator.run(
1055 RunInput {
1056 run_id: "run-1".to_string(),
1057 session_id: "s1".to_string(),
1058 branch_id: "main".to_string(),
1059 messages: vec![ChatMessage::user("test")],
1060 state: AppState::default(),
1061 },
1062 move |event| {
1063 received_clone.lock().unwrap().push(event);
1064 },
1065 );
1066
1067 let events = received.lock().unwrap();
1068 assert!(events.len() >= 4); assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
1070 assert!(matches!(
1071 events.last().unwrap(),
1072 AgentEvent::RunFinished { .. }
1073 ));
1074 }
1075
1076 #[test]
1077 fn tool_result_includes_call_id() {
1078 let provider = ScriptedProvider {
1079 turns: vec![
1080 ModelTurn {
1081 directives: vec![ModelDirective::ToolCall {
1082 call: ToolCall {
1083 call_id: "my-call-id".to_string(),
1084 tool_name: "echo".to_string(),
1085 input: json!({"value": "test"}),
1086 },
1087 }],
1088 stop_reason: ModelStopReason::ToolUse,
1089 usage: None,
1090 },
1091 ModelTurn {
1092 directives: vec![ModelDirective::FinalAnswer {
1093 text: "ok".to_string(),
1094 }],
1095 stop_reason: ModelStopReason::EndTurn,
1096 usage: None,
1097 },
1098 ],
1099 cursor: Mutex::new(0),
1100 };
1101
1102 let mut tools = ToolRegistry::default();
1103 tools.register(EchoTool);
1104
1105 let orchestrator = Orchestrator::new(
1106 Arc::new(provider),
1107 tools,
1108 Vec::new(),
1109 OrchestratorConfig {
1110 max_iterations: 4,
1111 context: None,
1112 context_compiler: None,
1113 },
1114 );
1115
1116 let output = orchestrator.run(
1117 RunInput {
1118 run_id: "run-1".to_string(),
1119 session_id: "s1".to_string(),
1120 branch_id: "main".to_string(),
1121 messages: vec![ChatMessage::user("test")],
1122 state: AppState::default(),
1123 },
1124 |_| {},
1125 );
1126
1127 let tool_msg = output
1129 .messages
1130 .iter()
1131 .find(|m| m.role == crate::protocol::Role::Tool)
1132 .expect("should have tool message");
1133 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1134 }
1135
1136 #[test]
1137 fn cancellation_stops_run() {
1138 let provider = ScriptedProvider {
1139 turns: vec![
1140 ModelTurn {
1141 directives: vec![ModelDirective::ToolCall {
1142 call: ToolCall {
1143 call_id: "c1".to_string(),
1144 tool_name: "echo".to_string(),
1145 input: json!({"value": "1"}),
1146 },
1147 }],
1148 stop_reason: ModelStopReason::ToolUse,
1149 usage: None,
1150 },
1151 ModelTurn {
1152 directives: vec![ModelDirective::FinalAnswer {
1153 text: "should not reach".to_string(),
1154 }],
1155 stop_reason: ModelStopReason::EndTurn,
1156 usage: None,
1157 },
1158 ],
1159 cursor: Mutex::new(0),
1160 };
1161
1162 let mut tools = ToolRegistry::default();
1163 tools.register(EchoTool);
1164
1165 let orchestrator = Orchestrator::new(
1166 Arc::new(provider),
1167 tools,
1168 Vec::new(),
1169 OrchestratorConfig {
1170 max_iterations: 10,
1171 context: None,
1172 context_compiler: None,
1173 },
1174 );
1175
1176 let cancel = Arc::new(AtomicBool::new(false));
1178 let cancel_clone = cancel.clone();
1179 let call_count = Arc::new(Mutex::new(0u32));
1180 let call_count_clone = call_count.clone();
1181
1182 let output = orchestrator.run_cancellable(
1183 RunInput {
1184 run_id: "run-1".to_string(),
1185 session_id: "s1".to_string(),
1186 branch_id: "main".to_string(),
1187 messages: vec![ChatMessage::user("test")],
1188 state: AppState::default(),
1189 },
1190 Some(&cancel_clone),
1191 move |event| {
1192 if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1194 let mut count = call_count_clone.lock().unwrap();
1195 *count += 1;
1196 if *count >= 1 {
1197 cancel.store(true, Ordering::Relaxed);
1198 }
1199 }
1200 },
1201 );
1202
1203 assert_eq!(output.reason, RunStopReason::Cancelled);
1204 assert!(output.final_answer.is_none());
1206 }
1207
1208 #[test]
1209 fn token_usage_accumulated() {
1210 let provider = ScriptedProvider {
1211 turns: vec![
1212 ModelTurn {
1213 directives: vec![ModelDirective::ToolCall {
1214 call: ToolCall {
1215 call_id: "c1".to_string(),
1216 tool_name: "echo".to_string(),
1217 input: json!({"value": "hi"}),
1218 },
1219 }],
1220 stop_reason: ModelStopReason::ToolUse,
1221 usage: Some(TokenUsage {
1222 input_tokens: 100,
1223 output_tokens: 50,
1224 cache_read_tokens: 0,
1225 cache_creation_tokens: 0,
1226 }),
1227 },
1228 ModelTurn {
1229 directives: vec![ModelDirective::FinalAnswer {
1230 text: "done".to_string(),
1231 }],
1232 stop_reason: ModelStopReason::EndTurn,
1233 usage: Some(TokenUsage {
1234 input_tokens: 200,
1235 output_tokens: 30,
1236 cache_read_tokens: 0,
1237 cache_creation_tokens: 0,
1238 }),
1239 },
1240 ],
1241 cursor: Mutex::new(0),
1242 };
1243
1244 let mut tools = ToolRegistry::default();
1245 tools.register(EchoTool);
1246
1247 let orchestrator = Orchestrator::new(
1248 Arc::new(provider),
1249 tools,
1250 Vec::new(),
1251 OrchestratorConfig {
1252 max_iterations: 4,
1253 context: None,
1254 context_compiler: None,
1255 },
1256 );
1257
1258 let output = orchestrator.run(
1259 RunInput {
1260 run_id: "run-1".to_string(),
1261 session_id: "s1".to_string(),
1262 branch_id: "main".to_string(),
1263 messages: vec![ChatMessage::user("test")],
1264 state: AppState::default(),
1265 },
1266 |_| {},
1267 );
1268
1269 assert_eq!(output.reason, RunStopReason::Completed);
1270 assert_eq!(output.total_usage.input_tokens, 300);
1271 assert_eq!(output.total_usage.output_tokens, 80);
1272 assert_eq!(output.total_usage.total(), 380);
1273 }
1274}