1#![deny(missing_docs)]
2use async_trait::async_trait;
8use layer0::content::Content;
9use layer0::duration::DurationMs;
10use layer0::effect::{Effect, Scope, SignalPayload};
11use layer0::error::OperatorError;
12use layer0::hook::{HookAction, HookContext, HookPoint};
13use layer0::id::{AgentId, WorkflowId};
14use layer0::operator::{
15 ExitReason, Operator, OperatorInput, OperatorMetadata, OperatorOutput, ToolCallRecord,
16};
17use neuron_hooks::HookRegistry;
18use neuron_tool::{ToolConcurrencyHint, ToolRegistry};
19use neuron_turn::context::ContextStrategy;
20use neuron_turn::convert::{content_to_user_message, parts_to_content};
21use neuron_turn::provider::Provider;
22use neuron_turn::types::*;
23use rust_decimal::Decimal;
24use std::sync::Arc;
25use std::time::Instant;
26
27pub struct ReactConfig {
29 pub system_prompt: String,
31 pub default_model: String,
33 pub default_max_tokens: u32,
35 pub default_max_turns: u32,
37}
38
39impl Default for ReactConfig {
40 fn default() -> Self {
41 Self {
42 system_prompt: String::new(),
43 default_model: String::new(),
44 default_max_tokens: 4096,
45 default_max_turns: 10,
46 }
47 }
48}
49
50const EFFECT_TOOL_NAMES: &[&str] = &[
52 "write_memory",
53 "delete_memory",
54 "delegate",
55 "handoff",
56 "signal",
57];
58
59struct ResolvedConfig {
61 model: Option<String>,
62 system: String,
63 max_turns: u32,
64 max_cost: Option<Decimal>,
65 max_duration: Option<DurationMs>,
66 allowed_tools: Option<Vec<String>>,
67 max_tokens: u32,
68}
69
70pub use neuron_turn_kit::{
72 BarrierPlanner, BatchItem, Concurrency, ConcurrencyDecider, SteeringSource,
73 ToolExecutionPlanner,
74};
75
76struct DefaultDecider;
78impl ConcurrencyDecider for DefaultDecider {
79 fn concurrency(&self, _tool_name: &str) -> Concurrency {
81 Concurrency::Exclusive
82 }
83}
84
85struct MetadataDecider {
87 tools: ToolRegistry,
88}
89impl ConcurrencyDecider for MetadataDecider {
90 fn concurrency(&self, tool_name: &str) -> Concurrency {
91 match self.tools.get(tool_name) {
92 Some(tool) => match tool.concurrency_hint() {
93 ToolConcurrencyHint::Shared => Concurrency::Shared,
94 ToolConcurrencyHint::Exclusive => Concurrency::Exclusive,
95 _ => Concurrency::Exclusive,
96 },
97 None => Concurrency::Exclusive,
98 }
99 }
100}
101
102struct SequentialPlanner;
104impl ToolExecutionPlanner for SequentialPlanner {
105 fn plan(
106 &self,
107 tool_uses: &[(String, String, serde_json::Value)],
108 _decider: &dyn ConcurrencyDecider,
109 ) -> Vec<BatchItem> {
110 tool_uses
111 .iter()
112 .cloned()
113 .map(BatchItem::Exclusive)
114 .collect()
115 }
116}
117pub struct ReactOperator<P: Provider> {
122 provider: P,
123 tools: ToolRegistry,
124 context_strategy: Box<dyn ContextStrategy>,
125 hooks: HookRegistry,
126 state_reader: Arc<dyn layer0::StateReader>,
127 config: ReactConfig,
128 planner: Box<dyn ToolExecutionPlanner>,
129 decider: Box<dyn ConcurrencyDecider>,
130 steering: Option<Arc<dyn SteeringSource>>,
131}
132
133impl<P: Provider> ReactOperator<P> {
134 pub fn new(
136 provider: P,
137 tools: ToolRegistry,
138 context_strategy: Box<dyn ContextStrategy>,
139 hooks: HookRegistry,
140 state_reader: Arc<dyn layer0::StateReader>,
141 config: ReactConfig,
142 ) -> Self {
143 Self {
144 provider,
145 tools,
146 context_strategy,
147 hooks,
148 state_reader,
149 config,
150 planner: Box::new(SequentialPlanner),
151 decider: Box::new(DefaultDecider),
152 steering: None,
153 }
154 }
155 pub fn with_planner(mut self, planner: Box<dyn ToolExecutionPlanner>) -> Self {
157 self.planner = planner;
158 self
159 }
160 pub fn with_concurrency_decider(mut self, decider: Box<dyn ConcurrencyDecider>) -> Self {
162 self.decider = decider;
163 self
164 }
165 pub fn with_metadata_concurrency(mut self) -> Self {
167 self.decider = Box::new(MetadataDecider {
168 tools: self.tools.clone(),
169 });
170 self
171 }
172 pub fn with_steering(mut self, s: Arc<dyn SteeringSource>) -> Self {
174 self.steering = Some(s);
175 self
176 }
177
178 fn resolve_config(&self, input: &OperatorInput) -> ResolvedConfig {
179 let tc = input.config.as_ref();
180 let system = match tc.and_then(|c| c.system_addendum.as_ref()) {
181 Some(addendum) => format!("{}\n{}", self.config.system_prompt, addendum),
182 None => self.config.system_prompt.clone(),
183 };
184 ResolvedConfig {
185 model: tc.and_then(|c| c.model.clone()).or_else(|| {
186 if self.config.default_model.is_empty() {
187 None
188 } else {
189 Some(self.config.default_model.clone())
190 }
191 }),
192 system,
193 max_turns: tc
194 .and_then(|c| c.max_turns)
195 .unwrap_or(self.config.default_max_turns),
196 max_cost: tc.and_then(|c| c.max_cost),
197 max_duration: tc.and_then(|c| c.max_duration),
198 allowed_tools: tc.and_then(|c| c.allowed_tools.clone()),
199 max_tokens: self.config.default_max_tokens,
200 }
201 }
202
203 fn build_tool_schemas(&self, config: &ResolvedConfig) -> Vec<ToolSchema> {
204 let mut schemas: Vec<ToolSchema> = self
205 .tools
206 .iter()
207 .map(|tool| ToolSchema {
208 name: tool.name().to_string(),
209 description: tool.description().to_string(),
210 input_schema: tool.input_schema(),
211 })
212 .collect();
213
214 schemas.extend(effect_tool_schemas());
216
217 if let Some(allowed) = &config.allowed_tools {
219 schemas.retain(|s| allowed.contains(&s.name));
220 }
221
222 schemas
223 }
224
225 async fn assemble_context(
226 &self,
227 input: &OperatorInput,
228 ) -> Result<Vec<ProviderMessage>, OperatorError> {
229 let mut messages = Vec::new();
230
231 if let Some(session) = &input.session {
233 let scope = Scope::Session(session.clone());
234 match self.state_reader.read(&scope, "messages").await {
235 Ok(Some(history)) => {
236 if let Ok(history_messages) =
237 serde_json::from_value::<Vec<ProviderMessage>>(history)
238 {
239 messages = history_messages;
240 }
241 }
242 Ok(None) => {} Err(_) => {} }
245 }
246
247 messages.push(content_to_user_message(&input.message));
249
250 Ok(messages)
251 }
252
253 fn try_as_effect(&self, name: &str, input: &serde_json::Value) -> Option<Effect> {
254 match name {
255 "write_memory" => {
256 let scope_str = input.get("scope")?.as_str()?;
257 let key = input.get("key")?.as_str()?.to_string();
258 let value = input.get("value")?.clone();
259 let scope = parse_scope(scope_str);
260 Some(Effect::WriteMemory { scope, key, value })
261 }
262 "delete_memory" => {
263 let scope_str = input.get("scope")?.as_str()?;
264 let key = input.get("key")?.as_str()?.to_string();
265 let scope = parse_scope(scope_str);
266 Some(Effect::DeleteMemory { scope, key })
267 }
268 "delegate" => {
269 let agent = input.get("agent")?.as_str()?;
270 let message = input.get("message").and_then(|m| m.as_str()).unwrap_or("");
271 let delegate_input =
272 OperatorInput::new(Content::text(message), layer0::operator::TriggerType::Task);
273 Some(Effect::Delegate {
274 agent: AgentId::new(agent),
275 input: Box::new(delegate_input),
276 })
277 }
278 "handoff" => {
279 let agent = input.get("agent")?.as_str()?;
280 let state = input
281 .get("state")
282 .cloned()
283 .unwrap_or(serde_json::Value::Null);
284 Some(Effect::Handoff {
285 agent: AgentId::new(agent),
286 state,
287 })
288 }
289 "signal" => {
290 let target = input.get("target")?.as_str()?;
291 let signal_type = input
292 .get("signal_type")
293 .and_then(|s| s.as_str())
294 .unwrap_or("default");
295 let data = input
296 .get("data")
297 .cloned()
298 .unwrap_or(serde_json::Value::Null);
299 Some(Effect::Signal {
300 target: WorkflowId::new(target),
301 payload: SignalPayload::new(signal_type, data),
302 })
303 }
304 _ => None,
305 }
306 }
307
308 fn build_metadata(
309 &self,
310 tokens_in: u64,
311 tokens_out: u64,
312 cost: Decimal,
313 turns_used: u32,
314 tools_called: Vec<ToolCallRecord>,
315 duration: DurationMs,
316 ) -> OperatorMetadata {
317 let mut meta = OperatorMetadata::default();
318 meta.tokens_in = tokens_in;
319 meta.tokens_out = tokens_out;
320 meta.cost = cost;
321 meta.turns_used = turns_used;
322 meta.tools_called = tools_called;
323 meta.duration = duration;
324 meta
325 }
326
327 fn make_output(
328 message: Content,
329 exit_reason: ExitReason,
330 metadata: OperatorMetadata,
331 effects: Vec<Effect>,
332 ) -> OperatorOutput {
333 let mut output = OperatorOutput::new(message, exit_reason);
334 output.metadata = metadata;
335 output.effects = effects;
336 output
337 }
338
339 fn build_hook_context(
340 &self,
341 point: HookPoint,
342 tokens_in: u64,
343 tokens_out: u64,
344 cost: Decimal,
345 turns_completed: u32,
346 elapsed: DurationMs,
347 ) -> HookContext {
348 let mut ctx = HookContext::new(point);
349 ctx.tokens_used = tokens_in + tokens_out;
350 ctx.cost = cost;
351 ctx.turns_completed = turns_completed;
352 ctx.elapsed = elapsed;
353 ctx
354 }
355}
356
357#[async_trait]
358impl<P: Provider + 'static> Operator for ReactOperator<P> {
359 async fn execute(&self, input: OperatorInput) -> Result<OperatorOutput, OperatorError> {
360 let start = Instant::now();
361 let config = self.resolve_config(&input);
362 let mut messages = self.assemble_context(&input).await?;
363 let tools = self.build_tool_schemas(&config);
364
365 let mut total_tokens_in: u64 = 0;
366 let mut total_tokens_out: u64 = 0;
367 let mut total_cost = Decimal::ZERO;
368 let mut turns_used: u32 = 0;
369 let mut tool_records: Vec<ToolCallRecord> = vec![];
370 let mut effects: Vec<Effect> = vec![];
371 let mut last_content: Vec<ContentPart> = vec![];
372
373 loop {
374 turns_used += 1;
375
376 let hook_ctx = self.build_hook_context(
378 HookPoint::PreInference,
379 total_tokens_in,
380 total_tokens_out,
381 total_cost,
382 turns_used - 1,
383 DurationMs::from(start.elapsed()),
384 );
385 if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
386 return Ok(Self::make_output(
387 parts_to_content(&last_content),
388 ExitReason::ObserverHalt { reason },
389 self.build_metadata(
390 total_tokens_in,
391 total_tokens_out,
392 total_cost,
393 turns_used,
394 tool_records,
395 DurationMs::from(start.elapsed()),
396 ),
397 effects,
398 ));
399 }
400
401 let request = ProviderRequest {
403 model: config.model.clone(),
404 messages: messages.clone(),
405 tools: tools.clone(),
406 max_tokens: Some(config.max_tokens),
407 temperature: None,
408 system: Some(config.system.clone()),
409 extra: input.metadata.clone(),
410 };
411
412 let response = self.provider.complete(request).await.map_err(|e| {
414 if e.is_retryable() {
415 OperatorError::Retryable(e.to_string())
416 } else {
417 OperatorError::Model(e.to_string())
418 }
419 })?;
420
421 let mut hook_ctx = self.build_hook_context(
423 HookPoint::PostInference,
424 total_tokens_in + response.usage.input_tokens,
425 total_tokens_out + response.usage.output_tokens,
426 total_cost + response.cost.unwrap_or(Decimal::ZERO),
427 turns_used,
428 DurationMs::from(start.elapsed()),
429 );
430 hook_ctx.model_output = Some(parts_to_content(&response.content));
431 if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
432 return Ok(Self::make_output(
433 parts_to_content(&response.content),
434 ExitReason::ObserverHalt { reason },
435 self.build_metadata(
436 total_tokens_in + response.usage.input_tokens,
437 total_tokens_out + response.usage.output_tokens,
438 total_cost + response.cost.unwrap_or(Decimal::ZERO),
439 turns_used,
440 tool_records,
441 DurationMs::from(start.elapsed()),
442 ),
443 effects,
444 ));
445 }
446
447 total_tokens_in += response.usage.input_tokens;
449 total_tokens_out += response.usage.output_tokens;
450 if let Some(cost) = response.cost {
451 total_cost += cost;
452 }
453
454 last_content.clone_from(&response.content);
455
456 match response.stop_reason {
458 StopReason::MaxTokens => {
459 return Err(OperatorError::Model("output truncated (max_tokens)".into()));
460 }
461 StopReason::ContentFilter => {
462 return Err(OperatorError::Model("content filtered".into()));
463 }
464 StopReason::EndTurn => {
465 return Ok(Self::make_output(
466 parts_to_content(&response.content),
467 ExitReason::Complete,
468 self.build_metadata(
469 total_tokens_in,
470 total_tokens_out,
471 total_cost,
472 turns_used,
473 tool_records,
474 DurationMs::from(start.elapsed()),
475 ),
476 effects,
477 ));
478 }
479 StopReason::ToolUse => {
480 }
482 }
483
484 messages.push(ProviderMessage {
487 role: Role::Assistant,
488 content: response.content.clone(),
489 });
490
491 let _tool_uses: Vec<(String, String, serde_json::Value)> = response
492 .content
493 .iter()
494 .filter_map(|part| match part {
495 ContentPart::ToolUse { id, name, input } => {
496 Some((id.clone(), name.clone(), input.clone()))
497 }
498 _ => None,
499 })
500 .collect();
501 let mut tool_results: Vec<ContentPart> = Vec::new();
502 let planned = {
504 let calls: Vec<(String, String, serde_json::Value)> = response
505 .content
506 .iter()
507 .filter_map(|part| match part {
508 ContentPart::ToolUse { id, name, input } => {
509 Some((id.clone(), name.clone(), input.clone()))
510 }
511 _ => None,
512 })
513 .collect();
514 self.planner.plan(&calls, self.decider.as_ref())
515 };
516
517 let mut _steered = false;
518 'batches: for batch in planned {
519 match batch {
520 BatchItem::Shared(call_group) => {
521 if let Some(s) = &self.steering {
523 let injected = s.drain();
524 if !injected.is_empty() {
525 messages.extend(injected);
526 for (id, name, _input) in call_group.into_iter() {
528 tool_results.push(ContentPart::ToolResult {
529 tool_use_id: id,
530 content: "Skipped due to steering".into(),
531 is_error: false,
532 });
533 tool_records.push(ToolCallRecord::new(
534 &name,
535 DurationMs::ZERO,
536 false,
537 ));
538 }
539 _steered = true;
540 break 'batches;
541 }
542 }
543 let len = call_group.len();
545 for idx in 0..len {
546 if idx > 0
548 && let Some(s) = &self.steering
549 {
550 let injected = s.drain();
551 if !injected.is_empty() {
552 messages.extend(injected);
553 for (rid, rname, _rinput) in
554 call_group.iter().skip(idx).cloned()
555 {
556 tool_results.push(ContentPart::ToolResult {
557 tool_use_id: rid,
558 content: "Skipped due to steering".into(),
559 is_error: false,
560 });
561 tool_records.push(ToolCallRecord::new(
562 &rname,
563 DurationMs::ZERO,
564 false,
565 ));
566 }
567 _steered = true;
568 _steered = true;
569 }
570 }
571 let (id, name, tool_input) = call_group[idx].clone();
572 if EFFECT_TOOL_NAMES.contains(&name.as_str()) {
574 if let Some(effect) = self.try_as_effect(&name, &tool_input) {
575 effects.push(effect);
576 }
577 tool_results.push(ContentPart::ToolResult {
578 tool_use_id: id,
579 content: format!("{name} effect recorded."),
580 is_error: false,
581 });
582 tool_records.push(ToolCallRecord::new(
583 &name,
584 DurationMs::ZERO,
585 true,
586 ));
587 } else {
588 let mut actual_input = tool_input.clone();
590 let mut hook_ctx = HookContext::new(HookPoint::PreToolUse);
591 hook_ctx.tool_name = Some(name.clone());
592 hook_ctx.tool_input = Some(tool_input.clone());
593 hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
594 hook_ctx.cost = total_cost;
595 hook_ctx.turns_completed = turns_used;
596 hook_ctx.elapsed = DurationMs::from(start.elapsed());
597 match self.hooks.dispatch(&hook_ctx).await {
598 HookAction::Halt { reason } => {
599 return Ok(Self::make_output(
600 parts_to_content(&last_content),
601 ExitReason::ObserverHalt { reason },
602 self.build_metadata(
603 total_tokens_in,
604 total_tokens_out,
605 total_cost,
606 turns_used,
607 tool_records,
608 DurationMs::from(start.elapsed()),
609 ),
610 effects,
611 ));
612 }
613 HookAction::SkipTool { reason } => {
614 tool_results.push(ContentPart::ToolResult {
615 tool_use_id: id,
616 content: format!("Skipped: {reason}"),
617 is_error: false,
618 });
619 tool_records.push(ToolCallRecord::new(
620 &name,
621 DurationMs::ZERO,
622 false,
623 ));
624 continue;
625 }
626 HookAction::ModifyToolInput { new_input } => {
627 actual_input = new_input;
628 }
629 HookAction::Continue => {}
630 _ => {}
631 }
632 let tool_start = Instant::now();
634 let (mut result_content, is_error, success, duration) = match self
636 .tools
637 .get(&name)
638 {
639 Some(tool) => {
640 if let Some(stream) = tool.maybe_streaming() {
641 let chunks_arc =
643 std::sync::Arc::new(std::sync::Mutex::new(Vec::<
644 String,
645 >::new(
646 )));
647 let chunks_cb = chunks_arc.clone();
648 let res = stream
649 .call_streaming(
650 actual_input.clone(),
651 Box::new(move |c: &str| {
652 if let Ok(mut v) = chunks_cb.lock() {
653 v.push(c.to_string());
654 }
655 }),
656 )
657 .await;
658 let tool_duration =
659 DurationMs::from(tool_start.elapsed());
660 if let Ok(chunks) =
662 std::sync::Arc::try_unwrap(chunks_arc)
663 .map(|m| m.into_inner().unwrap())
664 {
665 for ch in &chunks {
666 let mut uctx = HookContext::new(
667 HookPoint::ToolExecutionUpdate,
668 );
669 uctx.tool_name = Some(name.clone());
670 uctx.tool_chunk = Some(ch.clone());
671 uctx.tokens_used =
672 total_tokens_in + total_tokens_out;
673 uctx.cost = total_cost;
674 uctx.turns_completed = turns_used;
675 uctx.elapsed =
676 DurationMs::from(start.elapsed());
677 let _ = self.hooks.dispatch(&uctx).await;
678 }
679 match res {
680 Ok(()) => (
681 chunks.concat(),
682 false,
683 true,
684 tool_duration,
685 ),
686 Err(e) => {
687 (e.to_string(), true, false, tool_duration)
688 }
689 }
690 } else {
691 match res {
693 Ok(()) => {
694 (String::new(), false, true, tool_duration)
695 }
696 Err(e) => {
697 (e.to_string(), true, false, tool_duration)
698 }
699 }
700 }
701 } else {
702 match tool.call(actual_input.clone()).await {
704 Ok(value) => (
705 serde_json::to_string(&value)
706 .unwrap_or_default(),
707 false,
708 true,
709 DurationMs::from(tool_start.elapsed()),
710 ),
711 Err(e) => (
712 e.to_string(),
713 true,
714 false,
715 DurationMs::from(tool_start.elapsed()),
716 ),
717 }
718 }
719 }
720 None => (
721 neuron_tool::ToolError::NotFound(name.clone()).to_string(),
722 true,
723 false,
724 DurationMs::from(tool_start.elapsed()),
725 ),
726 };
727 let mut hook_ctx = HookContext::new(HookPoint::PostToolUse);
729 hook_ctx.tool_name = Some(name.clone());
730 hook_ctx.tool_result = Some(result_content.clone());
731 hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
732 hook_ctx.cost = total_cost;
733 hook_ctx.turns_completed = turns_used;
734 hook_ctx.elapsed = DurationMs::from(start.elapsed());
735 match self.hooks.dispatch(&hook_ctx).await {
736 HookAction::Halt { reason } => {
737 return Ok(Self::make_output(
738 parts_to_content(&last_content),
739 ExitReason::ObserverHalt { reason },
740 self.build_metadata(
741 total_tokens_in,
742 total_tokens_out,
743 total_cost,
744 turns_used,
745 tool_records,
746 DurationMs::from(start.elapsed()),
747 ),
748 effects,
749 ));
750 }
751 HookAction::ModifyToolOutput { new_output } => {
752 result_content = new_output.to_string();
753 }
754 _ => {}
755 }
756 tool_results.push(ContentPart::ToolResult {
757 tool_use_id: id,
758 content: result_content,
759 is_error,
760 });
761 tool_records.push(ToolCallRecord::new(name, duration, success));
762 }
763 if let Some(s) = &self.steering {
765 let injected = s.drain();
766 if !injected.is_empty() {
767 messages.extend(injected);
768 }
769 if idx + 1 < len {
770 for (rid, rname, _rinput) in
771 call_group.iter().skip(idx + 1).cloned()
772 {
773 tool_results.push(ContentPart::ToolResult {
774 tool_use_id: rid,
775 content: "Skipped due to steering".into(),
776 is_error: false,
777 });
778 tool_records.push(ToolCallRecord::new(
779 &rname,
780 DurationMs::ZERO,
781 false,
782 ));
783 }
784 break 'batches;
785 }
786 }
787 }
788 if let Some(s) = &self.steering {
790 let injected = s.drain();
791 if !injected.is_empty() {
792 messages.extend(injected);
793 _steered = true;
794 break 'batches;
795 }
796 }
797 }
798 BatchItem::Exclusive((id, name, tool_input)) => {
799 if let Some(s) = &self.steering {
801 let injected = s.drain();
802 if !injected.is_empty() {
803 messages.extend(injected);
804 tool_results.push(ContentPart::ToolResult {
805 tool_use_id: id,
806 content: "Skipped due to steering".into(),
807 is_error: false,
808 });
809 tool_records.push(ToolCallRecord::new(
810 &name,
811 DurationMs::ZERO,
812 false,
813 ));
814 _steered = true;
815 break 'batches;
816 }
817 }
818 if EFFECT_TOOL_NAMES.contains(&name.as_str()) {
819 if let Some(effect) = self.try_as_effect(&name, &tool_input) {
820 effects.push(effect);
821 }
822 tool_results.push(ContentPart::ToolResult {
823 tool_use_id: id,
824 content: format!("{name} effect recorded."),
825 is_error: false,
826 });
827 tool_records.push(ToolCallRecord::new(&name, DurationMs::ZERO, true));
828 continue;
829 }
830 let mut actual_input = tool_input.clone();
831 let mut hook_ctx = HookContext::new(HookPoint::PreToolUse);
832 hook_ctx.tool_name = Some(name.clone());
833 hook_ctx.tool_input = Some(tool_input.clone());
834 hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
835 hook_ctx.cost = total_cost;
836 hook_ctx.turns_completed = turns_used;
837 hook_ctx.elapsed = DurationMs::from(start.elapsed());
838 match self.hooks.dispatch(&hook_ctx).await {
839 HookAction::Halt { reason } => {
840 return Ok(Self::make_output(
841 parts_to_content(&last_content),
842 ExitReason::ObserverHalt { reason },
843 self.build_metadata(
844 total_tokens_in,
845 total_tokens_out,
846 total_cost,
847 turns_used,
848 tool_records,
849 DurationMs::from(start.elapsed()),
850 ),
851 effects,
852 ));
853 }
854 HookAction::SkipTool { reason } => {
855 tool_results.push(ContentPart::ToolResult {
856 tool_use_id: id,
857 content: format!("Skipped: {reason}"),
858 is_error: false,
859 });
860 tool_records.push(ToolCallRecord::new(
861 &name,
862 DurationMs::ZERO,
863 false,
864 ));
865 continue;
866 }
867 HookAction::ModifyToolInput { new_input } => {
868 actual_input = new_input;
869 }
870 HookAction::Continue => {}
871 _ => {}
872 }
873 let tool_start = Instant::now();
874 let (mut result_content, is_error, success, tool_duration) = match self
876 .tools
877 .get(&name)
878 {
879 Some(tool) => {
880 if let Some(stream) = tool.maybe_streaming() {
881 let chunks_arc = std::sync::Arc::new(std::sync::Mutex::new(
882 Vec::<String>::new(),
883 ));
884 let chunks_cb = chunks_arc.clone();
885 let res = stream
886 .call_streaming(
887 actual_input.clone(),
888 Box::new(move |c: &str| {
889 if let Ok(mut v) = chunks_cb.lock() {
890 v.push(c.to_string());
891 }
892 }),
893 )
894 .await;
895 let dur = DurationMs::from(tool_start.elapsed());
896 if let Ok(chunks) = std::sync::Arc::try_unwrap(chunks_arc)
897 .map(|m| m.into_inner().unwrap())
898 {
899 for ch in &chunks {
900 let mut uctx =
901 HookContext::new(HookPoint::ToolExecutionUpdate);
902 uctx.tool_name = Some(name.clone());
903 uctx.tool_chunk = Some(ch.clone());
904 uctx.tokens_used = total_tokens_in + total_tokens_out;
905 uctx.cost = total_cost;
906 uctx.turns_completed = turns_used;
907 uctx.elapsed = DurationMs::from(start.elapsed());
908 let _ = self.hooks.dispatch(&uctx).await;
909 }
910 match res {
911 Ok(()) => (chunks.concat(), false, true, dur),
912 Err(e) => (e.to_string(), true, false, dur),
913 }
914 } else {
915 match res {
916 Ok(()) => (String::new(), false, true, dur),
917 Err(e) => (e.to_string(), true, false, dur),
918 }
919 }
920 } else {
921 match tool.call(actual_input.clone()).await {
922 Ok(value) => (
923 serde_json::to_string(&value).unwrap_or_default(),
924 false,
925 true,
926 DurationMs::from(tool_start.elapsed()),
927 ),
928 Err(e) => (
929 e.to_string(),
930 true,
931 false,
932 DurationMs::from(tool_start.elapsed()),
933 ),
934 }
935 }
936 }
937 None => (
938 neuron_tool::ToolError::NotFound(name.clone()).to_string(),
939 true,
940 false,
941 DurationMs::from(tool_start.elapsed()),
942 ),
943 };
944 let mut hook_ctx = HookContext::new(HookPoint::PostToolUse);
945 hook_ctx.tool_name = Some(name.clone());
946 hook_ctx.tool_result = Some(result_content.clone());
947 hook_ctx.tokens_used = total_tokens_in + total_tokens_out;
948 hook_ctx.cost = total_cost;
949 hook_ctx.turns_completed = turns_used;
950 hook_ctx.elapsed = DurationMs::from(start.elapsed());
951 match self.hooks.dispatch(&hook_ctx).await {
952 HookAction::Halt { reason } => {
953 return Ok(Self::make_output(
954 parts_to_content(&last_content),
955 ExitReason::ObserverHalt { reason },
956 self.build_metadata(
957 total_tokens_in,
958 total_tokens_out,
959 total_cost,
960 turns_used,
961 tool_records,
962 DurationMs::from(start.elapsed()),
963 ),
964 effects,
965 ));
966 }
967 HookAction::ModifyToolOutput { new_output } => {
968 result_content = new_output.to_string();
969 }
970 _ => {}
971 }
972 tool_results.push(ContentPart::ToolResult {
973 tool_use_id: id,
974 content: result_content,
975 is_error,
976 });
977 tool_records.push(ToolCallRecord::new(name, tool_duration, success));
978 if let Some(s) = &self.steering {
980 let injected = s.drain();
981 if !injected.is_empty() {
982 messages.extend(injected);
983 _steered = true;
984 break 'batches;
985 }
986 }
987 }
988 }
989 }
990
991 messages.push(ProviderMessage {
993 role: Role::User,
994 content: tool_results,
995 });
996
997 if turns_used >= config.max_turns {
999 return Ok(Self::make_output(
1000 parts_to_content(&last_content),
1001 ExitReason::MaxTurns,
1002 self.build_metadata(
1003 total_tokens_in,
1004 total_tokens_out,
1005 total_cost,
1006 turns_used,
1007 tool_records,
1008 DurationMs::from(start.elapsed()),
1009 ),
1010 effects,
1011 ));
1012 }
1013
1014 if let Some(max_cost) = &config.max_cost
1015 && total_cost >= *max_cost
1016 {
1017 return Ok(Self::make_output(
1018 parts_to_content(&last_content),
1019 ExitReason::BudgetExhausted,
1020 self.build_metadata(
1021 total_tokens_in,
1022 total_tokens_out,
1023 total_cost,
1024 turns_used,
1025 tool_records,
1026 DurationMs::from(start.elapsed()),
1027 ),
1028 effects,
1029 ));
1030 }
1031
1032 if let Some(max_duration) = &config.max_duration
1033 && start.elapsed() >= max_duration.to_std()
1034 {
1035 return Ok(Self::make_output(
1036 parts_to_content(&last_content),
1037 ExitReason::Timeout,
1038 self.build_metadata(
1039 total_tokens_in,
1040 total_tokens_out,
1041 total_cost,
1042 turns_used,
1043 tool_records,
1044 DurationMs::from(start.elapsed()),
1045 ),
1046 effects,
1047 ));
1048 }
1049
1050 let hook_ctx = self.build_hook_context(
1052 HookPoint::ExitCheck,
1053 total_tokens_in,
1054 total_tokens_out,
1055 total_cost,
1056 turns_used,
1057 DurationMs::from(start.elapsed()),
1058 );
1059 if let HookAction::Halt { reason } = self.hooks.dispatch(&hook_ctx).await {
1060 return Ok(Self::make_output(
1061 parts_to_content(&last_content),
1062 ExitReason::ObserverHalt { reason },
1063 self.build_metadata(
1064 total_tokens_in,
1065 total_tokens_out,
1066 total_cost,
1067 turns_used,
1068 tool_records,
1069 DurationMs::from(start.elapsed()),
1070 ),
1071 effects,
1072 ));
1073 }
1074
1075 let limit = config.max_tokens as usize * 4;
1077 if self.context_strategy.should_compact(&messages, limit) {
1078 messages = self.context_strategy.compact(messages);
1079 }
1080
1081 }
1083 }
1084}
1085
1086fn effect_tool_schemas() -> Vec<ToolSchema> {
1088 vec![
1089 ToolSchema {
1090 name: "write_memory".into(),
1091 description: "Write a value to persistent memory.".into(),
1092 input_schema: serde_json::json!({
1093 "type": "object",
1094 "properties": {
1095 "scope": {"type": "string", "description": "Memory scope (e.g. 'global', 'session:id')"},
1096 "key": {"type": "string", "description": "Memory key"},
1097 "value": {"description": "Value to store"}
1098 },
1099 "required": ["scope", "key", "value"]
1100 }),
1101 },
1102 ToolSchema {
1103 name: "delete_memory".into(),
1104 description: "Delete a value from persistent memory.".into(),
1105 input_schema: serde_json::json!({
1106 "type": "object",
1107 "properties": {
1108 "scope": {"type": "string", "description": "Memory scope"},
1109 "key": {"type": "string", "description": "Memory key"}
1110 },
1111 "required": ["scope", "key"]
1112 }),
1113 },
1114 ToolSchema {
1115 name: "delegate".into(),
1116 description: "Delegate a task to another agent.".into(),
1117 input_schema: serde_json::json!({
1118 "type": "object",
1119 "properties": {
1120 "agent": {"type": "string", "description": "Agent ID to delegate to"},
1121 "message": {"type": "string", "description": "Task description for the agent"}
1122 },
1123 "required": ["agent", "message"]
1124 }),
1125 },
1126 ToolSchema {
1127 name: "handoff".into(),
1128 description: "Hand off the conversation to another agent.".into(),
1129 input_schema: serde_json::json!({
1130 "type": "object",
1131 "properties": {
1132 "agent": {"type": "string", "description": "Agent ID to hand off to"},
1133 "state": {"description": "State to pass to the next agent"}
1134 },
1135 "required": ["agent"]
1136 }),
1137 },
1138 ToolSchema {
1139 name: "signal".into(),
1140 description: "Send a signal to another workflow.".into(),
1141 input_schema: serde_json::json!({
1142 "type": "object",
1143 "properties": {
1144 "target": {"type": "string", "description": "Target workflow ID"},
1145 "signal_type": {"type": "string", "description": "Signal type identifier"},
1146 "data": {"description": "Signal payload data"}
1147 },
1148 "required": ["target"]
1149 }),
1150 },
1151 ]
1152}
1153
1154fn parse_scope(s: &str) -> Scope {
1156 if s == "global" {
1157 return Scope::Global;
1158 }
1159 if let Some(id) = s.strip_prefix("session:") {
1160 return Scope::Session(layer0::SessionId::new(id));
1161 }
1162 if let Some(id) = s.strip_prefix("workflow:") {
1163 return Scope::Workflow(layer0::WorkflowId::new(id));
1164 }
1165 Scope::Custom(s.to_string())
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170 use super::*;
1171 use neuron_hooks::HookRegistry;
1172 use neuron_tool::ToolRegistry;
1173 use neuron_turn::context::NoCompaction;
1174 use neuron_turn::provider::ProviderError;
1175 use serde_json::json;
1176 use std::collections::VecDeque;
1177 use std::sync::Mutex;
1178 use std::sync::atomic::{AtomicUsize, Ordering};
1179
1180 struct MockProvider {
1183 responses: Mutex<VecDeque<ProviderResponse>>,
1184 call_count: AtomicUsize,
1185 }
1186
1187 impl MockProvider {
1188 fn new(responses: Vec<ProviderResponse>) -> Self {
1189 Self {
1190 responses: Mutex::new(responses.into()),
1191 call_count: AtomicUsize::new(0),
1192 }
1193 }
1194 }
1195
1196 impl Provider for MockProvider {
1197 fn complete(
1198 &self,
1199 _request: ProviderRequest,
1200 ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1201 {
1202 self.call_count.fetch_add(1, Ordering::SeqCst);
1203 let response = self
1204 .responses
1205 .lock()
1206 .unwrap()
1207 .pop_front()
1208 .expect("MockProvider: no more responses queued");
1209 async move { Ok(response) }
1210 }
1211 }
1212
1213 struct NullStateReader;
1216
1217 #[async_trait]
1218 impl layer0::StateReader for NullStateReader {
1219 async fn read(
1220 &self,
1221 _scope: &Scope,
1222 _key: &str,
1223 ) -> Result<Option<serde_json::Value>, layer0::StateError> {
1224 Ok(None)
1225 }
1226 async fn list(
1227 &self,
1228 _scope: &Scope,
1229 _prefix: &str,
1230 ) -> Result<Vec<String>, layer0::StateError> {
1231 Ok(vec![])
1232 }
1233 async fn search(
1234 &self,
1235 _scope: &Scope,
1236 _query: &str,
1237 _limit: usize,
1238 ) -> Result<Vec<layer0::state::SearchResult>, layer0::StateError> {
1239 Ok(vec![])
1240 }
1241 }
1242
1243 struct EchoTool;
1246
1247 impl neuron_tool::ToolDyn for EchoTool {
1248 fn name(&self) -> &str {
1249 "echo"
1250 }
1251 fn description(&self) -> &str {
1252 "Echoes input"
1253 }
1254 fn input_schema(&self) -> serde_json::Value {
1255 json!({"type": "object"})
1256 }
1257 fn call(
1258 &self,
1259 input: serde_json::Value,
1260 ) -> std::pin::Pin<
1261 Box<
1262 dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
1263 + Send
1264 + '_,
1265 >,
1266 > {
1267 Box::pin(async move { Ok(json!({"echoed": input})) })
1268 }
1269 }
1270
1271 fn simple_text_response(text: &str) -> ProviderResponse {
1274 ProviderResponse {
1275 content: vec![ContentPart::Text {
1276 text: text.to_string(),
1277 }],
1278 stop_reason: StopReason::EndTurn,
1279 usage: TokenUsage {
1280 input_tokens: 10,
1281 output_tokens: 5,
1282 ..Default::default()
1283 },
1284 model: "mock-model".into(),
1285 cost: Some(Decimal::new(1, 4)), truncated: None,
1287 }
1288 }
1289
1290 fn tool_use_response(
1291 tool_id: &str,
1292 tool_name: &str,
1293 input: serde_json::Value,
1294 ) -> ProviderResponse {
1295 ProviderResponse {
1296 content: vec![ContentPart::ToolUse {
1297 id: tool_id.to_string(),
1298 name: tool_name.to_string(),
1299 input,
1300 }],
1301 stop_reason: StopReason::ToolUse,
1302 usage: TokenUsage {
1303 input_tokens: 10,
1304 output_tokens: 15,
1305 ..Default::default()
1306 },
1307 model: "mock-model".into(),
1308 cost: Some(Decimal::new(2, 4)), truncated: None,
1310 }
1311 }
1312
1313 fn make_op<P: Provider>(provider: P) -> ReactOperator<P> {
1314 ReactOperator::new(
1315 provider,
1316 ToolRegistry::new(),
1317 Box::new(NoCompaction),
1318 HookRegistry::new(),
1319 Arc::new(NullStateReader),
1320 ReactConfig::default(),
1321 )
1322 }
1323
1324 fn make_op_with_tools<P: Provider>(provider: P, tools: ToolRegistry) -> ReactOperator<P> {
1325 ReactOperator::new(
1326 provider,
1327 tools,
1328 Box::new(NoCompaction),
1329 HookRegistry::new(),
1330 Arc::new(NullStateReader),
1331 ReactConfig::default(),
1332 )
1333 }
1334
1335 fn simple_input(text: &str) -> OperatorInput {
1336 OperatorInput::new(Content::text(text), layer0::operator::TriggerType::User)
1337 }
1338
1339 #[tokio::test]
1342 async fn simple_completion() {
1343 let provider = MockProvider::new(vec![simple_text_response("Hello!")]);
1344 let op = make_op(provider);
1345
1346 let output = op.execute(simple_input("Hi")).await.unwrap();
1347
1348 assert_eq!(output.exit_reason, ExitReason::Complete);
1349 assert_eq!(output.message.as_text().unwrap(), "Hello!");
1350 assert_eq!(output.metadata.turns_used, 1);
1351 assert_eq!(output.metadata.tokens_in, 10);
1352 assert_eq!(output.metadata.tokens_out, 5);
1353 assert!(output.effects.is_empty());
1354 }
1355
1356 #[tokio::test]
1357 async fn tool_use_and_followup() {
1358 let provider = MockProvider::new(vec![
1359 tool_use_response("tu_1", "echo", json!({"msg": "test"})),
1360 simple_text_response("Done."),
1361 ]);
1362 let mut tools = ToolRegistry::new();
1363 tools.register(Arc::new(EchoTool));
1364 let op = make_op_with_tools(provider, tools);
1365
1366 let output = op.execute(simple_input("Use echo")).await.unwrap();
1367
1368 assert_eq!(output.exit_reason, ExitReason::Complete);
1369 assert_eq!(output.metadata.turns_used, 2);
1370 assert_eq!(output.metadata.tools_called.len(), 1);
1371 assert_eq!(output.metadata.tools_called[0].name, "echo");
1372 }
1373
1374 #[tokio::test]
1375 async fn unknown_tool_returns_error_result() {
1376 let provider = MockProvider::new(vec![
1377 tool_use_response("tu_1", "nonexistent_tool", json!({})),
1378 simple_text_response("Got an error."),
1379 ]);
1380 let op = make_op(provider);
1381
1382 let output = op.execute(simple_input("Use nonexistent")).await.unwrap();
1384 assert_eq!(output.exit_reason, ExitReason::Complete);
1385 assert_eq!(output.metadata.tools_called.len(), 1);
1387 }
1388
1389 #[tokio::test]
1390 async fn max_turns_enforced() {
1391 let provider = MockProvider::new(vec![
1393 tool_use_response("tu_1", "echo", json!({})),
1394 tool_use_response("tu_2", "echo", json!({})),
1395 tool_use_response("tu_3", "echo", json!({})),
1396 simple_text_response("never reached"),
1397 ]);
1398 let mut tools = ToolRegistry::new();
1399 tools.register(Arc::new(EchoTool));
1400
1401 let mut op = ReactOperator::new(
1402 provider,
1403 tools,
1404 Box::new(NoCompaction),
1405 HookRegistry::new(),
1406 Arc::new(NullStateReader),
1407 ReactConfig {
1408 default_max_turns: 2,
1409 ..Default::default()
1410 },
1411 );
1412 let _ = &mut op;
1414
1415 let op = ReactOperator::new(
1416 MockProvider::new(vec![
1417 tool_use_response("tu_1", "echo", json!({})),
1418 tool_use_response("tu_2", "echo", json!({})),
1419 simple_text_response("never reached"),
1420 ]),
1421 {
1422 let mut t = ToolRegistry::new();
1423 t.register(Arc::new(EchoTool));
1424 t
1425 },
1426 Box::new(NoCompaction),
1427 HookRegistry::new(),
1428 Arc::new(NullStateReader),
1429 ReactConfig {
1430 default_max_turns: 2,
1431 ..Default::default()
1432 },
1433 );
1434
1435 let output = op.execute(simple_input("loop")).await.unwrap();
1436 assert_eq!(output.exit_reason, ExitReason::MaxTurns);
1437 assert_eq!(output.metadata.turns_used, 2);
1438 }
1439
1440 #[tokio::test]
1441 async fn budget_exhausted() {
1442 let provider = MockProvider::new(vec![
1444 tool_use_response("tu_1", "echo", json!({})),
1445 simple_text_response("Done"),
1446 ]);
1447 let mut tools = ToolRegistry::new();
1448 tools.register(Arc::new(EchoTool));
1449 let op = ReactOperator::new(
1450 provider,
1451 tools,
1452 Box::new(NoCompaction),
1453 HookRegistry::new(),
1454 Arc::new(NullStateReader),
1455 ReactConfig::default(),
1456 );
1457
1458 let mut input = simple_input("spend");
1459 let mut tc = layer0::operator::OperatorConfig::default();
1460 tc.max_cost = Some(Decimal::new(15, 5)); input.config = Some(tc);
1462
1463 let output = op.execute(input).await.unwrap();
1464 assert_eq!(output.exit_reason, ExitReason::BudgetExhausted);
1466 }
1467
1468 #[tokio::test]
1469 async fn max_tokens_returns_model_error() {
1470 let provider = MockProvider::new(vec![ProviderResponse {
1471 content: vec![],
1472 stop_reason: StopReason::MaxTokens,
1473 usage: TokenUsage::default(),
1474 model: "mock".into(),
1475 cost: None,
1476 truncated: None,
1477 }]);
1478 let op = make_op(provider);
1479
1480 let result = op.execute(simple_input("Hi")).await;
1481 assert!(result.is_err());
1482 match result.unwrap_err() {
1483 OperatorError::Model(msg) => assert!(msg.contains("max_tokens")),
1484 other => panic!("expected OperatorError::Model, got {:?}", other),
1485 }
1486 }
1487
1488 #[tokio::test]
1489 async fn content_filter_returns_model_error() {
1490 let provider = MockProvider::new(vec![ProviderResponse {
1491 content: vec![],
1492 stop_reason: StopReason::ContentFilter,
1493 usage: TokenUsage::default(),
1494 model: "mock".into(),
1495 cost: None,
1496 truncated: None,
1497 }]);
1498 let op = make_op(provider);
1499
1500 let result = op.execute(simple_input("Hi")).await;
1501 assert!(result.is_err());
1502 match result.unwrap_err() {
1503 OperatorError::Model(msg) => assert!(msg.contains("content filtered")),
1504 other => panic!("expected OperatorError::Model, got {:?}", other),
1505 }
1506 }
1507
1508 #[tokio::test]
1509 async fn cost_aggregated_across_turns() {
1510 let provider = MockProvider::new(vec![
1511 tool_use_response("tu_1", "echo", json!({})),
1512 simple_text_response("Done"),
1513 ]);
1514 let mut tools = ToolRegistry::new();
1515 tools.register(Arc::new(EchoTool));
1516 let op = make_op_with_tools(provider, tools);
1517
1518 let output = op.execute(simple_input("Hi")).await.unwrap();
1519
1520 assert_eq!(output.metadata.cost, Decimal::new(3, 4));
1522 assert_eq!(output.metadata.tokens_in, 20);
1523 assert_eq!(output.metadata.tokens_out, 20);
1524 }
1525
1526 #[tokio::test]
1527 async fn operator_config_overrides_defaults() {
1528 let provider = MockProvider::new(vec![simple_text_response("Hi")]);
1529 let op = make_op(provider);
1530
1531 let mut input = simple_input("test");
1532 let mut tc = layer0::operator::OperatorConfig::default();
1533 tc.system_addendum = Some("Be concise.".into());
1534 tc.model = Some("custom-model".into());
1535 tc.max_turns = Some(5);
1536 input.config = Some(tc);
1537
1538 let output = op.execute(input).await.unwrap();
1539 assert_eq!(output.exit_reason, ExitReason::Complete);
1540 }
1541
1542 #[tokio::test]
1543 async fn effect_tool_write_memory() {
1544 let provider = MockProvider::new(vec![
1545 ProviderResponse {
1547 content: vec![ContentPart::ToolUse {
1548 id: "tu_1".into(),
1549 name: "write_memory".into(),
1550 input: json!({"scope": "global", "key": "test", "value": "hello"}),
1551 }],
1552 stop_reason: StopReason::ToolUse,
1553 usage: TokenUsage {
1554 input_tokens: 10,
1555 output_tokens: 5,
1556 ..Default::default()
1557 },
1558 model: "mock".into(),
1559 cost: None,
1560 truncated: None,
1561 },
1562 simple_text_response("Memory written."),
1563 ]);
1564 let op = make_op(provider);
1565
1566 let output = op.execute(simple_input("Write memory")).await.unwrap();
1567
1568 assert_eq!(output.effects.len(), 1);
1569 match &output.effects[0] {
1570 Effect::WriteMemory { key, .. } => assert_eq!(key, "test"),
1571 _ => panic!("expected WriteMemory"),
1572 }
1573 }
1574
1575 #[test]
1576 fn parse_scope_variants() {
1577 assert_eq!(parse_scope("global"), Scope::Global);
1578 assert_eq!(
1579 parse_scope("session:abc"),
1580 Scope::Session(layer0::SessionId::new("abc"))
1581 );
1582 assert_eq!(
1583 parse_scope("workflow:wf1"),
1584 Scope::Workflow(layer0::WorkflowId::new("wf1"))
1585 );
1586 match parse_scope("other") {
1587 Scope::Custom(s) => assert_eq!(s, "other"),
1588 _ => panic!("expected Custom"),
1589 }
1590 }
1591
1592 #[tokio::test]
1593 async fn effect_tool_delete_memory() {
1594 let provider = MockProvider::new(vec![
1595 ProviderResponse {
1596 content: vec![ContentPart::ToolUse {
1597 id: "tu_1".into(),
1598 name: "delete_memory".into(),
1599 input: json!({"scope": "global", "key": "old_key"}),
1600 }],
1601 stop_reason: StopReason::ToolUse,
1602 usage: TokenUsage::default(),
1603 model: "mock".into(),
1604 cost: None,
1605 truncated: None,
1606 },
1607 simple_text_response("Deleted."),
1608 ]);
1609 let op = make_op(provider);
1610
1611 let output = op.execute(simple_input("Delete memory")).await.unwrap();
1612 assert_eq!(output.effects.len(), 1);
1613 match &output.effects[0] {
1614 Effect::DeleteMemory { key, .. } => assert_eq!(key, "old_key"),
1615 _ => panic!("expected DeleteMemory"),
1616 }
1617 }
1618
1619 #[tokio::test]
1620 async fn effect_tool_delegate() {
1621 let provider = MockProvider::new(vec![
1622 ProviderResponse {
1623 content: vec![ContentPart::ToolUse {
1624 id: "tu_1".into(),
1625 name: "delegate".into(),
1626 input: json!({"agent": "helper", "message": "do this task"}),
1627 }],
1628 stop_reason: StopReason::ToolUse,
1629 usage: TokenUsage::default(),
1630 model: "mock".into(),
1631 cost: None,
1632 truncated: None,
1633 },
1634 simple_text_response("Delegated."),
1635 ]);
1636 let op = make_op(provider);
1637
1638 let output = op.execute(simple_input("Delegate task")).await.unwrap();
1639 assert_eq!(output.effects.len(), 1);
1640 match &output.effects[0] {
1641 Effect::Delegate { agent, input } => {
1642 assert_eq!(agent.as_str(), "helper");
1643 assert_eq!(input.message.as_text().unwrap(), "do this task");
1644 }
1645 _ => panic!("expected Delegate"),
1646 }
1647 }
1648
1649 #[tokio::test]
1650 async fn effect_tool_handoff() {
1651 let provider = MockProvider::new(vec![
1652 ProviderResponse {
1653 content: vec![ContentPart::ToolUse {
1654 id: "tu_1".into(),
1655 name: "handoff".into(),
1656 input: json!({"agent": "specialist", "state": {"context": "data"}}),
1657 }],
1658 stop_reason: StopReason::ToolUse,
1659 usage: TokenUsage::default(),
1660 model: "mock".into(),
1661 cost: None,
1662 truncated: None,
1663 },
1664 simple_text_response("Handed off."),
1665 ]);
1666 let op = make_op(provider);
1667
1668 let output = op.execute(simple_input("Handoff")).await.unwrap();
1669 assert_eq!(output.effects.len(), 1);
1670 match &output.effects[0] {
1671 Effect::Handoff { agent, state } => {
1672 assert_eq!(agent.as_str(), "specialist");
1673 assert_eq!(state["context"], "data");
1674 }
1675 _ => panic!("expected Handoff"),
1676 }
1677 }
1678
1679 #[tokio::test]
1680 async fn effect_tool_signal() {
1681 let provider = MockProvider::new(vec![
1682 ProviderResponse {
1683 content: vec![ContentPart::ToolUse {
1684 id: "tu_1".into(),
1685 name: "signal".into(),
1686 input: json!({"target": "workflow_1", "signal_type": "completed", "data": {"result": "ok"}}),
1687 }],
1688 stop_reason: StopReason::ToolUse,
1689 usage: TokenUsage::default(),
1690 model: "mock".into(),
1691 cost: None,
1692 truncated: None,
1693 },
1694 simple_text_response("Signal sent."),
1695 ]);
1696 let op = make_op(provider);
1697
1698 let output = op.execute(simple_input("Signal")).await.unwrap();
1699 assert_eq!(output.effects.len(), 1);
1700 match &output.effects[0] {
1701 Effect::Signal { target, payload } => {
1702 assert_eq!(target.as_str(), "workflow_1");
1703 assert_eq!(payload.signal_type, "completed");
1704 }
1705 _ => panic!("expected Signal"),
1706 }
1707 }
1708
1709 #[test]
1710 fn effect_tool_schemas_all_present() {
1711 let schemas = effect_tool_schemas();
1712 let names: Vec<&str> = schemas.iter().map(|s| s.name.as_str()).collect();
1713 assert!(names.contains(&"write_memory"));
1714 assert!(names.contains(&"delete_memory"));
1715 assert!(names.contains(&"delegate"));
1716 assert!(names.contains(&"handoff"));
1717 assert!(names.contains(&"signal"));
1718 assert_eq!(schemas.len(), 5);
1719 }
1720
1721 #[test]
1722 fn react_operator_implements_operator_trait() {
1723 fn _assert_operator<T: Operator>() {}
1725 _assert_operator::<ReactOperator<MockProvider>>();
1726 }
1727
1728 #[tokio::test]
1729 async fn react_operator_as_arc_dyn_operator() {
1730 let provider = MockProvider::new(vec![simple_text_response("Hello!")]);
1732 let op: Arc<dyn Operator> = Arc::new(ReactOperator::new(
1733 provider,
1734 ToolRegistry::new(),
1735 Box::new(NoCompaction),
1736 HookRegistry::new(),
1737 Arc::new(NullStateReader),
1738 ReactConfig::default(),
1739 ));
1740
1741 let output = op.execute(simple_input("Hi")).await.unwrap();
1742 assert_eq!(output.exit_reason, ExitReason::Complete);
1743 }
1744
1745 #[tokio::test]
1746 async fn provider_retryable_error_maps_to_retryable() {
1747 struct ErrorProvider;
1748 impl Provider for ErrorProvider {
1749 #[allow(clippy::manual_async_fn)]
1750 fn complete(
1751 &self,
1752 _request: ProviderRequest,
1753 ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1754 {
1755 async { Err(ProviderError::RateLimited) }
1756 }
1757 }
1758
1759 let op = ReactOperator::new(
1760 ErrorProvider,
1761 ToolRegistry::new(),
1762 Box::new(NoCompaction),
1763 HookRegistry::new(),
1764 Arc::new(NullStateReader),
1765 ReactConfig::default(),
1766 );
1767
1768 let result = op.execute(simple_input("test")).await;
1769 assert!(matches!(result, Err(OperatorError::Retryable(_))));
1770 }
1771
1772 #[tokio::test]
1773 async fn provider_call_count() {
1774 let provider = MockProvider::new(vec![
1775 tool_use_response("tu_1", "echo", json!({})),
1776 tool_use_response("tu_2", "echo", json!({})),
1777 simple_text_response("Done"),
1778 ]);
1779 let call_count = std::sync::Arc::new(AtomicUsize::new(0));
1780
1781 struct CountingProvider {
1782 inner: MockProvider,
1783 count: std::sync::Arc<AtomicUsize>,
1784 }
1785 impl Provider for CountingProvider {
1786 #[allow(clippy::manual_async_fn)]
1787 fn complete(
1788 &self,
1789 request: ProviderRequest,
1790 ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1791 {
1792 self.count.fetch_add(1, Ordering::SeqCst);
1793 self.inner.complete(request)
1794 }
1795 }
1796
1797 let counting_provider = CountingProvider {
1798 inner: MockProvider::new(vec![
1799 tool_use_response("tu_1", "echo", json!({})),
1800 tool_use_response("tu_2", "echo", json!({})),
1801 simple_text_response("Done"),
1802 ]),
1803 count: call_count.clone(),
1804 };
1805
1806 let mut tools = ToolRegistry::new();
1807 tools.register(Arc::new(EchoTool));
1808 let op = make_op_with_tools(counting_provider, tools);
1809
1810 op.execute(simple_input("Multi-turn")).await.unwrap();
1811 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1813 drop(provider);
1815 }
1816
1817 struct MockSteering {
1819 seq: Mutex<VecDeque<Vec<ProviderMessage>>>,
1820 calls: AtomicUsize,
1821 }
1822 impl MockSteering {
1823 fn new(seq: Vec<Vec<ProviderMessage>>) -> Self {
1824 Self {
1825 seq: Mutex::new(seq.into()),
1826 calls: AtomicUsize::new(0),
1827 }
1828 }
1829 fn call_count(&self) -> usize {
1830 self.calls.load(Ordering::SeqCst)
1831 }
1832 }
1833 impl SteeringSource for MockSteering {
1834 fn drain(&self) -> Vec<ProviderMessage> {
1835 self.calls.fetch_add(1, Ordering::SeqCst);
1836 self.seq.lock().unwrap().pop_front().unwrap_or_default()
1837 }
1838 }
1839
1840 struct CountingEchoTool {
1841 hits: std::sync::Arc<AtomicUsize>,
1842 }
1843 impl CountingEchoTool {
1844 fn new(h: std::sync::Arc<AtomicUsize>) -> Self {
1845 Self { hits: h }
1846 }
1847 }
1848 impl neuron_tool::ToolDyn for CountingEchoTool {
1849 fn name(&self) -> &str {
1850 "echo"
1851 }
1852 fn description(&self) -> &str {
1853 "Echoes input (counting)"
1854 }
1855 fn input_schema(&self) -> serde_json::Value {
1856 json!({"type":"object"})
1857 }
1858 fn call(
1859 &self,
1860 input: serde_json::Value,
1861 ) -> std::pin::Pin<
1862 Box<
1863 dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
1864 + Send
1865 + '_,
1866 >,
1867 > {
1868 self.hits.fetch_add(1, Ordering::SeqCst);
1869 Box::pin(async move { Ok(json!({"echoed": input})) })
1870 }
1871 }
1872
1873 struct SharedOnlyDecider;
1874 impl ConcurrencyDecider for SharedOnlyDecider {
1875 fn concurrency(&self, tool_name: &str) -> Concurrency {
1876 if tool_name == "echo" {
1877 Concurrency::Shared
1878 } else {
1879 Concurrency::Exclusive
1880 }
1881 }
1882 }
1883
1884 fn user_msg(text: &str) -> ProviderMessage {
1885 ProviderMessage {
1886 role: Role::User,
1887 content: vec![ContentPart::Text { text: text.into() }],
1888 }
1889 }
1890
1891 #[tokio::test]
1892 async fn steering_skips_remaining_shared_batch() {
1893 let first = ProviderResponse {
1895 content: vec![
1896 ContentPart::ToolUse {
1897 id: "t1".into(),
1898 name: "echo".into(),
1899 input: json!({"n":1}),
1900 },
1901 ContentPart::ToolUse {
1902 id: "t2".into(),
1903 name: "echo".into(),
1904 input: json!({"n":2}),
1905 },
1906 ],
1907 stop_reason: StopReason::ToolUse,
1908 usage: TokenUsage {
1909 input_tokens: 10,
1910 output_tokens: 15,
1911 ..Default::default()
1912 },
1913 model: "mock".into(),
1914 cost: None,
1915 truncated: None,
1916 };
1917 let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
1918 let hits = std::sync::Arc::new(AtomicUsize::new(0));
1919 let mut tools = ToolRegistry::new();
1920 tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
1921 let steering = Arc::new(MockSteering::new(vec![
1922 vec![], vec![user_msg("STEER")], ]));
1925 let steering_ref = steering.clone();
1926 let op = make_op_with_tools(provider, tools)
1927 .with_planner(Box::new(BarrierPlanner))
1928 .with_concurrency_decider(Box::new(SharedOnlyDecider))
1929 .with_steering(steering);
1930 let output = op.execute(simple_input("run"));
1931 let output = output.await.unwrap();
1932 assert_eq!(output.exit_reason, ExitReason::Complete);
1933 assert!(steering_ref.call_count() >= 1);
1934 assert_eq!(hits.load(Ordering::SeqCst), 1);
1936 assert_eq!(output.metadata.turns_used, 2);
1937 assert_eq!(output.metadata.tools_called.len(), 2);
1938 assert_eq!(output.metadata.tools_called[0].name, "echo");
1939 assert_eq!(output.metadata.tools_called[1].name, "echo");
1940 }
1941 #[tokio::test]
1942 async fn steering_skips_before_exclusive() {
1943 let first = ProviderResponse {
1945 content: vec![ContentPart::ToolUse {
1946 id: "t1".into(),
1947 name: "echo".into(),
1948 input: json!({}),
1949 }],
1950 stop_reason: StopReason::ToolUse,
1951 usage: TokenUsage {
1952 input_tokens: 10,
1953 output_tokens: 15,
1954 ..Default::default()
1955 },
1956 model: "mock".into(),
1957 cost: None,
1958 truncated: None,
1959 };
1960 let call_count = std::sync::Arc::new(AtomicUsize::new(0));
1962 struct CountingProvider {
1963 inner: MockProvider,
1964 count: std::sync::Arc<AtomicUsize>,
1965 }
1966 impl Provider for CountingProvider {
1967 #[allow(clippy::manual_async_fn)]
1968 fn complete(
1969 &self,
1970 request: ProviderRequest,
1971 ) -> impl std::future::Future<Output = Result<ProviderResponse, ProviderError>> + Send
1972 {
1973 self.count.fetch_add(1, Ordering::SeqCst);
1974 self.inner.complete(request)
1975 }
1976 }
1977 let counting = CountingProvider {
1978 inner: MockProvider::new(vec![first, simple_text_response("Done")]),
1979 count: call_count.clone(),
1980 };
1981 let hits = std::sync::Arc::new(AtomicUsize::new(0));
1982 let mut tools = ToolRegistry::new();
1983 tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
1984 let steering = Arc::new(MockSteering::new(vec![
1985 vec![user_msg("STEER")], ]));
1987 let op = ReactOperator::new(
1988 counting,
1989 tools,
1990 Box::new(NoCompaction),
1991 HookRegistry::new(),
1992 Arc::new(NullStateReader),
1993 ReactConfig::default(),
1994 )
1995 .with_steering(steering);
1996 let output = op.execute(simple_input("run"));
1997 let output = output.await.unwrap();
1998 assert_eq!(output.exit_reason, ExitReason::Complete);
1999 assert_eq!(hits.load(Ordering::SeqCst), 0);
2001 assert_eq!(call_count.load(Ordering::SeqCst), 2);
2003 assert_eq!(output.metadata.turns_used, 2);
2004 }
2005
2006 #[tokio::test]
2007 async fn no_steering_default() {
2008 let first = ProviderResponse {
2010 content: vec![
2011 ContentPart::ToolUse {
2012 id: "t1".into(),
2013 name: "echo".into(),
2014 input: json!({}),
2015 },
2016 ContentPart::ToolUse {
2017 id: "t2".into(),
2018 name: "echo".into(),
2019 input: json!({}),
2020 },
2021 ],
2022 stop_reason: StopReason::ToolUse,
2023 usage: TokenUsage {
2024 input_tokens: 10,
2025 output_tokens: 15,
2026 ..Default::default()
2027 },
2028 model: "mock".into(),
2029 cost: None,
2030 truncated: None,
2031 };
2032 let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
2033 let hits = std::sync::Arc::new(AtomicUsize::new(0));
2034 let mut tools = ToolRegistry::new();
2035 tools.register(Arc::new(CountingEchoTool::new(hits.clone())));
2036 let op = make_op_with_tools(provider, tools)
2037 .with_planner(Box::new(BarrierPlanner))
2038 .with_concurrency_decider(Box::new(SharedOnlyDecider));
2039 let output = op.execute(simple_input("run"));
2040 let output = output.await.unwrap();
2041 assert_eq!(output.exit_reason, ExitReason::Complete);
2042 assert_eq!(hits.load(Ordering::SeqCst), 2);
2043 assert_eq!(output.metadata.tools_called.len(), 2);
2044 assert_eq!(output.metadata.turns_used, 2);
2045 }
2046
2047 struct StreamEcho;
2049 impl neuron_tool::ToolDyn for StreamEcho {
2050 fn name(&self) -> &str {
2051 "stream_echo"
2052 }
2053 fn description(&self) -> &str {
2054 "Streams echo chunks"
2055 }
2056 fn input_schema(&self) -> serde_json::Value {
2057 json!({"type":"object"})
2058 }
2059 fn call(
2060 &self,
2061 _input: serde_json::Value,
2062 ) -> std::pin::Pin<
2063 Box<
2064 dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
2065 + Send
2066 + '_,
2067 >,
2068 > {
2069 Box::pin(async { Ok(serde_json::json!({"note":"non-stream fallback"})) })
2070 }
2071 fn maybe_streaming(&self) -> Option<&dyn neuron_tool::ToolDynStreaming> {
2072 Some(self)
2073 }
2074 }
2075 impl neuron_tool::ToolDynStreaming for StreamEcho {
2076 fn call_streaming<'a>(
2077 &'a self,
2078 _input: serde_json::Value,
2079 on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
2080 ) -> std::pin::Pin<
2081 Box<dyn std::future::Future<Output = Result<(), neuron_tool::ToolError>> + Send + 'a>,
2082 > {
2083 Box::pin(async move {
2084 for ch in ["A", "B", "C"] {
2085 on_chunk(ch);
2086 }
2087 Ok(())
2088 })
2089 }
2090 }
2091
2092 struct CollectHook {
2093 points: Vec<HookPoint>,
2094 chunks: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
2095 finals: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
2096 }
2097 #[async_trait]
2098 impl layer0::hook::Hook for CollectHook {
2099 fn points(&self) -> &[HookPoint] {
2100 &self.points
2101 }
2102 async fn on_event(
2103 &self,
2104 ctx: &HookContext,
2105 ) -> Result<HookAction, layer0::error::HookError> {
2106 if ctx.point == HookPoint::ToolExecutionUpdate {
2107 if let Some(c) = &ctx.tool_chunk {
2108 self.chunks.lock().unwrap().push(c.clone());
2109 }
2110 Ok(HookAction::Continue)
2111 } else if ctx.point == HookPoint::PostToolUse {
2112 if let Some(r) = &ctx.tool_result {
2113 self.finals.lock().unwrap().push(r.clone());
2114 }
2115 Ok(HookAction::Continue)
2116 } else {
2117 Ok(HookAction::Continue)
2118 }
2119 }
2120 }
2121
2122 #[tokio::test]
2123 async fn streaming_chunks_forwarded_and_concatenated() {
2124 let _provider = MockProvider::new(vec![
2126 tool_use_response("tu_s", "stream_echo", json!({"n":1})),
2127 simple_text_response("OK"),
2128 ]);
2129 let mut tools = ToolRegistry::new();
2130 tools.register(Arc::new(StreamEcho));
2131 let chunks = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
2133 let finals = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
2134 let mut hooks = HookRegistry::new();
2135 hooks.add(Arc::new(CollectHook {
2136 points: vec![HookPoint::ToolExecutionUpdate, HookPoint::PostToolUse],
2137 chunks: chunks.clone(),
2138 finals: finals.clone(),
2139 }));
2140 let op = ReactOperator::new(
2141 MockProvider::new(vec![
2142 tool_use_response("tu_s", "stream_echo", json!({})),
2143 simple_text_response("OK"),
2144 ]),
2145 tools,
2146 Box::new(NoCompaction),
2147 hooks,
2148 Arc::new(NullStateReader),
2149 ReactConfig::default(),
2150 );
2151 let _ = op.execute(simple_input("run")).await.unwrap();
2152 let got_chunks = chunks.lock().unwrap().clone();
2153 assert_eq!(got_chunks, vec!["A", "B", "C"]);
2154 let got_finals = finals.lock().unwrap().clone();
2155 assert_eq!(got_finals.len(), 1);
2156 assert_eq!(got_finals[0], "ABC");
2157 }
2158
2159 struct CountingSharedEchoTool {
2160 hits: std::sync::Arc<AtomicUsize>,
2161 }
2162 impl CountingSharedEchoTool {
2163 fn new(h: std::sync::Arc<AtomicUsize>) -> Self {
2164 Self { hits: h }
2165 }
2166 }
2167 impl neuron_tool::ToolDyn for CountingSharedEchoTool {
2168 fn name(&self) -> &str {
2169 "meta_echo"
2170 }
2171 fn description(&self) -> &str {
2172 "Echoes input (shared via metadata)"
2173 }
2174 fn input_schema(&self) -> serde_json::Value {
2175 json!({"type":"object"})
2176 }
2177 fn call(
2178 &self,
2179 input: serde_json::Value,
2180 ) -> std::pin::Pin<
2181 Box<
2182 dyn std::future::Future<Output = Result<serde_json::Value, neuron_tool::ToolError>>
2183 + Send
2184 + '_,
2185 >,
2186 > {
2187 self.hits.fetch_add(1, Ordering::SeqCst);
2188 Box::pin(async move { Ok(json!({"echoed": input})) })
2189 }
2190 fn concurrency_hint(&self) -> neuron_tool::ToolConcurrencyHint {
2191 neuron_tool::ToolConcurrencyHint::Shared
2192 }
2193 }
2194
2195 #[tokio::test]
2196 async fn metadata_concurrency_batches_shared() {
2197 let first = ProviderResponse {
2199 content: vec![
2200 ContentPart::ToolUse {
2201 id: "t1".into(),
2202 name: "meta_echo".into(),
2203 input: json!({}),
2204 },
2205 ContentPart::ToolUse {
2206 id: "t2".into(),
2207 name: "meta_echo".into(),
2208 input: json!({}),
2209 },
2210 ],
2211 stop_reason: StopReason::ToolUse,
2212 usage: TokenUsage {
2213 input_tokens: 10,
2214 output_tokens: 15,
2215 ..Default::default()
2216 },
2217 model: "mock".into(),
2218 cost: None,
2219 truncated: None,
2220 };
2221 let provider = MockProvider::new(vec![first, simple_text_response("Done")]);
2222 let hits = std::sync::Arc::new(AtomicUsize::new(0));
2223 let mut tools = ToolRegistry::new();
2224 tools.register(Arc::new(CountingSharedEchoTool::new(hits.clone())));
2225 let op = make_op_with_tools(provider, tools)
2226 .with_planner(Box::new(BarrierPlanner))
2227 .with_metadata_concurrency();
2228 let output = op.execute(simple_input("run")).await.unwrap();
2229 assert_eq!(output.exit_reason, ExitReason::Complete);
2230 assert_eq!(hits.load(Ordering::SeqCst), 2);
2231 assert_eq!(output.metadata.tools_called.len(), 2);
2232 assert_eq!(output.metadata.turns_used, 2);
2233 }
2234}