Skip to main content

neuron_op_react/
lib.rs

1#![deny(missing_docs)]
2//! ReAct operator — model + tools in a reasoning loop.
3//!
4//! Implements `layer0::Operator` by running the Reason-Act-Observe cycle:
5//! assemble context → call model → execute tools → repeat until done.
6
7use async_trait::async_trait;
8use layer0::content::Content;
9use layer0::duration::DurationMs;
10use layer0::effect::{Effect, Scope, SignalPayload};
11use layer0::error::OperatorError;
12use layer0::hook::{HookAction, HookContext, HookPoint};
13use layer0::id::{AgentId, WorkflowId};
14use layer0::operator::{
15    ExitReason, Operator, OperatorInput, OperatorMetadata, OperatorOutput, ToolCallRecord,
16};
17use neuron_hooks::HookRegistry;
18use neuron_tool::{ToolConcurrencyHint, ToolRegistry};
19use neuron_turn::context::ContextStrategy;
20use neuron_turn::convert::{content_to_user_message, parts_to_content};
21use neuron_turn::provider::Provider;
22use neuron_turn::types::*;
23use rust_decimal::Decimal;
24use std::sync::Arc;
25use std::time::Instant;
26
27/// Static configuration for a ReactOperator instance.
28pub struct ReactConfig {
29    /// Base system prompt.
30    pub system_prompt: String,
31    /// Default model identifier.
32    pub default_model: String,
33    /// Default max tokens per response.
34    pub default_max_tokens: u32,
35    /// Default max turns before stopping.
36    pub default_max_turns: u32,
37}
38
39impl Default for ReactConfig {
40    fn default() -> Self {
41        Self {
42            system_prompt: String::new(),
43            default_model: String::new(),
44            default_max_tokens: 4096,
45            default_max_turns: 10,
46        }
47    }
48}
49
50/// Names of tools that produce Effects instead of executing locally.
51const EFFECT_TOOL_NAMES: &[&str] = &[
52    "write_memory",
53    "delete_memory",
54    "delegate",
55    "handoff",
56    "signal",
57];
58
59/// Resolved configuration merging defaults with per-request overrides.
60struct ResolvedConfig {
61    model: Option<String>,
62    system: String,
63    max_turns: u32,
64    max_cost: Option<Decimal>,
65    max_duration: Option<DurationMs>,
66    allowed_tools: Option<Vec<String>>,
67    max_tokens: u32,
68}
69
70// Re-export turn-kit primitives
71pub use neuron_turn_kit::{
72    BarrierPlanner, BatchItem, Concurrency, ConcurrencyDecider, SteeringSource,
73    ToolExecutionPlanner,
74};
75
76/// Default decider: all tools Exclusive.
77struct DefaultDecider;
78impl ConcurrencyDecider for DefaultDecider {
79    /// Return the concurrency class for a tool by name.
80    fn concurrency(&self, _tool_name: &str) -> Concurrency {
81        Concurrency::Exclusive
82    }
83}
84
85/// Concurrency decider that reads per-tool metadata from ToolRegistry.
86struct MetadataDecider {
87    tools: ToolRegistry,
88}
89impl ConcurrencyDecider for MetadataDecider {
90    fn concurrency(&self, tool_name: &str) -> Concurrency {
91        match self.tools.get(tool_name) {
92            Some(tool) => match tool.concurrency_hint() {
93                ToolConcurrencyHint::Shared => Concurrency::Shared,
94                ToolConcurrencyHint::Exclusive => Concurrency::Exclusive,
95                _ => Concurrency::Exclusive,
96            },
97            None => Concurrency::Exclusive,
98        }
99    }
100}
101
102/// Sequential planner: each tool runs alone.
103struct SequentialPlanner;
104impl ToolExecutionPlanner for SequentialPlanner {
105    fn plan(
106        &self,
107        tool_uses: &[(String, String, serde_json::Value)],
108        _decider: &dyn ConcurrencyDecider,
109    ) -> Vec<BatchItem> {
110        tool_uses
111            .iter()
112            .cloned()
113            .map(BatchItem::Exclusive)
114            .collect()
115    }
116}
117/// A full-featured Operator implementation with a ReAct loop.
118///
119/// Generic over `P: Provider` (not object-safe). The object-safe boundary
120/// is `layer0::Operator`, which `ReactOperator<P>` implements via `#[async_trait]`.
121pub struct ReactOperator<P: Provider> {
122    provider: P,
123    tools: ToolRegistry,
124    context_strategy: Box<dyn ContextStrategy>,
125    hooks: HookRegistry,
126    state_reader: Arc<dyn layer0::StateReader>,
127    config: ReactConfig,
128    planner: Box<dyn ToolExecutionPlanner>,
129    decider: Box<dyn ConcurrencyDecider>,
130    steering: Option<Arc<dyn SteeringSource>>,
131}
132
133impl<P: Provider> ReactOperator<P> {
134    /// Create a new ReactOperator with all dependencies.
135    pub fn new(
136        provider: P,
137        tools: ToolRegistry,
138        context_strategy: Box<dyn ContextStrategy>,
139        hooks: HookRegistry,
140        state_reader: Arc<dyn layer0::StateReader>,
141        config: ReactConfig,
142    ) -> Self {
143        Self {
144            provider,
145            tools,
146            context_strategy,
147            hooks,
148            state_reader,
149            config,
150            planner: Box::new(SequentialPlanner),
151            decider: Box::new(DefaultDecider),
152            steering: None,
153        }
154    }
155    /// Opt-in: set a custom tool execution planner.
156    pub fn with_planner(mut self, planner: Box<dyn ToolExecutionPlanner>) -> Self {
157        self.planner = planner;
158        self
159    }
160    /// Opt-in: set a custom concurrency decider.
161    pub fn with_concurrency_decider(mut self, decider: Box<dyn ConcurrencyDecider>) -> Self {
162        self.decider = decider;
163        self
164    }
165    /// Opt-in: use tool metadata to decide concurrency.
166    pub fn with_metadata_concurrency(mut self) -> Self {
167        self.decider = Box::new(MetadataDecider {
168            tools: self.tools.clone(),
169        });
170        self
171    }
172    /// Opt-in: attach a steering source.
173    pub fn with_steering(mut self, s: Arc<dyn SteeringSource>) -> Self {
174        self.steering = Some(s);
175        self
176    }
177
178    fn resolve_config(&self, input: &OperatorInput) -> ResolvedConfig {
179        let tc = input.config.as_ref();
180        let system = match tc.and_then(|c| c.system_addendum.as_ref()) {
181            Some(addendum) => format!("{}\n{}", self.config.system_prompt, addendum),
182            None => self.config.system_prompt.clone(),
183        };
184        ResolvedConfig {
185            model: tc.and_then(|c| c.model.clone()).or_else(|| {
186                if self.config.default_model.is_empty() {
187                    None
188                } else {
189                    Some(self.config.default_model.clone())
190                }
191            }),
192            system,
193            max_turns: tc
194                .and_then(|c| c.max_turns)
195                .unwrap_or(self.config.default_max_turns),
196            max_cost: tc.and_then(|c| c.max_cost),
197            max_duration: tc.and_then(|c| c.max_duration),
198            allowed_tools: tc.and_then(|c| c.allowed_tools.clone()),
199            max_tokens: self.config.default_max_tokens,
200        }
201    }
202
203    fn build_tool_schemas(&self, config: &ResolvedConfig) -> Vec<ToolSchema> {
204        let mut schemas: Vec<ToolSchema> = self
205            .tools
206            .iter()
207            .map(|tool| ToolSchema {
208                name: tool.name().to_string(),
209                description: tool.description().to_string(),
210                input_schema: tool.input_schema(),
211            })
212            .collect();
213
214        // Add effect tool schemas
215        schemas.extend(effect_tool_schemas());
216
217        // Filter by allowed_tools if specified
218        if let Some(allowed) = &config.allowed_tools {
219            schemas.retain(|s| allowed.contains(&s.name));
220        }
221
222        schemas
223    }
224
225    async fn assemble_context(
226        &self,
227        input: &OperatorInput,
228    ) -> Result<Vec<ProviderMessage>, OperatorError> {
229        let mut messages = Vec::new();
230
231        // Read history from state if session is present
232        if let Some(session) = &input.session {
233            let scope = Scope::Session(session.clone());
234            match self.state_reader.read(&scope, "messages").await {
235                Ok(Some(history)) => {
236                    if let Ok(history_messages) =
237                        serde_json::from_value::<Vec<ProviderMessage>>(history)
238                    {
239                        messages = history_messages;
240                    }
241                }
242                Ok(None) => {} // No history yet
243                Err(_) => {}   // State read errors are non-fatal
244            }
245        }
246
247        // Add the new user message
248        messages.push(content_to_user_message(&input.message));
249
250        Ok(messages)
251    }
252
253    fn try_as_effect(&self, name: &str, input: &serde_json::Value) -> Option<Effect> {
254        match name {
255            "write_memory" => {
256                let scope_str = input.get("scope")?.as_str()?;
257                let key = input.get("key")?.as_str()?.to_string();
258                let value = input.get("value")?.clone();
259                let scope = parse_scope(scope_str);
260                Some(Effect::WriteMemory { scope, key, value })
261            }
262            "delete_memory" => {
263                let scope_str = input.get("scope")?.as_str()?;
264                let key = input.get("key")?.as_str()?.to_string();
265                let scope = parse_scope(scope_str);
266                Some(Effect::DeleteMemory { scope, key })
267            }
268            "delegate" => {
269                let agent = input.get("agent")?.as_str()?;
270                let message = input.get("message").and_then(|m| m.as_str()).unwrap_or("");
271                let delegate_input =
272                    OperatorInput::new(Content::text(message), layer0::operator::TriggerType::Task);
273                Some(Effect::Delegate {
274                    agent: AgentId::new(agent),
275                    input: Box::new(delegate_input),
276                })
277            }
278            "handoff" => {
279                let agent = input.get("agent")?.as_str()?;
280                let state = input
281                    .get("state")
282                    .cloned()
283                    .unwrap_or(serde_json::Value::Null);
284                Some(Effect::Handoff {
285                    agent: AgentId::new(agent),
286                    state,
287                })
288            }
289            "signal" => {
290                let target = input.get("target")?.as_str()?;
291                let signal_type = input
292                    .get("signal_type")
293                    .and_then(|s| s.as_str())
294                    .unwrap_or("default");
295                let data = input
296                    .get("data")
297                    .cloned()
298                    .unwrap_or(serde_json::Value::Null);
299                Some(Effect::Signal {
300                    target: WorkflowId::new(target),
301                    payload: SignalPayload::new(signal_type, data),
302                })
303            }
304            _ => None,
305        }
306    }
307
308    fn build_metadata(
309        &self,
310        tokens_in: u64,
311        tokens_out: u64,
312        cost: Decimal,
313        turns_used: u32,
314        tools_called: Vec<ToolCallRecord>,
315        duration: DurationMs,
316    ) -> OperatorMetadata {
317        let mut meta = OperatorMetadata::default();
318        meta.tokens_in = tokens_in;
319        meta.tokens_out = tokens_out;
320        meta.cost = cost;
321        meta.turns_used = turns_used;
322        meta.tools_called = tools_called;
323        meta.duration = duration;
324        meta
325    }
326
327    fn make_output(
328        message: Content,
329        exit_reason: ExitReason,
330        metadata: OperatorMetadata,
331        effects: Vec<Effect>,
332    ) -> OperatorOutput {
333        let mut output = OperatorOutput::new(message, exit_reason);
334        output.metadata = metadata;
335        output.effects = effects;
336        output
337    }
338
339    fn build_hook_context(
340        &self,
341        point: HookPoint,
342        tokens_in: u64,
343        tokens_out: u64,
344        cost: Decimal,
345        turns_completed: u32,
346        elapsed: DurationMs,
347    ) -> HookContext {
348        let mut ctx = HookContext::new(point);
349        ctx.tokens_used = tokens_in + tokens_out;
350        ctx.cost = cost;
351        ctx.turns_completed = turns_completed;
352        ctx.elapsed = elapsed;
353        ctx
354    }
355}
356
357#[async_trait]
358impl<P: Provider + 'static> Operator for ReactOperator<P> {
359    async fn execute(&self, input: OperatorInput) -> Result<OperatorOutput, OperatorError> {
360        let start = Instant::now();
361        let config = self.resolve_config(&input);
362        let mut messages = self.assemble_context(&input).await?;
363        let tools = self.build_tool_schemas(&config);
364
365        let mut total_tokens_in: u64 = 0;
366        let mut total_tokens_out: u64 = 0;
367        let mut total_cost = Decimal::ZERO;
368        let mut turns_used: u32 = 0;
369        let mut tool_records: Vec<ToolCallRecord> = vec![];
370        let mut effects: Vec<Effect> = vec![];
371        let mut last_content: Vec<ContentPart> = vec![];
372
373        loop {
374            turns_used += 1;
375
376            // 1. Hook: PreInference
377            let hook_ctx = self.build_hook_context(
378                HookPoint::PreInference,
379                total_tokens_in,
380                total_tokens_out,
381                total_cost,
382                turns_used - 1,
383                DurationMs::from(start.elapsed()),
384            );
385            if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
386                return Ok(Self::make_output(
387                    parts_to_content(&last_content),
388                    ExitReason::ObserverHalt { reason },
389                    self.build_metadata(
390                        total_tokens_in,
391                        total_tokens_out,
392                        total_cost,
393                        turns_used,
394                        tool_records,
395                        DurationMs::from(start.elapsed()),
396                    ),
397                    effects,
398                ));
399            }
400
401            // 2. Build ProviderRequest
402            let request = ProviderRequest {
403                model: config.model.clone(),
404                messages: messages.clone(),
405                tools: tools.clone(),
406                max_tokens: Some(config.max_tokens),
407                temperature: None,
408                system: Some(config.system.clone()),
409                extra: input.metadata.clone(),
410            };
411
412            // 3. Call provider
413            let response = self.provider.complete(request).await.map_err(|e| {
414                if e.is_retryable() {
415                    OperatorError::Retryable(e.to_string())
416                } else {
417                    OperatorError::Model(e.to_string())
418                }
419            })?;
420
421            // 4. Hook: PostInference
422            let mut hook_ctx = self.build_hook_context(
423                HookPoint::PostInference,
424                total_tokens_in + response.usage.input_tokens,
425                total_tokens_out + response.usage.output_tokens,
426                total_cost + response.cost.unwrap_or(Decimal::ZERO),
427                turns_used,
428                DurationMs::from(start.elapsed()),
429            );
430            hook_ctx.model_output = Some(parts_to_content(&response.content));
431            if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
432                return Ok(Self::make_output(
433                    parts_to_content(&response.content),
434                    ExitReason::ObserverHalt { reason },
435                    self.build_metadata(
436                        total_tokens_in + response.usage.input_tokens,
437                        total_tokens_out + response.usage.output_tokens,
438                        total_cost + response.cost.unwrap_or(Decimal::ZERO),
439                        turns_used,
440                        tool_records,
441                        DurationMs::from(start.elapsed()),
442                    ),
443                    effects,
444                ));
445            }
446
447            // 5. Aggregate tokens + cost
448            total_tokens_in += response.usage.input_tokens;
449            total_tokens_out += response.usage.output_tokens;
450            if let Some(cost) = response.cost {
451                total_cost += cost;
452            }
453
454            last_content.clone_from(&response.content);
455
456            // 6. Check StopReason
457            match response.stop_reason {
458                StopReason::MaxTokens => {
459                    return Err(OperatorError::Model("output truncated (max_tokens)".into()));
460                }
461                StopReason::ContentFilter => {
462                    return Err(OperatorError::Model("content filtered".into()));
463                }
464                StopReason::EndTurn => {
465                    return Ok(Self::make_output(
466                        parts_to_content(&response.content),
467                        ExitReason::Complete,
468                        self.build_metadata(
469                            total_tokens_in,
470                            total_tokens_out,
471                            total_cost,
472                            turns_used,
473                            tool_records,
474                            DurationMs::from(start.elapsed()),
475                        ),
476                        effects,
477                    ));
478                }
479                StopReason::ToolUse => {
480                    // Continue to tool execution below
481                }
482            }
483
484            // 7. Tool execution
485            // Add assistant message to context
486            messages.push(ProviderMessage {
487                role: Role::Assistant,
488                content: response.content.clone(),
489            });
490
491            let _tool_uses: Vec<(String, String, serde_json::Value)> = response
492                .content
493                .iter()
494                .filter_map(|part| match part {
495                    ContentPart::ToolUse { id, name, input } => {
496                        Some((id.clone(), name.clone(), input.clone()))
497                    }
498                    _ => None,
499                })
500                .collect();
501            let mut tool_results: Vec<ContentPart> = Vec::new();
502            // Use planner to decide batches. Build (id,name,input) vector first.
503            let planned = {
504                let calls: Vec<(String, String, serde_json::Value)> = response
505                    .content
506                    .iter()
507                    .filter_map(|part| match part {
508                        ContentPart::ToolUse { id, name, input } => {
509                            Some((id.clone(), name.clone(), input.clone()))
510                        }
511                        _ => None,
512                    })
513                    .collect();
514                self.planner.plan(&calls, self.decider.as_ref())
515            };
516
517            let mut _steered = false;
518            'batches: for batch in planned {
519                match batch {
520                    BatchItem::Shared(call_group) => {
521                        // Pre-batch steering poll
522                        if let Some(s) = &self.steering {
523                            let injected = s.drain();
524                            if !injected.is_empty() {
525                                messages.extend(injected);
526                                // All tools in this batch are skipped with placeholders
527                                for (id, name, _input) in call_group.into_iter() {
528                                    tool_results.push(ContentPart::ToolResult {
529                                        tool_use_id: id,
530                                        content: "Skipped due to steering".into(),
531                                        is_error: false,
532                                    });
533                                    tool_records.push(ToolCallRecord::new(
534                                        &name,
535                                        DurationMs::ZERO,
536                                        false,
537                                    ));
538                                }
539                                _steered = true;
540                                break 'batches;
541                            }
542                        }
543                        // Execute shared tools sequentially to allow steering to interrupt mid-batch
544                        let len = call_group.len();
545                        for idx in 0..len {
546                            // Pre-next-tool steering poll (after some tools completed)
547                            if idx > 0
548                                && let Some(s) = &self.steering
549                            {
550                                let injected = s.drain();
551                                if !injected.is_empty() {
552                                    messages.extend(injected);
553                                    for (rid, rname, _rinput) in
554                                        call_group.iter().skip(idx).cloned()
555                                    {
556                                        tool_results.push(ContentPart::ToolResult {
557                                            tool_use_id: rid,
558                                            content: "Skipped due to steering".into(),
559                                            is_error: false,
560                                        });
561                                        tool_records.push(ToolCallRecord::new(
562                                            &rname,
563                                            DurationMs::ZERO,
564                                            false,
565                                        ));
566                                    }
567                                    _steered = true;
568                                    _steered = true;
569                                }
570                            }
571                            let (id, name, tool_input) = call_group[idx].clone();
572                            // Effects handled immediately
573                            if EFFECT_TOOL_NAMES.contains(&name.as_str()) {
574                                if let Some(effect) = self.try_as_effect(&name, &tool_input) {
575                                    effects.push(effect);
576                                }
577                                tool_results.push(ContentPart::ToolResult {
578                                    tool_use_id: id,
579                                    content: format!("{name} effect recorded."),
580                                    is_error: false,
581                                });
582                                tool_records.push(ToolCallRecord::new(
583                                    &name,
584                                    DurationMs::ZERO,
585                                    true,
586                                ));
587                            } else {
588                                // Hook: PreToolUse
589                                let mut actual_input = tool_input.clone();
590                                let mut hook_ctx = HookContext::new(HookPoint::PreToolUse);
591                                hook_ctx.tool_name = Some(name.clone());
592                                hook_ctx.tool_input = Some(tool_input.clone());
593                                hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
594                                hook_ctx.cost = total_cost;
595                                hook_ctx.turns_completed = turns_used;
596                                hook_ctx.elapsed = DurationMs::from(start.elapsed());
597                                match self.hooks.dispatch(&hook_ctx).await {
598                                    HookAction::Halt { reason } => {
599                                        return Ok(Self::make_output(
600                                            parts_to_content(&last_content),
601                                            ExitReason::ObserverHalt { reason },
602                                            self.build_metadata(
603                                                total_tokens_in,
604                                                total_tokens_out,
605                                                total_cost,
606                                                turns_used,
607                                                tool_records,
608                                                DurationMs::from(start.elapsed()),
609                                            ),
610                                            effects,
611                                        ));
612                                    }
613                                    HookAction::SkipTool { reason } => {
614                                        tool_results.push(ContentPart::ToolResult {
615                                            tool_use_id: id,
616                                            content: format!("Skipped: {reason}"),
617                                            is_error: false,
618                                        });
619                                        tool_records.push(ToolCallRecord::new(
620                                            &name,
621                                            DurationMs::ZERO,
622                                            false,
623                                        ));
624                                        continue;
625                                    }
626                                    HookAction::ModifyToolInput { new_input } => {
627                                        actual_input = new_input;
628                                    }
629                                    HookAction::Continue => {}
630                                    _ => {}
631                                }
632                                // Execute tool (streaming if supported)
633                                let tool_start = Instant::now();
634                                // Defaults for non-streaming path
635                                let (mut result_content, is_error, success, duration) = match self
636                                    .tools
637                                    .get(&name)
638                                {
639                                    Some(tool) => {
640                                        if let Some(stream) = tool.maybe_streaming() {
641                                            // Collect chunks during streaming
642                                            let chunks_arc =
643                                                std::sync::Arc::new(std::sync::Mutex::new(Vec::<
644                                                    String,
645                                                >::new(
646                                                )));
647                                            let chunks_cb = chunks_arc.clone();
648                                            let res = stream
649                                                .call_streaming(
650                                                    actual_input.clone(),
651                                                    Box::new(move |c: &str| {
652                                                        if let Ok(mut v) = chunks_cb.lock() {
653                                                            v.push(c.to_string());
654                                                        }
655                                                    }),
656                                                )
657                                                .await;
658                                            let tool_duration =
659                                                DurationMs::from(tool_start.elapsed());
660                                            // Dispatch chunk updates in order, ignoring actions/errors
661                                            if let Ok(chunks) =
662                                                std::sync::Arc::try_unwrap(chunks_arc)
663                                                    .map(|m| m.into_inner().unwrap())
664                                            {
665                                                for ch in &chunks {
666                                                    let mut uctx = HookContext::new(
667                                                        HookPoint::ToolExecutionUpdate,
668                                                    );
669                                                    uctx.tool_name = Some(name.clone());
670                                                    uctx.tool_chunk = Some(ch.clone());
671                                                    uctx.tokens_used =
672                                                        total_tokens_in + total_tokens_out;
673                                                    uctx.cost = total_cost;
674                                                    uctx.turns_completed = turns_used;
675                                                    uctx.elapsed =
676                                                        DurationMs::from(start.elapsed());
677                                                    let _ = self.hooks.dispatch(&uctx).await;
678                                                }
679                                                match res {
680                                                    Ok(()) => (
681                                                        chunks.concat(),
682                                                        false,
683                                                        true,
684                                                        tool_duration,
685                                                    ),
686                                                    Err(e) => {
687                                                        (e.to_string(), true, false, tool_duration)
688                                                    }
689                                                }
690                                            } else {
691                                                // Fallback if Arc could not be unwrapped
692                                                match res {
693                                                    Ok(()) => {
694                                                        (String::new(), false, true, tool_duration)
695                                                    }
696                                                    Err(e) => {
697                                                        (e.to_string(), true, false, tool_duration)
698                                                    }
699                                                }
700                                            }
701                                        } else {
702                                            // Non-streaming
703                                            match tool.call(actual_input.clone()).await {
704                                                Ok(value) => (
705                                                    serde_json::to_string(&value)
706                                                        .unwrap_or_default(),
707                                                    false,
708                                                    true,
709                                                    DurationMs::from(tool_start.elapsed()),
710                                                ),
711                                                Err(e) => (
712                                                    e.to_string(),
713                                                    true,
714                                                    false,
715                                                    DurationMs::from(tool_start.elapsed()),
716                                                ),
717                                            }
718                                        }
719                                    }
720                                    None => (
721                                        neuron_tool::ToolError::NotFound(name.clone()).to_string(),
722                                        true,
723                                        false,
724                                        DurationMs::from(tool_start.elapsed()),
725                                    ),
726                                };
727                                // PostToolUse hook
728                                let mut hook_ctx = HookContext::new(HookPoint::PostToolUse);
729                                hook_ctx.tool_name = Some(name.clone());
730                                hook_ctx.tool_result = Some(result_content.clone());
731                                hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
732                                hook_ctx.cost = total_cost;
733                                hook_ctx.turns_completed = turns_used;
734                                hook_ctx.elapsed = DurationMs::from(start.elapsed());
735                                match self.hooks.dispatch(&hook_ctx).await {
736                                    HookAction::Halt { reason } => {
737                                        return Ok(Self::make_output(
738                                            parts_to_content(&last_content),
739                                            ExitReason::ObserverHalt { reason },
740                                            self.build_metadata(
741                                                total_tokens_in,
742                                                total_tokens_out,
743                                                total_cost,
744                                                turns_used,
745                                                tool_records,
746                                                DurationMs::from(start.elapsed()),
747                                            ),
748                                            effects,
749                                        ));
750                                    }
751                                    HookAction::ModifyToolOutput { new_output } => {
752                                        result_content = new_output.to_string();
753                                    }
754                                    _ => {}
755                                }
756                                tool_results.push(ContentPart::ToolResult {
757                                    tool_use_id: id,
758                                    content: result_content,
759                                    is_error,
760                                });
761                                tool_records.push(ToolCallRecord::new(name, duration, success));
762                            }
763                            // Mid-batch steering poll — skip remaining tools in this batch
764                            if let Some(s) = &self.steering {
765                                let injected = s.drain();
766                                if !injected.is_empty() {
767                                    messages.extend(injected);
768                                }
769                                if idx + 1 < len {
770                                    for (rid, rname, _rinput) in
771                                        call_group.iter().skip(idx + 1).cloned()
772                                    {
773                                        tool_results.push(ContentPart::ToolResult {
774                                            tool_use_id: rid,
775                                            content: "Skipped due to steering".into(),
776                                            is_error: false,
777                                        });
778                                        tool_records.push(ToolCallRecord::new(
779                                            &rname,
780                                            DurationMs::ZERO,
781                                            false,
782                                        ));
783                                    }
784                                    break 'batches;
785                                }
786                            }
787                        }
788                        // Post-batch steering poll
789                        if let Some(s) = &self.steering {
790                            let injected = s.drain();
791                            if !injected.is_empty() {
792                                messages.extend(injected);
793                                _steered = true;
794                                break 'batches;
795                            }
796                        }
797                    }
798                    BatchItem::Exclusive((id, name, tool_input)) => {
799                        // Pre-exclusive steering poll
800                        if let Some(s) = &self.steering {
801                            let injected = s.drain();
802                            if !injected.is_empty() {
803                                messages.extend(injected);
804                                tool_results.push(ContentPart::ToolResult {
805                                    tool_use_id: id,
806                                    content: "Skipped due to steering".into(),
807                                    is_error: false,
808                                });
809                                tool_records.push(ToolCallRecord::new(
810                                    &name,
811                                    DurationMs::ZERO,
812                                    false,
813                                ));
814                                _steered = true;
815                                break 'batches;
816                            }
817                        }
818                        if EFFECT_TOOL_NAMES.contains(&name.as_str()) {
819                            if let Some(effect) = self.try_as_effect(&name, &tool_input) {
820                                effects.push(effect);
821                            }
822                            tool_results.push(ContentPart::ToolResult {
823                                tool_use_id: id,
824                                content: format!("{name} effect recorded."),
825                                is_error: false,
826                            });
827                            tool_records.push(ToolCallRecord::new(&name, DurationMs::ZERO, true));
828                            continue;
829                        }
830                        let mut actual_input = tool_input.clone();
831                        let mut hook_ctx = HookContext::new(HookPoint::PreToolUse);
832                        hook_ctx.tool_name = Some(name.clone());
833                        hook_ctx.tool_input = Some(tool_input.clone());
834                        hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
835                        hook_ctx.cost = total_cost;
836                        hook_ctx.turns_completed = turns_used;
837                        hook_ctx.elapsed = DurationMs::from(start.elapsed());
838                        match self.hooks.dispatch(&hook_ctx).await {
839                            HookAction::Halt { reason } => {
840                                return Ok(Self::make_output(
841                                    parts_to_content(&last_content),
842                                    ExitReason::ObserverHalt { reason },
843                                    self.build_metadata(
844                                        total_tokens_in,
845                                        total_tokens_out,
846                                        total_cost,
847                                        turns_used,
848                                        tool_records,
849                                        DurationMs::from(start.elapsed()),
850                                    ),
851                                    effects,
852                                ));
853                            }
854                            HookAction::SkipTool { reason } => {
855                                tool_results.push(ContentPart::ToolResult {
856                                    tool_use_id: id,
857                                    content: format!("Skipped: {reason}"),
858                                    is_error: false,
859                                });
860                                tool_records.push(ToolCallRecord::new(
861                                    &name,
862                                    DurationMs::ZERO,
863                                    false,
864                                ));
865                                continue;
866                            }
867                            HookAction::ModifyToolInput { new_input } => {
868                                actual_input = new_input;
869                            }
870                            HookAction::Continue => {}
871                            _ => {}
872                        }
873                        let tool_start = Instant::now();
874                        // Execute tool (streaming if supported)
875                        let (mut result_content, is_error, success, tool_duration) = match self
876                            .tools
877                            .get(&name)
878                        {
879                            Some(tool) => {
880                                if let Some(stream) = tool.maybe_streaming() {
881                                    let chunks_arc = std::sync::Arc::new(std::sync::Mutex::new(
882                                        Vec::<String>::new(),
883                                    ));
884                                    let chunks_cb = chunks_arc.clone();
885                                    let res = stream
886                                        .call_streaming(
887                                            actual_input.clone(),
888                                            Box::new(move |c: &str| {
889                                                if let Ok(mut v) = chunks_cb.lock() {
890                                                    v.push(c.to_string());
891                                                }
892                                            }),
893                                        )
894                                        .await;
895                                    let dur = DurationMs::from(tool_start.elapsed());
896                                    if let Ok(chunks) = std::sync::Arc::try_unwrap(chunks_arc)
897                                        .map(|m| m.into_inner().unwrap())
898                                    {
899                                        for ch in &chunks {
900                                            let mut uctx =
901                                                HookContext::new(HookPoint::ToolExecutionUpdate);
902                                            uctx.tool_name = Some(name.clone());
903                                            uctx.tool_chunk = Some(ch.clone());
904                                            uctx.tokens_used = total_tokens_in + total_tokens_out;
905                                            uctx.cost = total_cost;
906                                            uctx.turns_completed = turns_used;
907                                            uctx.elapsed = DurationMs::from(start.elapsed());
908                                            let _ = self.hooks.dispatch(&uctx).await;
909                                        }
910                                        match res {
911                                            Ok(()) => (chunks.concat(), false, true, dur),
912                                            Err(e) => (e.to_string(), true, false, dur),
913                                        }
914                                    } else {
915                                        match res {
916                                            Ok(()) => (String::new(), false, true, dur),
917                                            Err(e) => (e.to_string(), true, false, dur),
918                                        }
919                                    }
920                                } else {
921                                    match tool.call(actual_input.clone()).await {
922                                        Ok(value) => (
923                                            serde_json::to_string(&value).unwrap_or_default(),
924                                            false,
925                                            true,
926                                            DurationMs::from(tool_start.elapsed()),
927                                        ),
928                                        Err(e) => (
929                                            e.to_string(),
930                                            true,
931                                            false,
932                                            DurationMs::from(tool_start.elapsed()),
933                                        ),
934                                    }
935                                }
936                            }
937                            None => (
938                                neuron_tool::ToolError::NotFound(name.clone()).to_string(),
939                                true,
940                                false,
941                                DurationMs::from(tool_start.elapsed()),
942                            ),
943                        };
944                        let mut hook_ctx = HookContext::new(HookPoint::PostToolUse);
945                        hook_ctx.tool_name = Some(name.clone());
946                        hook_ctx.tool_result = Some(result_content.clone());
947                        hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
948                        hook_ctx.cost = total_cost;
949                        hook_ctx.turns_completed = turns_used;
950                        hook_ctx.elapsed = DurationMs::from(start.elapsed());
951                        match self.hooks.dispatch(&hook_ctx).await {
952                            HookAction::Halt { reason } => {
953                                return Ok(Self::make_output(
954                                    parts_to_content(&last_content),
955                                    ExitReason::ObserverHalt { reason },
956                                    self.build_metadata(
957                                        total_tokens_in,
958                                        total_tokens_out,
959                                        total_cost,
960                                        turns_used,
961                                        tool_records,
962                                        DurationMs::from(start.elapsed()),
963                                    ),
964                                    effects,
965                                ));
966                            }
967                            HookAction::ModifyToolOutput { new_output } => {
968                                result_content = new_output.to_string();
969                            }
970                            _ => {}
971                        }
972                        tool_results.push(ContentPart::ToolResult {
973                            tool_use_id: id,
974                            content: result_content,
975                            is_error,
976                        });
977                        tool_records.push(ToolCallRecord::new(name, tool_duration, success));
978                        // Post-exclusive steering poll
979                        if let Some(s) = &self.steering {
980                            let injected = s.drain();
981                            if !injected.is_empty() {
982                                messages.extend(injected);
983                                _steered = true;
984                                break 'batches;
985                            }
986                        }
987                    }
988                }
989            }
990
991            // Add tool results as user message
992            messages.push(ProviderMessage {
993                role: Role::User,
994                content: tool_results,
995            });
996
997            // 8. Check limits
998            if turns_used >= config.max_turns {
999                return Ok(Self::make_output(
1000                    parts_to_content(&last_content),
1001                    ExitReason::MaxTurns,
1002                    self.build_metadata(
1003                        total_tokens_in,
1004                        total_tokens_out,
1005                        total_cost,
1006                        turns_used,
1007                        tool_records,
1008                        DurationMs::from(start.elapsed()),
1009                    ),
1010                    effects,
1011                ));
1012            }
1013
1014            if let Some(max_cost) = &config.max_cost
1015                && total_cost >= *max_cost
1016            {
1017                return Ok(Self::make_output(
1018                    parts_to_content(&last_content),
1019                    ExitReason::BudgetExhausted,
1020                    self.build_metadata(
1021                        total_tokens_in,
1022                        total_tokens_out,
1023                        total_cost,
1024                        turns_used,
1025                        tool_records,
1026                        DurationMs::from(start.elapsed()),
1027                    ),
1028                    effects,
1029                ));
1030            }
1031
1032            if let Some(max_duration) = &config.max_duration
1033                && start.elapsed() >= max_duration.to_std()
1034            {
1035                return Ok(Self::make_output(
1036                    parts_to_content(&last_content),
1037                    ExitReason::Timeout,
1038                    self.build_metadata(
1039                        total_tokens_in,
1040                        total_tokens_out,
1041                        total_cost,
1042                        turns_used,
1043                        tool_records,
1044                        DurationMs::from(start.elapsed()),
1045                    ),
1046                    effects,
1047                ));
1048            }
1049
1050            // 9. Hook: ExitCheck
1051            let hook_ctx = self.build_hook_context(
1052                HookPoint::ExitCheck,
1053                total_tokens_in,
1054                total_tokens_out,
1055                total_cost,
1056                turns_used,
1057                DurationMs::from(start.elapsed()),
1058            );
1059            if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
1060                return Ok(Self::make_output(
1061                    parts_to_content(&last_content),
1062                    ExitReason::ObserverHalt { reason },
1063                    self.build_metadata(
1064                        total_tokens_in,
1065                        total_tokens_out,
1066                        total_cost,
1067                        turns_used,
1068                        tool_records,
1069                        DurationMs::from(start.elapsed()),
1070                    ),
1071                    effects,
1072                ));
1073            }
1074
1075            // 10. Context compaction
1076            let limit = config.max_tokens as usize * 4;
1077            if self.context_strategy.should_compact(&messages, limit) {
1078                messages = self.context_strategy.compact(messages);
1079            }
1080
1081            // 11. Loop repeats
1082        }
1083    }
1084}
1085
1086/// Schemas for effect tools that the model can call.
1087fn effect_tool_schemas() -> Vec<ToolSchema> {
1088    vec![
1089        ToolSchema {
1090            name: "write_memory".into(),
1091            description: "Write a value to persistent memory.".into(),
1092            input_schema: serde_json::json!({
1093                "type": "object",
1094                "properties": {
1095                    "scope": {"type": "string", "description": "Memory scope (e.g. 'global', 'session:id')"},
1096                    "key": {"type": "string", "description": "Memory key"},
1097                    "value": {"description": "Value to store"}
1098                },
1099                "required": ["scope", "key", "value"]
1100            }),
1101        },
1102        ToolSchema {
1103            name: "delete_memory".into(),
1104            description: "Delete a value from persistent memory.".into(),
1105            input_schema: serde_json::json!({
1106                "type": "object",
1107                "properties": {
1108                    "scope": {"type": "string", "description": "Memory scope"},
1109                    "key": {"type": "string", "description": "Memory key"}
1110                },
1111                "required": ["scope", "key"]
1112            }),
1113        },
1114        ToolSchema {
1115            name: "delegate".into(),
1116            description: "Delegate a task to another agent.".into(),
1117            input_schema: serde_json::json!({
1118                "type": "object",
1119                "properties": {
1120                    "agent": {"type": "string", "description": "Agent ID to delegate to"},
1121                    "message": {"type": "string", "description": "Task description for the agent"}
1122                },
1123                "required": ["agent", "message"]
1124            }),
1125        },
1126        ToolSchema {
1127            name: "handoff".into(),
1128            description: "Hand off the conversation to another agent.".into(),
1129            input_schema: serde_json::json!({
1130                "type": "object",
1131                "properties": {
1132                    "agent": {"type": "string", "description": "Agent ID to hand off to"},
1133                    "state": {"description": "State to pass to the next agent"}
1134                },
1135                "required": ["agent"]
1136            }),
1137        },
1138        ToolSchema {
1139            name: "signal".into(),
1140            description: "Send a signal to another workflow.".into(),
1141            input_schema: serde_json::json!({
1142                "type": "object",
1143                "properties": {
1144                    "target": {"type": "string", "description": "Target workflow ID"},
1145                    "signal_type": {"type": "string", "description": "Signal type identifier"},
1146                    "data": {"description": "Signal payload data"}
1147                },
1148                "required": ["target"]
1149            }),
1150        },
1151    ]
1152}
1153
1154/// Parse a scope string into a layer0 Scope.
1155fn parse_scope(s: &str) -> Scope {
1156    if s == "global" {
1157        return Scope::Global;
1158    }
1159    if let Some(id) = s.strip_prefix("session:") {
1160        return Scope::Session(layer0::SessionId::new(id));
1161    }
1162    if let Some(id) = s.strip_prefix("workflow:") {
1163        return Scope::Workflow(layer0::WorkflowId::new(id));
1164    }
1165    Scope::Custom(s.to_string())
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170    use super::*;
1171    use neuron_hooks::HookRegistry;
1172    use neuron_tool::ToolRegistry;
1173    use neuron_turn::context::NoCompaction;
1174    use neuron_turn::provider::ProviderError;
1175    use serde_json::json;
1176    use std::collections::VecDeque;
1177    use std::sync::Mutex;
1178    use std::sync::atomic::{AtomicUsize, Ordering};
1179
1180    // -- Mock Provider --
1181
1182    struct MockProvider {
1183        responses: Mutex<VecDeque<ProviderResponse>>,
1184        call_count: AtomicUsize,
1185    }
1186
1187    impl MockProvider {
1188        fn new(responses: Vec<ProviderResponse>) -> Self {
1189            Self {
1190                responses: Mutex::new(responses.into()),
1191                call_count: AtomicUsize::new(0),
1192            }
1193        }
1194    }
1195
1196    impl Provider for MockProvider {
1197        fn complete(
1198            &self,
1199            _request: ProviderRequest,
1200        ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1201        {
1202            self.call_count.fetch_add(1, Ordering::SeqCst);
1203            let response = self
1204                .responses
1205                .lock()
1206                .unwrap()
1207                .pop_front()
1208                .expect("MockProvider: no more responses queued");
1209            async move { Ok(response) }
1210        }
1211    }
1212
1213    // -- Mock StateReader --
1214
1215    struct NullStateReader;
1216
1217    #[async_trait]
1218    impl layer0::StateReader for NullStateReader {
1219        async fn read(
1220            &self,
1221            _scope: &Scope,
1222            _key: &str,
1223        ) -> Result<Option<serde_json::Value>, layer0::StateError> {
1224            Ok(None)
1225        }
1226        async fn list(
1227            &self,
1228            _scope: &Scope,
1229            _prefix: &str,
1230        ) -> Result<Vec<String>, layer0::StateError> {
1231            Ok(vec![])
1232        }
1233        async fn search(
1234            &self,
1235            _scope: &Scope,
1236            _query: &str,
1237            _limit: usize,
1238        ) -> Result<Vec<layer0::state::SearchResult>, layer0::StateError> {
1239            Ok(vec![])
1240        }
1241    }
1242
1243    // -- Mock Tool --
1244
1245    struct EchoTool;
1246
1247    impl neuron_tool::ToolDyn for EchoTool {
1248        fn name(&self) -> &str {
1249            "echo"
1250        }
1251        fn description(&self) -> &str {
1252            "Echoes input"
1253        }
1254        fn input_schema(&self) -> serde_json::Value {
1255            json!({"type": "object"})
1256        }
1257        fn call(
1258            &self,
1259            input: serde_json::Value,
1260        ) -> std::pin::Pin<
1261            Box<
1262                dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
1263                    + Send
1264                    + '_,
1265            >,
1266        > {
1267            Box::pin(async move { Ok(json!({"echoed": input})) })
1268        }
1269    }
1270
1271    // -- Helpers --
1272
1273    fn simple_text_response(text: &str) -> ProviderResponse {
1274        ProviderResponse {
1275            content: vec![ContentPart::Text {
1276                text: text.to_string(),
1277            }],
1278            stop_reason: StopReason::EndTurn,
1279            usage: TokenUsage {
1280                input_tokens: 10,
1281                output_tokens: 5,
1282                ..Default::default()
1283            },
1284            model: "mock-model".into(),
1285            cost: Some(Decimal::new(1, 4)), // $0.0001
1286            truncated: None,
1287        }
1288    }
1289
1290    fn tool_use_response(
1291        tool_id: &str,
1292        tool_name: &str,
1293        input: serde_json::Value,
1294    ) -> ProviderResponse {
1295        ProviderResponse {
1296            content: vec![ContentPart::ToolUse {
1297                id: tool_id.to_string(),
1298                name: tool_name.to_string(),
1299                input,
1300            }],
1301            stop_reason: StopReason::ToolUse,
1302            usage: TokenUsage {
1303                input_tokens: 10,
1304                output_tokens: 15,
1305                ..Default::default()
1306            },
1307            model: "mock-model".into(),
1308            cost: Some(Decimal::new(2, 4)), // $0.0002
1309            truncated: None,
1310        }
1311    }
1312
1313    fn make_op<P: Provider>(provider: P) -> ReactOperator<P> {
1314        ReactOperator::new(
1315            provider,
1316            ToolRegistry::new(),
1317            Box::new(NoCompaction),
1318            HookRegistry::new(),
1319            Arc::new(NullStateReader),
1320            ReactConfig::default(),
1321        )
1322    }
1323
1324    fn make_op_with_tools<P: Provider>(provider: P, tools: ToolRegistry) -> ReactOperator<P> {
1325        ReactOperator::new(
1326            provider,
1327            tools,
1328            Box::new(NoCompaction),
1329            HookRegistry::new(),
1330            Arc::new(NullStateReader),
1331            ReactConfig::default(),
1332        )
1333    }
1334
1335    fn simple_input(text: &str) -> OperatorInput {
1336        OperatorInput::new(Content::text(text), layer0::operator::TriggerType::User)
1337    }
1338
1339    // -- Tests --
1340
1341    #[tokio::test]
1342    async fn simple_completion() {
1343        let provider = MockProvider::new(vec![simple_text_response("Hello!")]);
1344        let op = make_op(provider);
1345
1346        let output = op.execute(simple_input("Hi")).await.unwrap();
1347
1348        assert_eq!(output.exit_reason, ExitReason::Complete);
1349        assert_eq!(output.message.as_text().unwrap(), "Hello!");
1350        assert_eq!(output.metadata.turns_used, 1);
1351        assert_eq!(output.metadata.tokens_in, 10);
1352        assert_eq!(output.metadata.tokens_out, 5);
1353        assert!(output.effects.is_empty());
1354    }
1355
1356    #[tokio::test]
1357    async fn tool_use_and_followup() {
1358        let provider = MockProvider::new(vec![
1359            tool_use_response("tu_1", "echo", json!({"msg": "test"})),
1360            simple_text_response("Done."),
1361        ]);
1362        let mut tools = ToolRegistry::new();
1363        tools.register(Arc::new(EchoTool));
1364        let op = make_op_with_tools(provider, tools);
1365
1366        let output = op.execute(simple_input("Use echo")).await.unwrap();
1367
1368        assert_eq!(output.exit_reason, ExitReason::Complete);
1369        assert_eq!(output.metadata.turns_used, 2);
1370        assert_eq!(output.metadata.tools_called.len(), 1);
1371        assert_eq!(output.metadata.tools_called[0].name, "echo");
1372    }
1373
1374    #[tokio::test]
1375    async fn unknown_tool_returns_error_result() {
1376        let provider = MockProvider::new(vec![
1377            tool_use_response("tu_1", "nonexistent_tool", json!({})),
1378            simple_text_response("Got an error."),
1379        ]);
1380        let op = make_op(provider);
1381
1382        // Should not panic — unknown tool produces an error result but loop continues
1383        let output = op.execute(simple_input("Use nonexistent")).await.unwrap();
1384        assert_eq!(output.exit_reason, ExitReason::Complete);
1385        // The tool call was recorded
1386        assert_eq!(output.metadata.tools_called.len(), 1);
1387    }
1388
1389    #[tokio::test]
1390    async fn max_turns_enforced() {
1391        // Provider always returns ToolUse — loop should hit max_turns limit
1392        let provider = MockProvider::new(vec![
1393            tool_use_response("tu_1", "echo", json!({})),
1394            tool_use_response("tu_2", "echo", json!({})),
1395            tool_use_response("tu_3", "echo", json!({})),
1396            simple_text_response("never reached"),
1397        ]);
1398        let mut tools = ToolRegistry::new();
1399        tools.register(Arc::new(EchoTool));
1400
1401        let mut op = ReactOperator::new(
1402            provider,
1403            tools,
1404            Box::new(NoCompaction),
1405            HookRegistry::new(),
1406            Arc::new(NullStateReader),
1407            ReactConfig {
1408                default_max_turns: 2,
1409                ..Default::default()
1410            },
1411        );
1412        // Avoid unused warning
1413        let _ = &mut op;
1414
1415        let op = ReactOperator::new(
1416            MockProvider::new(vec![
1417                tool_use_response("tu_1", "echo", json!({})),
1418                tool_use_response("tu_2", "echo", json!({})),
1419                simple_text_response("never reached"),
1420            ]),
1421            {
1422                let mut t = ToolRegistry::new();
1423                t.register(Arc::new(EchoTool));
1424                t
1425            },
1426            Box::new(NoCompaction),
1427            HookRegistry::new(),
1428            Arc::new(NullStateReader),
1429            ReactConfig {
1430                default_max_turns: 2,
1431                ..Default::default()
1432            },
1433        );
1434
1435        let output = op.execute(simple_input("loop")).await.unwrap();
1436        assert_eq!(output.exit_reason, ExitReason::MaxTurns);
1437        assert_eq!(output.metadata.turns_used, 2);
1438    }
1439
1440    #[tokio::test]
1441    async fn budget_exhausted() {
1442        // Two calls, each costing $0.0001, with max_cost = $0.00015
1443        let provider = MockProvider::new(vec![
1444            tool_use_response("tu_1", "echo", json!({})),
1445            simple_text_response("Done"),
1446        ]);
1447        let mut tools = ToolRegistry::new();
1448        tools.register(Arc::new(EchoTool));
1449        let op = ReactOperator::new(
1450            provider,
1451            tools,
1452            Box::new(NoCompaction),
1453            HookRegistry::new(),
1454            Arc::new(NullStateReader),
1455            ReactConfig::default(),
1456        );
1457
1458        let mut input = simple_input("spend");
1459        let mut tc = layer0::operator::OperatorConfig::default();
1460        tc.max_cost = Some(Decimal::new(15, 5)); // $0.00015
1461        input.config = Some(tc);
1462
1463        let output = op.execute(input).await.unwrap();
1464        // First call costs $0.0002 > $0.00015, so BudgetExhausted after second call
1465        assert_eq!(output.exit_reason, ExitReason::BudgetExhausted);
1466    }
1467
1468    #[tokio::test]
1469    async fn max_tokens_returns_model_error() {
1470        let provider = MockProvider::new(vec![ProviderResponse {
1471            content: vec![],
1472            stop_reason: StopReason::MaxTokens,
1473            usage: TokenUsage::default(),
1474            model: "mock".into(),
1475            cost: None,
1476            truncated: None,
1477        }]);
1478        let op = make_op(provider);
1479
1480        let result = op.execute(simple_input("Hi")).await;
1481        assert!(result.is_err());
1482        match result.unwrap_err() {
1483            OperatorError::Model(msg) => assert!(msg.contains("max_tokens")),
1484            other => panic!("expected OperatorError::Model, got {:?}", other),
1485        }
1486    }
1487
1488    #[tokio::test]
1489    async fn content_filter_returns_model_error() {
1490        let provider = MockProvider::new(vec![ProviderResponse {
1491            content: vec![],
1492            stop_reason: StopReason::ContentFilter,
1493            usage: TokenUsage::default(),
1494            model: "mock".into(),
1495            cost: None,
1496            truncated: None,
1497        }]);
1498        let op = make_op(provider);
1499
1500        let result = op.execute(simple_input("Hi")).await;
1501        assert!(result.is_err());
1502        match result.unwrap_err() {
1503            OperatorError::Model(msg) => assert!(msg.contains("content filtered")),
1504            other => panic!("expected OperatorError::Model, got {:?}", other),
1505        }
1506    }
1507
1508    #[tokio::test]
1509    async fn cost_aggregated_across_turns() {
1510        let provider = MockProvider::new(vec![
1511            tool_use_response("tu_1", "echo", json!({})),
1512            simple_text_response("Done"),
1513        ]);
1514        let mut tools = ToolRegistry::new();
1515        tools.register(Arc::new(EchoTool));
1516        let op = make_op_with_tools(provider, tools);
1517
1518        let output = op.execute(simple_input("Hi")).await.unwrap();
1519
1520        // First call: $0.0002, second call: $0.0001
1521        assert_eq!(output.metadata.cost, Decimal::new(3, 4));
1522        assert_eq!(output.metadata.tokens_in, 20);
1523        assert_eq!(output.metadata.tokens_out, 20);
1524    }
1525
1526    #[tokio::test]
1527    async fn operator_config_overrides_defaults() {
1528        let provider = MockProvider::new(vec![simple_text_response("Hi")]);
1529        let op = make_op(provider);
1530
1531        let mut input = simple_input("test");
1532        let mut tc = layer0::operator::OperatorConfig::default();
1533        tc.system_addendum = Some("Be concise.".into());
1534        tc.model = Some("custom-model".into());
1535        tc.max_turns = Some(5);
1536        input.config = Some(tc);
1537
1538        let output = op.execute(input).await.unwrap();
1539        assert_eq!(output.exit_reason, ExitReason::Complete);
1540    }
1541
1542    #[tokio::test]
1543    async fn effect_tool_write_memory() {
1544        let provider = MockProvider::new(vec![
1545            // Model calls write_memory
1546            ProviderResponse {
1547                content: vec![ContentPart::ToolUse {
1548                    id: "tu_1".into(),
1549                    name: "write_memory".into(),
1550                    input: json!({"scope": "global", "key": "test", "value": "hello"}),
1551                }],
1552                stop_reason: StopReason::ToolUse,
1553                usage: TokenUsage {
1554                    input_tokens: 10,
1555                    output_tokens: 5,
1556                    ..Default::default()
1557                },
1558                model: "mock".into(),
1559                cost: None,
1560                truncated: None,
1561            },
1562            simple_text_response("Memory written."),
1563        ]);
1564        let op = make_op(provider);
1565
1566        let output = op.execute(simple_input("Write memory")).await.unwrap();
1567
1568        assert_eq!(output.effects.len(), 1);
1569        match &output.effects[0] {
1570            Effect::WriteMemory { key, .. } => assert_eq!(key, "test"),
1571            _ => panic!("expected WriteMemory"),
1572        }
1573    }
1574
1575    #[test]
1576    fn parse_scope_variants() {
1577        assert_eq!(parse_scope("global"), Scope::Global);
1578        assert_eq!(
1579            parse_scope("session:abc"),
1580            Scope::Session(layer0::SessionId::new("abc"))
1581        );
1582        assert_eq!(
1583            parse_scope("workflow:wf1"),
1584            Scope::Workflow(layer0::WorkflowId::new("wf1"))
1585        );
1586        match parse_scope("other") {
1587            Scope::Custom(s) => assert_eq!(s, "other"),
1588            _ => panic!("expected Custom"),
1589        }
1590    }
1591
1592    #[tokio::test]
1593    async fn effect_tool_delete_memory() {
1594        let provider = MockProvider::new(vec![
1595            ProviderResponse {
1596                content: vec![ContentPart::ToolUse {
1597                    id: "tu_1".into(),
1598                    name: "delete_memory".into(),
1599                    input: json!({"scope": "global", "key": "old_key"}),
1600                }],
1601                stop_reason: StopReason::ToolUse,
1602                usage: TokenUsage::default(),
1603                model: "mock".into(),
1604                cost: None,
1605                truncated: None,
1606            },
1607            simple_text_response("Deleted."),
1608        ]);
1609        let op = make_op(provider);
1610
1611        let output = op.execute(simple_input("Delete memory")).await.unwrap();
1612        assert_eq!(output.effects.len(), 1);
1613        match &output.effects[0] {
1614            Effect::DeleteMemory { key, .. } => assert_eq!(key, "old_key"),
1615            _ => panic!("expected DeleteMemory"),
1616        }
1617    }
1618
1619    #[tokio::test]
1620    async fn effect_tool_delegate() {
1621        let provider = MockProvider::new(vec![
1622            ProviderResponse {
1623                content: vec![ContentPart::ToolUse {
1624                    id: "tu_1".into(),
1625                    name: "delegate".into(),
1626                    input: json!({"agent": "helper", "message": "do this task"}),
1627                }],
1628                stop_reason: StopReason::ToolUse,
1629                usage: TokenUsage::default(),
1630                model: "mock".into(),
1631                cost: None,
1632                truncated: None,
1633            },
1634            simple_text_response("Delegated."),
1635        ]);
1636        let op = make_op(provider);
1637
1638        let output = op.execute(simple_input("Delegate task")).await.unwrap();
1639        assert_eq!(output.effects.len(), 1);
1640        match &output.effects[0] {
1641            Effect::Delegate { agent, input } => {
1642                assert_eq!(agent.as_str(), "helper");
1643                assert_eq!(input.message.as_text().unwrap(), "do this task");
1644            }
1645            _ => panic!("expected Delegate"),
1646        }
1647    }
1648
1649    #[tokio::test]
1650    async fn effect_tool_handoff() {
1651        let provider = MockProvider::new(vec![
1652            ProviderResponse {
1653                content: vec![ContentPart::ToolUse {
1654                    id: "tu_1".into(),
1655                    name: "handoff".into(),
1656                    input: json!({"agent": "specialist", "state": {"context": "data"}}),
1657                }],
1658                stop_reason: StopReason::ToolUse,
1659                usage: TokenUsage::default(),
1660                model: "mock".into(),
1661                cost: None,
1662                truncated: None,
1663            },
1664            simple_text_response("Handed off."),
1665        ]);
1666        let op = make_op(provider);
1667
1668        let output = op.execute(simple_input("Handoff")).await.unwrap();
1669        assert_eq!(output.effects.len(), 1);
1670        match &output.effects[0] {
1671            Effect::Handoff { agent, state } => {
1672                assert_eq!(agent.as_str(), "specialist");
1673                assert_eq!(state["context"], "data");
1674            }
1675            _ => panic!("expected Handoff"),
1676        }
1677    }
1678
1679    #[tokio::test]
1680    async fn effect_tool_signal() {
1681        let provider = MockProvider::new(vec![
1682            ProviderResponse {
1683                content: vec![ContentPart::ToolUse {
1684                    id: "tu_1".into(),
1685                    name: "signal".into(),
1686                    input: json!({"target": "workflow_1", "signal_type": "completed", "data": {"result": "ok"}}),
1687                }],
1688                stop_reason: StopReason::ToolUse,
1689                usage: TokenUsage::default(),
1690                model: "mock".into(),
1691                cost: None,
1692                truncated: None,
1693            },
1694            simple_text_response("Signal sent."),
1695        ]);
1696        let op = make_op(provider);
1697
1698        let output = op.execute(simple_input("Signal")).await.unwrap();
1699        assert_eq!(output.effects.len(), 1);
1700        match &output.effects[0] {
1701            Effect::Signal { target, payload } => {
1702                assert_eq!(target.as_str(), "workflow_1");
1703                assert_eq!(payload.signal_type, "completed");
1704            }
1705            _ => panic!("expected Signal"),
1706        }
1707    }
1708
1709    #[test]
1710    fn effect_tool_schemas_all_present() {
1711        let schemas = effect_tool_schemas();
1712        let names: Vec<&str> = schemas.iter().map(|s| s.name.as_str()).collect();
1713        assert!(names.contains(&"write_memory"));
1714        assert!(names.contains(&"delete_memory"));
1715        assert!(names.contains(&"delegate"));
1716        assert!(names.contains(&"handoff"));
1717        assert!(names.contains(&"signal"));
1718        assert_eq!(schemas.len(), 5);
1719    }
1720
1721    #[test]
1722    fn react_operator_implements_operator_trait() {
1723        // Compile-time check: ReactOperator<MockProvider> implements Operator
1724        fn _assert_operator<T: Operator>() {}
1725        _assert_operator::<ReactOperator<MockProvider>>();
1726    }
1727
1728    #[tokio::test]
1729    async fn react_operator_as_arc_dyn_operator() {
1730        // ReactOperator<P> can be used as Arc<dyn Operator>
1731        let provider = MockProvider::new(vec![simple_text_response("Hello!")]);
1732        let op: Arc<dyn Operator> = Arc::new(ReactOperator::new(
1733            provider,
1734            ToolRegistry::new(),
1735            Box::new(NoCompaction),
1736            HookRegistry::new(),
1737            Arc::new(NullStateReader),
1738            ReactConfig::default(),
1739        ));
1740
1741        let output = op.execute(simple_input("Hi")).await.unwrap();
1742        assert_eq!(output.exit_reason, ExitReason::Complete);
1743    }
1744
1745    #[tokio::test]
1746    async fn provider_retryable_error_maps_to_retryable() {
1747        struct ErrorProvider;
1748        impl Provider for ErrorProvider {
1749            #[allow(clippy::manual_async_fn)]
1750            fn complete(
1751                &self,
1752                _request: ProviderRequest,
1753            ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1754            {
1755                async { Err(ProviderError::RateLimited) }
1756            }
1757        }
1758
1759        let op = ReactOperator::new(
1760            ErrorProvider,
1761            ToolRegistry::new(),
1762            Box::new(NoCompaction),
1763            HookRegistry::new(),
1764            Arc::new(NullStateReader),
1765            ReactConfig::default(),
1766        );
1767
1768        let result = op.execute(simple_input("test")).await;
1769        assert!(matches!(result, Err(OperatorError::Retryable(_))));
1770    }
1771
1772    #[tokio::test]
1773    async fn provider_call_count() {
1774        let provider = MockProvider::new(vec![
1775            tool_use_response("tu_1", "echo", json!({})),
1776            tool_use_response("tu_2", "echo", json!({})),
1777            simple_text_response("Done"),
1778        ]);
1779        let call_count = std::sync::Arc::new(AtomicUsize::new(0));
1780
1781        struct CountingProvider {
1782            inner: MockProvider,
1783            count: std::sync::Arc<AtomicUsize>,
1784        }
1785        impl Provider for CountingProvider {
1786            #[allow(clippy::manual_async_fn)]
1787            fn complete(
1788                &self,
1789                request: ProviderRequest,
1790            ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1791            {
1792                self.count.fetch_add(1, Ordering::SeqCst);
1793                self.inner.complete(request)
1794            }
1795        }
1796
1797        let counting_provider = CountingProvider {
1798            inner: MockProvider::new(vec![
1799                tool_use_response("tu_1", "echo", json!({})),
1800                tool_use_response("tu_2", "echo", json!({})),
1801                simple_text_response("Done"),
1802            ]),
1803            count: call_count.clone(),
1804        };
1805
1806        let mut tools = ToolRegistry::new();
1807        tools.register(Arc::new(EchoTool));
1808        let op = make_op_with_tools(counting_provider, tools);
1809
1810        op.execute(simple_input("Multi-turn")).await.unwrap();
1811        // Only counting_provider was called — provider was called 3 times
1812        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1813        // The unused `provider` variable should not cause issues
1814        drop(provider);
1815    }
1816
1817    // -- Steering Mocks --
1818    struct MockSteering {
1819        seq: Mutex<VecDeque<Vec<ProviderMessage>>>,
1820        calls: AtomicUsize,
1821    }
1822    impl MockSteering {
1823        fn new(seq: Vec<Vec<ProviderMessage>>) -> Self {
1824            Self {
1825                seq: Mutex::new(seq.into()),
1826                calls: AtomicUsize::new(0),
1827            }
1828        }
1829        fn call_count(&self) -> usize {
1830            self.calls.load(Ordering::SeqCst)
1831        }
1832    }
1833    impl SteeringSource for MockSteering {
1834        fn drain(&self) -> Vec<ProviderMessage> {
1835            self.calls.fetch_add(1, Ordering::SeqCst);
1836            self.seq.lock().unwrap().pop_front().unwrap_or_default()
1837        }
1838    }
1839
1840    struct CountingEchoTool {
1841        hits: std::sync::Arc<AtomicUsize>,
1842    }
1843    impl CountingEchoTool {
1844        fn new(h: std::sync::Arc<AtomicUsize>) -> Self {
1845            Self { hits: h }
1846        }
1847    }
1848    impl neuron_tool::ToolDyn for CountingEchoTool {
1849        fn name(&self) -> &str {
1850            "echo"
1851        }
1852        fn description(&self) -> &str {
1853            "Echoes input (counting)"
1854        }
1855        fn input_schema(&self) -> serde_json::Value {
1856            json!({"type":"object"})
1857        }
1858        fn call(
1859            &self,
1860            input: serde_json::Value,
1861        ) -> std::pin::Pin<
1862            Box<
1863                dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
1864                    + Send
1865                    + '_,
1866            >,
1867        > {
1868            self.hits.fetch_add(1, Ordering::SeqCst);
1869            Box::pin(async move { Ok(json!({"echoed": input})) })
1870        }
1871    }
1872
1873    struct SharedOnlyDecider;
1874    impl ConcurrencyDecider for SharedOnlyDecider {
1875        fn concurrency(&self, tool_name: &str) -> Concurrency {
1876            if tool_name == "echo" {
1877                Concurrency::Shared
1878            } else {
1879                Concurrency::Exclusive
1880            }
1881        }
1882    }
1883
1884    fn user_msg(text: &str) -> ProviderMessage {
1885        ProviderMessage {
1886            role: Role::User,
1887            content: vec![ContentPart::Text { text: text.into() }],
1888        }
1889    }
1890
1891    #[tokio::test]
1892    async fn steering_skips_remaining_shared_batch() {
1893        // Provider returns two shared tool uses in one response
1894        let first = ProviderResponse {
1895            content: vec![
1896                ContentPart::ToolUse {
1897                    id: "t1".into(),
1898                    name: "echo".into(),
1899                    input: json!({"n":1}),
1900                },
1901                ContentPart::ToolUse {
1902                    id: "t2".into(),
1903                    name: "echo".into(),
1904                    input: json!({"n":2}),
1905                },
1906            ],
1907            stop_reason: StopReason::ToolUse,
1908            usage: TokenUsage {
1909                input_tokens: 10,
1910                output_tokens: 15,
1911                ..Default::default()
1912            },
1913            model: "mock".into(),
1914            cost: None,
1915            truncated: None,
1916        };
1917        let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
1918        let hits = std::sync::Arc::new(AtomicUsize::new(0));
1919        let mut tools = ToolRegistry::new();
1920        tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
1921        let steering = Arc::new(MockSteering::new(vec![
1922            vec![],                  // pre-batch: no steering
1923            vec![user_msg("STEER")], // after first result: trigger steering
1924        ]));
1925        let steering_ref = steering.clone();
1926        let op = make_op_with_tools(provider, tools)
1927            .with_planner(Box::new(BarrierPlanner))
1928            .with_concurrency_decider(Box::new(SharedOnlyDecider))
1929            .with_steering(steering);
1930        let output = op.execute(simple_input("run"));
1931        let output = output.await.unwrap();
1932        assert_eq!(output.exit_reason, ExitReason::Complete);
1933        assert!(steering_ref.call_count() >= 1);
1934        // Only first tool executed
1935        assert_eq!(hits.load(Ordering::SeqCst), 1);
1936        assert_eq!(output.metadata.turns_used, 2);
1937        assert_eq!(output.metadata.tools_called.len(), 2);
1938        assert_eq!(output.metadata.tools_called[0].name, "echo");
1939        assert_eq!(output.metadata.tools_called[1].name, "echo");
1940    }
1941    #[tokio::test]
1942    async fn steering_skips_before_exclusive() {
1943        // Single exclusive tool use, steering triggers before execution
1944        let first = ProviderResponse {
1945            content: vec![ContentPart::ToolUse {
1946                id: "t1".into(),
1947                name: "echo".into(),
1948                input: json!({}),
1949            }],
1950            stop_reason: StopReason::ToolUse,
1951            usage: TokenUsage {
1952                input_tokens: 10,
1953                output_tokens: 15,
1954                ..Default::default()
1955            },
1956            model: "mock".into(),
1957            cost: None,
1958            truncated: None,
1959        };
1960        // Provider should be called again after steering injection
1961        let call_count = std::sync::Arc::new(AtomicUsize::new(0));
1962        struct CountingProvider {
1963            inner: MockProvider,
1964            count: std::sync::Arc<AtomicUsize>,
1965        }
1966        impl Provider for CountingProvider {
1967            #[allow(clippy::manual_async_fn)]
1968            fn complete(
1969                &self,
1970                request: ProviderRequest,
1971            ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1972            {
1973                self.count.fetch_add(1, Ordering::SeqCst);
1974                self.inner.complete(request)
1975            }
1976        }
1977        let counting = CountingProvider {
1978            inner: MockProvider::new(vec![first, simple_text_response("Done")]),
1979            count: call_count.clone(),
1980        };
1981        let hits = std::sync::Arc::new(AtomicUsize::new(0));
1982        let mut tools = ToolRegistry::new();
1983        tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
1984        let steering = Arc::new(MockSteering::new(vec![
1985            vec![user_msg("STEER")], // pre-exclusive: trigger
1986        ]));
1987        let op = ReactOperator::new(
1988            counting,
1989            tools,
1990            Box::new(NoCompaction),
1991            HookRegistry::new(),
1992            Arc::new(NullStateReader),
1993            ReactConfig::default(),
1994        )
1995        .with_steering(steering);
1996        let output = op.execute(simple_input("run"));
1997        let output = output.await.unwrap();
1998        assert_eq!(output.exit_reason, ExitReason::Complete);
1999        // Tool implementation was never called
2000        assert_eq!(hits.load(Ordering::SeqCst), 0);
2001        // Provider called twice (two turns)
2002        assert_eq!(call_count.load(Ordering::SeqCst), 2);
2003        assert_eq!(output.metadata.turns_used, 2);
2004    }
2005
2006    #[tokio::test]
2007    async fn no_steering_default() {
2008        // Two shared tools; without steering both execute
2009        let first = ProviderResponse {
2010            content: vec![
2011                ContentPart::ToolUse {
2012                    id: "t1".into(),
2013                    name: "echo".into(),
2014                    input: json!({}),
2015                },
2016                ContentPart::ToolUse {
2017                    id: "t2".into(),
2018                    name: "echo".into(),
2019                    input: json!({}),
2020                },
2021            ],
2022            stop_reason: StopReason::ToolUse,
2023            usage: TokenUsage {
2024                input_tokens: 10,
2025                output_tokens: 15,
2026                ..Default::default()
2027            },
2028            model: "mock".into(),
2029            cost: None,
2030            truncated: None,
2031        };
2032        let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
2033        let hits = std::sync::Arc::new(AtomicUsize::new(0));
2034        let mut tools = ToolRegistry::new();
2035        tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
2036        let op = make_op_with_tools(provider, tools)
2037            .with_planner(Box::new(BarrierPlanner))
2038            .with_concurrency_decider(Box::new(SharedOnlyDecider));
2039        let output = op.execute(simple_input("run"));
2040        let output = output.await.unwrap();
2041        assert_eq!(output.exit_reason, ExitReason::Complete);
2042        assert_eq!(hits.load(Ordering::SeqCst), 2);
2043        assert_eq!(output.metadata.tools_called.len(), 2);
2044        assert_eq!(output.metadata.turns_used, 2);
2045    }
2046
2047    // -- Streaming Tool + Hook tests --
2048    struct StreamEcho;
2049    impl neuron_tool::ToolDyn for StreamEcho {
2050        fn name(&self) -> &str {
2051            "stream_echo"
2052        }
2053        fn description(&self) -> &str {
2054            "Streams echo chunks"
2055        }
2056        fn input_schema(&self) -> serde_json::Value {
2057            json!({"type":"object"})
2058        }
2059        fn call(
2060            &self,
2061            _input: serde_json::Value,
2062        ) -> std::pin::Pin<
2063            Box<
2064                dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
2065                    + Send
2066                    + '_,
2067            >,
2068        > {
2069            Box::pin(async { Ok(serde_json::json!({"note":"non-stream fallback"})) })
2070        }
2071        fn maybe_streaming(&self) -> Option<&dyn neuron_tool::ToolDynStreaming> {
2072            Some(self)
2073        }
2074    }
2075    impl neuron_tool::ToolDynStreaming for StreamEcho {
2076        fn call_streaming<'a>(
2077            &'a self,
2078            _input: serde_json::Value,
2079            on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
2080        ) -> std::pin::Pin<
2081            Box<dyn std::future::Future<Output = Result<(), neuron_tool::ToolError>> + Send + 'a>,
2082        > {
2083            Box::pin(async move {
2084                for ch in ["A", "B", "C"] {
2085                    on_chunk(ch);
2086                }
2087                Ok(())
2088            })
2089        }
2090    }
2091
2092    struct CollectHook {
2093        points: Vec<HookPoint>,
2094        chunks: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
2095        finals: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
2096    }
2097    #[async_trait]
2098    impl layer0::hook::Hook for CollectHook {
2099        fn points(&self) -> &[HookPoint] {
2100            &self.points
2101        }
2102        async fn on_event(
2103            &self,
2104            ctx: &HookContext,
2105        ) -> Result<HookAction, layer0::error::HookError> {
2106            if ctx.point == HookPoint::ToolExecutionUpdate {
2107                if let Some(c) = &ctx.tool_chunk {
2108                    self.chunks.lock().unwrap().push(c.clone());
2109                }
2110                Ok(HookAction::Continue)
2111            } else if ctx.point == HookPoint::PostToolUse {
2112                if let Some(r) = &ctx.tool_result {
2113                    self.finals.lock().unwrap().push(r.clone());
2114                }
2115                Ok(HookAction::Continue)
2116            } else {
2117                Ok(HookAction::Continue)
2118            }
2119        }
2120    }
2121
2122    #[tokio::test]
2123    async fn streaming_chunks_forwarded_and_concatenated() {
2124        // Provider returns a single tool use then an EndTurn
2125        let _provider = MockProvider::new(vec![
2126            tool_use_response("tu_s", "stream_echo", json!({"n":1})),
2127            simple_text_response("OK"),
2128        ]);
2129        let mut tools = ToolRegistry::new();
2130        tools.register(Arc::new(StreamEcho));
2131        // Hook to collect updates
2132        let chunks = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
2133        let finals = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
2134        let mut hooks = HookRegistry::new();
2135        hooks.add(Arc::new(CollectHook {
2136            points: vec![HookPoint::ToolExecutionUpdate, HookPoint::PostToolUse],
2137            chunks: chunks.clone(),
2138            finals: finals.clone(),
2139        }));
2140        let op = ReactOperator::new(
2141            MockProvider::new(vec![
2142                tool_use_response("tu_s", "stream_echo", json!({})),
2143                simple_text_response("OK"),
2144            ]),
2145            tools,
2146            Box::new(NoCompaction),
2147            hooks,
2148            Arc::new(NullStateReader),
2149            ReactConfig::default(),
2150        );
2151        let _ = op.execute(simple_input("run")).await.unwrap();
2152        let got_chunks = chunks.lock().unwrap().clone();
2153        assert_eq!(got_chunks, vec!["A", "B", "C"]);
2154        let got_finals = finals.lock().unwrap().clone();
2155        assert_eq!(got_finals.len(), 1);
2156        assert_eq!(got_finals[0], "ABC");
2157    }
2158
2159    struct CountingSharedEchoTool {
2160        hits: std::sync::Arc<AtomicUsize>,
2161    }
2162    impl CountingSharedEchoTool {
2163        fn new(h: std::sync::Arc<AtomicUsize>) -> Self {
2164            Self { hits: h }
2165        }
2166    }
2167    impl neuron_tool::ToolDyn for CountingSharedEchoTool {
2168        fn name(&self) -> &str {
2169            "meta_echo"
2170        }
2171        fn description(&self) -> &str {
2172            "Echoes input (shared via metadata)"
2173        }
2174        fn input_schema(&self) -> serde_json::Value {
2175            json!({"type":"object"})
2176        }
2177        fn call(
2178            &self,
2179            input: serde_json::Value,
2180        ) -> std::pin::Pin<
2181            Box<
2182                dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
2183                    + Send
2184                    + '_,
2185            >,
2186        > {
2187            self.hits.fetch_add(1, Ordering::SeqCst);
2188            Box::pin(async move { Ok(json!({"echoed": input})) })
2189        }
2190        fn concurrency_hint(&self) -> neuron_tool::ToolConcurrencyHint {
2191            neuron_tool::ToolConcurrencyHint::Shared
2192        }
2193    }
2194
2195    #[tokio::test]
2196    async fn metadata_concurrency_batches_shared() {
2197        // Two uses of the same tool should batch as Shared when metadata decider is used
2198        let first = ProviderResponse {
2199            content: vec![
2200                ContentPart::ToolUse {
2201                    id: "t1".into(),
2202                    name: "meta_echo".into(),
2203                    input: json!({}),
2204                },
2205                ContentPart::ToolUse {
2206                    id: "t2".into(),
2207                    name: "meta_echo".into(),
2208                    input: json!({}),
2209                },
2210            ],
2211            stop_reason: StopReason::ToolUse,
2212            usage: TokenUsage {
2213                input_tokens: 10,
2214                output_tokens: 15,
2215                ..Default::default()
2216            },
2217            model: "mock".into(),
2218            cost: None,
2219            truncated: None,
2220        };
2221        let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
2222        let hits = std::sync::Arc::new(AtomicUsize::new(0));
2223        let mut tools = ToolRegistry::new();
2224        tools.register(Arc::new(CountingSharedEchoTool::new(hits.clone())));
2225        let op = make_op_with_tools(provider, tools)
2226            .with_planner(Box::new(BarrierPlanner))
2227            .with_metadata_concurrency();
2228        let output = op.execute(simple_input("run")).await.unwrap();
2229        assert_eq!(output.exit_reason, ExitReason::Complete);
2230        assert_eq!(hits.load(Ordering::SeqCst), 2);
2231        assert_eq!(output.metadata.tools_called.len(), 2);
2232        assert_eq!(output.metadata.turns_used, 2);
2233    }
2234}