Skip to main content

arcan_core/
runtime.rs

1use crate::error::CoreError;
2use crate::protocol::{
3    AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
4    ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
5};
6use crate::state::AppState;
7use std::collections::BTreeMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
12pub struct ProviderRequest {
13    pub run_id: String,
14    pub session_id: String,
15    pub iteration: u32,
16    pub messages: Vec<ChatMessage>,
17    pub tools: Vec<ToolDefinition>,
18    pub state: AppState,
19}
20
21pub trait Provider: Send + Sync {
22    fn name(&self) -> &str;
23    fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
24}
25
26#[derive(Debug, Clone)]
27pub struct ToolContext {
28    pub run_id: String,
29    pub session_id: String,
30    pub iteration: u32,
31}
32
33pub trait Tool: Send + Sync {
34    fn definition(&self) -> ToolDefinition;
35    fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
36}
37
38pub trait Middleware: Send + Sync {
39    fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
40        Ok(())
41    }
42
43    fn after_model_call(
44        &self,
45        _request: &ProviderRequest,
46        _response: &ModelTurn,
47    ) -> Result<(), CoreError> {
48        Ok(())
49    }
50
51    fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
52        Ok(())
53    }
54
55    fn post_tool_call(
56        &self,
57        _context: &ToolContext,
58        _result: &ToolResult,
59    ) -> Result<(), CoreError> {
60        Ok(())
61    }
62
63    fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
64        Ok(())
65    }
66}
67
68#[derive(Clone, Default)]
69pub struct ToolRegistry {
70    tools: BTreeMap<String, Arc<dyn Tool>>,
71}
72
73impl ToolRegistry {
74    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
75        self.tools
76            .insert(tool.definition().name.clone(), Arc::new(tool));
77    }
78
79    pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
80        self.tools.get(tool_name).cloned()
81    }
82
83    pub fn definitions(&self) -> Vec<ToolDefinition> {
84        self.tools.values().map(|tool| tool.definition()).collect()
85    }
86}
87
88#[derive(Debug, Clone)]
89pub struct OrchestratorConfig {
90    pub max_iterations: u32,
91}
92
93impl Default for OrchestratorConfig {
94    fn default() -> Self {
95        Self { max_iterations: 24 }
96    }
97}
98
99#[derive(Debug, Clone)]
100pub struct RunInput {
101    pub run_id: String,
102    pub session_id: String,
103    pub messages: Vec<ChatMessage>,
104    pub state: AppState,
105}
106
107#[derive(Debug, Clone)]
108pub struct RunOutput {
109    pub run_id: String,
110    pub session_id: String,
111    pub events: Vec<AgentEvent>,
112    pub messages: Vec<ChatMessage>,
113    pub state: AppState,
114    pub reason: RunStopReason,
115    pub final_answer: Option<String>,
116    /// Accumulated token usage across all iterations.
117    pub total_usage: TokenUsage,
118}
119
120pub struct Orchestrator {
121    provider: Arc<dyn Provider>,
122    tools: ToolRegistry,
123    middlewares: Vec<Arc<dyn Middleware>>,
124    config: OrchestratorConfig,
125}
126
127impl Orchestrator {
128    pub fn new(
129        provider: Arc<dyn Provider>,
130        tools: ToolRegistry,
131        middlewares: Vec<Arc<dyn Middleware>>,
132        config: OrchestratorConfig,
133    ) -> Self {
134        Self {
135            provider,
136            tools,
137            middlewares,
138            config,
139        }
140    }
141
142    pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
143        self.run_cancellable(input, None, event_handler)
144    }
145
146    /// Run the orchestrator loop with an optional cancellation flag.
147    ///
148    /// If `cancel` is provided and set to `true` during execution,
149    /// the loop will stop at the next iteration boundary.
150    pub fn run_cancellable(
151        &self,
152        input: RunInput,
153        cancel: Option<Arc<AtomicBool>>,
154        mut event_handler: impl FnMut(AgentEvent),
155    ) -> RunOutput {
156        let mut events = Vec::new();
157        let mut messages = input.messages;
158        let mut state = input.state;
159        let mut final_answer: Option<String> = None;
160        let mut stop_reason = RunStopReason::BudgetExceeded;
161        let mut total_iterations = 0;
162        let mut total_usage = TokenUsage::default();
163
164        let start_event = AgentEvent::RunStarted {
165            run_id: input.run_id.clone(),
166            session_id: input.session_id.clone(),
167            provider: self.provider.name().to_string(),
168            max_iterations: self.config.max_iterations,
169        };
170        event_handler(start_event.clone());
171        events.push(start_event);
172
173        for iteration in 1..=self.config.max_iterations {
174            // Check cancellation at each iteration boundary
175            if let Some(ref flag) = cancel {
176                if flag.load(Ordering::Relaxed) {
177                    stop_reason = RunStopReason::Cancelled;
178                    let err_event = AgentEvent::RunErrored {
179                        run_id: input.run_id.clone(),
180                        session_id: input.session_id.clone(),
181                        error: "run cancelled".to_string(),
182                    };
183                    event_handler(err_event.clone());
184                    events.push(err_event);
185                    break;
186                }
187            }
188
189            total_iterations = iteration;
190            let iter_event = AgentEvent::IterationStarted {
191                run_id: input.run_id.clone(),
192                session_id: input.session_id.clone(),
193                iteration,
194            };
195            event_handler(iter_event.clone());
196            events.push(iter_event);
197
198            let provider_request = ProviderRequest {
199                run_id: input.run_id.clone(),
200                session_id: input.session_id.clone(),
201                iteration,
202                messages: messages.clone(),
203                tools: self.tools.definitions(),
204                state: state.clone(),
205            };
206
207            if let Err(err) = self.run_before_model(&provider_request) {
208                stop_reason = RunStopReason::BlockedByPolicy;
209                let err_event = AgentEvent::RunErrored {
210                    run_id: input.run_id.clone(),
211                    session_id: input.session_id.clone(),
212                    error: err.to_string(),
213                };
214                event_handler(err_event.clone());
215                events.push(err_event);
216                break;
217            }
218
219            let model_turn = match self.provider.complete(&provider_request) {
220                Ok(turn) => turn,
221                Err(err) => {
222                    stop_reason = RunStopReason::Error;
223                    let err_event = AgentEvent::RunErrored {
224                        run_id: input.run_id.clone(),
225                        session_id: input.session_id.clone(),
226                        error: err.to_string(),
227                    };
228                    event_handler(err_event.clone());
229                    events.push(err_event);
230                    break;
231                }
232            };
233
234            if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
235                stop_reason = RunStopReason::BlockedByPolicy;
236                let err_event = AgentEvent::RunErrored {
237                    run_id: input.run_id.clone(),
238                    session_id: input.session_id.clone(),
239                    error: err.to_string(),
240                };
241                event_handler(err_event.clone());
242                events.push(err_event);
243                break;
244            }
245
246            // Accumulate token usage if reported
247            if let Some(ref usage) = model_turn.usage {
248                total_usage.accumulate(usage);
249            }
250
251            let output_event = AgentEvent::ModelOutput {
252                run_id: input.run_id.clone(),
253                session_id: input.session_id.clone(),
254                iteration,
255                stop_reason: model_turn.stop_reason,
256                directive_count: model_turn.directives.len(),
257                usage: model_turn.usage,
258            };
259            event_handler(output_event.clone());
260            events.push(output_event);
261
262            let mut requested_tool = false;
263
264            for directive in model_turn.directives {
265                match directive {
266                    ModelDirective::Text { delta } => {
267                        let delta_event = AgentEvent::TextDelta {
268                            run_id: input.run_id.clone(),
269                            session_id: input.session_id.clone(),
270                            iteration,
271                            delta: delta.clone(),
272                        };
273                        event_handler(delta_event.clone());
274                        events.push(delta_event);
275                        messages.push(ChatMessage::assistant(delta));
276                    }
277                    ModelDirective::ToolCall { call } => {
278                        requested_tool = true;
279                        let tc_event = AgentEvent::ToolCallRequested {
280                            run_id: input.run_id.clone(),
281                            session_id: input.session_id.clone(),
282                            iteration,
283                            call: call.clone(),
284                        };
285                        event_handler(tc_event.clone());
286                        events.push(tc_event);
287
288                        let context = ToolContext {
289                            run_id: input.run_id.clone(),
290                            session_id: input.session_id.clone(),
291                            iteration,
292                        };
293
294                        if let Err(err) = self.run_pre_tool(&context, &call) {
295                            stop_reason = RunStopReason::BlockedByPolicy;
296                            let err_event = AgentEvent::ToolCallFailed {
297                                run_id: input.run_id.clone(),
298                                session_id: input.session_id.clone(),
299                                iteration,
300                                call_id: call.call_id.clone(),
301                                tool_name: call.tool_name.clone(),
302                                error: err.to_string(),
303                            };
304                            event_handler(err_event.clone());
305                            events.push(err_event);
306                            break;
307                        }
308
309                        let Some(tool) = self.tools.get(&call.tool_name) else {
310                            stop_reason = RunStopReason::Error;
311                            let err_event = AgentEvent::ToolCallFailed {
312                                run_id: input.run_id.clone(),
313                                session_id: input.session_id.clone(),
314                                iteration,
315                                call_id: call.call_id.clone(),
316                                tool_name: call.tool_name.clone(),
317                                error: format!(
318                                    "{}",
319                                    CoreError::ToolNotFound {
320                                        tool_name: call.tool_name.clone(),
321                                    }
322                                ),
323                            };
324                            event_handler(err_event.clone());
325                            events.push(err_event);
326                            break;
327                        };
328
329                        match tool.execute(&call, &context) {
330                            Ok(result) => {
331                                if let Some(patch) = &result.state_patch {
332                                    match state.apply_patch(patch) {
333                                        Ok(()) => {
334                                            let patch_event = AgentEvent::StatePatched {
335                                                run_id: input.run_id.clone(),
336                                                session_id: input.session_id.clone(),
337                                                iteration,
338                                                patch: patch.clone(),
339                                                revision: state.revision,
340                                            };
341                                            event_handler(patch_event.clone());
342                                            events.push(patch_event);
343                                        }
344                                        Err(err) => {
345                                            stop_reason = RunStopReason::Error;
346                                            let err_event = AgentEvent::ToolCallFailed {
347                                                run_id: input.run_id.clone(),
348                                                session_id: input.session_id.clone(),
349                                                iteration,
350                                                call_id: call.call_id.clone(),
351                                                tool_name: call.tool_name.clone(),
352                                                error: err.to_string(),
353                                            };
354                                            event_handler(err_event.clone());
355                                            events.push(err_event);
356                                            break;
357                                        }
358                                    }
359                                }
360
361                                if let Err(err) = self.run_post_tool(&context, &result) {
362                                    stop_reason = RunStopReason::BlockedByPolicy;
363                                    let err_event = AgentEvent::ToolCallFailed {
364                                        run_id: input.run_id.clone(),
365                                        session_id: input.session_id.clone(),
366                                        iteration,
367                                        call_id: call.call_id.clone(),
368                                        tool_name: call.tool_name.clone(),
369                                        error: err.to_string(),
370                                    };
371                                    event_handler(err_event.clone());
372                                    events.push(err_event);
373                                    break;
374                                }
375
376                                let completed_event = AgentEvent::ToolCallCompleted {
377                                    run_id: input.run_id.clone(),
378                                    session_id: input.session_id.clone(),
379                                    iteration,
380                                    result: ToolResultSummary::from(&result),
381                                };
382                                event_handler(completed_event.clone());
383                                events.push(completed_event);
384
385                                messages.push(ChatMessage::tool_result(
386                                    &result.call_id,
387                                    serde_json::to_string(&result.output)
388                                        .unwrap_or_else(|_| "{}".to_string()),
389                                ));
390                            }
391                            Err(err) => {
392                                stop_reason = RunStopReason::Error;
393                                let err_event = AgentEvent::ToolCallFailed {
394                                    run_id: input.run_id.clone(),
395                                    session_id: input.session_id.clone(),
396                                    iteration,
397                                    call_id: call.call_id.clone(),
398                                    tool_name: call.tool_name.clone(),
399                                    error: err.to_string(),
400                                };
401                                event_handler(err_event.clone());
402                                events.push(err_event);
403                                break;
404                            }
405                        }
406                    }
407                    ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
408                        Ok(()) => {
409                            let patch_event = AgentEvent::StatePatched {
410                                run_id: input.run_id.clone(),
411                                session_id: input.session_id.clone(),
412                                iteration,
413                                patch: patch.clone(),
414                                revision: state.revision,
415                            };
416                            event_handler(patch_event.clone());
417                            events.push(patch_event);
418                        }
419                        Err(err) => {
420                            stop_reason = RunStopReason::Error;
421                            let err_event = AgentEvent::RunErrored {
422                                run_id: input.run_id.clone(),
423                                session_id: input.session_id.clone(),
424                                error: err.to_string(),
425                            };
426                            event_handler(err_event.clone());
427                            events.push(err_event);
428                            break;
429                        }
430                    },
431                    ModelDirective::FinalAnswer { text } => {
432                        final_answer = Some(text.clone());
433                        let delta_event = AgentEvent::TextDelta {
434                            run_id: input.run_id.clone(),
435                            session_id: input.session_id.clone(),
436                            iteration,
437                            delta: text.clone(),
438                        };
439                        event_handler(delta_event.clone());
440                        events.push(delta_event);
441                        messages.push(ChatMessage::assistant(text));
442                    }
443                }
444            }
445
446            if matches!(
447                stop_reason,
448                RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
449            ) {
450                break;
451            }
452
453            match model_turn.stop_reason {
454                ModelStopReason::EndTurn => {
455                    stop_reason = RunStopReason::Completed;
456                    break;
457                }
458                ModelStopReason::NeedsUser => {
459                    stop_reason = RunStopReason::NeedsUser;
460                    break;
461                }
462                ModelStopReason::Safety => {
463                    stop_reason = RunStopReason::BlockedByPolicy;
464                    break;
465                }
466                ModelStopReason::ToolUse => {
467                    if !requested_tool {
468                        stop_reason = RunStopReason::Error;
469                        let err_event = AgentEvent::RunErrored {
470                            run_id: input.run_id.clone(),
471                            session_id: input.session_id.clone(),
472                            error: "model requested tool_use stop reason without tool call"
473                                .to_string(),
474                        };
475                        event_handler(err_event.clone());
476                        events.push(err_event);
477                        break;
478                    }
479                }
480                ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
481                    if !requested_tool {
482                        stop_reason = RunStopReason::Error;
483                        let err_event = AgentEvent::RunErrored {
484                            run_id: input.run_id.clone(),
485                            session_id: input.session_id.clone(),
486                            error: "model returned non-terminal stop reason without tool call"
487                                .to_string(),
488                        };
489                        event_handler(err_event.clone());
490                        events.push(err_event);
491                        break;
492                    }
493                }
494            }
495        }
496
497        if total_iterations == self.config.max_iterations
498            && stop_reason == RunStopReason::BudgetExceeded
499        {
500            let err_event = AgentEvent::RunErrored {
501                run_id: input.run_id.clone(),
502                session_id: input.session_id.clone(),
503                error: "max iteration budget exceeded".to_string(),
504            };
505            event_handler(err_event.clone());
506            events.push(err_event);
507        }
508
509        let finished_event = AgentEvent::RunFinished {
510            run_id: input.run_id.clone(),
511            session_id: input.session_id.clone(),
512            reason: stop_reason,
513            total_iterations,
514            final_answer: final_answer.clone(),
515        };
516        event_handler(finished_event.clone());
517        events.push(finished_event);
518
519        let output = RunOutput {
520            run_id: input.run_id,
521            session_id: input.session_id,
522            events,
523            messages,
524            state,
525            reason: stop_reason,
526            final_answer,
527            total_usage,
528        };
529
530        let _ = self
531            .middlewares
532            .iter()
533            .try_for_each(|middleware| middleware.on_run_finished(&output));
534
535        output
536    }
537
538    fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
539        self.middlewares
540            .iter()
541            .try_for_each(|middleware| middleware.before_model_call(request))
542    }
543
544    fn run_after_model(
545        &self,
546        request: &ProviderRequest,
547        response: &ModelTurn,
548    ) -> Result<(), CoreError> {
549        self.middlewares
550            .iter()
551            .try_for_each(|middleware| middleware.after_model_call(request, response))
552    }
553
554    fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
555        self.middlewares
556            .iter()
557            .try_for_each(|middleware| middleware.pre_tool_call(context, call))
558    }
559
560    fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
561        self.middlewares
562            .iter()
563            .try_for_each(|middleware| middleware.post_tool_call(context, result))
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::protocol::{
571        ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
572    };
573    use serde_json::json;
574    use std::sync::Mutex;
575
576    struct ScriptedProvider {
577        turns: Vec<ModelTurn>,
578        cursor: Mutex<usize>,
579    }
580
581    impl Provider for ScriptedProvider {
582        fn name(&self) -> &str {
583            "scripted"
584        }
585
586        fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
587            let mut cursor = self
588                .cursor
589                .lock()
590                .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
591            let idx = *cursor;
592            let Some(turn) = self.turns.get(idx) else {
593                return Err(CoreError::Provider("no scripted turn left".to_string()));
594            };
595            *cursor += 1;
596            Ok(turn.clone())
597        }
598    }
599
600    struct EchoTool;
601
602    impl Tool for EchoTool {
603        fn definition(&self) -> ToolDefinition {
604            ToolDefinition {
605                name: "echo".to_string(),
606                description: "Echoes the provided value".to_string(),
607                input_schema: json!({
608                    "type": "object",
609                    "properties": { "value": { "type": "string" } },
610                    "required": ["value"]
611                }),
612                title: None,
613                output_schema: None,
614                annotations: None,
615                category: None,
616                tags: Vec::new(),
617                timeout_secs: None,
618            }
619        }
620
621        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
622            let value = call
623                .input
624                .get("value")
625                .cloned()
626                .unwrap_or_else(|| json!(null));
627            Ok(ToolResult {
628                call_id: call.call_id.clone(),
629                tool_name: call.tool_name.clone(),
630                output: json!({ "echo": value.clone() }),
631                content: None,
632                is_error: false,
633                state_patch: Some(StatePatch {
634                    format: StatePatchFormat::MergePatch,
635                    patch: json!({ "last_echo": value }),
636                    source: StatePatchSource::Tool,
637                }),
638            })
639        }
640    }
641
642    #[test]
643    fn orchestrator_runs_tool_then_finishes() {
644        let provider = ScriptedProvider {
645            turns: vec![
646                ModelTurn {
647                    directives: vec![ModelDirective::ToolCall {
648                        call: ToolCall {
649                            call_id: "call-1".to_string(),
650                            tool_name: "echo".to_string(),
651                            input: json!({ "value": "hello" }),
652                        },
653                    }],
654                    stop_reason: ModelStopReason::ToolUse,
655                    usage: None,
656                },
657                ModelTurn {
658                    directives: vec![ModelDirective::FinalAnswer {
659                        text: "done".to_string(),
660                    }],
661                    stop_reason: ModelStopReason::EndTurn,
662                    usage: None,
663                },
664            ],
665            cursor: Mutex::new(0),
666        };
667
668        let mut tools = ToolRegistry::default();
669        tools.register(EchoTool);
670
671        let orchestrator = Orchestrator::new(
672            Arc::new(provider),
673            tools,
674            Vec::new(),
675            OrchestratorConfig { max_iterations: 4 },
676        );
677
678        let output = orchestrator.run(
679            RunInput {
680                run_id: "run-1".to_string(),
681                session_id: "session-1".to_string(),
682                messages: vec![ChatMessage::user("test")],
683                state: AppState::default(),
684            },
685            |_| {},
686        );
687
688        assert_eq!(output.reason, RunStopReason::Completed);
689        assert_eq!(output.final_answer.as_deref(), Some("done"));
690        assert_eq!(output.state.revision, 1);
691        assert_eq!(output.state.data["last_echo"], "hello");
692
693        assert!(output
694            .events
695            .iter()
696            .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. })));
697        assert!(output.events.iter().any(|event| matches!(
698            event,
699            AgentEvent::RunFinished {
700                reason: RunStopReason::Completed,
701                ..
702            }
703        )));
704    }
705
706    #[test]
707    fn provider_error_stops_run() {
708        struct FailProvider;
709        impl Provider for FailProvider {
710            fn name(&self) -> &str {
711                "fail"
712            }
713            fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
714                Err(CoreError::Provider("connection refused".to_string()))
715            }
716        }
717
718        let orchestrator = Orchestrator::new(
719            Arc::new(FailProvider),
720            ToolRegistry::default(),
721            Vec::new(),
722            OrchestratorConfig { max_iterations: 4 },
723        );
724
725        let output = orchestrator.run(
726            RunInput {
727                run_id: "run-1".to_string(),
728                session_id: "s1".to_string(),
729                messages: vec![ChatMessage::user("test")],
730                state: AppState::default(),
731            },
732            |_| {},
733        );
734
735        assert_eq!(output.reason, RunStopReason::Error);
736        assert!(output
737            .events
738            .iter()
739            .any(|e| matches!(e, AgentEvent::RunErrored { .. })));
740    }
741
742    #[test]
743    fn tool_not_found_stops_run() {
744        let provider = ScriptedProvider {
745            turns: vec![ModelTurn {
746                directives: vec![ModelDirective::ToolCall {
747                    call: ToolCall {
748                        call_id: "c1".to_string(),
749                        tool_name: "nonexistent".to_string(),
750                        input: json!({}),
751                    },
752                }],
753                stop_reason: ModelStopReason::ToolUse,
754                usage: None,
755            }],
756            cursor: Mutex::new(0),
757        };
758
759        let orchestrator = Orchestrator::new(
760            Arc::new(provider),
761            ToolRegistry::default(),
762            Vec::new(),
763            OrchestratorConfig { max_iterations: 4 },
764        );
765
766        let output = orchestrator.run(
767            RunInput {
768                run_id: "run-1".to_string(),
769                session_id: "s1".to_string(),
770                messages: vec![ChatMessage::user("test")],
771                state: AppState::default(),
772            },
773            |_| {},
774        );
775
776        assert_eq!(output.reason, RunStopReason::Error);
777        assert!(output
778            .events
779            .iter()
780            .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. })));
781    }
782
783    #[test]
784    fn middleware_blocks_model_call() {
785        struct BlockMiddleware;
786        impl Middleware for BlockMiddleware {
787            fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
788                Err(CoreError::Middleware("blocked by policy".to_string()))
789            }
790        }
791
792        let provider = ScriptedProvider {
793            turns: vec![ModelTurn {
794                directives: vec![ModelDirective::Text {
795                    delta: "hi".to_string(),
796                }],
797                stop_reason: ModelStopReason::EndTurn,
798                usage: None,
799            }],
800            cursor: Mutex::new(0),
801        };
802
803        let orchestrator = Orchestrator::new(
804            Arc::new(provider),
805            ToolRegistry::default(),
806            vec![Arc::new(BlockMiddleware)],
807            OrchestratorConfig { max_iterations: 4 },
808        );
809
810        let output = orchestrator.run(
811            RunInput {
812                run_id: "run-1".to_string(),
813                session_id: "s1".to_string(),
814                messages: vec![ChatMessage::user("test")],
815                state: AppState::default(),
816            },
817            |_| {},
818        );
819
820        assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
821    }
822
823    #[test]
824    fn budget_exceeded_when_iterations_exhausted() {
825        // Provider always returns ToolUse but no tool call directives → continues loop
826        // Actually, we need it to keep looping. Use a tool that works, but provider
827        // always asks for more.
828        let provider = ScriptedProvider {
829            turns: vec![
830                ModelTurn {
831                    directives: vec![ModelDirective::ToolCall {
832                        call: ToolCall {
833                            call_id: "c1".to_string(),
834                            tool_name: "echo".to_string(),
835                            input: json!({"value": "1"}),
836                        },
837                    }],
838                    stop_reason: ModelStopReason::ToolUse,
839                    usage: None,
840                },
841                ModelTurn {
842                    directives: vec![ModelDirective::ToolCall {
843                        call: ToolCall {
844                            call_id: "c2".to_string(),
845                            tool_name: "echo".to_string(),
846                            input: json!({"value": "2"}),
847                        },
848                    }],
849                    stop_reason: ModelStopReason::ToolUse,
850                    usage: None,
851                },
852                // Only 2 turns, but max_iterations = 2, so it exhausts budget
853                // 3rd iteration will fail because no more scripted turns
854            ],
855            cursor: Mutex::new(0),
856        };
857
858        let mut tools = ToolRegistry::default();
859        tools.register(EchoTool);
860
861        let orchestrator = Orchestrator::new(
862            Arc::new(provider),
863            tools,
864            Vec::new(),
865            OrchestratorConfig { max_iterations: 2 },
866        );
867
868        let output = orchestrator.run(
869            RunInput {
870                run_id: "run-1".to_string(),
871                session_id: "s1".to_string(),
872                messages: vec![ChatMessage::user("test")],
873                state: AppState::default(),
874            },
875            |_| {},
876        );
877
878        assert_eq!(output.reason, RunStopReason::BudgetExceeded);
879    }
880
881    #[test]
882    fn text_only_response_completes() {
883        let provider = ScriptedProvider {
884            turns: vec![ModelTurn {
885                directives: vec![ModelDirective::Text {
886                    delta: "Hello, world!".to_string(),
887                }],
888                stop_reason: ModelStopReason::EndTurn,
889                usage: None,
890            }],
891            cursor: Mutex::new(0),
892        };
893
894        let orchestrator = Orchestrator::new(
895            Arc::new(provider),
896            ToolRegistry::default(),
897            Vec::new(),
898            OrchestratorConfig { max_iterations: 4 },
899        );
900
901        let output = orchestrator.run(
902            RunInput {
903                run_id: "run-1".to_string(),
904                session_id: "s1".to_string(),
905                messages: vec![ChatMessage::user("hi")],
906                state: AppState::default(),
907            },
908            |_| {},
909        );
910
911        assert_eq!(output.reason, RunStopReason::Completed);
912        assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
913    }
914
915    #[test]
916    fn event_handler_receives_all_events() {
917        let provider = ScriptedProvider {
918            turns: vec![ModelTurn {
919                directives: vec![ModelDirective::FinalAnswer {
920                    text: "done".to_string(),
921                }],
922                stop_reason: ModelStopReason::EndTurn,
923                usage: None,
924            }],
925            cursor: Mutex::new(0),
926        };
927
928        let orchestrator = Orchestrator::new(
929            Arc::new(provider),
930            ToolRegistry::default(),
931            Vec::new(),
932            OrchestratorConfig { max_iterations: 4 },
933        );
934
935        let received = Arc::new(Mutex::new(Vec::new()));
936        let received_clone = received.clone();
937
938        orchestrator.run(
939            RunInput {
940                run_id: "run-1".to_string(),
941                session_id: "s1".to_string(),
942                messages: vec![ChatMessage::user("test")],
943                state: AppState::default(),
944            },
945            move |event| {
946                received_clone.lock().unwrap().push(event);
947            },
948        );
949
950        let events = received.lock().unwrap();
951        assert!(events.len() >= 4); // RunStarted, IterationStarted, ModelOutput, TextDelta, RunFinished
952        assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
953        assert!(matches!(
954            events.last().unwrap(),
955            AgentEvent::RunFinished { .. }
956        ));
957    }
958
959    #[test]
960    fn tool_result_includes_call_id() {
961        let provider = ScriptedProvider {
962            turns: vec![
963                ModelTurn {
964                    directives: vec![ModelDirective::ToolCall {
965                        call: ToolCall {
966                            call_id: "my-call-id".to_string(),
967                            tool_name: "echo".to_string(),
968                            input: json!({"value": "test"}),
969                        },
970                    }],
971                    stop_reason: ModelStopReason::ToolUse,
972                    usage: None,
973                },
974                ModelTurn {
975                    directives: vec![ModelDirective::FinalAnswer {
976                        text: "ok".to_string(),
977                    }],
978                    stop_reason: ModelStopReason::EndTurn,
979                    usage: None,
980                },
981            ],
982            cursor: Mutex::new(0),
983        };
984
985        let mut tools = ToolRegistry::default();
986        tools.register(EchoTool);
987
988        let orchestrator = Orchestrator::new(
989            Arc::new(provider),
990            tools,
991            Vec::new(),
992            OrchestratorConfig { max_iterations: 4 },
993        );
994
995        let output = orchestrator.run(
996            RunInput {
997                run_id: "run-1".to_string(),
998                session_id: "s1".to_string(),
999                messages: vec![ChatMessage::user("test")],
1000                state: AppState::default(),
1001            },
1002            |_| {},
1003        );
1004
1005        // Verify tool result message has the correct call_id
1006        let tool_msg = output
1007            .messages
1008            .iter()
1009            .find(|m| m.role == crate::protocol::Role::Tool)
1010            .expect("should have tool message");
1011        assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1012    }
1013
1014    #[test]
1015    fn cancellation_stops_run() {
1016        let provider = ScriptedProvider {
1017            turns: vec![
1018                ModelTurn {
1019                    directives: vec![ModelDirective::ToolCall {
1020                        call: ToolCall {
1021                            call_id: "c1".to_string(),
1022                            tool_name: "echo".to_string(),
1023                            input: json!({"value": "1"}),
1024                        },
1025                    }],
1026                    stop_reason: ModelStopReason::ToolUse,
1027                    usage: None,
1028                },
1029                ModelTurn {
1030                    directives: vec![ModelDirective::FinalAnswer {
1031                        text: "should not reach".to_string(),
1032                    }],
1033                    stop_reason: ModelStopReason::EndTurn,
1034                    usage: None,
1035                },
1036            ],
1037            cursor: Mutex::new(0),
1038        };
1039
1040        let mut tools = ToolRegistry::default();
1041        tools.register(EchoTool);
1042
1043        let orchestrator = Orchestrator::new(
1044            Arc::new(provider),
1045            tools,
1046            Vec::new(),
1047            OrchestratorConfig { max_iterations: 10 },
1048        );
1049
1050        // Set cancellation flag before the second iteration
1051        let cancel = Arc::new(AtomicBool::new(false));
1052        let cancel_clone = cancel.clone();
1053        let call_count = Arc::new(Mutex::new(0u32));
1054        let call_count_clone = call_count.clone();
1055
1056        let output = orchestrator.run_cancellable(
1057            RunInput {
1058                run_id: "run-1".to_string(),
1059                session_id: "s1".to_string(),
1060                messages: vec![ChatMessage::user("test")],
1061                state: AppState::default(),
1062            },
1063            Some(cancel_clone),
1064            move |event| {
1065                // Cancel after first iteration completes
1066                if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1067                    let mut count = call_count_clone.lock().unwrap();
1068                    *count += 1;
1069                    if *count >= 1 {
1070                        cancel.store(true, Ordering::Relaxed);
1071                    }
1072                }
1073            },
1074        );
1075
1076        assert_eq!(output.reason, RunStopReason::Cancelled);
1077        // Should not have a final answer since we cancelled
1078        assert!(output.final_answer.is_none());
1079    }
1080
1081    #[test]
1082    fn token_usage_accumulated() {
1083        let provider = ScriptedProvider {
1084            turns: vec![
1085                ModelTurn {
1086                    directives: vec![ModelDirective::ToolCall {
1087                        call: ToolCall {
1088                            call_id: "c1".to_string(),
1089                            tool_name: "echo".to_string(),
1090                            input: json!({"value": "hi"}),
1091                        },
1092                    }],
1093                    stop_reason: ModelStopReason::ToolUse,
1094                    usage: Some(TokenUsage {
1095                        input_tokens: 100,
1096                        output_tokens: 50,
1097                        cache_read_tokens: 0,
1098                        cache_creation_tokens: 0,
1099                    }),
1100                },
1101                ModelTurn {
1102                    directives: vec![ModelDirective::FinalAnswer {
1103                        text: "done".to_string(),
1104                    }],
1105                    stop_reason: ModelStopReason::EndTurn,
1106                    usage: Some(TokenUsage {
1107                        input_tokens: 200,
1108                        output_tokens: 30,
1109                        cache_read_tokens: 0,
1110                        cache_creation_tokens: 0,
1111                    }),
1112                },
1113            ],
1114            cursor: Mutex::new(0),
1115        };
1116
1117        let mut tools = ToolRegistry::default();
1118        tools.register(EchoTool);
1119
1120        let orchestrator = Orchestrator::new(
1121            Arc::new(provider),
1122            tools,
1123            Vec::new(),
1124            OrchestratorConfig { max_iterations: 4 },
1125        );
1126
1127        let output = orchestrator.run(
1128            RunInput {
1129                run_id: "run-1".to_string(),
1130                session_id: "s1".to_string(),
1131                messages: vec![ChatMessage::user("test")],
1132                state: AppState::default(),
1133            },
1134            |_| {},
1135        );
1136
1137        assert_eq!(output.reason, RunStopReason::Completed);
1138        assert_eq!(output.total_usage.input_tokens, 300);
1139        assert_eq!(output.total_usage.output_tokens, 80);
1140        assert_eq!(output.total_usage.total(), 380);
1141    }
1142}