1use crate::context::{ContextConfig, compact_messages};
2use crate::error::CoreError;
3use crate::lifecycle::LifecycleHook;
4use crate::protocol::{
5 AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
6 ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
7};
8use crate::state::AppState;
9use std::collections::BTreeMap;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, Ordering};
12
13#[derive(Debug, Clone)]
15pub enum StreamEvent<'a> {
16 Text(&'a str),
18 Reasoning(&'a str),
20}
21
22pub type SwappableProviderHandle = Arc<std::sync::RwLock<Arc<dyn Provider>>>;
28
29pub trait ProviderFactory: Send + Sync {
33 fn build(&self, spec: &str) -> Result<Arc<dyn Provider>, CoreError>;
35
36 fn available_providers(&self) -> Vec<String>;
38}
39
40pub trait ApprovalGateHook: Send + Sync {
43 fn set_event_handler(&self, handler: Arc<dyn Fn(AgentEvent) + Send + Sync>);
44 fn clear_event_handler(&self);
45}
46
47pub trait ApprovalResolver: Send + Sync {
49 fn resolve_approval(&self, approval_id: &str, decision: &str, reason: Option<String>) -> bool;
50 fn pending_approval_ids(&self) -> Vec<String>;
51}
52
53#[derive(Debug, Clone)]
54pub struct ProviderRequest {
55 pub run_id: String,
56 pub session_id: String,
57 pub iteration: u32,
58 pub messages: Vec<ChatMessage>,
59 pub tools: Vec<ToolDefinition>,
60 pub state: AppState,
61}
62
63pub trait Provider: Send + Sync {
64 fn name(&self) -> &str;
65 fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
66
67 fn supports_streaming(&self) -> bool {
69 false
70 }
71
72 fn complete_streaming(
75 &self,
76 request: &ProviderRequest,
77 _on_delta: &dyn Fn(StreamEvent<'_>),
78 ) -> Result<ModelTurn, CoreError> {
79 self.complete(request)
80 }
81
82 fn context_window(&self) -> Option<u32> {
87 None
88 }
89}
90
91#[derive(Debug, Clone)]
92pub struct ToolContext {
93 pub run_id: String,
94 pub session_id: String,
95 pub iteration: u32,
96}
97
98pub trait Tool: Send + Sync {
99 fn definition(&self) -> ToolDefinition;
100 fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
101}
102
103pub trait Middleware: Send + Sync {
104 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
105 Ok(())
106 }
107
108 fn after_model_call(
109 &self,
110 _request: &ProviderRequest,
111 _response: &ModelTurn,
112 ) -> Result<(), CoreError> {
113 Ok(())
114 }
115
116 fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
117 Ok(())
118 }
119
120 fn post_tool_call(
121 &self,
122 _context: &ToolContext,
123 _result: &ToolResult,
124 ) -> Result<(), CoreError> {
125 Ok(())
126 }
127
128 fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
129 Ok(())
130 }
131}
132
133pub trait TurnMiddleware: Send + Sync {
139 fn before_model_call(&self, _request: &mut ProviderRequest) -> Result<(), CoreError> {
140 Ok(())
141 }
142
143 fn after_model_call(
144 &self,
145 _request: &ProviderRequest,
146 _response: &mut ModelTurn,
147 ) -> Result<(), CoreError> {
148 Ok(())
149 }
150
151 fn pre_tool_call(&self, _context: &ToolContext, _call: &mut ToolCall) -> Result<(), CoreError> {
152 Ok(())
153 }
154
155 fn post_tool_call(
156 &self,
157 _context: &ToolContext,
158 _result: &mut ToolResult,
159 ) -> Result<(), CoreError> {
160 Ok(())
161 }
162
163 fn on_run_finished(&self, _output: &mut RunOutput) -> Result<(), CoreError> {
164 Ok(())
165 }
166}
167
168struct LegacyMiddlewareAdapter {
169 inner: Arc<dyn Middleware>,
170}
171
172impl LegacyMiddlewareAdapter {
173 fn new(inner: Arc<dyn Middleware>) -> Self {
174 Self { inner }
175 }
176}
177
178impl TurnMiddleware for LegacyMiddlewareAdapter {
179 fn before_model_call(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
180 self.inner.before_model_call(request)
181 }
182
183 fn after_model_call(
184 &self,
185 request: &ProviderRequest,
186 response: &mut ModelTurn,
187 ) -> Result<(), CoreError> {
188 self.inner.after_model_call(request, response)
189 }
190
191 fn pre_tool_call(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
192 self.inner.pre_tool_call(context, call)
193 }
194
195 fn post_tool_call(
196 &self,
197 context: &ToolContext,
198 result: &mut ToolResult,
199 ) -> Result<(), CoreError> {
200 self.inner.post_tool_call(context, result)
201 }
202
203 fn on_run_finished(&self, output: &mut RunOutput) -> Result<(), CoreError> {
204 self.inner.on_run_finished(output)
205 }
206}
207
208#[derive(Clone, Default)]
209pub struct ToolRegistry {
210 tools: BTreeMap<String, Arc<dyn Tool>>,
211}
212
213impl ToolRegistry {
214 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
215 self.tools
216 .insert(tool.definition().name.clone(), Arc::new(tool));
217 }
218
219 pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
220 self.tools.get(tool_name).cloned()
221 }
222
223 pub fn definitions(&self) -> Vec<ToolDefinition> {
224 self.tools.values().map(|tool| tool.definition()).collect()
225 }
226}
227
228#[derive(Debug, Clone)]
229pub struct OrchestratorConfig {
230 pub max_iterations: u32,
231 pub context: Option<ContextConfig>,
234 pub context_compiler: Option<crate::context_compiler::ContextCompilerConfig>,
237}
238
239impl Default for OrchestratorConfig {
240 fn default() -> Self {
241 Self {
242 max_iterations: 24,
243 context: Some(ContextConfig::default()),
244 context_compiler: None,
245 }
246 }
247}
248
249#[derive(Debug, Clone)]
250pub struct RunInput {
251 pub run_id: String,
252 pub session_id: String,
253 pub branch_id: String,
254 pub messages: Vec<ChatMessage>,
255 pub state: AppState,
256}
257
258#[derive(Debug, Clone)]
259pub struct RunOutput {
260 pub run_id: String,
261 pub session_id: String,
262 pub branch_id: String,
263 pub events: Vec<AgentEvent>,
264 pub messages: Vec<ChatMessage>,
265 pub state: AppState,
266 pub reason: RunStopReason,
267 pub final_answer: Option<String>,
268 pub total_usage: TokenUsage,
270}
271
272pub struct Orchestrator {
273 provider: Arc<std::sync::RwLock<Arc<dyn Provider>>>,
274 tools: ToolRegistry,
275 turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
276 lifecycle_hooks: Vec<Arc<dyn LifecycleHook>>,
277 config: OrchestratorConfig,
278}
279
280impl Orchestrator {
281 pub fn new(
282 provider: Arc<dyn Provider>,
283 tools: ToolRegistry,
284 middlewares: Vec<Arc<dyn Middleware>>,
285 config: OrchestratorConfig,
286 ) -> Self {
287 Self {
288 provider: Arc::new(std::sync::RwLock::new(provider)),
289 tools,
290 turn_middlewares: middlewares
291 .into_iter()
292 .map(|middleware| {
293 Arc::new(LegacyMiddlewareAdapter::new(middleware)) as Arc<dyn TurnMiddleware>
294 })
295 .collect(),
296 lifecycle_hooks: Vec::new(),
297 config,
298 }
299 }
300
301 pub fn with_turn_middlewares(
302 provider: Arc<dyn Provider>,
303 tools: ToolRegistry,
304 turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
305 config: OrchestratorConfig,
306 ) -> Self {
307 Self {
308 provider: Arc::new(std::sync::RwLock::new(provider)),
309 tools,
310 turn_middlewares,
311 lifecycle_hooks: Vec::new(),
312 config,
313 }
314 }
315
316 pub fn with_lifecycle_hooks(mut self, hooks: Vec<Arc<dyn LifecycleHook>>) -> Self {
323 self.lifecycle_hooks = hooks;
324 self
325 }
326
327 pub fn add_lifecycle_hook(&mut self, hook: Arc<dyn LifecycleHook>) {
329 self.lifecycle_hooks.push(hook);
330 }
331
332 pub fn swap_provider(&self, new_provider: Arc<dyn Provider>) -> Result<String, CoreError> {
334 let name = new_provider.name().to_string();
335 let mut guard = self
336 .provider
337 .write()
338 .map_err(|e| CoreError::LockPoisoned(format!("provider write lock: {e}")))?;
339 *guard = new_provider;
340 Ok(name)
341 }
342
343 pub fn provider_name(&self) -> Result<String, CoreError> {
345 let guard = self
346 .provider
347 .read()
348 .map_err(|e| CoreError::LockPoisoned(format!("provider read lock: {e}")))?;
349 Ok(guard.name().to_string())
350 }
351
352 pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
353 self.run_cancellable(input, None, event_handler)
354 }
355
356 pub fn run_cancellable(
361 &self,
362 input: RunInput,
363 cancel: Option<&Arc<AtomicBool>>,
364 mut event_handler: impl FnMut(AgentEvent),
365 ) -> RunOutput {
366 let mut events = Vec::new();
367 let mut messages = input.messages;
368 let mut state = input.state;
369 let mut final_answer: Option<String> = None;
370 let mut stop_reason = RunStopReason::BudgetExceeded;
371 let mut total_iterations = 0;
372 let mut total_usage = TokenUsage::default();
373
374 let provider = match self.provider.read() {
376 Ok(guard) => guard.clone(),
377 Err(e) => {
378 let err_event = AgentEvent::RunErrored {
379 run_id: input.run_id.clone(),
380 session_id: input.session_id.clone(),
381 error: format!("provider lock poisoned: {e}"),
382 };
383 event_handler(err_event.clone());
384 return RunOutput {
385 run_id: input.run_id,
386 session_id: input.session_id,
387 branch_id: input.branch_id,
388 events: vec![err_event],
389 final_answer: None,
390 messages,
391 state,
392 reason: RunStopReason::Error,
393 total_usage: TokenUsage::default(),
394 };
395 }
396 };
397
398 let start_event = AgentEvent::RunStarted {
399 run_id: input.run_id.clone(),
400 session_id: input.session_id.clone(),
401 provider: provider.name().to_string(),
402 max_iterations: self.config.max_iterations,
403 };
404 event_handler(start_event.clone());
405 events.push(start_event);
406
407 for hook in &self.lifecycle_hooks {
409 hook.on_session_start(&input.session_id);
410 }
411
412 for iteration in 1..=self.config.max_iterations {
413 if let Some(flag) = cancel
415 && flag.load(Ordering::Relaxed)
416 {
417 stop_reason = RunStopReason::Cancelled;
418 let err_event = AgentEvent::RunErrored {
419 run_id: input.run_id.clone(),
420 session_id: input.session_id.clone(),
421 error: "run cancelled".to_string(),
422 };
423 event_handler(err_event.clone());
424 events.push(err_event);
425 break;
426 }
427
428 total_iterations = iteration;
429 let iter_event = AgentEvent::IterationStarted {
430 run_id: input.run_id.clone(),
431 session_id: input.session_id.clone(),
432 iteration,
433 };
434 event_handler(iter_event.clone());
435 events.push(iter_event);
436
437 if let Some(ref ctx_config) = self.config.context
439 && let Some(result) = compact_messages(&messages, ctx_config)
440 {
441 let compact_event = AgentEvent::ContextCompacted {
442 run_id: input.run_id.clone(),
443 session_id: input.session_id.clone(),
444 iteration,
445 dropped_count: result.dropped_count,
446 tokens_before: result.tokens_before,
447 tokens_after: result.tokens_after,
448 };
449 event_handler(compact_event.clone());
450 events.push(compact_event);
451 messages = result.messages;
452 }
453
454 let mut provider_request = ProviderRequest {
455 run_id: input.run_id.clone(),
456 session_id: input.session_id.clone(),
457 iteration,
458 messages: messages.clone(),
459 tools: self.tools.definitions(),
460 state: state.clone(),
461 };
462
463 if let Err(err) = self.run_before_model(&mut provider_request) {
464 stop_reason = RunStopReason::BlockedByPolicy;
465 let err_event = AgentEvent::RunErrored {
466 run_id: input.run_id.clone(),
467 session_id: input.session_id.clone(),
468 error: err.to_string(),
469 };
470 event_handler(err_event.clone());
471 events.push(err_event);
472 break;
473 }
474
475 for hook in &self.lifecycle_hooks {
477 hook.pre_llm_call(&provider_request);
478 }
479
480 let mut model_turn = match provider.complete(&provider_request) {
481 Ok(turn) => turn,
482 Err(err) => {
483 stop_reason = RunStopReason::Error;
484 let err_event = AgentEvent::RunErrored {
485 run_id: input.run_id.clone(),
486 session_id: input.session_id.clone(),
487 error: err.to_string(),
488 };
489 event_handler(err_event.clone());
490 events.push(err_event);
491 break;
492 }
493 };
494
495 for hook in &self.lifecycle_hooks {
497 hook.post_llm_call(&provider_request);
498 }
499
500 if let Err(err) = self.run_after_model(&provider_request, &mut model_turn) {
501 stop_reason = RunStopReason::BlockedByPolicy;
502 let err_event = AgentEvent::RunErrored {
503 run_id: input.run_id.clone(),
504 session_id: input.session_id.clone(),
505 error: err.to_string(),
506 };
507 event_handler(err_event.clone());
508 events.push(err_event);
509 break;
510 }
511
512 if let Some(ref usage) = model_turn.usage {
514 total_usage.accumulate(usage);
515 }
516
517 let output_event = AgentEvent::ModelOutput {
518 run_id: input.run_id.clone(),
519 session_id: input.session_id.clone(),
520 iteration,
521 stop_reason: model_turn.stop_reason,
522 directive_count: model_turn.directives.len(),
523 usage: model_turn.usage,
524 };
525 event_handler(output_event.clone());
526 events.push(output_event);
527
528 let mut requested_tool = false;
529
530 for directive in model_turn.directives {
531 match directive {
532 ModelDirective::Text { delta } => {
533 let delta_event = AgentEvent::TextDelta {
534 run_id: input.run_id.clone(),
535 session_id: input.session_id.clone(),
536 iteration,
537 delta: delta.clone(),
538 };
539 event_handler(delta_event.clone());
540 events.push(delta_event);
541 messages.push(ChatMessage::assistant(delta));
542 }
543 ModelDirective::ToolCall { mut call } => {
544 requested_tool = true;
545 let tc_event = AgentEvent::ToolCallRequested {
546 run_id: input.run_id.clone(),
547 session_id: input.session_id.clone(),
548 iteration,
549 call: call.clone(),
550 };
551 event_handler(tc_event.clone());
552 events.push(tc_event);
553
554 let context = ToolContext {
555 run_id: input.run_id.clone(),
556 session_id: input.session_id.clone(),
557 iteration,
558 };
559
560 if let Err(err) = self.run_pre_tool(&context, &mut call) {
561 stop_reason = RunStopReason::BlockedByPolicy;
562 let err_event = AgentEvent::ToolCallFailed {
563 run_id: input.run_id.clone(),
564 session_id: input.session_id.clone(),
565 iteration,
566 call_id: call.call_id.clone(),
567 tool_name: call.tool_name.clone(),
568 error: err.to_string(),
569 };
570 event_handler(err_event.clone());
571 events.push(err_event);
572 break;
573 }
574
575 let Some(tool) = self.tools.get(&call.tool_name) else {
576 stop_reason = RunStopReason::Error;
577 let err_event = AgentEvent::ToolCallFailed {
578 run_id: input.run_id.clone(),
579 session_id: input.session_id.clone(),
580 iteration,
581 call_id: call.call_id.clone(),
582 tool_name: call.tool_name.clone(),
583 error: format!(
584 "{}",
585 CoreError::ToolNotFound {
586 tool_name: call.tool_name.clone(),
587 }
588 ),
589 };
590 event_handler(err_event.clone());
591 events.push(err_event);
592 break;
593 };
594
595 for hook in &self.lifecycle_hooks {
597 hook.pre_tool_call(&call.tool_name, &call.input);
598 }
599
600 match tool.execute(&call, &context) {
601 Ok(mut result) => {
602 if let Err(err) = self.run_post_tool(&context, &mut result) {
603 stop_reason = RunStopReason::BlockedByPolicy;
604 let err_event = AgentEvent::ToolCallFailed {
605 run_id: input.run_id.clone(),
606 session_id: input.session_id.clone(),
607 iteration,
608 call_id: call.call_id.clone(),
609 tool_name: call.tool_name.clone(),
610 error: err.to_string(),
611 };
612 event_handler(err_event.clone());
613 events.push(err_event);
614 break;
615 }
616
617 if let Some(patch) = &result.state_patch {
618 match state.apply_patch(patch) {
619 Ok(()) => {
620 let patch_event = AgentEvent::StatePatched {
621 run_id: input.run_id.clone(),
622 session_id: input.session_id.clone(),
623 iteration,
624 patch: patch.clone(),
625 revision: state.revision,
626 };
627 event_handler(patch_event.clone());
628 events.push(patch_event);
629 }
630 Err(err) => {
631 stop_reason = RunStopReason::Error;
632 let err_event = AgentEvent::ToolCallFailed {
633 run_id: input.run_id.clone(),
634 session_id: input.session_id.clone(),
635 iteration,
636 call_id: call.call_id.clone(),
637 tool_name: call.tool_name.clone(),
638 error: err.to_string(),
639 };
640 event_handler(err_event.clone());
641 events.push(err_event);
642 break;
643 }
644 }
645 }
646
647 let result_str = serde_json::to_string(&result.output)
649 .unwrap_or_else(|_| "{}".to_string());
650 for hook in &self.lifecycle_hooks {
651 hook.post_tool_call(&call.tool_name, &result_str);
652 }
653
654 let completed_event = AgentEvent::ToolCallCompleted {
655 run_id: input.run_id.clone(),
656 session_id: input.session_id.clone(),
657 iteration,
658 result: ToolResultSummary::from(&result),
659 };
660 event_handler(completed_event.clone());
661 events.push(completed_event);
662
663 messages.push(ChatMessage::tool_result(
664 &result.call_id,
665 serde_json::to_string(&result.output)
666 .unwrap_or_else(|_| "{}".to_string()),
667 ));
668 }
669 Err(err) => {
670 stop_reason = RunStopReason::Error;
671 let err_event = AgentEvent::ToolCallFailed {
672 run_id: input.run_id.clone(),
673 session_id: input.session_id.clone(),
674 iteration,
675 call_id: call.call_id.clone(),
676 tool_name: call.tool_name.clone(),
677 error: err.to_string(),
678 };
679 event_handler(err_event.clone());
680 events.push(err_event);
681 break;
682 }
683 }
684 }
685 ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
686 Ok(()) => {
687 let patch_event = AgentEvent::StatePatched {
688 run_id: input.run_id.clone(),
689 session_id: input.session_id.clone(),
690 iteration,
691 patch: patch.clone(),
692 revision: state.revision,
693 };
694 event_handler(patch_event.clone());
695 events.push(patch_event);
696 }
697 Err(err) => {
698 stop_reason = RunStopReason::Error;
699 let err_event = AgentEvent::RunErrored {
700 run_id: input.run_id.clone(),
701 session_id: input.session_id.clone(),
702 error: err.to_string(),
703 };
704 event_handler(err_event.clone());
705 events.push(err_event);
706 break;
707 }
708 },
709 ModelDirective::FinalAnswer { text } => {
710 final_answer = Some(text.clone());
711 let delta_event = AgentEvent::TextDelta {
712 run_id: input.run_id.clone(),
713 session_id: input.session_id.clone(),
714 iteration,
715 delta: text.clone(),
716 };
717 event_handler(delta_event.clone());
718 events.push(delta_event);
719 messages.push(ChatMessage::assistant(text));
720 }
721 }
722 }
723
724 if matches!(
725 stop_reason,
726 RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
727 ) {
728 break;
729 }
730
731 match model_turn.stop_reason {
732 ModelStopReason::EndTurn => {
733 stop_reason = RunStopReason::Completed;
734 break;
735 }
736 ModelStopReason::NeedsUser => {
737 stop_reason = RunStopReason::NeedsUser;
738 break;
739 }
740 ModelStopReason::Safety => {
741 stop_reason = RunStopReason::BlockedByPolicy;
742 break;
743 }
744 ModelStopReason::ToolUse => {
745 if !requested_tool {
746 stop_reason = RunStopReason::Error;
747 let err_event = AgentEvent::RunErrored {
748 run_id: input.run_id.clone(),
749 session_id: input.session_id.clone(),
750 error: "model requested tool_use stop reason without tool call"
751 .to_string(),
752 };
753 event_handler(err_event.clone());
754 events.push(err_event);
755 break;
756 }
757 }
758 ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
759 if !requested_tool {
760 stop_reason = RunStopReason::Error;
761 let err_event = AgentEvent::RunErrored {
762 run_id: input.run_id.clone(),
763 session_id: input.session_id.clone(),
764 error: "model returned non-terminal stop reason without tool call"
765 .to_string(),
766 };
767 event_handler(err_event.clone());
768 events.push(err_event);
769 break;
770 }
771 }
772 }
773 }
774
775 if total_iterations == self.config.max_iterations
776 && stop_reason == RunStopReason::BudgetExceeded
777 {
778 let err_event = AgentEvent::RunErrored {
779 run_id: input.run_id.clone(),
780 session_id: input.session_id.clone(),
781 error: "max iteration budget exceeded".to_string(),
782 };
783 event_handler(err_event.clone());
784 events.push(err_event);
785 }
786
787 let finished_event = AgentEvent::RunFinished {
788 run_id: input.run_id.clone(),
789 session_id: input.session_id.clone(),
790 reason: stop_reason,
791 total_iterations,
792 final_answer: final_answer.clone(),
793 usage: if total_usage.total() > 0 {
794 Some(total_usage)
795 } else {
796 None
797 },
798 };
799 event_handler(finished_event.clone());
800 events.push(finished_event);
801
802 let mut output = RunOutput {
803 run_id: input.run_id,
804 session_id: input.session_id,
805 branch_id: input.branch_id,
806 events,
807 messages,
808 state,
809 reason: stop_reason,
810 final_answer,
811 total_usage,
812 };
813
814 if let Err(e) = self
815 .turn_middlewares
816 .iter()
817 .try_for_each(|m| m.on_run_finished(&mut output))
818 {
819 tracing::warn!(error = %e, "middleware on_run_finished failed (non-fatal)");
820 }
821
822 for hook in &self.lifecycle_hooks {
824 hook.on_session_end(&output.session_id, &output);
825 }
826
827 output
828 }
829
830 fn run_before_model(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
831 self.turn_middlewares
832 .iter()
833 .try_for_each(|middleware| middleware.before_model_call(request))
834 }
835
836 fn run_after_model(
837 &self,
838 request: &ProviderRequest,
839 response: &mut ModelTurn,
840 ) -> Result<(), CoreError> {
841 self.turn_middlewares
842 .iter()
843 .try_for_each(|middleware| middleware.after_model_call(request, response))
844 }
845
846 fn run_pre_tool(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
847 self.turn_middlewares
848 .iter()
849 .try_for_each(|middleware| middleware.pre_tool_call(context, call))
850 }
851
852 fn run_post_tool(
853 &self,
854 context: &ToolContext,
855 result: &mut ToolResult,
856 ) -> Result<(), CoreError> {
857 self.turn_middlewares
858 .iter()
859 .try_for_each(|middleware| middleware.post_tool_call(context, result))
860 }
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866 use crate::protocol::{
867 ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
868 };
869 use serde_json::json;
870 use std::sync::Mutex;
871
872 struct ScriptedProvider {
873 turns: Vec<ModelTurn>,
874 cursor: Mutex<usize>,
875 }
876
877 impl Provider for ScriptedProvider {
878 fn name(&self) -> &str {
879 "scripted"
880 }
881
882 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
883 let mut cursor = self
884 .cursor
885 .lock()
886 .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
887 let idx = *cursor;
888 let Some(turn) = self.turns.get(idx) else {
889 return Err(CoreError::Provider("no scripted turn left".to_string()));
890 };
891 *cursor += 1;
892 Ok(turn.clone())
893 }
894 }
895
896 struct EchoTool;
897
898 impl Tool for EchoTool {
899 fn definition(&self) -> ToolDefinition {
900 ToolDefinition {
901 name: "echo".to_string(),
902 description: "Echoes the provided value".to_string(),
903 input_schema: json!({
904 "type": "object",
905 "properties": { "value": { "type": "string" } },
906 "required": ["value"]
907 }),
908 title: None,
909 output_schema: None,
910 annotations: None,
911 category: None,
912 tags: Vec::new(),
913 timeout_secs: None,
914 }
915 }
916
917 fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
918 let value = call.input.get("value").cloned().unwrap_or(json!(null));
919 Ok(ToolResult {
920 call_id: call.call_id.clone(),
921 tool_name: call.tool_name.clone(),
922 output: json!({ "echo": value.clone() }),
923 content: None,
924 is_error: false,
925 state_patch: Some(StatePatch {
926 format: StatePatchFormat::MergePatch,
927 patch: json!({ "last_echo": value }),
928 source: StatePatchSource::Tool,
929 }),
930 })
931 }
932 }
933
934 #[test]
935 fn orchestrator_runs_tool_then_finishes() {
936 let provider = ScriptedProvider {
937 turns: vec![
938 ModelTurn {
939 directives: vec![ModelDirective::ToolCall {
940 call: ToolCall {
941 call_id: "call-1".to_string(),
942 tool_name: "echo".to_string(),
943 input: json!({ "value": "hello" }),
944 },
945 }],
946 stop_reason: ModelStopReason::ToolUse,
947 usage: None,
948 },
949 ModelTurn {
950 directives: vec![ModelDirective::FinalAnswer {
951 text: "done".to_string(),
952 }],
953 stop_reason: ModelStopReason::EndTurn,
954 usage: None,
955 },
956 ],
957 cursor: Mutex::new(0),
958 };
959
960 let mut tools = ToolRegistry::default();
961 tools.register(EchoTool);
962
963 let orchestrator = Orchestrator::new(
964 Arc::new(provider),
965 tools,
966 Vec::new(),
967 OrchestratorConfig {
968 max_iterations: 4,
969 context: None,
970 context_compiler: None,
971 },
972 );
973
974 let output = orchestrator.run(
975 RunInput {
976 run_id: "run-1".to_string(),
977 session_id: "session-1".to_string(),
978 branch_id: "main".to_string(),
979 messages: vec![ChatMessage::user("test")],
980 state: AppState::default(),
981 },
982 |_| {},
983 );
984
985 assert_eq!(output.reason, RunStopReason::Completed);
986 assert_eq!(output.final_answer.as_deref(), Some("done"));
987 assert_eq!(output.state.revision, 1);
988 assert_eq!(output.state.data["last_echo"], "hello");
989
990 assert!(
991 output
992 .events
993 .iter()
994 .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. }))
995 );
996 assert!(output.events.iter().any(|event| matches!(
997 event,
998 AgentEvent::RunFinished {
999 reason: RunStopReason::Completed,
1000 ..
1001 }
1002 )));
1003 }
1004
1005 #[test]
1006 fn provider_error_stops_run() {
1007 struct FailProvider;
1008 impl Provider for FailProvider {
1009 fn name(&self) -> &str {
1010 "fail"
1011 }
1012 fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1013 Err(CoreError::Provider("connection refused".to_string()))
1014 }
1015 }
1016
1017 let orchestrator = Orchestrator::new(
1018 Arc::new(FailProvider),
1019 ToolRegistry::default(),
1020 Vec::new(),
1021 OrchestratorConfig {
1022 max_iterations: 4,
1023 context: None,
1024 context_compiler: None,
1025 },
1026 );
1027
1028 let output = orchestrator.run(
1029 RunInput {
1030 run_id: "run-1".to_string(),
1031 session_id: "s1".to_string(),
1032 branch_id: "main".to_string(),
1033 messages: vec![ChatMessage::user("test")],
1034 state: AppState::default(),
1035 },
1036 |_| {},
1037 );
1038
1039 assert_eq!(output.reason, RunStopReason::Error);
1040 assert!(
1041 output
1042 .events
1043 .iter()
1044 .any(|e| matches!(e, AgentEvent::RunErrored { .. }))
1045 );
1046 }
1047
1048 #[test]
1049 fn tool_not_found_stops_run() {
1050 let provider = ScriptedProvider {
1051 turns: vec![ModelTurn {
1052 directives: vec![ModelDirective::ToolCall {
1053 call: ToolCall {
1054 call_id: "c1".to_string(),
1055 tool_name: "nonexistent".to_string(),
1056 input: json!({}),
1057 },
1058 }],
1059 stop_reason: ModelStopReason::ToolUse,
1060 usage: None,
1061 }],
1062 cursor: Mutex::new(0),
1063 };
1064
1065 let orchestrator = Orchestrator::new(
1066 Arc::new(provider),
1067 ToolRegistry::default(),
1068 Vec::new(),
1069 OrchestratorConfig {
1070 max_iterations: 4,
1071 context: None,
1072 context_compiler: None,
1073 },
1074 );
1075
1076 let output = orchestrator.run(
1077 RunInput {
1078 run_id: "run-1".to_string(),
1079 session_id: "s1".to_string(),
1080 branch_id: "main".to_string(),
1081 messages: vec![ChatMessage::user("test")],
1082 state: AppState::default(),
1083 },
1084 |_| {},
1085 );
1086
1087 assert_eq!(output.reason, RunStopReason::Error);
1088 assert!(
1089 output
1090 .events
1091 .iter()
1092 .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. }))
1093 );
1094 }
1095
1096 #[test]
1097 fn middleware_blocks_model_call() {
1098 struct BlockMiddleware;
1099 impl Middleware for BlockMiddleware {
1100 fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
1101 Err(CoreError::Middleware("blocked by policy".to_string()))
1102 }
1103 }
1104
1105 let provider = ScriptedProvider {
1106 turns: vec![ModelTurn {
1107 directives: vec![ModelDirective::Text {
1108 delta: "hi".to_string(),
1109 }],
1110 stop_reason: ModelStopReason::EndTurn,
1111 usage: None,
1112 }],
1113 cursor: Mutex::new(0),
1114 };
1115
1116 let orchestrator = Orchestrator::new(
1117 Arc::new(provider),
1118 ToolRegistry::default(),
1119 vec![Arc::new(BlockMiddleware)],
1120 OrchestratorConfig {
1121 max_iterations: 4,
1122 context: None,
1123 context_compiler: None,
1124 },
1125 );
1126
1127 let output = orchestrator.run(
1128 RunInput {
1129 run_id: "run-1".to_string(),
1130 session_id: "s1".to_string(),
1131 branch_id: "main".to_string(),
1132 messages: vec![ChatMessage::user("test")],
1133 state: AppState::default(),
1134 },
1135 |_| {},
1136 );
1137
1138 assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
1139 }
1140
1141 #[test]
1142 fn turn_middleware_can_rewrite_calls_and_responses() {
1143 struct RewriteMiddleware;
1144
1145 impl TurnMiddleware for RewriteMiddleware {
1146 fn after_model_call(
1147 &self,
1148 _request: &ProviderRequest,
1149 response: &mut ModelTurn,
1150 ) -> Result<(), CoreError> {
1151 for directive in &mut response.directives {
1152 if let ModelDirective::FinalAnswer { text } = directive {
1153 *text = "rewritten answer".to_string();
1154 }
1155 }
1156 Ok(())
1157 }
1158
1159 fn pre_tool_call(
1160 &self,
1161 _context: &ToolContext,
1162 call: &mut ToolCall,
1163 ) -> Result<(), CoreError> {
1164 call.input = json!({ "value": "rewritten input" });
1165 Ok(())
1166 }
1167 }
1168
1169 let provider = ScriptedProvider {
1170 turns: vec![
1171 ModelTurn {
1172 directives: vec![ModelDirective::ToolCall {
1173 call: ToolCall {
1174 call_id: "call-1".to_string(),
1175 tool_name: "echo".to_string(),
1176 input: json!({ "value": "original input" }),
1177 },
1178 }],
1179 stop_reason: ModelStopReason::ToolUse,
1180 usage: None,
1181 },
1182 ModelTurn {
1183 directives: vec![ModelDirective::FinalAnswer {
1184 text: "original answer".to_string(),
1185 }],
1186 stop_reason: ModelStopReason::EndTurn,
1187 usage: None,
1188 },
1189 ],
1190 cursor: Mutex::new(0),
1191 };
1192
1193 let mut tools = ToolRegistry::default();
1194 tools.register(EchoTool);
1195
1196 let orchestrator = Orchestrator::with_turn_middlewares(
1197 Arc::new(provider),
1198 tools,
1199 vec![Arc::new(RewriteMiddleware)],
1200 OrchestratorConfig {
1201 max_iterations: 4,
1202 context: None,
1203 context_compiler: None,
1204 },
1205 );
1206
1207 let output = orchestrator.run(
1208 RunInput {
1209 run_id: "run-1".to_string(),
1210 session_id: "session-1".to_string(),
1211 branch_id: "main".to_string(),
1212 messages: vec![ChatMessage::user("test")],
1213 state: AppState::default(),
1214 },
1215 |_| {},
1216 );
1217
1218 assert_eq!(output.reason, RunStopReason::Completed);
1219 assert_eq!(output.final_answer.as_deref(), Some("rewritten answer"));
1220 assert_eq!(output.state.data["last_echo"], "rewritten input");
1221 }
1222
1223 #[test]
1224 fn budget_exceeded_when_iterations_exhausted() {
1225 let provider = ScriptedProvider {
1229 turns: vec![
1230 ModelTurn {
1231 directives: vec![ModelDirective::ToolCall {
1232 call: ToolCall {
1233 call_id: "c1".to_string(),
1234 tool_name: "echo".to_string(),
1235 input: json!({"value": "1"}),
1236 },
1237 }],
1238 stop_reason: ModelStopReason::ToolUse,
1239 usage: None,
1240 },
1241 ModelTurn {
1242 directives: vec![ModelDirective::ToolCall {
1243 call: ToolCall {
1244 call_id: "c2".to_string(),
1245 tool_name: "echo".to_string(),
1246 input: json!({"value": "2"}),
1247 },
1248 }],
1249 stop_reason: ModelStopReason::ToolUse,
1250 usage: None,
1251 },
1252 ],
1255 cursor: Mutex::new(0),
1256 };
1257
1258 let mut tools = ToolRegistry::default();
1259 tools.register(EchoTool);
1260
1261 let orchestrator = Orchestrator::new(
1262 Arc::new(provider),
1263 tools,
1264 Vec::new(),
1265 OrchestratorConfig {
1266 max_iterations: 2,
1267 context: None,
1268 context_compiler: None,
1269 },
1270 );
1271
1272 let output = orchestrator.run(
1273 RunInput {
1274 run_id: "run-1".to_string(),
1275 session_id: "s1".to_string(),
1276 branch_id: "main".to_string(),
1277 messages: vec![ChatMessage::user("test")],
1278 state: AppState::default(),
1279 },
1280 |_| {},
1281 );
1282
1283 assert_eq!(output.reason, RunStopReason::BudgetExceeded);
1284 }
1285
1286 #[test]
1287 fn text_only_response_completes() {
1288 let provider = ScriptedProvider {
1289 turns: vec![ModelTurn {
1290 directives: vec![ModelDirective::Text {
1291 delta: "Hello, world!".to_string(),
1292 }],
1293 stop_reason: ModelStopReason::EndTurn,
1294 usage: None,
1295 }],
1296 cursor: Mutex::new(0),
1297 };
1298
1299 let orchestrator = Orchestrator::new(
1300 Arc::new(provider),
1301 ToolRegistry::default(),
1302 Vec::new(),
1303 OrchestratorConfig {
1304 max_iterations: 4,
1305 context: None,
1306 context_compiler: None,
1307 },
1308 );
1309
1310 let output = orchestrator.run(
1311 RunInput {
1312 run_id: "run-1".to_string(),
1313 session_id: "s1".to_string(),
1314 branch_id: "main".to_string(),
1315 messages: vec![ChatMessage::user("hi")],
1316 state: AppState::default(),
1317 },
1318 |_| {},
1319 );
1320
1321 assert_eq!(output.reason, RunStopReason::Completed);
1322 assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
1323 }
1324
1325 #[test]
1326 fn event_handler_receives_all_events() {
1327 let provider = ScriptedProvider {
1328 turns: vec![ModelTurn {
1329 directives: vec![ModelDirective::FinalAnswer {
1330 text: "done".to_string(),
1331 }],
1332 stop_reason: ModelStopReason::EndTurn,
1333 usage: None,
1334 }],
1335 cursor: Mutex::new(0),
1336 };
1337
1338 let orchestrator = Orchestrator::new(
1339 Arc::new(provider),
1340 ToolRegistry::default(),
1341 Vec::new(),
1342 OrchestratorConfig {
1343 max_iterations: 4,
1344 context: None,
1345 context_compiler: None,
1346 },
1347 );
1348
1349 let received = Arc::new(Mutex::new(Vec::new()));
1350 let received_clone = received.clone();
1351
1352 orchestrator.run(
1353 RunInput {
1354 run_id: "run-1".to_string(),
1355 session_id: "s1".to_string(),
1356 branch_id: "main".to_string(),
1357 messages: vec![ChatMessage::user("test")],
1358 state: AppState::default(),
1359 },
1360 move |event| {
1361 received_clone.lock().unwrap().push(event);
1362 },
1363 );
1364
1365 let events = received.lock().unwrap();
1366 assert!(events.len() >= 4); assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
1368 assert!(matches!(
1369 events.last().unwrap(),
1370 AgentEvent::RunFinished { .. }
1371 ));
1372 }
1373
1374 #[test]
1375 fn tool_result_includes_call_id() {
1376 let provider = ScriptedProvider {
1377 turns: vec![
1378 ModelTurn {
1379 directives: vec![ModelDirective::ToolCall {
1380 call: ToolCall {
1381 call_id: "my-call-id".to_string(),
1382 tool_name: "echo".to_string(),
1383 input: json!({"value": "test"}),
1384 },
1385 }],
1386 stop_reason: ModelStopReason::ToolUse,
1387 usage: None,
1388 },
1389 ModelTurn {
1390 directives: vec![ModelDirective::FinalAnswer {
1391 text: "ok".to_string(),
1392 }],
1393 stop_reason: ModelStopReason::EndTurn,
1394 usage: None,
1395 },
1396 ],
1397 cursor: Mutex::new(0),
1398 };
1399
1400 let mut tools = ToolRegistry::default();
1401 tools.register(EchoTool);
1402
1403 let orchestrator = Orchestrator::new(
1404 Arc::new(provider),
1405 tools,
1406 Vec::new(),
1407 OrchestratorConfig {
1408 max_iterations: 4,
1409 context: None,
1410 context_compiler: None,
1411 },
1412 );
1413
1414 let output = orchestrator.run(
1415 RunInput {
1416 run_id: "run-1".to_string(),
1417 session_id: "s1".to_string(),
1418 branch_id: "main".to_string(),
1419 messages: vec![ChatMessage::user("test")],
1420 state: AppState::default(),
1421 },
1422 |_| {},
1423 );
1424
1425 let tool_msg = output
1427 .messages
1428 .iter()
1429 .find(|m| m.role == crate::protocol::Role::Tool)
1430 .expect("should have tool message");
1431 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1432 }
1433
1434 #[test]
1435 fn cancellation_stops_run() {
1436 let provider = ScriptedProvider {
1437 turns: vec![
1438 ModelTurn {
1439 directives: vec![ModelDirective::ToolCall {
1440 call: ToolCall {
1441 call_id: "c1".to_string(),
1442 tool_name: "echo".to_string(),
1443 input: json!({"value": "1"}),
1444 },
1445 }],
1446 stop_reason: ModelStopReason::ToolUse,
1447 usage: None,
1448 },
1449 ModelTurn {
1450 directives: vec![ModelDirective::FinalAnswer {
1451 text: "should not reach".to_string(),
1452 }],
1453 stop_reason: ModelStopReason::EndTurn,
1454 usage: None,
1455 },
1456 ],
1457 cursor: Mutex::new(0),
1458 };
1459
1460 let mut tools = ToolRegistry::default();
1461 tools.register(EchoTool);
1462
1463 let orchestrator = Orchestrator::new(
1464 Arc::new(provider),
1465 tools,
1466 Vec::new(),
1467 OrchestratorConfig {
1468 max_iterations: 10,
1469 context: None,
1470 context_compiler: None,
1471 },
1472 );
1473
1474 let cancel = Arc::new(AtomicBool::new(false));
1476 let cancel_clone = cancel.clone();
1477 let call_count = Arc::new(Mutex::new(0u32));
1478 let call_count_clone = call_count.clone();
1479
1480 let output = orchestrator.run_cancellable(
1481 RunInput {
1482 run_id: "run-1".to_string(),
1483 session_id: "s1".to_string(),
1484 branch_id: "main".to_string(),
1485 messages: vec![ChatMessage::user("test")],
1486 state: AppState::default(),
1487 },
1488 Some(&cancel_clone),
1489 move |event| {
1490 if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1492 let mut count = call_count_clone.lock().unwrap();
1493 *count += 1;
1494 if *count >= 1 {
1495 cancel.store(true, Ordering::Relaxed);
1496 }
1497 }
1498 },
1499 );
1500
1501 assert_eq!(output.reason, RunStopReason::Cancelled);
1502 assert!(output.final_answer.is_none());
1504 }
1505
1506 #[test]
1507 fn swappable_provider_handle_swap() {
1508 struct ProviderA;
1509 impl Provider for ProviderA {
1510 fn name(&self) -> &str {
1511 "provider-a"
1512 }
1513 fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1514 Err(CoreError::Provider("stub provider".into()))
1515 }
1516 }
1517
1518 struct ProviderB;
1519 impl Provider for ProviderB {
1520 fn name(&self) -> &str {
1521 "provider-b"
1522 }
1523 fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1524 Err(CoreError::Provider("stub provider".into()))
1525 }
1526 }
1527
1528 let handle: SwappableProviderHandle = Arc::new(std::sync::RwLock::new(Arc::new(ProviderA)));
1529
1530 assert_eq!(handle.read().unwrap().name(), "provider-a");
1532
1533 {
1535 let mut guard = handle.write().unwrap();
1536 *guard = Arc::new(ProviderB);
1537 }
1538
1539 assert_eq!(handle.read().unwrap().name(), "provider-b");
1541 }
1542
1543 #[test]
1544 fn token_usage_accumulated() {
1545 let provider = ScriptedProvider {
1546 turns: vec![
1547 ModelTurn {
1548 directives: vec![ModelDirective::ToolCall {
1549 call: ToolCall {
1550 call_id: "c1".to_string(),
1551 tool_name: "echo".to_string(),
1552 input: json!({"value": "hi"}),
1553 },
1554 }],
1555 stop_reason: ModelStopReason::ToolUse,
1556 usage: Some(TokenUsage {
1557 input_tokens: 100,
1558 output_tokens: 50,
1559 cache_read_tokens: 0,
1560 cache_creation_tokens: 0,
1561 }),
1562 },
1563 ModelTurn {
1564 directives: vec![ModelDirective::FinalAnswer {
1565 text: "done".to_string(),
1566 }],
1567 stop_reason: ModelStopReason::EndTurn,
1568 usage: Some(TokenUsage {
1569 input_tokens: 200,
1570 output_tokens: 30,
1571 cache_read_tokens: 0,
1572 cache_creation_tokens: 0,
1573 }),
1574 },
1575 ],
1576 cursor: Mutex::new(0),
1577 };
1578
1579 let mut tools = ToolRegistry::default();
1580 tools.register(EchoTool);
1581
1582 let orchestrator = Orchestrator::new(
1583 Arc::new(provider),
1584 tools,
1585 Vec::new(),
1586 OrchestratorConfig {
1587 max_iterations: 4,
1588 context: None,
1589 context_compiler: None,
1590 },
1591 );
1592
1593 let output = orchestrator.run(
1594 RunInput {
1595 run_id: "run-1".to_string(),
1596 session_id: "s1".to_string(),
1597 branch_id: "main".to_string(),
1598 messages: vec![ChatMessage::user("test")],
1599 state: AppState::default(),
1600 },
1601 |_| {},
1602 );
1603
1604 assert_eq!(output.reason, RunStopReason::Completed);
1605 assert_eq!(output.total_usage.input_tokens, 300);
1606 assert_eq!(output.total_usage.output_tokens, 80);
1607 assert_eq!(output.total_usage.total(), 380);
1608 }
1609
1610 #[test]
1611 fn lifecycle_hooks_fire_during_run() {
1612 use crate::lifecycle::LifecycleHook;
1613 use std::sync::atomic::{AtomicU32, Ordering};
1614
1615 struct CountingHook {
1616 pre_tool: AtomicU32,
1617 post_tool: AtomicU32,
1618 pre_llm: AtomicU32,
1619 post_llm: AtomicU32,
1620 session_start: AtomicU32,
1621 session_end: AtomicU32,
1622 }
1623
1624 impl CountingHook {
1625 fn new() -> Self {
1626 Self {
1627 pre_tool: AtomicU32::new(0),
1628 post_tool: AtomicU32::new(0),
1629 pre_llm: AtomicU32::new(0),
1630 post_llm: AtomicU32::new(0),
1631 session_start: AtomicU32::new(0),
1632 session_end: AtomicU32::new(0),
1633 }
1634 }
1635 }
1636
1637 impl LifecycleHook for CountingHook {
1638 fn pre_tool_call(&self, _tool_name: &str, _input: &serde_json::Value) {
1639 self.pre_tool.fetch_add(1, Ordering::Relaxed);
1640 }
1641 fn post_tool_call(&self, _tool_name: &str, _result: &str) {
1642 self.post_tool.fetch_add(1, Ordering::Relaxed);
1643 }
1644 fn pre_llm_call(&self, _request: &ProviderRequest) {
1645 self.pre_llm.fetch_add(1, Ordering::Relaxed);
1646 }
1647 fn post_llm_call(&self, _request: &ProviderRequest) {
1648 self.post_llm.fetch_add(1, Ordering::Relaxed);
1649 }
1650 fn on_session_start(&self, _session_id: &str) {
1651 self.session_start.fetch_add(1, Ordering::Relaxed);
1652 }
1653 fn on_session_end(&self, _session_id: &str, _output: &RunOutput) {
1654 self.session_end.fetch_add(1, Ordering::Relaxed);
1655 }
1656 }
1657
1658 let provider = ScriptedProvider {
1659 turns: vec![
1660 ModelTurn {
1661 directives: vec![ModelDirective::ToolCall {
1662 call: ToolCall {
1663 call_id: "call-1".to_string(),
1664 tool_name: "echo".to_string(),
1665 input: json!({ "value": "hello" }),
1666 },
1667 }],
1668 stop_reason: ModelStopReason::ToolUse,
1669 usage: None,
1670 },
1671 ModelTurn {
1672 directives: vec![ModelDirective::FinalAnswer {
1673 text: "done".to_string(),
1674 }],
1675 stop_reason: ModelStopReason::EndTurn,
1676 usage: None,
1677 },
1678 ],
1679 cursor: Mutex::new(0),
1680 };
1681
1682 let mut tools = ToolRegistry::default();
1683 tools.register(EchoTool);
1684
1685 let hook = Arc::new(CountingHook::new());
1686
1687 let orchestrator = Orchestrator::new(
1688 Arc::new(provider),
1689 tools,
1690 Vec::new(),
1691 OrchestratorConfig {
1692 max_iterations: 4,
1693 context: None,
1694 context_compiler: None,
1695 },
1696 )
1697 .with_lifecycle_hooks(vec![hook.clone()]);
1698
1699 let output = orchestrator.run(
1700 RunInput {
1701 run_id: "run-1".to_string(),
1702 session_id: "session-1".to_string(),
1703 branch_id: "main".to_string(),
1704 messages: vec![ChatMessage::user("test")],
1705 state: AppState::default(),
1706 },
1707 |_| {},
1708 );
1709
1710 assert_eq!(output.reason, RunStopReason::Completed);
1711
1712 assert_eq!(hook.session_start.load(Ordering::Relaxed), 1);
1714 assert_eq!(hook.session_end.load(Ordering::Relaxed), 1);
1715
1716 assert_eq!(hook.pre_llm.load(Ordering::Relaxed), 2);
1718 assert_eq!(hook.post_llm.load(Ordering::Relaxed), 2);
1719
1720 assert_eq!(hook.pre_tool.load(Ordering::Relaxed), 1);
1722 assert_eq!(hook.post_tool.load(Ordering::Relaxed), 1);
1723 }
1724}