Skip to main content

arcan_core/
runtime.rs

1use crate::context::{ContextConfig, compact_messages};
2use crate::error::CoreError;
3use crate::lifecycle::LifecycleHook;
4use crate::protocol::{
5    AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
6    ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
7};
8use crate::state::AppState;
9use std::collections::BTreeMap;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, Ordering};
12
13/// Events emitted during streaming completions.
14#[derive(Debug, Clone)]
15pub enum StreamEvent<'a> {
16    /// Regular text content delta.
17    Text(&'a str),
18    /// Reasoning/thinking content delta (model's chain-of-thought).
19    Reasoning(&'a str),
20}
21
22/// A shared, swappable provider handle.
23///
24/// Uses `std::sync::RwLock` (not tokio) because providers use blocking I/O
25/// via `spawn_blocking`. Multiple readers can access concurrently; writes
26/// only hold the lock for the instant of swapping the inner `Arc`.
27pub type SwappableProviderHandle = Arc<std::sync::RwLock<Arc<dyn Provider>>>;
28
29/// Factory for building providers from a spec string at runtime.
30///
31/// The spec format is `"provider_name"` or `"provider_name:model"`.
32pub trait ProviderFactory: Send + Sync {
33    /// Build a new provider from a spec like `"anthropic"` or `"ollama:llama3.2"`.
34    fn build(&self, spec: &str) -> Result<Arc<dyn Provider>, CoreError>;
35
36    /// List the available provider names (e.g. `["anthropic", "openai", "ollama", "mock"]`).
37    fn available_providers(&self) -> Vec<String>;
38}
39
40/// Hook trait for approval gates, allowing the agent loop to wire event handlers
41/// into the gate without depending on the concrete `ApprovalGate` type.
42pub trait ApprovalGateHook: Send + Sync {
43    fn set_event_handler(&self, handler: Arc<dyn Fn(AgentEvent) + Send + Sync>);
44    fn clear_event_handler(&self);
45}
46
47/// Trait for resolving pending approvals from HTTP endpoints.
48pub trait ApprovalResolver: Send + Sync {
49    fn resolve_approval(&self, approval_id: &str, decision: &str, reason: Option<String>) -> bool;
50    fn pending_approval_ids(&self) -> Vec<String>;
51}
52
53#[derive(Debug, Clone)]
54pub struct ProviderRequest {
55    pub run_id: String,
56    pub session_id: String,
57    pub iteration: u32,
58    pub messages: Vec<ChatMessage>,
59    pub tools: Vec<ToolDefinition>,
60    pub state: AppState,
61}
62
63pub trait Provider: Send + Sync {
64    fn name(&self) -> &str;
65    fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
66
67    /// Whether this provider supports streaming completions.
68    fn supports_streaming(&self) -> bool {
69        false
70    }
71
72    /// Stream a completion, calling `on_delta` for each content delta as it arrives.
73    /// Returns the final assembled `ModelTurn`. Default falls back to `complete()`.
74    fn complete_streaming(
75        &self,
76        request: &ProviderRequest,
77        _on_delta: &dyn Fn(StreamEvent<'_>),
78    ) -> Result<ModelTurn, CoreError> {
79        self.complete(request)
80    }
81
82    /// Context window size in tokens, if known.
83    ///
84    /// Used by the shell for auto-compaction thresholds. Returns `None` when
85    /// the provider doesn't know its context window (conservative defaults apply).
86    fn context_window(&self) -> Option<u32> {
87        None
88    }
89}
90
91#[derive(Debug, Clone)]
92pub struct ToolContext {
93    pub run_id: String,
94    pub session_id: String,
95    pub iteration: u32,
96}
97
98pub trait Tool: Send + Sync {
99    fn definition(&self) -> ToolDefinition;
100    fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
101}
102
103pub trait Middleware: Send + Sync {
104    fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
105        Ok(())
106    }
107
108    fn after_model_call(
109        &self,
110        _request: &ProviderRequest,
111        _response: &ModelTurn,
112    ) -> Result<(), CoreError> {
113        Ok(())
114    }
115
116    fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
117        Ok(())
118    }
119
120    fn post_tool_call(
121        &self,
122        _context: &ToolContext,
123        _result: &ToolResult,
124    ) -> Result<(), CoreError> {
125        Ok(())
126    }
127
128    fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
129        Ok(())
130    }
131}
132
133/// Middleware for a single agent turn.
134///
135/// Unlike the legacy [`Middleware`] trait, turn middleware receives mutable
136/// access to requests, tool calls, tool results, and the final run output so
137/// the chain can both veto and transform a turn.
138pub trait TurnMiddleware: Send + Sync {
139    fn before_model_call(&self, _request: &mut ProviderRequest) -> Result<(), CoreError> {
140        Ok(())
141    }
142
143    fn after_model_call(
144        &self,
145        _request: &ProviderRequest,
146        _response: &mut ModelTurn,
147    ) -> Result<(), CoreError> {
148        Ok(())
149    }
150
151    fn pre_tool_call(&self, _context: &ToolContext, _call: &mut ToolCall) -> Result<(), CoreError> {
152        Ok(())
153    }
154
155    fn post_tool_call(
156        &self,
157        _context: &ToolContext,
158        _result: &mut ToolResult,
159    ) -> Result<(), CoreError> {
160        Ok(())
161    }
162
163    fn on_run_finished(&self, _output: &mut RunOutput) -> Result<(), CoreError> {
164        Ok(())
165    }
166}
167
168struct LegacyMiddlewareAdapter {
169    inner: Arc<dyn Middleware>,
170}
171
172impl LegacyMiddlewareAdapter {
173    fn new(inner: Arc<dyn Middleware>) -> Self {
174        Self { inner }
175    }
176}
177
178impl TurnMiddleware for LegacyMiddlewareAdapter {
179    fn before_model_call(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
180        self.inner.before_model_call(request)
181    }
182
183    fn after_model_call(
184        &self,
185        request: &ProviderRequest,
186        response: &mut ModelTurn,
187    ) -> Result<(), CoreError> {
188        self.inner.after_model_call(request, response)
189    }
190
191    fn pre_tool_call(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
192        self.inner.pre_tool_call(context, call)
193    }
194
195    fn post_tool_call(
196        &self,
197        context: &ToolContext,
198        result: &mut ToolResult,
199    ) -> Result<(), CoreError> {
200        self.inner.post_tool_call(context, result)
201    }
202
203    fn on_run_finished(&self, output: &mut RunOutput) -> Result<(), CoreError> {
204        self.inner.on_run_finished(output)
205    }
206}
207
208#[derive(Clone, Default)]
209pub struct ToolRegistry {
210    tools: BTreeMap<String, Arc<dyn Tool>>,
211}
212
213impl ToolRegistry {
214    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
215        self.tools
216            .insert(tool.definition().name.clone(), Arc::new(tool));
217    }
218
219    pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
220        self.tools.get(tool_name).cloned()
221    }
222
223    pub fn definitions(&self) -> Vec<ToolDefinition> {
224        self.tools.values().map(|tool| tool.definition()).collect()
225    }
226}
227
228#[derive(Debug, Clone)]
229pub struct OrchestratorConfig {
230    pub max_iterations: u32,
231    /// Context window management configuration.
232    /// When set, messages are compacted before each provider call to stay within limits.
233    pub context: Option<ContextConfig>,
234    /// Context compiler configuration for assembling system prompts from typed blocks.
235    /// When set, context blocks are compiled into system messages with budget enforcement.
236    pub context_compiler: Option<crate::context_compiler::ContextCompilerConfig>,
237}
238
239impl Default for OrchestratorConfig {
240    fn default() -> Self {
241        Self {
242            max_iterations: 24,
243            context: Some(ContextConfig::default()),
244            context_compiler: None,
245        }
246    }
247}
248
249#[derive(Debug, Clone)]
250pub struct RunInput {
251    pub run_id: String,
252    pub session_id: String,
253    pub branch_id: String,
254    pub messages: Vec<ChatMessage>,
255    pub state: AppState,
256}
257
258#[derive(Debug, Clone)]
259pub struct RunOutput {
260    pub run_id: String,
261    pub session_id: String,
262    pub branch_id: String,
263    pub events: Vec<AgentEvent>,
264    pub messages: Vec<ChatMessage>,
265    pub state: AppState,
266    pub reason: RunStopReason,
267    pub final_answer: Option<String>,
268    /// Accumulated token usage across all iterations.
269    pub total_usage: TokenUsage,
270}
271
272pub struct Orchestrator {
273    provider: Arc<std::sync::RwLock<Arc<dyn Provider>>>,
274    tools: ToolRegistry,
275    turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
276    lifecycle_hooks: Vec<Arc<dyn LifecycleHook>>,
277    config: OrchestratorConfig,
278}
279
280impl Orchestrator {
281    pub fn new(
282        provider: Arc<dyn Provider>,
283        tools: ToolRegistry,
284        middlewares: Vec<Arc<dyn Middleware>>,
285        config: OrchestratorConfig,
286    ) -> Self {
287        Self {
288            provider: Arc::new(std::sync::RwLock::new(provider)),
289            tools,
290            turn_middlewares: middlewares
291                .into_iter()
292                .map(|middleware| {
293                    Arc::new(LegacyMiddlewareAdapter::new(middleware)) as Arc<dyn TurnMiddleware>
294                })
295                .collect(),
296            lifecycle_hooks: Vec::new(),
297            config,
298        }
299    }
300
301    pub fn with_turn_middlewares(
302        provider: Arc<dyn Provider>,
303        tools: ToolRegistry,
304        turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
305        config: OrchestratorConfig,
306    ) -> Self {
307        Self {
308            provider: Arc::new(std::sync::RwLock::new(provider)),
309            tools,
310            turn_middlewares,
311            lifecycle_hooks: Vec::new(),
312            config,
313        }
314    }
315
316    /// Set lifecycle hooks for this orchestrator.
317    ///
318    /// Lifecycle hooks are fire-and-forget observers that cannot block or
319    /// transform the agent loop. They are called at key points (session
320    /// start/end, pre/post tool call, pre/post LLM call) for telemetry,
321    /// billing, and notification use-cases.
322    pub fn with_lifecycle_hooks(mut self, hooks: Vec<Arc<dyn LifecycleHook>>) -> Self {
323        self.lifecycle_hooks = hooks;
324        self
325    }
326
327    /// Add a single lifecycle hook to this orchestrator.
328    pub fn add_lifecycle_hook(&mut self, hook: Arc<dyn LifecycleHook>) {
329        self.lifecycle_hooks.push(hook);
330    }
331
332    /// Swap the active provider at runtime. Returns the name of the new provider.
333    pub fn swap_provider(&self, new_provider: Arc<dyn Provider>) -> Result<String, CoreError> {
334        let name = new_provider.name().to_string();
335        let mut guard = self
336            .provider
337            .write()
338            .map_err(|e| CoreError::LockPoisoned(format!("provider write lock: {e}")))?;
339        *guard = new_provider;
340        Ok(name)
341    }
342
343    /// Get the current provider name.
344    pub fn provider_name(&self) -> Result<String, CoreError> {
345        let guard = self
346            .provider
347            .read()
348            .map_err(|e| CoreError::LockPoisoned(format!("provider read lock: {e}")))?;
349        Ok(guard.name().to_string())
350    }
351
352    pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
353        self.run_cancellable(input, None, event_handler)
354    }
355
356    /// Run the orchestrator loop with an optional cancellation flag.
357    ///
358    /// If `cancel` is provided and set to `true` during execution,
359    /// the loop will stop at the next iteration boundary.
360    pub fn run_cancellable(
361        &self,
362        input: RunInput,
363        cancel: Option<&Arc<AtomicBool>>,
364        mut event_handler: impl FnMut(AgentEvent),
365    ) -> RunOutput {
366        let mut events = Vec::new();
367        let mut messages = input.messages;
368        let mut state = input.state;
369        let mut final_answer: Option<String> = None;
370        let mut stop_reason = RunStopReason::BudgetExceeded;
371        let mut total_iterations = 0;
372        let mut total_usage = TokenUsage::default();
373
374        // Acquire provider reference for this run
375        let provider = match self.provider.read() {
376            Ok(guard) => guard.clone(),
377            Err(e) => {
378                let err_event = AgentEvent::RunErrored {
379                    run_id: input.run_id.clone(),
380                    session_id: input.session_id.clone(),
381                    error: format!("provider lock poisoned: {e}"),
382                };
383                event_handler(err_event.clone());
384                return RunOutput {
385                    run_id: input.run_id,
386                    session_id: input.session_id,
387                    branch_id: input.branch_id,
388                    events: vec![err_event],
389                    final_answer: None,
390                    messages,
391                    state,
392                    reason: RunStopReason::Error,
393                    total_usage: TokenUsage::default(),
394                };
395            }
396        };
397
398        let start_event = AgentEvent::RunStarted {
399            run_id: input.run_id.clone(),
400            session_id: input.session_id.clone(),
401            provider: provider.name().to_string(),
402            max_iterations: self.config.max_iterations,
403        };
404        event_handler(start_event.clone());
405        events.push(start_event);
406
407        // Fire lifecycle hooks: session start
408        for hook in &self.lifecycle_hooks {
409            hook.on_session_start(&input.session_id);
410        }
411
412        for iteration in 1..=self.config.max_iterations {
413            // Check cancellation at each iteration boundary
414            if let Some(flag) = cancel
415                && flag.load(Ordering::Relaxed)
416            {
417                stop_reason = RunStopReason::Cancelled;
418                let err_event = AgentEvent::RunErrored {
419                    run_id: input.run_id.clone(),
420                    session_id: input.session_id.clone(),
421                    error: "run cancelled".to_string(),
422                };
423                event_handler(err_event.clone());
424                events.push(err_event);
425                break;
426            }
427
428            total_iterations = iteration;
429            let iter_event = AgentEvent::IterationStarted {
430                run_id: input.run_id.clone(),
431                session_id: input.session_id.clone(),
432                iteration,
433            };
434            event_handler(iter_event.clone());
435            events.push(iter_event);
436
437            // Context window compaction: trim messages before sending to provider
438            if let Some(ref ctx_config) = self.config.context
439                && let Some(result) = compact_messages(&messages, ctx_config)
440            {
441                let compact_event = AgentEvent::ContextCompacted {
442                    run_id: input.run_id.clone(),
443                    session_id: input.session_id.clone(),
444                    iteration,
445                    dropped_count: result.dropped_count,
446                    tokens_before: result.tokens_before,
447                    tokens_after: result.tokens_after,
448                };
449                event_handler(compact_event.clone());
450                events.push(compact_event);
451                messages = result.messages;
452            }
453
454            let mut provider_request = ProviderRequest {
455                run_id: input.run_id.clone(),
456                session_id: input.session_id.clone(),
457                iteration,
458                messages: messages.clone(),
459                tools: self.tools.definitions(),
460                state: state.clone(),
461            };
462
463            if let Err(err) = self.run_before_model(&mut provider_request) {
464                stop_reason = RunStopReason::BlockedByPolicy;
465                let err_event = AgentEvent::RunErrored {
466                    run_id: input.run_id.clone(),
467                    session_id: input.session_id.clone(),
468                    error: err.to_string(),
469                };
470                event_handler(err_event.clone());
471                events.push(err_event);
472                break;
473            }
474
475            // Fire lifecycle hooks: pre LLM call
476            for hook in &self.lifecycle_hooks {
477                hook.pre_llm_call(&provider_request);
478            }
479
480            let mut model_turn = match provider.complete(&provider_request) {
481                Ok(turn) => turn,
482                Err(err) => {
483                    stop_reason = RunStopReason::Error;
484                    let err_event = AgentEvent::RunErrored {
485                        run_id: input.run_id.clone(),
486                        session_id: input.session_id.clone(),
487                        error: err.to_string(),
488                    };
489                    event_handler(err_event.clone());
490                    events.push(err_event);
491                    break;
492                }
493            };
494
495            // Fire lifecycle hooks: post LLM call
496            for hook in &self.lifecycle_hooks {
497                hook.post_llm_call(&provider_request);
498            }
499
500            if let Err(err) = self.run_after_model(&provider_request, &mut model_turn) {
501                stop_reason = RunStopReason::BlockedByPolicy;
502                let err_event = AgentEvent::RunErrored {
503                    run_id: input.run_id.clone(),
504                    session_id: input.session_id.clone(),
505                    error: err.to_string(),
506                };
507                event_handler(err_event.clone());
508                events.push(err_event);
509                break;
510            }
511
512            // Accumulate token usage if reported
513            if let Some(ref usage) = model_turn.usage {
514                total_usage.accumulate(usage);
515            }
516
517            let output_event = AgentEvent::ModelOutput {
518                run_id: input.run_id.clone(),
519                session_id: input.session_id.clone(),
520                iteration,
521                stop_reason: model_turn.stop_reason,
522                directive_count: model_turn.directives.len(),
523                usage: model_turn.usage,
524            };
525            event_handler(output_event.clone());
526            events.push(output_event);
527
528            let mut requested_tool = false;
529
530            for directive in model_turn.directives {
531                match directive {
532                    ModelDirective::Text { delta } => {
533                        let delta_event = AgentEvent::TextDelta {
534                            run_id: input.run_id.clone(),
535                            session_id: input.session_id.clone(),
536                            iteration,
537                            delta: delta.clone(),
538                        };
539                        event_handler(delta_event.clone());
540                        events.push(delta_event);
541                        messages.push(ChatMessage::assistant(delta));
542                    }
543                    ModelDirective::ToolCall { mut call } => {
544                        requested_tool = true;
545                        let tc_event = AgentEvent::ToolCallRequested {
546                            run_id: input.run_id.clone(),
547                            session_id: input.session_id.clone(),
548                            iteration,
549                            call: call.clone(),
550                        };
551                        event_handler(tc_event.clone());
552                        events.push(tc_event);
553
554                        let context = ToolContext {
555                            run_id: input.run_id.clone(),
556                            session_id: input.session_id.clone(),
557                            iteration,
558                        };
559
560                        if let Err(err) = self.run_pre_tool(&context, &mut call) {
561                            stop_reason = RunStopReason::BlockedByPolicy;
562                            let err_event = AgentEvent::ToolCallFailed {
563                                run_id: input.run_id.clone(),
564                                session_id: input.session_id.clone(),
565                                iteration,
566                                call_id: call.call_id.clone(),
567                                tool_name: call.tool_name.clone(),
568                                error: err.to_string(),
569                            };
570                            event_handler(err_event.clone());
571                            events.push(err_event);
572                            break;
573                        }
574
575                        let Some(tool) = self.tools.get(&call.tool_name) else {
576                            stop_reason = RunStopReason::Error;
577                            let err_event = AgentEvent::ToolCallFailed {
578                                run_id: input.run_id.clone(),
579                                session_id: input.session_id.clone(),
580                                iteration,
581                                call_id: call.call_id.clone(),
582                                tool_name: call.tool_name.clone(),
583                                error: format!(
584                                    "{}",
585                                    CoreError::ToolNotFound {
586                                        tool_name: call.tool_name.clone(),
587                                    }
588                                ),
589                            };
590                            event_handler(err_event.clone());
591                            events.push(err_event);
592                            break;
593                        };
594
595                        // Fire lifecycle hooks: pre tool call
596                        for hook in &self.lifecycle_hooks {
597                            hook.pre_tool_call(&call.tool_name, &call.input);
598                        }
599
600                        match tool.execute(&call, &context) {
601                            Ok(mut result) => {
602                                if let Err(err) = self.run_post_tool(&context, &mut result) {
603                                    stop_reason = RunStopReason::BlockedByPolicy;
604                                    let err_event = AgentEvent::ToolCallFailed {
605                                        run_id: input.run_id.clone(),
606                                        session_id: input.session_id.clone(),
607                                        iteration,
608                                        call_id: call.call_id.clone(),
609                                        tool_name: call.tool_name.clone(),
610                                        error: err.to_string(),
611                                    };
612                                    event_handler(err_event.clone());
613                                    events.push(err_event);
614                                    break;
615                                }
616
617                                if let Some(patch) = &result.state_patch {
618                                    match state.apply_patch(patch) {
619                                        Ok(()) => {
620                                            let patch_event = AgentEvent::StatePatched {
621                                                run_id: input.run_id.clone(),
622                                                session_id: input.session_id.clone(),
623                                                iteration,
624                                                patch: patch.clone(),
625                                                revision: state.revision,
626                                            };
627                                            event_handler(patch_event.clone());
628                                            events.push(patch_event);
629                                        }
630                                        Err(err) => {
631                                            stop_reason = RunStopReason::Error;
632                                            let err_event = AgentEvent::ToolCallFailed {
633                                                run_id: input.run_id.clone(),
634                                                session_id: input.session_id.clone(),
635                                                iteration,
636                                                call_id: call.call_id.clone(),
637                                                tool_name: call.tool_name.clone(),
638                                                error: err.to_string(),
639                                            };
640                                            event_handler(err_event.clone());
641                                            events.push(err_event);
642                                            break;
643                                        }
644                                    }
645                                }
646
647                                // Fire lifecycle hooks: post tool call
648                                let result_str = serde_json::to_string(&result.output)
649                                    .unwrap_or_else(|_| "{}".to_string());
650                                for hook in &self.lifecycle_hooks {
651                                    hook.post_tool_call(&call.tool_name, &result_str);
652                                }
653
654                                let completed_event = AgentEvent::ToolCallCompleted {
655                                    run_id: input.run_id.clone(),
656                                    session_id: input.session_id.clone(),
657                                    iteration,
658                                    result: ToolResultSummary::from(&result),
659                                };
660                                event_handler(completed_event.clone());
661                                events.push(completed_event);
662
663                                messages.push(ChatMessage::tool_result(
664                                    &result.call_id,
665                                    serde_json::to_string(&result.output)
666                                        .unwrap_or_else(|_| "{}".to_string()),
667                                ));
668                            }
669                            Err(err) => {
670                                stop_reason = RunStopReason::Error;
671                                let err_event = AgentEvent::ToolCallFailed {
672                                    run_id: input.run_id.clone(),
673                                    session_id: input.session_id.clone(),
674                                    iteration,
675                                    call_id: call.call_id.clone(),
676                                    tool_name: call.tool_name.clone(),
677                                    error: err.to_string(),
678                                };
679                                event_handler(err_event.clone());
680                                events.push(err_event);
681                                break;
682                            }
683                        }
684                    }
685                    ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
686                        Ok(()) => {
687                            let patch_event = AgentEvent::StatePatched {
688                                run_id: input.run_id.clone(),
689                                session_id: input.session_id.clone(),
690                                iteration,
691                                patch: patch.clone(),
692                                revision: state.revision,
693                            };
694                            event_handler(patch_event.clone());
695                            events.push(patch_event);
696                        }
697                        Err(err) => {
698                            stop_reason = RunStopReason::Error;
699                            let err_event = AgentEvent::RunErrored {
700                                run_id: input.run_id.clone(),
701                                session_id: input.session_id.clone(),
702                                error: err.to_string(),
703                            };
704                            event_handler(err_event.clone());
705                            events.push(err_event);
706                            break;
707                        }
708                    },
709                    ModelDirective::FinalAnswer { text } => {
710                        final_answer = Some(text.clone());
711                        let delta_event = AgentEvent::TextDelta {
712                            run_id: input.run_id.clone(),
713                            session_id: input.session_id.clone(),
714                            iteration,
715                            delta: text.clone(),
716                        };
717                        event_handler(delta_event.clone());
718                        events.push(delta_event);
719                        messages.push(ChatMessage::assistant(text));
720                    }
721                }
722            }
723
724            if matches!(
725                stop_reason,
726                RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
727            ) {
728                break;
729            }
730
731            match model_turn.stop_reason {
732                ModelStopReason::EndTurn => {
733                    stop_reason = RunStopReason::Completed;
734                    break;
735                }
736                ModelStopReason::NeedsUser => {
737                    stop_reason = RunStopReason::NeedsUser;
738                    break;
739                }
740                ModelStopReason::Safety => {
741                    stop_reason = RunStopReason::BlockedByPolicy;
742                    break;
743                }
744                ModelStopReason::ToolUse => {
745                    if !requested_tool {
746                        stop_reason = RunStopReason::Error;
747                        let err_event = AgentEvent::RunErrored {
748                            run_id: input.run_id.clone(),
749                            session_id: input.session_id.clone(),
750                            error: "model requested tool_use stop reason without tool call"
751                                .to_string(),
752                        };
753                        event_handler(err_event.clone());
754                        events.push(err_event);
755                        break;
756                    }
757                }
758                ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
759                    if !requested_tool {
760                        stop_reason = RunStopReason::Error;
761                        let err_event = AgentEvent::RunErrored {
762                            run_id: input.run_id.clone(),
763                            session_id: input.session_id.clone(),
764                            error: "model returned non-terminal stop reason without tool call"
765                                .to_string(),
766                        };
767                        event_handler(err_event.clone());
768                        events.push(err_event);
769                        break;
770                    }
771                }
772            }
773        }
774
775        if total_iterations == self.config.max_iterations
776            && stop_reason == RunStopReason::BudgetExceeded
777        {
778            let err_event = AgentEvent::RunErrored {
779                run_id: input.run_id.clone(),
780                session_id: input.session_id.clone(),
781                error: "max iteration budget exceeded".to_string(),
782            };
783            event_handler(err_event.clone());
784            events.push(err_event);
785        }
786
787        let finished_event = AgentEvent::RunFinished {
788            run_id: input.run_id.clone(),
789            session_id: input.session_id.clone(),
790            reason: stop_reason,
791            total_iterations,
792            final_answer: final_answer.clone(),
793            usage: if total_usage.total() > 0 {
794                Some(total_usage)
795            } else {
796                None
797            },
798        };
799        event_handler(finished_event.clone());
800        events.push(finished_event);
801
802        let mut output = RunOutput {
803            run_id: input.run_id,
804            session_id: input.session_id,
805            branch_id: input.branch_id,
806            events,
807            messages,
808            state,
809            reason: stop_reason,
810            final_answer,
811            total_usage,
812        };
813
814        if let Err(e) = self
815            .turn_middlewares
816            .iter()
817            .try_for_each(|m| m.on_run_finished(&mut output))
818        {
819            tracing::warn!(error = %e, "middleware on_run_finished failed (non-fatal)");
820        }
821
822        // Fire lifecycle hooks: session end
823        for hook in &self.lifecycle_hooks {
824            hook.on_session_end(&output.session_id, &output);
825        }
826
827        output
828    }
829
830    fn run_before_model(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
831        self.turn_middlewares
832            .iter()
833            .try_for_each(|middleware| middleware.before_model_call(request))
834    }
835
836    fn run_after_model(
837        &self,
838        request: &ProviderRequest,
839        response: &mut ModelTurn,
840    ) -> Result<(), CoreError> {
841        self.turn_middlewares
842            .iter()
843            .try_for_each(|middleware| middleware.after_model_call(request, response))
844    }
845
846    fn run_pre_tool(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
847        self.turn_middlewares
848            .iter()
849            .try_for_each(|middleware| middleware.pre_tool_call(context, call))
850    }
851
852    fn run_post_tool(
853        &self,
854        context: &ToolContext,
855        result: &mut ToolResult,
856    ) -> Result<(), CoreError> {
857        self.turn_middlewares
858            .iter()
859            .try_for_each(|middleware| middleware.post_tool_call(context, result))
860    }
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866    use crate::protocol::{
867        ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
868    };
869    use serde_json::json;
870    use std::sync::Mutex;
871
872    struct ScriptedProvider {
873        turns: Vec<ModelTurn>,
874        cursor: Mutex<usize>,
875    }
876
877    impl Provider for ScriptedProvider {
878        fn name(&self) -> &str {
879            "scripted"
880        }
881
882        fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
883            let mut cursor = self
884                .cursor
885                .lock()
886                .map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
887            let idx = *cursor;
888            let Some(turn) = self.turns.get(idx) else {
889                return Err(CoreError::Provider("no scripted turn left".to_string()));
890            };
891            *cursor += 1;
892            Ok(turn.clone())
893        }
894    }
895
896    struct EchoTool;
897
898    impl Tool for EchoTool {
899        fn definition(&self) -> ToolDefinition {
900            ToolDefinition {
901                name: "echo".to_string(),
902                description: "Echoes the provided value".to_string(),
903                input_schema: json!({
904                    "type": "object",
905                    "properties": { "value": { "type": "string" } },
906                    "required": ["value"]
907                }),
908                title: None,
909                output_schema: None,
910                annotations: None,
911                category: None,
912                tags: Vec::new(),
913                timeout_secs: None,
914            }
915        }
916
917        fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
918            let value = call.input.get("value").cloned().unwrap_or(json!(null));
919            Ok(ToolResult {
920                call_id: call.call_id.clone(),
921                tool_name: call.tool_name.clone(),
922                output: json!({ "echo": value.clone() }),
923                content: None,
924                is_error: false,
925                state_patch: Some(StatePatch {
926                    format: StatePatchFormat::MergePatch,
927                    patch: json!({ "last_echo": value }),
928                    source: StatePatchSource::Tool,
929                }),
930            })
931        }
932    }
933
934    #[test]
935    fn orchestrator_runs_tool_then_finishes() {
936        let provider = ScriptedProvider {
937            turns: vec![
938                ModelTurn {
939                    directives: vec![ModelDirective::ToolCall {
940                        call: ToolCall {
941                            call_id: "call-1".to_string(),
942                            tool_name: "echo".to_string(),
943                            input: json!({ "value": "hello" }),
944                        },
945                    }],
946                    stop_reason: ModelStopReason::ToolUse,
947                    usage: None,
948                },
949                ModelTurn {
950                    directives: vec![ModelDirective::FinalAnswer {
951                        text: "done".to_string(),
952                    }],
953                    stop_reason: ModelStopReason::EndTurn,
954                    usage: None,
955                },
956            ],
957            cursor: Mutex::new(0),
958        };
959
960        let mut tools = ToolRegistry::default();
961        tools.register(EchoTool);
962
963        let orchestrator = Orchestrator::new(
964            Arc::new(provider),
965            tools,
966            Vec::new(),
967            OrchestratorConfig {
968                max_iterations: 4,
969                context: None,
970                context_compiler: None,
971            },
972        );
973
974        let output = orchestrator.run(
975            RunInput {
976                run_id: "run-1".to_string(),
977                session_id: "session-1".to_string(),
978                branch_id: "main".to_string(),
979                messages: vec![ChatMessage::user("test")],
980                state: AppState::default(),
981            },
982            |_| {},
983        );
984
985        assert_eq!(output.reason, RunStopReason::Completed);
986        assert_eq!(output.final_answer.as_deref(), Some("done"));
987        assert_eq!(output.state.revision, 1);
988        assert_eq!(output.state.data["last_echo"], "hello");
989
990        assert!(
991            output
992                .events
993                .iter()
994                .any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. }))
995        );
996        assert!(output.events.iter().any(|event| matches!(
997            event,
998            AgentEvent::RunFinished {
999                reason: RunStopReason::Completed,
1000                ..
1001            }
1002        )));
1003    }
1004
1005    #[test]
1006    fn provider_error_stops_run() {
1007        struct FailProvider;
1008        impl Provider for FailProvider {
1009            fn name(&self) -> &str {
1010                "fail"
1011            }
1012            fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1013                Err(CoreError::Provider("connection refused".to_string()))
1014            }
1015        }
1016
1017        let orchestrator = Orchestrator::new(
1018            Arc::new(FailProvider),
1019            ToolRegistry::default(),
1020            Vec::new(),
1021            OrchestratorConfig {
1022                max_iterations: 4,
1023                context: None,
1024                context_compiler: None,
1025            },
1026        );
1027
1028        let output = orchestrator.run(
1029            RunInput {
1030                run_id: "run-1".to_string(),
1031                session_id: "s1".to_string(),
1032                branch_id: "main".to_string(),
1033                messages: vec![ChatMessage::user("test")],
1034                state: AppState::default(),
1035            },
1036            |_| {},
1037        );
1038
1039        assert_eq!(output.reason, RunStopReason::Error);
1040        assert!(
1041            output
1042                .events
1043                .iter()
1044                .any(|e| matches!(e, AgentEvent::RunErrored { .. }))
1045        );
1046    }
1047
1048    #[test]
1049    fn tool_not_found_stops_run() {
1050        let provider = ScriptedProvider {
1051            turns: vec![ModelTurn {
1052                directives: vec![ModelDirective::ToolCall {
1053                    call: ToolCall {
1054                        call_id: "c1".to_string(),
1055                        tool_name: "nonexistent".to_string(),
1056                        input: json!({}),
1057                    },
1058                }],
1059                stop_reason: ModelStopReason::ToolUse,
1060                usage: None,
1061            }],
1062            cursor: Mutex::new(0),
1063        };
1064
1065        let orchestrator = Orchestrator::new(
1066            Arc::new(provider),
1067            ToolRegistry::default(),
1068            Vec::new(),
1069            OrchestratorConfig {
1070                max_iterations: 4,
1071                context: None,
1072                context_compiler: None,
1073            },
1074        );
1075
1076        let output = orchestrator.run(
1077            RunInput {
1078                run_id: "run-1".to_string(),
1079                session_id: "s1".to_string(),
1080                branch_id: "main".to_string(),
1081                messages: vec![ChatMessage::user("test")],
1082                state: AppState::default(),
1083            },
1084            |_| {},
1085        );
1086
1087        assert_eq!(output.reason, RunStopReason::Error);
1088        assert!(
1089            output
1090                .events
1091                .iter()
1092                .any(|e| matches!(e, AgentEvent::ToolCallFailed { .. }))
1093        );
1094    }
1095
1096    #[test]
1097    fn middleware_blocks_model_call() {
1098        struct BlockMiddleware;
1099        impl Middleware for BlockMiddleware {
1100            fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
1101                Err(CoreError::Middleware("blocked by policy".to_string()))
1102            }
1103        }
1104
1105        let provider = ScriptedProvider {
1106            turns: vec![ModelTurn {
1107                directives: vec![ModelDirective::Text {
1108                    delta: "hi".to_string(),
1109                }],
1110                stop_reason: ModelStopReason::EndTurn,
1111                usage: None,
1112            }],
1113            cursor: Mutex::new(0),
1114        };
1115
1116        let orchestrator = Orchestrator::new(
1117            Arc::new(provider),
1118            ToolRegistry::default(),
1119            vec![Arc::new(BlockMiddleware)],
1120            OrchestratorConfig {
1121                max_iterations: 4,
1122                context: None,
1123                context_compiler: None,
1124            },
1125        );
1126
1127        let output = orchestrator.run(
1128            RunInput {
1129                run_id: "run-1".to_string(),
1130                session_id: "s1".to_string(),
1131                branch_id: "main".to_string(),
1132                messages: vec![ChatMessage::user("test")],
1133                state: AppState::default(),
1134            },
1135            |_| {},
1136        );
1137
1138        assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
1139    }
1140
1141    #[test]
1142    fn turn_middleware_can_rewrite_calls_and_responses() {
1143        struct RewriteMiddleware;
1144
1145        impl TurnMiddleware for RewriteMiddleware {
1146            fn after_model_call(
1147                &self,
1148                _request: &ProviderRequest,
1149                response: &mut ModelTurn,
1150            ) -> Result<(), CoreError> {
1151                for directive in &mut response.directives {
1152                    if let ModelDirective::FinalAnswer { text } = directive {
1153                        *text = "rewritten answer".to_string();
1154                    }
1155                }
1156                Ok(())
1157            }
1158
1159            fn pre_tool_call(
1160                &self,
1161                _context: &ToolContext,
1162                call: &mut ToolCall,
1163            ) -> Result<(), CoreError> {
1164                call.input = json!({ "value": "rewritten input" });
1165                Ok(())
1166            }
1167        }
1168
1169        let provider = ScriptedProvider {
1170            turns: vec![
1171                ModelTurn {
1172                    directives: vec![ModelDirective::ToolCall {
1173                        call: ToolCall {
1174                            call_id: "call-1".to_string(),
1175                            tool_name: "echo".to_string(),
1176                            input: json!({ "value": "original input" }),
1177                        },
1178                    }],
1179                    stop_reason: ModelStopReason::ToolUse,
1180                    usage: None,
1181                },
1182                ModelTurn {
1183                    directives: vec![ModelDirective::FinalAnswer {
1184                        text: "original answer".to_string(),
1185                    }],
1186                    stop_reason: ModelStopReason::EndTurn,
1187                    usage: None,
1188                },
1189            ],
1190            cursor: Mutex::new(0),
1191        };
1192
1193        let mut tools = ToolRegistry::default();
1194        tools.register(EchoTool);
1195
1196        let orchestrator = Orchestrator::with_turn_middlewares(
1197            Arc::new(provider),
1198            tools,
1199            vec![Arc::new(RewriteMiddleware)],
1200            OrchestratorConfig {
1201                max_iterations: 4,
1202                context: None,
1203                context_compiler: None,
1204            },
1205        );
1206
1207        let output = orchestrator.run(
1208            RunInput {
1209                run_id: "run-1".to_string(),
1210                session_id: "session-1".to_string(),
1211                branch_id: "main".to_string(),
1212                messages: vec![ChatMessage::user("test")],
1213                state: AppState::default(),
1214            },
1215            |_| {},
1216        );
1217
1218        assert_eq!(output.reason, RunStopReason::Completed);
1219        assert_eq!(output.final_answer.as_deref(), Some("rewritten answer"));
1220        assert_eq!(output.state.data["last_echo"], "rewritten input");
1221    }
1222
1223    #[test]
1224    fn budget_exceeded_when_iterations_exhausted() {
1225        // Provider always returns ToolUse but no tool call directives → continues loop
1226        // Actually, we need it to keep looping. Use a tool that works, but provider
1227        // always asks for more.
1228        let provider = ScriptedProvider {
1229            turns: vec![
1230                ModelTurn {
1231                    directives: vec![ModelDirective::ToolCall {
1232                        call: ToolCall {
1233                            call_id: "c1".to_string(),
1234                            tool_name: "echo".to_string(),
1235                            input: json!({"value": "1"}),
1236                        },
1237                    }],
1238                    stop_reason: ModelStopReason::ToolUse,
1239                    usage: None,
1240                },
1241                ModelTurn {
1242                    directives: vec![ModelDirective::ToolCall {
1243                        call: ToolCall {
1244                            call_id: "c2".to_string(),
1245                            tool_name: "echo".to_string(),
1246                            input: json!({"value": "2"}),
1247                        },
1248                    }],
1249                    stop_reason: ModelStopReason::ToolUse,
1250                    usage: None,
1251                },
1252                // Only 2 turns, but max_iterations = 2, so it exhausts budget
1253                // 3rd iteration will fail because no more scripted turns
1254            ],
1255            cursor: Mutex::new(0),
1256        };
1257
1258        let mut tools = ToolRegistry::default();
1259        tools.register(EchoTool);
1260
1261        let orchestrator = Orchestrator::new(
1262            Arc::new(provider),
1263            tools,
1264            Vec::new(),
1265            OrchestratorConfig {
1266                max_iterations: 2,
1267                context: None,
1268                context_compiler: None,
1269            },
1270        );
1271
1272        let output = orchestrator.run(
1273            RunInput {
1274                run_id: "run-1".to_string(),
1275                session_id: "s1".to_string(),
1276                branch_id: "main".to_string(),
1277                messages: vec![ChatMessage::user("test")],
1278                state: AppState::default(),
1279            },
1280            |_| {},
1281        );
1282
1283        assert_eq!(output.reason, RunStopReason::BudgetExceeded);
1284    }
1285
1286    #[test]
1287    fn text_only_response_completes() {
1288        let provider = ScriptedProvider {
1289            turns: vec![ModelTurn {
1290                directives: vec![ModelDirective::Text {
1291                    delta: "Hello, world!".to_string(),
1292                }],
1293                stop_reason: ModelStopReason::EndTurn,
1294                usage: None,
1295            }],
1296            cursor: Mutex::new(0),
1297        };
1298
1299        let orchestrator = Orchestrator::new(
1300            Arc::new(provider),
1301            ToolRegistry::default(),
1302            Vec::new(),
1303            OrchestratorConfig {
1304                max_iterations: 4,
1305                context: None,
1306                context_compiler: None,
1307            },
1308        );
1309
1310        let output = orchestrator.run(
1311            RunInput {
1312                run_id: "run-1".to_string(),
1313                session_id: "s1".to_string(),
1314                branch_id: "main".to_string(),
1315                messages: vec![ChatMessage::user("hi")],
1316                state: AppState::default(),
1317            },
1318            |_| {},
1319        );
1320
1321        assert_eq!(output.reason, RunStopReason::Completed);
1322        assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
1323    }
1324
1325    #[test]
1326    fn event_handler_receives_all_events() {
1327        let provider = ScriptedProvider {
1328            turns: vec![ModelTurn {
1329                directives: vec![ModelDirective::FinalAnswer {
1330                    text: "done".to_string(),
1331                }],
1332                stop_reason: ModelStopReason::EndTurn,
1333                usage: None,
1334            }],
1335            cursor: Mutex::new(0),
1336        };
1337
1338        let orchestrator = Orchestrator::new(
1339            Arc::new(provider),
1340            ToolRegistry::default(),
1341            Vec::new(),
1342            OrchestratorConfig {
1343                max_iterations: 4,
1344                context: None,
1345                context_compiler: None,
1346            },
1347        );
1348
1349        let received = Arc::new(Mutex::new(Vec::new()));
1350        let received_clone = received.clone();
1351
1352        orchestrator.run(
1353            RunInput {
1354                run_id: "run-1".to_string(),
1355                session_id: "s1".to_string(),
1356                branch_id: "main".to_string(),
1357                messages: vec![ChatMessage::user("test")],
1358                state: AppState::default(),
1359            },
1360            move |event| {
1361                received_clone.lock().unwrap().push(event);
1362            },
1363        );
1364
1365        let events = received.lock().unwrap();
1366        assert!(events.len() >= 4); // RunStarted, IterationStarted, ModelOutput, TextDelta, RunFinished
1367        assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
1368        assert!(matches!(
1369            events.last().unwrap(),
1370            AgentEvent::RunFinished { .. }
1371        ));
1372    }
1373
1374    #[test]
1375    fn tool_result_includes_call_id() {
1376        let provider = ScriptedProvider {
1377            turns: vec![
1378                ModelTurn {
1379                    directives: vec![ModelDirective::ToolCall {
1380                        call: ToolCall {
1381                            call_id: "my-call-id".to_string(),
1382                            tool_name: "echo".to_string(),
1383                            input: json!({"value": "test"}),
1384                        },
1385                    }],
1386                    stop_reason: ModelStopReason::ToolUse,
1387                    usage: None,
1388                },
1389                ModelTurn {
1390                    directives: vec![ModelDirective::FinalAnswer {
1391                        text: "ok".to_string(),
1392                    }],
1393                    stop_reason: ModelStopReason::EndTurn,
1394                    usage: None,
1395                },
1396            ],
1397            cursor: Mutex::new(0),
1398        };
1399
1400        let mut tools = ToolRegistry::default();
1401        tools.register(EchoTool);
1402
1403        let orchestrator = Orchestrator::new(
1404            Arc::new(provider),
1405            tools,
1406            Vec::new(),
1407            OrchestratorConfig {
1408                max_iterations: 4,
1409                context: None,
1410                context_compiler: None,
1411            },
1412        );
1413
1414        let output = orchestrator.run(
1415            RunInput {
1416                run_id: "run-1".to_string(),
1417                session_id: "s1".to_string(),
1418                branch_id: "main".to_string(),
1419                messages: vec![ChatMessage::user("test")],
1420                state: AppState::default(),
1421            },
1422            |_| {},
1423        );
1424
1425        // Verify tool result message has the correct call_id
1426        let tool_msg = output
1427            .messages
1428            .iter()
1429            .find(|m| m.role == crate::protocol::Role::Tool)
1430            .expect("should have tool message");
1431        assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
1432    }
1433
1434    #[test]
1435    fn cancellation_stops_run() {
1436        let provider = ScriptedProvider {
1437            turns: vec![
1438                ModelTurn {
1439                    directives: vec![ModelDirective::ToolCall {
1440                        call: ToolCall {
1441                            call_id: "c1".to_string(),
1442                            tool_name: "echo".to_string(),
1443                            input: json!({"value": "1"}),
1444                        },
1445                    }],
1446                    stop_reason: ModelStopReason::ToolUse,
1447                    usage: None,
1448                },
1449                ModelTurn {
1450                    directives: vec![ModelDirective::FinalAnswer {
1451                        text: "should not reach".to_string(),
1452                    }],
1453                    stop_reason: ModelStopReason::EndTurn,
1454                    usage: None,
1455                },
1456            ],
1457            cursor: Mutex::new(0),
1458        };
1459
1460        let mut tools = ToolRegistry::default();
1461        tools.register(EchoTool);
1462
1463        let orchestrator = Orchestrator::new(
1464            Arc::new(provider),
1465            tools,
1466            Vec::new(),
1467            OrchestratorConfig {
1468                max_iterations: 10,
1469                context: None,
1470                context_compiler: None,
1471            },
1472        );
1473
1474        // Set cancellation flag before the second iteration
1475        let cancel = Arc::new(AtomicBool::new(false));
1476        let cancel_clone = cancel.clone();
1477        let call_count = Arc::new(Mutex::new(0u32));
1478        let call_count_clone = call_count.clone();
1479
1480        let output = orchestrator.run_cancellable(
1481            RunInput {
1482                run_id: "run-1".to_string(),
1483                session_id: "s1".to_string(),
1484                branch_id: "main".to_string(),
1485                messages: vec![ChatMessage::user("test")],
1486                state: AppState::default(),
1487            },
1488            Some(&cancel_clone),
1489            move |event| {
1490                // Cancel after first iteration completes
1491                if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
1492                    let mut count = call_count_clone.lock().unwrap();
1493                    *count += 1;
1494                    if *count >= 1 {
1495                        cancel.store(true, Ordering::Relaxed);
1496                    }
1497                }
1498            },
1499        );
1500
1501        assert_eq!(output.reason, RunStopReason::Cancelled);
1502        // Should not have a final answer since we cancelled
1503        assert!(output.final_answer.is_none());
1504    }
1505
1506    #[test]
1507    fn swappable_provider_handle_swap() {
1508        struct ProviderA;
1509        impl Provider for ProviderA {
1510            fn name(&self) -> &str {
1511                "provider-a"
1512            }
1513            fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1514                Err(CoreError::Provider("stub provider".into()))
1515            }
1516        }
1517
1518        struct ProviderB;
1519        impl Provider for ProviderB {
1520            fn name(&self) -> &str {
1521                "provider-b"
1522            }
1523            fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
1524                Err(CoreError::Provider("stub provider".into()))
1525            }
1526        }
1527
1528        let handle: SwappableProviderHandle = Arc::new(std::sync::RwLock::new(Arc::new(ProviderA)));
1529
1530        // Read returns provider-a
1531        assert_eq!(handle.read().unwrap().name(), "provider-a");
1532
1533        // Swap to provider-b
1534        {
1535            let mut guard = handle.write().unwrap();
1536            *guard = Arc::new(ProviderB);
1537        }
1538
1539        // Read returns provider-b
1540        assert_eq!(handle.read().unwrap().name(), "provider-b");
1541    }
1542
1543    #[test]
1544    fn token_usage_accumulated() {
1545        let provider = ScriptedProvider {
1546            turns: vec![
1547                ModelTurn {
1548                    directives: vec![ModelDirective::ToolCall {
1549                        call: ToolCall {
1550                            call_id: "c1".to_string(),
1551                            tool_name: "echo".to_string(),
1552                            input: json!({"value": "hi"}),
1553                        },
1554                    }],
1555                    stop_reason: ModelStopReason::ToolUse,
1556                    usage: Some(TokenUsage {
1557                        input_tokens: 100,
1558                        output_tokens: 50,
1559                        cache_read_tokens: 0,
1560                        cache_creation_tokens: 0,
1561                    }),
1562                },
1563                ModelTurn {
1564                    directives: vec![ModelDirective::FinalAnswer {
1565                        text: "done".to_string(),
1566                    }],
1567                    stop_reason: ModelStopReason::EndTurn,
1568                    usage: Some(TokenUsage {
1569                        input_tokens: 200,
1570                        output_tokens: 30,
1571                        cache_read_tokens: 0,
1572                        cache_creation_tokens: 0,
1573                    }),
1574                },
1575            ],
1576            cursor: Mutex::new(0),
1577        };
1578
1579        let mut tools = ToolRegistry::default();
1580        tools.register(EchoTool);
1581
1582        let orchestrator = Orchestrator::new(
1583            Arc::new(provider),
1584            tools,
1585            Vec::new(),
1586            OrchestratorConfig {
1587                max_iterations: 4,
1588                context: None,
1589                context_compiler: None,
1590            },
1591        );
1592
1593        let output = orchestrator.run(
1594            RunInput {
1595                run_id: "run-1".to_string(),
1596                session_id: "s1".to_string(),
1597                branch_id: "main".to_string(),
1598                messages: vec![ChatMessage::user("test")],
1599                state: AppState::default(),
1600            },
1601            |_| {},
1602        );
1603
1604        assert_eq!(output.reason, RunStopReason::Completed);
1605        assert_eq!(output.total_usage.input_tokens, 300);
1606        assert_eq!(output.total_usage.output_tokens, 80);
1607        assert_eq!(output.total_usage.total(), 380);
1608    }
1609
1610    #[test]
1611    fn lifecycle_hooks_fire_during_run() {
1612        use crate::lifecycle::LifecycleHook;
1613        use std::sync::atomic::{AtomicU32, Ordering};
1614
1615        struct CountingHook {
1616            pre_tool: AtomicU32,
1617            post_tool: AtomicU32,
1618            pre_llm: AtomicU32,
1619            post_llm: AtomicU32,
1620            session_start: AtomicU32,
1621            session_end: AtomicU32,
1622        }
1623
1624        impl CountingHook {
1625            fn new() -> Self {
1626                Self {
1627                    pre_tool: AtomicU32::new(0),
1628                    post_tool: AtomicU32::new(0),
1629                    pre_llm: AtomicU32::new(0),
1630                    post_llm: AtomicU32::new(0),
1631                    session_start: AtomicU32::new(0),
1632                    session_end: AtomicU32::new(0),
1633                }
1634            }
1635        }
1636
1637        impl LifecycleHook for CountingHook {
1638            fn pre_tool_call(&self, _tool_name: &str, _input: &serde_json::Value) {
1639                self.pre_tool.fetch_add(1, Ordering::Relaxed);
1640            }
1641            fn post_tool_call(&self, _tool_name: &str, _result: &str) {
1642                self.post_tool.fetch_add(1, Ordering::Relaxed);
1643            }
1644            fn pre_llm_call(&self, _request: &ProviderRequest) {
1645                self.pre_llm.fetch_add(1, Ordering::Relaxed);
1646            }
1647            fn post_llm_call(&self, _request: &ProviderRequest) {
1648                self.post_llm.fetch_add(1, Ordering::Relaxed);
1649            }
1650            fn on_session_start(&self, _session_id: &str) {
1651                self.session_start.fetch_add(1, Ordering::Relaxed);
1652            }
1653            fn on_session_end(&self, _session_id: &str, _output: &RunOutput) {
1654                self.session_end.fetch_add(1, Ordering::Relaxed);
1655            }
1656        }
1657
1658        let provider = ScriptedProvider {
1659            turns: vec![
1660                ModelTurn {
1661                    directives: vec![ModelDirective::ToolCall {
1662                        call: ToolCall {
1663                            call_id: "call-1".to_string(),
1664                            tool_name: "echo".to_string(),
1665                            input: json!({ "value": "hello" }),
1666                        },
1667                    }],
1668                    stop_reason: ModelStopReason::ToolUse,
1669                    usage: None,
1670                },
1671                ModelTurn {
1672                    directives: vec![ModelDirective::FinalAnswer {
1673                        text: "done".to_string(),
1674                    }],
1675                    stop_reason: ModelStopReason::EndTurn,
1676                    usage: None,
1677                },
1678            ],
1679            cursor: Mutex::new(0),
1680        };
1681
1682        let mut tools = ToolRegistry::default();
1683        tools.register(EchoTool);
1684
1685        let hook = Arc::new(CountingHook::new());
1686
1687        let orchestrator = Orchestrator::new(
1688            Arc::new(provider),
1689            tools,
1690            Vec::new(),
1691            OrchestratorConfig {
1692                max_iterations: 4,
1693                context: None,
1694                context_compiler: None,
1695            },
1696        )
1697        .with_lifecycle_hooks(vec![hook.clone()]);
1698
1699        let output = orchestrator.run(
1700            RunInput {
1701                run_id: "run-1".to_string(),
1702                session_id: "session-1".to_string(),
1703                branch_id: "main".to_string(),
1704                messages: vec![ChatMessage::user("test")],
1705                state: AppState::default(),
1706            },
1707            |_| {},
1708        );
1709
1710        assert_eq!(output.reason, RunStopReason::Completed);
1711
1712        // Session lifecycle: start and end should each fire once
1713        assert_eq!(hook.session_start.load(Ordering::Relaxed), 1);
1714        assert_eq!(hook.session_end.load(Ordering::Relaxed), 1);
1715
1716        // LLM calls: 2 iterations = 2 provider calls
1717        assert_eq!(hook.pre_llm.load(Ordering::Relaxed), 2);
1718        assert_eq!(hook.post_llm.load(Ordering::Relaxed), 2);
1719
1720        // Tool calls: 1 echo tool call
1721        assert_eq!(hook.pre_tool.load(Ordering::Relaxed), 1);
1722        assert_eq!(hook.post_tool.load(Ordering::Relaxed), 1);
1723    }
1724}