Skip to main content

arcan_core/
runtime.rs

1use crate::context::{ContextConfig, compact_messages};
2use crate::error::CoreError;
3use crate::protocol::{
4    AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
5    ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
6};
7use crate::state::AppState;
8use std::collections::BTreeMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12/// Hook trait for approval gates, allowing the agent loop to wire event handlers
13/// into the gate without depending on the concrete `ApprovalGate` type.
14pub trait ApprovalGateHook: Send + Sync {
15    fn set_event_handler(&self, handler: Arc<dyn Fn(AgentEvent) + Send + Sync>);
16    fn clear_event_handler(&self);
17}
18
19/// Trait for resolving pending approvals from HTTP endpoints.
20pub trait ApprovalResolver: Send + Sync {
21    fn resolve_approval(&self, approval_id: &str, decision: &str, reason: Option<String>) -> bool;
22    fn pending_approval_ids(&self) -> Vec<String>;
23}
24
25#[derive(Debug, Clone)]
26pub struct ProviderRequest {
27    pub run_id: String,
28    pub session_id: String,
29    pub iteration: u32,
30    pub messages: Vec<ChatMessage>,
31    pub tools: Vec<ToolDefinition>,
32    pub state: AppState,
33}
34
35pub trait Provider: Send + Sync {
36    fn name(&self) -> &str;
37    fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
38
39    /// Whether this provider supports streaming completions.
40    fn supports_streaming(&self) -> bool {
41        false
42    }
43
44    /// Stream a completion, calling `on_text` for each text delta as it arrives.
45    /// Returns the final assembled `ModelTurn`. Default falls back to `complete()`.
46    fn complete_streaming(
47        &self,
48        request: &ProviderRequest,
49        _on_text: &dyn Fn(&str),
50    ) -> Result<ModelTurn, CoreError> {
51        self.complete(request)
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct ToolContext {
57    pub run_id: String,
58    pub session_id: String,
59    pub iteration: u32,
60}
61
62pub trait Tool: Send + Sync {
63    fn definition(&self) -> ToolDefinition;
64    fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
65}
66
67pub trait Middleware: Send + Sync {
68    fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
69        Ok(())
70    }
71
72    fn after_model_call(
73        &self,
74        _request: &ProviderRequest,
75        _response: &ModelTurn,
76    ) -> Result<(), CoreError> {
77        Ok(())
78    }
79
80    fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
81        Ok(())
82    }
83
84    fn post_tool_call(
85        &self,
86        _context: &ToolContext,
87        _result: &ToolResult,
88    ) -> Result<(), CoreError> {
89        Ok(())
90    }
91
92    fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
93        Ok(())
94    }
95}
96
97#[derive(Clone, Default)]
98pub struct ToolRegistry {
99    tools: BTreeMap<String, Arc<dyn Tool>>,
100}
101
102impl ToolRegistry {
103    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
104        self.tools
105            .insert(tool.definition().name.clone(), Arc::new(tool));
106    }
107
108    pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
109        self.tools.get(tool_name).cloned()
110    }
111
112    pub fn definitions(&self) -> Vec<ToolDefinition> {
113        self.tools.values().map(|tool| tool.definition()).collect()
114    }
115}
116
117#[derive(Debug, Clone)]
118pub struct OrchestratorConfig {
119    pub max_iterations: u32,
120    /// Context window management configuration.
121    /// When set, messages are compacted before each provider call to stay within limits.
122    pub context: Option<ContextConfig>,
123    /// Context compiler configuration for assembling system prompts from typed blocks.
124    /// When set, context blocks are compiled into system messages with budget enforcement.
125    pub context_compiler: Option<crate::context_compiler::ContextCompilerConfig>,
126}
127
128impl Default for OrchestratorConfig {
129    fn default() -> Self {
130        Self {
131            max_iterations: 24,
132            context: Some(ContextConfig::default()),
133            context_compiler: None,
134        }
135    }
136}
137
138#[derive(Debug, Clone)]
139pub struct RunInput {
140    pub run_id: String,
141    pub session_id: String,
142    pub branch_id: String,
143    pub messages: Vec<ChatMessage>,
144    pub state: AppState,
145}
146
147#[derive(Debug, Clone)]
148pub struct RunOutput {
149    pub run_id: String,
150    pub session_id: String,
151    pub branch_id: String,
152    pub events: Vec<AgentEvent>,
153    pub messages: Vec<ChatMessage>,
154    pub state: AppState,
155    pub reason: RunStopReason,
156    pub final_answer: Option<String>,
157    /// Accumulated token usage across all iterations.
158    pub total_usage: TokenUsage,
159}
160
161pub struct Orchestrator {
162    provider: Arc<std::sync::RwLock<Arc<dyn Provider>>>,
163    tools: ToolRegistry,
164    middlewares: Vec<Arc<dyn Middleware>>,
165    config: OrchestratorConfig,
166}
167
168impl Orchestrator {
169    pub fn new(
170        provider: Arc<dyn Provider>,
171        tools: ToolRegistry,
172        middlewares: Vec<Arc<dyn Middleware>>,
173        config: OrchestratorConfig,
174    ) -> Self {
175        Self {
176            provider: Arc::new(std::sync::RwLock::new(provider)),
177            tools,
178            middlewares,
179            config,
180        }
181    }
182
183    /// Swap the active provider at runtime. Returns the name of the new provider.
184    pub fn swap_provider(&self, new_provider: Arc<dyn Provider>) -> String {
185        let name = new_provider.name().to_string();
186        let mut guard = self.provider.write().expect("provider lock poisoned");
187        *guard = new_provider;
188        name
189    }
190
191    /// Get the current provider name.
192    pub fn provider_name(&self) -> String {
193        let guard = self.provider.read().expect("provider lock poisoned");
194        guard.name().to_string()
195    }
196
197    pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
198        self.run_cancellable(input, None, event_handler)
199    }
200
201    /// Run the orchestrator loop with an optional cancellation flag.
202    ///
203    /// If `cancel` is provided and set to `true` during execution,
204    /// the loop will stop at the next iteration boundary.
205    pub fn run_cancellable(
206        &self,
207        input: RunInput,
208        cancel: Option<&Arc<AtomicBool>>,
209        mut event_handler: impl FnMut(AgentEvent),
210    ) -> RunOutput {
211        let mut events = Vec::new();
212        let mut messages = input.messages;
213        let mut state = input.state;
214        let mut final_answer: Option<String> = None;
215        let mut stop_reason = RunStopReason::BudgetExceeded;
216        let mut total_iterations = 0;
217        let mut total_usage = TokenUsage::default();
218
219        // Acquire provider reference for this run
220        let provider = self
221            .provider
222            .read()
223            .expect("provider lock poisoned")
224            .clone();
225
226        let start_event = AgentEvent::RunStarted {
227            run_id: input.run_id.clone(),
228            session_id: input.session_id.clone(),
229            provider: provider.name().to_string(),
230            max_iterations: self.config.max_iterations,
231        };
232        event_handler(start_event.clone());
233        events.push(start_event);
234
235        for iteration in 1..=self.config.max_iterations {
236            // Check cancellation at each iteration boundary
237            if let Some(flag) = cancel {
238                if flag.load(Ordering::Relaxed) {
239                    stop_reason = RunStopReason::Cancelled;
240                    let err_event = AgentEvent::RunErrored {
241                        run_id: input.run_id.clone(),
242                        session_id: input.session_id.clone(),
243                        error: "run cancelled".to_string(),
244                    };
245                    event_handler(err_event.clone());
246                    events.push(err_event);
247                    break;
248                }
249            }
250
251            total_iterations = iteration;
252            let iter_event = AgentEvent::IterationStarted {
253                run_id: input.run_id.clone(),
254                session_id: input.session_id.clone(),
255                iteration,
256            };
257            event_handler(iter_event.clone());
258            events.push(iter_event);
259
260            // Context window compaction: trim messages before sending to provider
261            if let Some(ref ctx_config) = self.config.context {
262                if let Some(result) = compact_messages(&messages, ctx_config) {
263                    let compact_event = AgentEvent::ContextCompacted {
264                        run_id: input.run_id.clone(),
265                        session_id: input.session_id.clone(),
266                        iteration,
267                        dropped_count: result.dropped_count,
268                        tokens_before: result.tokens_before,
269                        tokens_after: result.tokens_after,
270                    };
271                    event_handler(compact_event.clone());
272                    events.push(compact_event);
273                    messages = result.messages;
274                }
275            }
276
277            let provider_request = ProviderRequest {
278                run_id: input.run_id.clone(),
279                session_id: input.session_id.clone(),
280                iteration,
281                messages: messages.clone(),
282                tools: self.tools.definitions(),
283                state: state.clone(),
284            };
285
286            if let Err(err) = self.run_before_model(&provider_request) {
287                stop_reason = RunStopReason::BlockedByPolicy;
288                let err_event = AgentEvent::RunErrored {
289                    run_id: input.run_id.clone(),
290                    session_id: input.session_id.clone(),
291                    error: err.to_string(),
292                };
293                event_handler(err_event.clone());
294                events.push(err_event);
295                break;
296            }
297
298            let model_turn = match provider.complete(&provider_request) {
299                Ok(turn) => turn,
300                Err(err) => {
301                    stop_reason = RunStopReason::Error;
302                    let err_event = AgentEvent::RunErrored {
303                        run_id: input.run_id.clone(),
304                        session_id: input.session_id.clone(),
305                        error: err.to_string(),
306                    };
307                    event_handler(err_event.clone());
308                    events.push(err_event);
309                    break;
310                }
311            };
312
313            if let Err(err) = self.run_after_model(&provider_request, &model_turn) {
314                stop_reason = RunStopReason::BlockedByPolicy;
315                let err_event = AgentEvent::RunErrored {
316                    run_id: input.run_id.clone(),
317                    session_id: input.session_id.clone(),
318                    error: err.to_string(),
319                };
320                event_handler(err_event.clone());
321                events.push(err_event);
322                break;
323            }
324
325            // Accumulate token usage if reported
326            if let Some(ref usage) = model_turn.usage {
327                total_usage.accumulate(usage);
328            }
329
330            let output_event = AgentEvent::ModelOutput {
331                run_id: input.run_id.clone(),
332                session_id: input.session_id.clone(),
333                iteration,
334                stop_reason: model_turn.stop_reason,
335                directive_count: model_turn.directives.len(),
336                usage: model_turn.usage,
337            };
338            event_handler(output_event.clone());
339            events.push(output_event);
340
341            let mut requested_tool = false;
342
343            for directive in model_turn.directives {
344                match directive {
345                    ModelDirective::Text { delta } => {
346                        let delta_event = AgentEvent::TextDelta {
347                            run_id: input.run_id.clone(),
348                            session_id: input.session_id.clone(),
349                            iteration,
350                            delta: delta.clone(),
351                        };
352                        event_handler(delta_event.clone());
353                        events.push(delta_event);
354                        messages.push(ChatMessage::assistant(delta));
355                    }
356                    ModelDirective::ToolCall { call } => {
357                        requested_tool = true;
358                        let tc_event = AgentEvent::ToolCallRequested {
359                            run_id: input.run_id.clone(),
360                            session_id: input.session_id.clone(),
361                            iteration,
362                            call: call.clone(),
363                        };
364                        event_handler(tc_event.clone());
365                        events.push(tc_event);
366
367                        let context = ToolContext {
368                            run_id: input.run_id.clone(),
369                            session_id: input.session_id.clone(),
370                            iteration,
371                        };
372
373                        if let Err(err) = self.run_pre_tool(&context, &call) {
374                            stop_reason = RunStopReason::BlockedByPolicy;
375                            let err_event = AgentEvent::ToolCallFailed {
376                                run_id: input.run_id.clone(),
377                                session_id: input.session_id.clone(),
378                                iteration,
379                                call_id: call.call_id.clone(),
380                                tool_name: call.tool_name.clone(),
381                                error: err.to_string(),
382                            };
383                            event_handler(err_event.clone());
384                            events.push(err_event);
385                            break;
386                        }
387
388                        let Some(tool) = self.tools.get(&call.tool_name) else {
389                            stop_reason = RunStopReason::Error;
390                            let err_event = AgentEvent::ToolCallFailed {
391                                run_id: input.run_id.clone(),
392                                session_id: input.session_id.clone(),
393                                iteration,
394                                call_id: call.call_id.clone(),
395                                tool_name: call.tool_name.clone(),
396                                error: format!(
397                                    "{}",
398                                    CoreError::ToolNotFound {
399                                        tool_name: call.tool_name.clone(),
400                                    }
401                                ),
402                            };
403                            event_handler(err_event.clone());
404                            events.push(err_event);
405                            break;
406                        };
407
408                        match tool.execute(&call, &context) {
409                            Ok(result) => {
410                                if let Some(patch) = &result.state_patch {
411                                    match state.apply_patch(patch) {
412                                        Ok(()) => {
413                                            let patch_event = AgentEvent::StatePatched {
414                                                run_id: input.run_id.clone(),
415                                                session_id: input.session_id.clone(),
416                                                iteration,
417                                                patch: patch.clone(),
418                                                revision: state.revision,
419                                            };
420                                            event_handler(patch_event.clone());
421                                            events.push(patch_event);
422                                        }
423                                        Err(err) => {
424                                            stop_reason = RunStopReason::Error;
425                                            let err_event = AgentEvent::ToolCallFailed {
426                                                run_id: input.run_id.clone(),
427                                                session_id: input.session_id.clone(),
428                                                iteration,
429                                                call_id: call.call_id.clone(),
430                                                tool_name: call.tool_name.clone(),
431                                                error: err.to_string(),
432                                            };
433                                            event_handler(err_event.clone());
434                                            events.push(err_event);
435                                            break;
436                                        }
437                                    }
438                                }
439
440                                if let Err(err) = self.run_post_tool(&context, &result) {
441                                    stop_reason = RunStopReason::BlockedByPolicy;
442                                    let err_event = AgentEvent::ToolCallFailed {
443                                        run_id: input.run_id.clone(),
444                                        session_id: input.session_id.clone(),
445                                        iteration,
446                                        call_id: call.call_id.clone(),
447                                        tool_name: call.tool_name.clone(),
448                                        error: err.to_string(),
449                                    };
450                                    event_handler(err_event.clone());
451                                    events.push(err_event);
452                                    break;
453                                }
454
455                                let completed_event = AgentEvent::ToolCallCompleted {
456                                    run_id: input.run_id.clone(),
457                                    session_id: input.session_id.clone(),
458                                    iteration,
459                                    result: ToolResultSummary::from(&result),
460                                };
461                                event_handler(completed_event.clone());
462                                events.push(completed_event);
463
464                                messages.push(ChatMessage::tool_result(
465                                    &result.call_id,
466                                    serde_json::to_string(&result.output)
467                                        .unwrap_or_else(|_| "{}".to_string()),
468                                ));
469                            }
470                            Err(err) => {
471                                stop_reason = RunStopReason::Error;
472                                let err_event = AgentEvent::ToolCallFailed {
473                                    run_id: input.run_id.clone(),
474                                    session_id: input.session_id.clone(),
475                                    iteration,
476                                    call_id: call.call_id.clone(),
477                                    tool_name: call.tool_name.clone(),
478                                    error: err.to_string(),
479                                };
480                                event_handler(err_event.clone());
481                                events.push(err_event);
482                                break;
483                            }
484                        }
485                    }
486                    ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
487                        Ok(()) => {
488                            let patch_event = AgentEvent::StatePatched {
489                                run_id: input.run_id.clone(),
490                                session_id: input.session_id.clone(),
491                                iteration,
492                                patch: patch.clone(),
493                                revision: state.revision,
494                            };
495                            event_handler(patch_event.clone());
496                            events.push(patch_event);
497                        }
498                        Err(err) => {
499                            stop_reason = RunStopReason::Error;
500                            let err_event = AgentEvent::RunErrored {
501                                run_id: input.run_id.clone(),
502                                session_id: input.session_id.clone(),
503                                error: err.to_string(),
504                            };
505                            event_handler(err_event.clone());
506                            events.push(err_event);
507                            break;
508                        }
509                    },
510                    ModelDirective::FinalAnswer { text } => {
511                        final_answer = Some(text.clone());
512                        let delta_event = AgentEvent::TextDelta {
513                            run_id: input.run_id.clone(),
514                            session_id: input.session_id.clone(),
515                            iteration,
516                            delta: text.clone(),
517                        };
518                        event_handler(delta_event.clone());
519                        events.push(delta_event);
520                        messages.push(ChatMessage::assistant(text));
521                    }
522                }
523            }
524
525            if matches!(
526                stop_reason,
527                RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
528            ) {
529                break;
530            }
531
532            match model_turn.stop_reason {
533                ModelStopReason::EndTurn => {
534                    stop_reason = RunStopReason::Completed;
535                    break;
536                }
537                ModelStopReason::NeedsUser => {
538                    stop_reason = RunStopReason::NeedsUser;
539                    break;
540                }
541                ModelStopReason::Safety => {
542                    stop_reason = RunStopReason::BlockedByPolicy;
543                    break;
544                }
545                ModelStopReason::ToolUse => {
546                    if !requested_tool {
547                        stop_reason = RunStopReason::Error;
548                        let err_event = AgentEvent::RunErrored {
549                            run_id: input.run_id.clone(),
550                            session_id: input.session_id.clone(),
551                            error: "model requested tool_use stop reason without tool call"
552                                .to_string(),
553                        };
554                        event_handler(err_event.clone());
555                        events.push(err_event);
556                        break;
557                    }
558                }
559                ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
560                    if !requested_tool {
561                        stop_reason = RunStopReason::Error;
562                        let err_event = AgentEvent::RunErrored {
563                            run_id: input.run_id.clone(),
564                            session_id: input.session_id.clone(),
565                            error: "model returned non-terminal stop reason without tool call"
566                                .to_string(),
567                        };
568                        event_handler(err_event.clone());
569                        events.push(err_event);
570                        break;
571                    }
572                }
573            }
574        }
575
576        if total_iterations == self.config.max_iterations
577            && stop_reason == RunStopReason::BudgetExceeded
578        {
579            let err_event = AgentEvent::RunErrored {
580                run_id: input.run_id.clone(),
581                session_id: input.session_id.clone(),
582                error: "max iteration budget exceeded".to_string(),
583            };
584            event_handler(err_event.clone());
585            events.push(err_event);
586        }
587
588        let finished_event = AgentEvent::RunFinished {
589            run_id: input.run_id.clone(),
590            session_id: input.session_id.clone(),
591            reason: stop_reason,
592            total_iterations,
593            final_answer: final_answer.clone(),
594        };
595        event_handler(finished_event.clone());
596        events.push(finished_event);
597
598        let output = RunOutput {
599            run_id: input.run_id,
600            session_id: input.session_id,
601            branch_id: input.branch_id,
602            events,
603            messages,
604            state,
605            reason: stop_reason,
606            final_answer,
607            total_usage,
608        };
609
610        let _ = self
611            .middlewares
612            .iter()
613            .try_for_each(|middleware| middleware.on_run_finished(&output));
614
615        output
616    }
617
618    fn run_before_model(&self, request: &ProviderRequest) -> Result<(), CoreError> {
619        self.middlewares
620            .iter()
621            .try_for_each(|middleware| middleware.before_model_call(request))
622    }
623
624    fn run_after_model(
625        &self,
626        request: &ProviderRequest,
627        response: &ModelTurn,
628    ) -> Result<(), CoreError> {
629        self.middlewares
630            .iter()
631            .try_for_each(|middleware| middleware.after_model_call(request, response))
632    }
633
634    fn run_pre_tool(&self, context: &ToolContext, call: &ToolCall) -> Result<(), CoreError> {
635        self.middlewares
636            .iter()
637            .try_for_each(|middleware| middleware.pre_tool_call(context, call))
638    }
639
640    fn run_post_tool(&self, context: &ToolContext, result: &ToolResult) -> Result<(), CoreError> {
641        self.middlewares
642            .iter()
643            .try_for_each(|middleware| middleware.post_tool_call(context, result))
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use crate::protocol::{
651        ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
652    };
653    use serde_json::json;
654    use std::sync::Mutex;
655
656    struct ScriptedProvider {
657        turns: Vec<ModelTurn>,
658        cursor: Mutex<usize>,
659    }
660
661    impl Provider for ScriptedProvider {
662        fn name(&self) -> &str {
663            "scripted"
664        }
665
666        fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
667            let mut cursor = self
668                .cursor
669                .lock()
670                .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
671            let idx = *cursor;
672            let Some(turn) = self.turns.get(idx) else {
673                return Err(CoreError::Provider("no scripted turn left".to_string()));
674            };
675            *cursor += 1;
676            Ok(turn.clone())
677        }
678    }
679
680    struct EchoTool;
681
682    impl Tool for EchoTool {
683        fn definition(&self) -> ToolDefinition {
684            ToolDefinition {
685                name: "echo".to_string(),
686                description: "Echoes the provided value".to_string(),
687                input_schema: json!({
688                    "type": "object",
689                    "properties": { "value": { "type": "string" } },
690                    "required": ["value"]
691                }),
692                title: None,
693                output_schema: None,
694                annotations: None,
695                category: None,
696                tags: Vec::new(),
697                timeout_secs: None,
698            }
699        }
700
701        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
702            let value = call.input.get("value").cloned().unwrap_or(json!(null));
703            Ok(ToolResult {
704                call_id: call.call_id.clone(),
705                tool_name: call.tool_name.clone(),
706                output: json!({ "echo": value.clone() }),
707                content: None,
708                is_error: false,
709                state_patch: Some(StatePatch {
710                    format: StatePatchFormat::MergePatch,
711                    patch: json!({ "last_echo": value }),
712                    source: StatePatchSource::Tool,
713                }),
714            })
715        }
716    }
717
718    #[test]
719    fn orchestrator_runs_tool_then_finishes() {
720        let provider = ScriptedProvider {
721            turns: vec![
722                ModelTurn {
723                    directives: vec![ModelDirective::ToolCall {
724                        call: ToolCall {
725                            call_id: "call-1".to_string(),
726                            tool_name: "echo".to_string(),
727                            input: json!({ "value": "hello" }),
728                        },
729                    }],
730                    stop_reason: ModelStopReason::ToolUse,
731                    usage: None,
732                },
733                ModelTurn {
734                    directives: vec![ModelDirective::FinalAnswer {
735                        text: "done".to_string(),
736                    }],
737                    stop_reason: ModelStopReason::EndTurn,
738                    usage: None,
739                },
740            ],
741            cursor: Mutex::new(0),
742        };
743
744        let mut tools = ToolRegistry::default();
745        tools.register(EchoTool);
746
747        let orchestrator = Orchestrator::new(
748            Arc::new(provider),
749            tools,
750            Vec::new(),
751            OrchestratorConfig {
752                max_iterations: 4,
753                context: None,
754                context_compiler: None,
755            },
756        );
757
758        let output = orchestrator.run(
759            RunInput {
760                run_id: "run-1".to_string(),
761                session_id: "session-1".to_string(),
762                branch_id: "main".to_string(),
763                messages: vec![ChatMessage::user("test")],
764                state: AppState::default(),
765            },
766            |_| {},
767        );
768
769        assert_eq!(output.reason, RunStopReason::Completed);
770        assert_eq!(output.final_answer.as_deref(), Some("done"));
771        assert_eq!(output.state.revision, 1);
772        assert_eq!(output.state.data["last_echo"], "hello");
773
774        assert!(
775            output
776                .events
777                .iter()
778                .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. }))
779        );
780        assert!(output.events.iter().any(|event| matches!(
781            event,
782            AgentEvent::RunFinished {
783                reason: RunStopReason::Completed,
784                ..
785            }
786        )));
787    }
788
789    #[test]
790    fn provider_error_stops_run() {
791        struct FailProvider;
792        impl Provider for FailProvider {
793            fn name(&self) -> &str {
794                "fail"
795            }
796            fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
797                Err(CoreError::Provider("connection refused".to_string()))
798            }
799        }
800
801        let orchestrator = Orchestrator::new(
802            Arc::new(FailProvider),
803            ToolRegistry::default(),
804            Vec::new(),
805            OrchestratorConfig {
806                max_iterations: 4,
807                context: None,
808                context_compiler: None,
809            },
810        );
811
812        let output = orchestrator.run(
813            RunInput {
814                run_id: "run-1".to_string(),
815                session_id: "s1".to_string(),
816                branch_id: "main".to_string(),
817                messages: vec![ChatMessage::user("test")],
818                state: AppState::default(),
819            },
820            |_| {},
821        );
822
823        assert_eq!(output.reason, RunStopReason::Error);
824        assert!(
825            output
826                .events
827                .iter()
828                .any(|e| matches!(e, AgentEvent::RunErrored { .. }))
829        );
830    }
831
832    #[test]
833    fn tool_not_found_stops_run() {
834        let provider = ScriptedProvider {
835            turns: vec![ModelTurn {
836                directives: vec![ModelDirective::ToolCall {
837                    call: ToolCall {
838                        call_id: "c1".to_string(),
839                        tool_name: "nonexistent".to_string(),
840                        input: json!({}),
841                    },
842                }],
843                stop_reason: ModelStopReason::ToolUse,
844                usage: None,
845            }],
846            cursor: Mutex::new(0),
847        };
848
849        let orchestrator = Orchestrator::new(
850            Arc::new(provider),
851            ToolRegistry::default(),
852            Vec::new(),
853            OrchestratorConfig {
854                max_iterations: 4,
855                context: None,
856                context_compiler: None,
857            },
858        );
859
860        let output = orchestrator.run(
861            RunInput {
862                run_id: "run-1".to_string(),
863                session_id: "s1".to_string(),
864                branch_id: "main".to_string(),
865                messages: vec![ChatMessage::user("test")],
866                state: AppState::default(),
867            },
868            |_| {},
869        );
870
871        assert_eq!(output.reason, RunStopReason::Error);
872        assert!(
873            output
874                .events
875                .iter()
876                .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. }))
877        );
878    }
879
880    #[test]
881    fn middleware_blocks_model_call() {
882        struct BlockMiddleware;
883        impl Middleware for BlockMiddleware {
884            fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
885                Err(CoreError::Middleware("blocked by policy".to_string()))
886            }
887        }
888
889        let provider = ScriptedProvider {
890            turns: vec![ModelTurn {
891                directives: vec![ModelDirective::Text {
892                    delta: "hi".to_string(),
893                }],
894                stop_reason: ModelStopReason::EndTurn,
895                usage: None,
896            }],
897            cursor: Mutex::new(0),
898        };
899
900        let orchestrator = Orchestrator::new(
901            Arc::new(provider),
902            ToolRegistry::default(),
903            vec![Arc::new(BlockMiddleware)],
904            OrchestratorConfig {
905                max_iterations: 4,
906                context: None,
907                context_compiler: None,
908            },
909        );
910
911        let output = orchestrator.run(
912            RunInput {
913                run_id: "run-1".to_string(),
914                session_id: "s1".to_string(),
915                branch_id: "main".to_string(),
916                messages: vec![ChatMessage::user("test")],
917                state: AppState::default(),
918            },
919            |_| {},
920        );
921
922        assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
923    }
924
925    #[test]
926    fn budget_exceeded_when_iterations_exhausted() {
927        // Provider always returns ToolUse but no tool call directives → continues loop
928        // Actually, we need it to keep looping. Use a tool that works, but provider
929        // always asks for more.
930        let provider = ScriptedProvider {
931            turns: vec![
932                ModelTurn {
933                    directives: vec![ModelDirective::ToolCall {
934                        call: ToolCall {
935                            call_id: "c1".to_string(),
936                            tool_name: "echo".to_string(),
937                            input: json!({"value": "1"}),
938                        },
939                    }],
940                    stop_reason: ModelStopReason::ToolUse,
941                    usage: None,
942                },
943                ModelTurn {
944                    directives: vec![ModelDirective::ToolCall {
945                        call: ToolCall {
946                            call_id: "c2".to_string(),
947                            tool_name: "echo".to_string(),
948                            input: json!({"value": "2"}),
949                        },
950                    }],
951                    stop_reason: ModelStopReason::ToolUse,
952                    usage: None,
953                },
954                // Only 2 turns, but max_iterations = 2, so it exhausts budget
955                // 3rd iteration will fail because no more scripted turns
956            ],
957            cursor: Mutex::new(0),
958        };
959
960        let mut tools = ToolRegistry::default();
961        tools.register(EchoTool);
962
963        let orchestrator = Orchestrator::new(
964            Arc::new(provider),
965            tools,
966            Vec::new(),
967            OrchestratorConfig {
968                max_iterations: 2,
969                context: None,
970                context_compiler: None,
971            },
972        );
973
974        let output = orchestrator.run(
975            RunInput {
976                run_id: "run-1".to_string(),
977                session_id: "s1".to_string(),
978                branch_id: "main".to_string(),
979                messages: vec![ChatMessage::user("test")],
980                state: AppState::default(),
981            },
982            |_| {},
983        );
984
985        assert_eq!(output.reason, RunStopReason::BudgetExceeded);
986    }
987
988    #[test]
989    fn text_only_response_completes() {
990        let provider = ScriptedProvider {
991            turns: vec![ModelTurn {
992                directives: vec![ModelDirective::Text {
993                    delta: "Hello, world!".to_string(),
994                }],
995                stop_reason: ModelStopReason::EndTurn,
996                usage: None,
997            }],
998            cursor: Mutex::new(0),
999        };
1000
1001        let orchestrator = Orchestrator::new(
1002            Arc::new(provider),
1003            ToolRegistry::default(),
1004            Vec::new(),
1005            OrchestratorConfig {
1006                max_iterations: 4,
1007                context: None,
1008                context_compiler: None,
1009            },
1010        );
1011
1012        let output = orchestrator.run(
1013            RunInput {
1014                run_id: "run-1".to_string(),
1015                session_id: "s1".to_string(),
1016                branch_id: "main".to_string(),
1017                messages: vec![ChatMessage::user("hi")],
1018                state: AppState::default(),
1019            },
1020            |_| {},
1021        );
1022
1023        assert_eq!(output.reason, RunStopReason::Completed);
1024        assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
1025    }
1026
1027    #[test]
1028    fn event_handler_receives_all_events() {
1029        let provider = ScriptedProvider {
1030            turns: vec![ModelTurn {
1031                directives: vec![ModelDirective::FinalAnswer {
1032                    text: "done".to_string(),
1033                }],
1034                stop_reason: ModelStopReason::EndTurn,
1035                usage: None,
1036            }],
1037            cursor: Mutex::new(0),
1038        };
1039
1040        let orchestrator = Orchestrator::new(
1041            Arc::new(provider),
1042            ToolRegistry::default(),
1043            Vec::new(),
1044            OrchestratorConfig {
1045                max_iterations: 4,
1046                context: None,
1047                context_compiler: None,
1048            },
1049        );
1050
1051        let received = Arc::new(Mutex::new(Vec::new()));
1052        let received_clone = received.clone();
1053
1054        orchestrator.run(
1055            RunInput {
1056                run_id: "run-1".to_string(),
1057                session_id: "s1".to_string(),
1058                branch_id: "main".to_string(),
1059                messages: vec![ChatMessage::user("test")],
1060                state: AppState::default(),
1061            },
1062            move |event| {
1063                received_clone.lock().unwrap().push(event);
1064            },
1065        );
1066
1067        let events = received.lock().unwrap();
1068        assert!(events.len() >= 4); // RunStarted, IterationStarted, ModelOutput, TextDelta, RunFinished
1069        assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
1070        assert!(matches!(
1071            events.last().unwrap(),
1072            AgentEvent::RunFinished { .. }
1073        ));
1074    }
1075
1076    #[test]
1077    fn tool_result_includes_call_id() {
1078        let provider = ScriptedProvider {
1079            turns: vec![
1080                ModelTurn {
1081                    directives: vec![ModelDirective::ToolCall {
1082                        call: ToolCall {
1083                            call_id: "my-call-id".to_string(),
1084                            tool_name: "echo".to_string(),
1085                            input: json!({"value": "test"}),
1086                        },
1087                    }],
1088                    stop_reason: ModelStopReason::ToolUse,
1089                    usage: None,
1090                },
1091                ModelTurn {
1092                    directives: vec![ModelDirective::FinalAnswer {
1093                        text: "ok".to_string(),
1094                    }],
1095                    stop_reason: ModelStopReason::EndTurn,
1096                    usage: None,
1097                },
1098            ],
1099            cursor: Mutex::new(0),
1100        };
1101
1102        let mut tools = ToolRegistry::default();
1103        tools.register(EchoTool);
1104
1105        let orchestrator = Orchestrator::new(
1106            Arc::new(provider),
1107            tools,
1108            Vec::new(),
1109            OrchestratorConfig {
1110                max_iterations: 4,
1111                context: None,
1112                context_compiler: None,
1113            },
1114        );
1115
1116        let output = orchestrator.run(
1117            RunInput {
1118                run_id: "run-1".to_string(),
1119                session_id: "s1".to_string(),
1120                branch_id: "main".to_string(),
1121                messages: vec![ChatMessage::user("test")],
1122                state: AppState::default(),
1123            },
1124            |_| {},
1125        );
1126
1127        // Verify tool result message has the correct call_id
1128        let tool_msg = output
1129            .messages
1130            .iter()
1131            .find(|m| m.role == crate::protocol::Role::Tool)
1132            .expect("should have tool message");
1133        assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1134    }
1135
1136    #[test]
1137    fn cancellation_stops_run() {
1138        let provider = ScriptedProvider {
1139            turns: vec![
1140                ModelTurn {
1141                    directives: vec![ModelDirective::ToolCall {
1142                        call: ToolCall {
1143                            call_id: "c1".to_string(),
1144                            tool_name: "echo".to_string(),
1145                            input: json!({"value": "1"}),
1146                        },
1147                    }],
1148                    stop_reason: ModelStopReason::ToolUse,
1149                    usage: None,
1150                },
1151                ModelTurn {
1152                    directives: vec![ModelDirective::FinalAnswer {
1153                        text: "should not reach".to_string(),
1154                    }],
1155                    stop_reason: ModelStopReason::EndTurn,
1156                    usage: None,
1157                },
1158            ],
1159            cursor: Mutex::new(0),
1160        };
1161
1162        let mut tools = ToolRegistry::default();
1163        tools.register(EchoTool);
1164
1165        let orchestrator = Orchestrator::new(
1166            Arc::new(provider),
1167            tools,
1168            Vec::new(),
1169            OrchestratorConfig {
1170                max_iterations: 10,
1171                context: None,
1172                context_compiler: None,
1173            },
1174        );
1175
1176        // Set cancellation flag before the second iteration
1177        let cancel = Arc::new(AtomicBool::new(false));
1178        let cancel_clone = cancel.clone();
1179        let call_count = Arc::new(Mutex::new(0u32));
1180        let call_count_clone = call_count.clone();
1181
1182        let output = orchestrator.run_cancellable(
1183            RunInput {
1184                run_id: "run-1".to_string(),
1185                session_id: "s1".to_string(),
1186                branch_id: "main".to_string(),
1187                messages: vec![ChatMessage::user("test")],
1188                state: AppState::default(),
1189            },
1190            Some(&cancel_clone),
1191            move |event| {
1192                // Cancel after first iteration completes
1193                if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1194                    let mut count = call_count_clone.lock().unwrap();
1195                    *count += 1;
1196                    if *count >= 1 {
1197                        cancel.store(true, Ordering::Relaxed);
1198                    }
1199                }
1200            },
1201        );
1202
1203        assert_eq!(output.reason, RunStopReason::Cancelled);
1204        // Should not have a final answer since we cancelled
1205        assert!(output.final_answer.is_none());
1206    }
1207
1208    #[test]
1209    fn token_usage_accumulated() {
1210        let provider = ScriptedProvider {
1211            turns: vec![
1212                ModelTurn {
1213                    directives: vec![ModelDirective::ToolCall {
1214                        call: ToolCall {
1215                            call_id: "c1".to_string(),
1216                            tool_name: "echo".to_string(),
1217                            input: json!({"value": "hi"}),
1218                        },
1219                    }],
1220                    stop_reason: ModelStopReason::ToolUse,
1221                    usage: Some(TokenUsage {
1222                        input_tokens: 100,
1223                        output_tokens: 50,
1224                        cache_read_tokens: 0,
1225                        cache_creation_tokens: 0,
1226                    }),
1227                },
1228                ModelTurn {
1229                    directives: vec![ModelDirective::FinalAnswer {
1230                        text: "done".to_string(),
1231                    }],
1232                    stop_reason: ModelStopReason::EndTurn,
1233                    usage: Some(TokenUsage {
1234                        input_tokens: 200,
1235                        output_tokens: 30,
1236                        cache_read_tokens: 0,
1237                        cache_creation_tokens: 0,
1238                    }),
1239                },
1240            ],
1241            cursor: Mutex::new(0),
1242        };
1243
1244        let mut tools = ToolRegistry::default();
1245        tools.register(EchoTool);
1246
1247        let orchestrator = Orchestrator::new(
1248            Arc::new(provider),
1249            tools,
1250            Vec::new(),
1251            OrchestratorConfig {
1252                max_iterations: 4,
1253                context: None,
1254                context_compiler: None,
1255            },
1256        );
1257
1258        let output = orchestrator.run(
1259            RunInput {
1260                run_id: "run-1".to_string(),
1261                session_id: "s1".to_string(),
1262                branch_id: "main".to_string(),
1263                messages: vec![ChatMessage::user("test")],
1264                state: AppState::default(),
1265            },
1266            |_| {},
1267        );
1268
1269        assert_eq!(output.reason, RunStopReason::Completed);
1270        assert_eq!(output.total_usage.input_tokens, 300);
1271        assert_eq!(output.total_usage.output_tokens, 80);
1272        assert_eq!(output.total_usage.total(), 380);
1273    }
1274}