Skip to main content

arcan_core/
runtime.rs

1use crate::error::CoreError;
2use crate::protocol::{
3    AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, ToolCall,
4    ToolDefinition, ToolResult, ToolResultSummary,
5};
6use crate::state::AppState;
7use std::collections::BTreeMap;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct ProviderRequest {
12    pub run_id: String,
13    pub session_id: String,
14    pub iteration: u32,
15    pub messages: Vec<ChatMessage>,
16    pub tools: Vec<ToolDefinition>,
17    pub state: AppState,
18}
19
20pub trait Provider: Send + Sync {
21    fn name(&self) -> &str;
22    fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
23}
24
25#[derive(Debug, Clone)]
26pub struct ToolContext {
27    pub run_id: String,
28    pub session_id: String,
29    pub iteration: u32,
30}
31
32pub trait Tool: Send + Sync {
33    fn definition(&self) -> ToolDefinition;
34    fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
35}
36
37pub trait Middleware: Send + Sync {
38    fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
39        Ok(())
40    }
41
42    fn after_model_call(
43        &self,
44        _request: &ProviderRequest,
45        _response: &ModelTurn,
46    ) -> Result<(), CoreError> {
47        Ok(())
48    }
49
50    fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
51        Ok(())
52    }
53
54    fn post_tool_call(
55        &self,
56        _context: &ToolContext,
57        _result: &ToolResult,
58    ) -> Result<(), CoreError> {
59        Ok(())
60    }
61
62    fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
63        Ok(())
64    }
65}
66
67#[derive(Clone, Default)]
68pub struct ToolRegistry {
69    tools: BTreeMap<String, Arc<dyn Tool>>,
70}
71
72impl ToolRegistry {
73    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
74        self.tools
75            .insert(tool.definition().name.clone(), Arc::new(tool));
76    }
77
78    pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
79        self.tools.get(tool_name).cloned()
80    }
81
82    pub fn definitions(&self) -> Vec<ToolDefinition> {
83        self.tools.values().map(|tool| tool.definition()).collect()
84    }
85}
86
87#[derive(Debug, Clone)]
88pub struct OrchestratorConfig {
89    pub max_iterations: u32,
90}
91
92impl Default for OrchestratorConfig {
93    fn default() -> Self {
94        Self { max_iterations: 24 }
95    }
96}
97
98#[derive(Debug, Clone)]
99pub struct RunInput {
100    pub run_id: String,
101    pub session_id: String,
102    pub messages: Vec<ChatMessage>,
103    pub state: AppState,
104}
105
106#[derive(Debug, Clone)]
107pub struct RunOutput {
108    pub run_id: String,
109    pub session_id: String,
110    pub events: Vec<AgentEvent>,
111    pub messages: Vec<ChatMessage>,
112    pub state: AppState,
113    pub reason: RunStopReason,
114    pub final_answer: Option<String>,
115}
116
117pub struct Orchestrator {
118    provider: Arc<dyn Provider>,
119    tools: ToolRegistry,
120    middlewares: Vec<Arc<dyn Middleware>>,
121    config: OrchestratorConfig,
122}
123
124impl Orchestrator {
125    pub fn new(
126        provider: Arc<dyn Provider>,
127        tools: ToolRegistry,
128        middlewares: Vec<Arc<dyn Middleware>>,
129        config: OrchestratorConfig,
130    ) -> Self {
131        Self {
132            provider,
133            tools,
134            middlewares,
135            config,
136        }
137    }
138
139    pub fn run(&self, input: RunInput, mut event_handler: impl FnMut(AgentEvent)) -> RunOutput {
140        let mut events = Vec::new();
141        let mut messages = input.messages;
142        let mut state = input.state;
143        let mut final_answer: Option<String> = None;
144        let mut stop_reason = RunStopReason::BudgetExceeded;
145        let mut total_iterations = 0;
146
147        let start_event = AgentEvent::RunStarted {
148            run_id: input.run_id.clone(),
149            session_id: input.session_id.clone(),
150            provider: self.provider.name().to_string(),
151            max_iterations: self.config.max_iterations,
152        };
153        event_handler(start_event.clone());
154        events.push(start_event);
155
156        for iteration in 1..=self.config.max_iterations {
157            total_iterations = iteration;
158            let iter_event = AgentEvent::IterationStarted {
159                run_id: input.run_id.clone(),
160                session_id: input.session_id.clone(),
161                iteration,
162            };
163            event_handler(iter_event.clone());
164            events.push(iter_event);
165
166            let provider_request = ProviderRequest {
167                run_id: input.run_id.clone(),
168                session_id: input.session_id.clone(),
169                iteration,
170                messages: messages.clone(),
171                tools: self.tools.definitions(),
172                state: state.clone(),
173            };
174
175            if let Err(err) = self.run_before_model(&provider_request) {
176                stop_reason = RunStopReason::BlockedByPolicy;
177                let err_event = AgentEvent::RunErrored {
178                    run_id: input.run_id.clone(),
179                    session_id: input.session_id.clone(),
180                    error: err.to_string(),
181                };
182                event_handler(err_event.clone());
183                events.push(err_event);
184                break;
185            }
186
187            let model_turn = match self.provider.complete(&provider_request) {
188                Ok(turn) => turn,
189                Err(err) => {
190                    stop_reason = RunStopReason::Error;
191                    let err_event = AgentEvent::RunErrored {
192                        run_id: input.run_id.clone(),
193                        session_id: input.session_id.clone(),
194                        error: err.to_string(),
195                    };
196                    event_handler(err_event.clone());
197                    events.push(err_event);
198                    break;
199                }
200            };
201
202            if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
203                stop_reason = RunStopReason::BlockedByPolicy;
204                let err_event = AgentEvent::RunErrored {
205                    run_id: input.run_id.clone(),
206                    session_id: input.session_id.clone(),
207                    error: err.to_string(),
208                };
209                event_handler(err_event.clone());
210                events.push(err_event);
211                break;
212            }
213
214            let output_event = AgentEvent::ModelOutput {
215                run_id: input.run_id.clone(),
216                session_id: input.session_id.clone(),
217                iteration,
218                stop_reason: model_turn.stop_reason,
219                directive_count: model_turn.directives.len(),
220            };
221            event_handler(output_event.clone());
222            events.push(output_event);
223
224            let mut requested_tool = false;
225
226            for directive in model_turn.directives {
227                match directive {
228                    ModelDirective::Text { delta } => {
229                        let delta_event = AgentEvent::TextDelta {
230                            run_id: input.run_id.clone(),
231                            session_id: input.session_id.clone(),
232                            iteration,
233                            delta: delta.clone(),
234                        };
235                        event_handler(delta_event.clone());
236                        events.push(delta_event);
237                        messages.push(ChatMessage::assistant(delta));
238                    }
239                    ModelDirective::ToolCall { call } => {
240                        requested_tool = true;
241                        let tc_event = AgentEvent::ToolCallRequested {
242                            run_id: input.run_id.clone(),
243                            session_id: input.session_id.clone(),
244                            iteration,
245                            call: call.clone(),
246                        };
247                        event_handler(tc_event.clone());
248                        events.push(tc_event);
249
250                        let context = ToolContext {
251                            run_id: input.run_id.clone(),
252                            session_id: input.session_id.clone(),
253                            iteration,
254                        };
255
256                        if let Err(err) = self.run_pre_tool(&context, &call) {
257                            stop_reason = RunStopReason::BlockedByPolicy;
258                            let err_event = AgentEvent::ToolCallFailed {
259                                run_id: input.run_id.clone(),
260                                session_id: input.session_id.clone(),
261                                iteration,
262                                call_id: call.call_id.clone(),
263                                tool_name: call.tool_name.clone(),
264                                error: err.to_string(),
265                            };
266                            event_handler(err_event.clone());
267                            events.push(err_event);
268                            break;
269                        }
270
271                        let Some(tool) = self.tools.get(&call.tool_name) else {
272                            stop_reason = RunStopReason::Error;
273                            let err_event = AgentEvent::ToolCallFailed {
274                                run_id: input.run_id.clone(),
275                                session_id: input.session_id.clone(),
276                                iteration,
277                                call_id: call.call_id.clone(),
278                                tool_name: call.tool_name.clone(),
279                                error: format!(
280                                    "{}",
281                                    CoreError::ToolNotFound {
282                                        tool_name: call.tool_name.clone(),
283                                    }
284                                ),
285                            };
286                            event_handler(err_event.clone());
287                            events.push(err_event);
288                            break;
289                        };
290
291                        match tool.execute(&call, &context) {
292                            Ok(result) => {
293                                if let Some(patch) = &result.state_patch {
294                                    match state.apply_patch(patch) {
295                                        Ok(()) => {
296                                            let patch_event = AgentEvent::StatePatched {
297                                                run_id: input.run_id.clone(),
298                                                session_id: input.session_id.clone(),
299                                                iteration,
300                                                patch: patch.clone(),
301                                                revision: state.revision,
302                                            };
303                                            event_handler(patch_event.clone());
304                                            events.push(patch_event);
305                                        }
306                                        Err(err) => {
307                                            stop_reason = RunStopReason::Error;
308                                            let err_event = AgentEvent::ToolCallFailed {
309                                                run_id: input.run_id.clone(),
310                                                session_id: input.session_id.clone(),
311                                                iteration,
312                                                call_id: call.call_id.clone(),
313                                                tool_name: call.tool_name.clone(),
314                                                error: err.to_string(),
315                                            };
316                                            event_handler(err_event.clone());
317                                            events.push(err_event);
318                                            break;
319                                        }
320                                    }
321                                }
322
323                                if let Err(err) = self.run_post_tool(&context, &result) {
324                                    stop_reason = RunStopReason::BlockedByPolicy;
325                                    let err_event = AgentEvent::ToolCallFailed {
326                                        run_id: input.run_id.clone(),
327                                        session_id: input.session_id.clone(),
328                                        iteration,
329                                        call_id: call.call_id.clone(),
330                                        tool_name: call.tool_name.clone(),
331                                        error: err.to_string(),
332                                    };
333                                    event_handler(err_event.clone());
334                                    events.push(err_event);
335                                    break;
336                                }
337
338                                let completed_event = AgentEvent::ToolCallCompleted {
339                                    run_id: input.run_id.clone(),
340                                    session_id: input.session_id.clone(),
341                                    iteration,
342                                    result: ToolResultSummary::from(&result),
343                                };
344                                event_handler(completed_event.clone());
345                                events.push(completed_event);
346
347                                messages.push(ChatMessage::tool_result(
348                                    &result.call_id,
349                                    serde_json::to_string(&result.output)
350                                        .unwrap_or_else(|_| "{}".to_string()),
351                                ));
352                            }
353                            Err(err) => {
354                                stop_reason = RunStopReason::Error;
355                                let err_event = AgentEvent::ToolCallFailed {
356                                    run_id: input.run_id.clone(),
357                                    session_id: input.session_id.clone(),
358                                    iteration,
359                                    call_id: call.call_id.clone(),
360                                    tool_name: call.tool_name.clone(),
361                                    error: err.to_string(),
362                                };
363                                event_handler(err_event.clone());
364                                events.push(err_event);
365                                break;
366                            }
367                        }
368                    }
369                    ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
370                        Ok(()) => {
371                            let patch_event = AgentEvent::StatePatched {
372                                run_id: input.run_id.clone(),
373                                session_id: input.session_id.clone(),
374                                iteration,
375                                patch: patch.clone(),
376                                revision: state.revision,
377                            };
378                            event_handler(patch_event.clone());
379                            events.push(patch_event);
380                        }
381                        Err(err) => {
382                            stop_reason = RunStopReason::Error;
383                            let err_event = AgentEvent::RunErrored {
384                                run_id: input.run_id.clone(),
385                                session_id: input.session_id.clone(),
386                                error: err.to_string(),
387                            };
388                            event_handler(err_event.clone());
389                            events.push(err_event);
390                            break;
391                        }
392                    },
393                    ModelDirective::FinalAnswer { text } => {
394                        final_answer = Some(text.clone());
395                        let delta_event = AgentEvent::TextDelta {
396                            run_id: input.run_id.clone(),
397                            session_id: input.session_id.clone(),
398                            iteration,
399                            delta: text.clone(),
400                        };
401                        event_handler(delta_event.clone());
402                        events.push(delta_event);
403                        messages.push(ChatMessage::assistant(text));
404                    }
405                }
406            }
407
408            if matches!(
409                stop_reason,
410                RunStopReason::Error | RunStopReason::BlockedByPolicy
411            ) {
412                break;
413            }
414
415            match model_turn.stop_reason {
416                ModelStopReason::EndTurn => {
417                    stop_reason = RunStopReason::Completed;
418                    break;
419                }
420                ModelStopReason::NeedsUser => {
421                    stop_reason = RunStopReason::NeedsUser;
422                    break;
423                }
424                ModelStopReason::Safety => {
425                    stop_reason = RunStopReason::BlockedByPolicy;
426                    break;
427                }
428                ModelStopReason::ToolUse => {
429                    if !requested_tool {
430                        stop_reason = RunStopReason::Error;
431                        let err_event = AgentEvent::RunErrored {
432                            run_id: input.run_id.clone(),
433                            session_id: input.session_id.clone(),
434                            error: "model requested tool_use stop reason without tool call"
435                                .to_string(),
436                        };
437                        event_handler(err_event.clone());
438                        events.push(err_event);
439                        break;
440                    }
441                }
442                ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
443                    if !requested_tool {
444                        stop_reason = RunStopReason::Error;
445                        let err_event = AgentEvent::RunErrored {
446                            run_id: input.run_id.clone(),
447                            session_id: input.session_id.clone(),
448                            error: "model returned non-terminal stop reason without tool call"
449                                .to_string(),
450                        };
451                        event_handler(err_event.clone());
452                        events.push(err_event);
453                        break;
454                    }
455                }
456            }
457        }
458
459        if total_iterations == self.config.max_iterations
460            && stop_reason == RunStopReason::BudgetExceeded
461        {
462            let err_event = AgentEvent::RunErrored {
463                run_id: input.run_id.clone(),
464                session_id: input.session_id.clone(),
465                error: "max iteration budget exceeded".to_string(),
466            };
467            event_handler(err_event.clone());
468            events.push(err_event);
469        }
470
471        let finished_event = AgentEvent::RunFinished {
472            run_id: input.run_id.clone(),
473            session_id: input.session_id.clone(),
474            reason: stop_reason,
475            total_iterations,
476            final_answer: final_answer.clone(),
477        };
478        event_handler(finished_event.clone());
479        events.push(finished_event);
480
481        let output = RunOutput {
482            run_id: input.run_id,
483            session_id: input.session_id,
484            events,
485            messages,
486            state,
487            reason: stop_reason,
488            final_answer,
489        };
490
491        let _ = self
492            .middlewares
493            .iter()
494            .try_for_each(|middleware| middleware.on_run_finished(&output));
495
496        output
497    }
498
499    fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
500        self.middlewares
501            .iter()
502            .try_for_each(|middleware| middleware.before_model_call(request))
503    }
504
505    fn run_after_model(
506        &self,
507        request: &ProviderRequest,
508        response: &ModelTurn,
509    ) -> Result<(), CoreError> {
510        self.middlewares
511            .iter()
512            .try_for_each(|middleware| middleware.after_model_call(request, response))
513    }
514
515    fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
516        self.middlewares
517            .iter()
518            .try_for_each(|middleware| middleware.pre_tool_call(context, call))
519    }
520
521    fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
522        self.middlewares
523            .iter()
524            .try_for_each(|middleware| middleware.post_tool_call(context, result))
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::protocol::{
532        ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
533    };
534    use serde_json::json;
535    use std::sync::Mutex;
536
537    struct ScriptedProvider {
538        turns: Vec<ModelTurn>,
539        cursor: Mutex<usize>,
540    }
541
542    impl Provider for ScriptedProvider {
543        fn name(&self) -> &str {
544            "scripted"
545        }
546
547        fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
548            let mut cursor = self
549                .cursor
550                .lock()
551                .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
552            let idx = *cursor;
553            let Some(turn) = self.turns.get(idx) else {
554                return Err(CoreError::Provider("no scripted turn left".to_string()));
555            };
556            *cursor += 1;
557            Ok(turn.clone())
558        }
559    }
560
561    struct EchoTool;
562
563    impl Tool for EchoTool {
564        fn definition(&self) -> ToolDefinition {
565            ToolDefinition {
566                name: "echo".to_string(),
567                description: "Echoes the provided value".to_string(),
568                input_schema: json!({
569                    "type": "object",
570                    "properties": { "value": { "type": "string" } },
571                    "required": ["value"]
572                }),
573                title: None,
574                output_schema: None,
575                annotations: None,
576                category: None,
577                tags: Vec::new(),
578                timeout_secs: None,
579            }
580        }
581
582        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
583            let value = call
584                .input
585                .get("value")
586                .cloned()
587                .unwrap_or_else(|| json!(null));
588            Ok(ToolResult {
589                call_id: call.call_id.clone(),
590                tool_name: call.tool_name.clone(),
591                output: json!({ "echo": value.clone() }),
592                content: None,
593                is_error: false,
594                state_patch: Some(StatePatch {
595                    format: StatePatchFormat::MergePatch,
596                    patch: json!({ "last_echo": value }),
597                    source: StatePatchSource::Tool,
598                }),
599            })
600        }
601    }
602
603    #[test]
604    fn orchestrator_runs_tool_then_finishes() {
605        let provider = ScriptedProvider {
606            turns: vec![
607                ModelTurn {
608                    directives: vec![ModelDirective::ToolCall {
609                        call: ToolCall {
610                            call_id: "call-1".to_string(),
611                            tool_name: "echo".to_string(),
612                            input: json!({ "value": "hello" }),
613                        },
614                    }],
615                    stop_reason: ModelStopReason::ToolUse,
616                },
617                ModelTurn {
618                    directives: vec![ModelDirective::FinalAnswer {
619                        text: "done".to_string(),
620                    }],
621                    stop_reason: ModelStopReason::EndTurn,
622                },
623            ],
624            cursor: Mutex::new(0),
625        };
626
627        let mut tools = ToolRegistry::default();
628        tools.register(EchoTool);
629
630        let orchestrator = Orchestrator::new(
631            Arc::new(provider),
632            tools,
633            Vec::new(),
634            OrchestratorConfig { max_iterations: 4 },
635        );
636
637        let output = orchestrator.run(
638            RunInput {
639                run_id: "run-1".to_string(),
640                session_id: "session-1".to_string(),
641                messages: vec![ChatMessage::user("test")],
642                state: AppState::default(),
643            },
644            |_| {},
645        );
646
647        assert_eq!(output.reason, RunStopReason::Completed);
648        assert_eq!(output.final_answer.as_deref(), Some("done"));
649        assert_eq!(output.state.revision, 1);
650        assert_eq!(output.state.data["last_echo"], "hello");
651
652        assert!(output
653            .events
654            .iter()
655            .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. })));
656        assert!(output.events.iter().any(|event| matches!(
657            event,
658            AgentEvent::RunFinished {
659                reason: RunStopReason::Completed,
660                ..
661            }
662        )));
663    }
664
665    #[test]
666    fn provider_error_stops_run() {
667        struct FailProvider;
668        impl Provider for FailProvider {
669            fn name(&self) -> &str {
670                "fail"
671            }
672            fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
673                Err(CoreError::Provider("connection refused".to_string()))
674            }
675        }
676
677        let orchestrator = Orchestrator::new(
678            Arc::new(FailProvider),
679            ToolRegistry::default(),
680            Vec::new(),
681            OrchestratorConfig { max_iterations: 4 },
682        );
683
684        let output = orchestrator.run(
685            RunInput {
686                run_id: "run-1".to_string(),
687                session_id: "s1".to_string(),
688                messages: vec![ChatMessage::user("test")],
689                state: AppState::default(),
690            },
691            |_| {},
692        );
693
694        assert_eq!(output.reason, RunStopReason::Error);
695        assert!(output
696            .events
697            .iter()
698            .any(|e| matches!(e, AgentEvent::RunErrored { .. })));
699    }
700
701    #[test]
702    fn tool_not_found_stops_run() {
703        let provider = ScriptedProvider {
704            turns: vec![ModelTurn {
705                directives: vec![ModelDirective::ToolCall {
706                    call: ToolCall {
707                        call_id: "c1".to_string(),
708                        tool_name: "nonexistent".to_string(),
709                        input: json!({}),
710                    },
711                }],
712                stop_reason: ModelStopReason::ToolUse,
713            }],
714            cursor: Mutex::new(0),
715        };
716
717        let orchestrator = Orchestrator::new(
718            Arc::new(provider),
719            ToolRegistry::default(),
720            Vec::new(),
721            OrchestratorConfig { max_iterations: 4 },
722        );
723
724        let output = orchestrator.run(
725            RunInput {
726                run_id: "run-1".to_string(),
727                session_id: "s1".to_string(),
728                messages: vec![ChatMessage::user("test")],
729                state: AppState::default(),
730            },
731            |_| {},
732        );
733
734        assert_eq!(output.reason, RunStopReason::Error);
735        assert!(output
736            .events
737            .iter()
738            .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. })));
739    }
740
741    #[test]
742    fn middleware_blocks_model_call() {
743        struct BlockMiddleware;
744        impl Middleware for BlockMiddleware {
745            fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
746                Err(CoreError::Middleware("blocked by policy".to_string()))
747            }
748        }
749
750        let provider = ScriptedProvider {
751            turns: vec![ModelTurn {
752                directives: vec![ModelDirective::Text {
753                    delta: "hi".to_string(),
754                }],
755                stop_reason: ModelStopReason::EndTurn,
756            }],
757            cursor: Mutex::new(0),
758        };
759
760        let orchestrator = Orchestrator::new(
761            Arc::new(provider),
762            ToolRegistry::default(),
763            vec![Arc::new(BlockMiddleware)],
764            OrchestratorConfig { max_iterations: 4 },
765        );
766
767        let output = orchestrator.run(
768            RunInput {
769                run_id: "run-1".to_string(),
770                session_id: "s1".to_string(),
771                messages: vec![ChatMessage::user("test")],
772                state: AppState::default(),
773            },
774            |_| {},
775        );
776
777        assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
778    }
779
780    #[test]
781    fn budget_exceeded_when_iterations_exhausted() {
782        // Provider always returns ToolUse but no tool call directives → continues loop
783        // Actually, we need it to keep looping. Use a tool that works, but provider
784        // always asks for more.
785        let provider = ScriptedProvider {
786            turns: vec![
787                ModelTurn {
788                    directives: vec![ModelDirective::ToolCall {
789                        call: ToolCall {
790                            call_id: "c1".to_string(),
791                            tool_name: "echo".to_string(),
792                            input: json!({"value": "1"}),
793                        },
794                    }],
795                    stop_reason: ModelStopReason::ToolUse,
796                },
797                ModelTurn {
798                    directives: vec![ModelDirective::ToolCall {
799                        call: ToolCall {
800                            call_id: "c2".to_string(),
801                            tool_name: "echo".to_string(),
802                            input: json!({"value": "2"}),
803                        },
804                    }],
805                    stop_reason: ModelStopReason::ToolUse,
806                },
807                // Only 2 turns, but max_iterations = 2, so it exhausts budget
808                // 3rd iteration will fail because no more scripted turns
809            ],
810            cursor: Mutex::new(0),
811        };
812
813        let mut tools = ToolRegistry::default();
814        tools.register(EchoTool);
815
816        let orchestrator = Orchestrator::new(
817            Arc::new(provider),
818            tools,
819            Vec::new(),
820            OrchestratorConfig { max_iterations: 2 },
821        );
822
823        let output = orchestrator.run(
824            RunInput {
825                run_id: "run-1".to_string(),
826                session_id: "s1".to_string(),
827                messages: vec![ChatMessage::user("test")],
828                state: AppState::default(),
829            },
830            |_| {},
831        );
832
833        assert_eq!(output.reason, RunStopReason::BudgetExceeded);
834    }
835
836    #[test]
837    fn text_only_response_completes() {
838        let provider = ScriptedProvider {
839            turns: vec![ModelTurn {
840                directives: vec![ModelDirective::Text {
841                    delta: "Hello, world!".to_string(),
842                }],
843                stop_reason: ModelStopReason::EndTurn,
844            }],
845            cursor: Mutex::new(0),
846        };
847
848        let orchestrator = Orchestrator::new(
849            Arc::new(provider),
850            ToolRegistry::default(),
851            Vec::new(),
852            OrchestratorConfig { max_iterations: 4 },
853        );
854
855        let output = orchestrator.run(
856            RunInput {
857                run_id: "run-1".to_string(),
858                session_id: "s1".to_string(),
859                messages: vec![ChatMessage::user("hi")],
860                state: AppState::default(),
861            },
862            |_| {},
863        );
864
865        assert_eq!(output.reason, RunStopReason::Completed);
866        assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
867    }
868
869    #[test]
870    fn event_handler_receives_all_events() {
871        let provider = ScriptedProvider {
872            turns: vec![ModelTurn {
873                directives: vec![ModelDirective::FinalAnswer {
874                    text: "done".to_string(),
875                }],
876                stop_reason: ModelStopReason::EndTurn,
877            }],
878            cursor: Mutex::new(0),
879        };
880
881        let orchestrator = Orchestrator::new(
882            Arc::new(provider),
883            ToolRegistry::default(),
884            Vec::new(),
885            OrchestratorConfig { max_iterations: 4 },
886        );
887
888        let received = Arc::new(Mutex::new(Vec::new()));
889        let received_clone = received.clone();
890
891        orchestrator.run(
892            RunInput {
893                run_id: "run-1".to_string(),
894                session_id: "s1".to_string(),
895                messages: vec![ChatMessage::user("test")],
896                state: AppState::default(),
897            },
898            move |event| {
899                received_clone.lock().unwrap().push(event);
900            },
901        );
902
903        let events = received.lock().unwrap();
904        assert!(events.len() >= 4); // RunStarted, IterationStarted, ModelOutput, TextDelta, RunFinished
905        assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
906        assert!(matches!(
907            events.last().unwrap(),
908            AgentEvent::RunFinished { .. }
909        ));
910    }
911
912    #[test]
913    fn tool_result_includes_call_id() {
914        let provider = ScriptedProvider {
915            turns: vec![
916                ModelTurn {
917                    directives: vec![ModelDirective::ToolCall {
918                        call: ToolCall {
919                            call_id: "my-call-id".to_string(),
920                            tool_name: "echo".to_string(),
921                            input: json!({"value": "test"}),
922                        },
923                    }],
924                    stop_reason: ModelStopReason::ToolUse,
925                },
926                ModelTurn {
927                    directives: vec![ModelDirective::FinalAnswer {
928                        text: "ok".to_string(),
929                    }],
930                    stop_reason: ModelStopReason::EndTurn,
931                },
932            ],
933            cursor: Mutex::new(0),
934        };
935
936        let mut tools = ToolRegistry::default();
937        tools.register(EchoTool);
938
939        let orchestrator = Orchestrator::new(
940            Arc::new(provider),
941            tools,
942            Vec::new(),
943            OrchestratorConfig { max_iterations: 4 },
944        );
945
946        let output = orchestrator.run(
947            RunInput {
948                run_id: "run-1".to_string(),
949                session_id: "s1".to_string(),
950                messages: vec![ChatMessage::user("test")],
951                state: AppState::default(),
952            },
953            |_| {},
954        );
955
956        // Verify tool result message has the correct call_id
957        let tool_msg = output
958            .messages
959            .iter()
960            .find(|m| m.role == crate::protocol::Role::Tool)
961            .expect("should have tool message");
962        assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
963    }
964}