1use parking_lot::RwLock;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::{Arc, LazyLock};
8
9use futures_core::Stream;
10use futures_util::StreamExt;
11use serde_json::Value;
12
13use crate::acl::ACL;
14use crate::approval::ApprovalHandler;
15use crate::builtin_steps::{
16 build_internal_strategy, build_minimal_strategy, build_performance_strategy,
17 build_standard_strategy, build_testing_strategy,
18};
19use crate::config::Config;
20use crate::context::{Context, Identity};
21use crate::errors::{ErrorCode, ModuleError};
22use crate::middleware::adapters::{AfterMiddleware, BeforeMiddleware};
23use crate::middleware::base::Middleware;
24use crate::middleware::manager::MiddlewareManager;
25use crate::module::PreflightCheckResult as PfCheck;
26use crate::module::{PreflightCheckResult, PreflightResult};
27use crate::pipeline::{
28 ExecutionStrategy, PipelineContext, PipelineEngine, PipelineTrace, StrategyInfo,
29};
30use crate::registry::registry::{module_id_pattern, Registry};
31
32const DEEP_MERGE_MAX_DEPTH: usize = 64;
35
36fn deep_merge_chunks(chunks: &[Value]) -> Value {
42 let mut acc = Value::Null;
43 for chunk in chunks {
44 deep_merge_value(&mut acc, chunk, 0);
45 }
46 acc
47}
48
49fn deep_merge_value(base: &mut Value, overlay: &Value, depth: usize) {
50 if depth >= DEEP_MERGE_MAX_DEPTH {
51 *base = overlay.clone();
53 return;
54 }
55 match (base, overlay) {
56 (Value::Object(base_map), Value::Object(overlay_map)) => {
57 for (k, v) in overlay_map {
58 let entry = base_map.entry(k.clone()).or_insert(Value::Null);
59 deep_merge_value(entry, v, depth + 1);
60 }
61 }
62 (base, overlay) => {
63 *base = overlay.clone();
64 }
65 }
66}
67
68pub fn resolve_strategy_by_name(name: &str) -> Result<ExecutionStrategy, ModuleError> {
72 match name {
73 "standard" => Ok(build_standard_strategy()),
74 "internal" => Ok(build_internal_strategy()),
75 "testing" => Ok(build_testing_strategy()),
76 "performance" => Ok(build_performance_strategy()),
77 "minimal" => Ok(build_minimal_strategy()),
78 _ => Err(ModuleError::new(
79 ErrorCode::GeneralInvalidInput,
80 format!("Unknown strategy name '{name}'. Built-in presets: standard, internal, testing, performance, minimal"),
81 )),
82 }
83}
84
85fn step_to_check_name(step_name: &str) -> &str {
87 match step_name {
88 "context_creation" => "context",
89 "call_chain_guard" => "call_chain",
90 "module_lookup" => "module_lookup",
91 "acl_check" => "acl",
92 "approval_gate" => "approval",
93 "middleware_before" => "middleware",
94 "input_validation" => "schema",
95 other => other,
96 }
97}
98
99fn trace_to_checks(trace: &PipelineTrace) -> Vec<PfCheck> {
101 trace
102 .steps
103 .iter()
104 .filter(|st| !st.skipped)
105 .map(|st| {
106 let check_name = step_to_check_name(&st.name).to_string();
107 let passed = st.result.action != "abort";
108 let error = if passed {
109 None
110 } else {
111 st.result.explanation.as_ref().map(|msg| {
112 serde_json::json!({
113 "code": format!("STEP_{}_FAILED", st.name.to_uppercase()),
114 "message": msg,
115 })
116 })
117 };
118 PfCheck {
119 check: check_name,
120 passed,
121 error,
122 warnings: vec![],
123 }
124 })
125 .collect()
126}
127
128pub fn has_schema(schema: &Value) -> bool {
130 if schema.is_null() {
131 return false;
132 }
133 if let Some(obj) = schema.as_object() {
134 return !obj.is_empty();
135 }
136 true
137}
138
139pub const REDACTED_VALUE: &str = "***REDACTED***";
141
142struct StreamSetup {
145 module: Arc<dyn crate::module::Module>,
146 inputs: Value,
147 context: Context<Value>,
148 output_schema: Value,
149 middleware_manager: Option<Arc<MiddlewareManager>>,
150}
151
152fn streaming_not_supported_error(module_id: &str) -> ModuleError {
155 ModuleError::new(
156 ErrorCode::GeneralNotImplemented,
157 format!("Module '{module_id}' does not support streaming (Module::stream returned None)"),
158 )
159}
160
161pub fn validate_against_schema(
164 value: &Value,
165 schema: &Value,
166 direction: &str,
167) -> Result<(), ModuleError> {
168 if !has_schema(schema) {
170 return Ok(());
171 }
172
173 let validator = match jsonschema::validator_for(schema) {
174 Ok(v) => v,
175 Err(e) => {
176 return Err(ModuleError::new(
177 ErrorCode::SchemaValidationError,
178 format!("{direction} schema is invalid: {e}"),
179 ));
180 }
181 };
182
183 if validator.is_valid(value) {
184 return Ok(());
185 }
186
187 let error_list: Vec<HashMap<String, String>> = validator
188 .iter_errors(value)
189 .map(|e| {
190 let mut map = HashMap::new();
191 map.insert("field".to_string(), e.instance_path.to_string());
192 map.insert("message".to_string(), e.to_string());
193 map
194 })
195 .collect();
196
197 let errors_json: Vec<Value> = error_list
198 .iter()
199 .map(|e| {
200 serde_json::to_value(e).expect("HashMap<String, String> serialization is infallible")
202 })
203 .collect();
204 let mut details = HashMap::new();
205 details.insert("errors".to_string(), Value::Array(errors_json));
206
207 Err(ModuleError::new(
208 ErrorCode::SchemaValidationError,
209 format!("{direction} validation failed"),
210 )
211 .with_details(details)
212 .with_ai_guidance(format!(
213 "{direction} failed schema validation. Check the 'errors' field in details for specific validation failures."
214 )))
215}
216
217pub fn redact_sensitive(data: &Value, schema: &Value) -> Value {
222 let mut redacted = data.clone();
223 if let Some(obj) = redacted.as_object_mut() {
224 redact_fields(obj, schema);
225 redact_secret_prefix(obj);
226 }
227 redacted
228}
229
230fn redact_fields(data: &mut serde_json::Map<String, Value>, schema: &Value) {
232 let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
233 return;
234 };
235
236 for (field_name, field_schema) in properties {
237 let value = match data.get(field_name) {
238 Some(v) => v.clone(),
239 None => continue,
240 };
241
242 if field_schema.get("x-sensitive") == Some(&Value::Bool(true)) {
244 if !value.is_null() {
245 data.insert(
246 field_name.clone(),
247 Value::String(REDACTED_VALUE.to_string()),
248 );
249 }
250 continue;
251 }
252
253 if field_schema.get("type") == Some(&Value::String("object".to_string()))
255 && field_schema.get("properties").is_some()
256 {
257 if let Some(obj) = data.get_mut(field_name).and_then(|v| v.as_object_mut()) {
258 redact_fields(obj, field_schema);
259 }
260 continue;
261 }
262
263 if field_schema.get("type") == Some(&Value::String("array".to_string())) {
265 if let Some(items_schema) = field_schema.get("items") {
266 if let Some(arr) = data.get_mut(field_name).and_then(|v| v.as_array_mut()) {
267 if items_schema.get("x-sensitive") == Some(&Value::Bool(true)) {
268 for item in arr.iter_mut() {
269 if !item.is_null() {
270 *item = Value::String(REDACTED_VALUE.to_string());
271 }
272 }
273 } else if items_schema.get("type") == Some(&Value::String("object".to_string()))
274 && items_schema.get("properties").is_some()
275 {
276 for item in arr.iter_mut() {
277 if let Some(obj) = item.as_object_mut() {
278 redact_fields(obj, items_schema);
279 }
280 }
281 }
282 }
283 }
284 }
285 }
286}
287
288fn redact_secret_prefix(data: &mut serde_json::Map<String, Value>) {
290 let keys: Vec<String> = data.keys().cloned().collect();
291 for key in keys {
292 if key.starts_with("_secret_") {
293 if let Some(val) = data.get(&key) {
294 if !val.is_null() {
295 data.insert(key, Value::String(REDACTED_VALUE.to_string()));
296 }
297 }
298 } else if let Some(obj) = data.get_mut(&key).and_then(|v| v.as_object_mut()) {
299 redact_secret_prefix(obj);
300 }
301 }
302}
303
304static STRATEGY_REGISTRY: LazyLock<RwLock<Vec<StrategyInfo>>> =
312 LazyLock::new(|| RwLock::new(Vec::new()));
313
314pub fn register_strategy(info: StrategyInfo) {
318 let mut registry = STRATEGY_REGISTRY.write();
319 if let Some(existing) = registry.iter_mut().find(|s| s.name == info.name) {
321 *existing = info;
322 } else {
323 registry.push(info);
324 }
325}
326
327pub fn list_strategies() -> Vec<StrategyInfo> {
329 STRATEGY_REGISTRY.read().clone()
330}
331
332#[derive(Debug)]
334pub struct Executor {
335 pub registry: Arc<Registry>,
336 pub config: Arc<Config>,
337 pub acl: Option<Arc<ACL>>,
338 pub approval_handler: Option<Arc<dyn ApprovalHandler>>,
339 pub middleware_manager: Arc<MiddlewareManager>,
340 strategy: ExecutionStrategy,
342}
343
344impl Executor {
345 pub fn new(registry: impl Into<Arc<Registry>>, config: impl Into<Arc<Config>>) -> Self {
351 Self {
352 registry: registry.into(),
353 config: config.into(),
354 acl: None,
355 approval_handler: None,
356 middleware_manager: Arc::new(MiddlewareManager::new()),
357 strategy: build_standard_strategy(),
358 }
359 }
360
361 pub fn with_strategy_name(
366 registry: impl Into<Arc<Registry>>,
367 config: impl Into<Arc<Config>>,
368 name: &str,
369 ) -> Result<Self, ModuleError> {
370 let strategy = resolve_strategy_by_name(name)?;
371 Ok(Self {
372 registry: registry.into(),
373 config: config.into(),
374 acl: None,
375 approval_handler: None,
376 middleware_manager: Arc::new(MiddlewareManager::new()),
377 strategy,
378 })
379 }
380
381 pub fn with_strategy(
383 registry: impl Into<Arc<Registry>>,
384 config: impl Into<Arc<Config>>,
385 strategy: ExecutionStrategy,
386 ) -> Self {
387 Self {
388 registry: registry.into(),
389 config: config.into(),
390 acl: None,
391 approval_handler: None,
392 middleware_manager: Arc::new(MiddlewareManager::new()),
393 strategy,
394 }
395 }
396
397 pub fn with_options(
399 registry: impl Into<Arc<Registry>>,
400 config: impl Into<Arc<Config>>,
401 middlewares: Option<Vec<Box<dyn Middleware>>>,
402 acl: Option<ACL>,
403 approval_handler: Option<Box<dyn ApprovalHandler>>,
404 ) -> Self {
405 let middleware_manager = MiddlewareManager::new();
406 if let Some(mws) = middlewares {
407 for mw in mws {
408 if let Err(e) = middleware_manager.add(mw) {
411 tracing::warn!("Skipping middleware during executor construction: {}", e);
412 }
413 }
414 }
415 Self {
416 registry: registry.into(),
417 config: config.into(),
418 acl: acl.map(Arc::new),
419 approval_handler: approval_handler.map(|h| Arc::from(h) as Arc<dyn ApprovalHandler>),
420 middleware_manager: Arc::new(middleware_manager),
421 strategy: build_standard_strategy(),
422 }
423 }
424
425 pub fn registry(&self) -> &Registry {
427 &self.registry
428 }
429
430 pub fn middlewares(&self) -> Vec<String> {
432 self.middleware_manager.snapshot()
433 }
434
435 pub fn set_acl(&mut self, acl: ACL) {
437 self.acl = Some(Arc::new(acl));
438 }
439
440 pub fn set_approval_handler(&mut self, handler: Box<dyn ApprovalHandler>) {
442 self.approval_handler = Some(Arc::from(handler));
443 }
444
445 pub fn use_middleware(&self, middleware: Box<dyn Middleware>) -> Result<(), ModuleError> {
455 self.middleware_manager.add(middleware)
456 }
457
458 pub fn remove(&self, name: &str) -> bool {
460 self.middleware_manager.remove(name)
461 }
462
463 pub fn remove_middleware(&self, name: &str) -> bool {
465 self.remove(name)
466 }
467
468 fn validate_module_id(module_id: &str) -> Result<(), ModuleError> {
475 if module_id.is_empty() || !module_id_pattern().is_match(module_id) {
476 return Err(ModuleError::new(
477 ErrorCode::GeneralInvalidInput,
478 format!(
479 "Invalid module ID: '{}'. Must match pattern: {}",
480 module_id,
481 crate::registry::registry::MODULE_ID_PATTERN,
482 ),
483 ));
484 }
485 Ok(())
486 }
487
488 pub async fn call(
492 &self,
493 module_id: &str,
494 inputs: serde_json::Value,
495 ctx: Option<&Context<serde_json::Value>>,
496 version_hint: Option<&str>,
497 ) -> Result<serde_json::Value, ModuleError> {
498 Self::validate_module_id(module_id)?;
499 let context = match ctx {
500 Some(c) => c.clone(),
501 None => Context::<serde_json::Value>::new(Identity::new(
502 "@external".to_string(),
503 "external".to_string(),
504 vec![],
505 HashMap::new(),
506 )),
507 };
508 let mut pipe_ctx = PipelineContext::new(module_id, inputs, context, self.strategy.name());
509 if let Some(hint) = version_hint {
510 pipe_ctx.version_hint = Some(hint.to_string());
511 }
512 self.inject_resources(&mut pipe_ctx);
513 match PipelineEngine::run(&self.strategy, &mut pipe_ctx).await {
514 Ok((output, trace)) => {
515 if !trace.success {
516 let (aborted_step, explanation) = trace
517 .steps
518 .iter()
519 .find_map(|s| {
520 if s.result.action == "abort" {
521 Some((s.name.as_str(), s.result.explanation.as_deref()))
522 } else {
523 None
524 }
525 })
526 .unwrap_or(("unknown", None));
527 return Err(ModuleError::pipeline_abort(aborted_step, explanation));
528 }
529 Ok(output.unwrap_or(serde_json::Value::Null))
530 }
531 Err(e) => {
532 let executed = pipe_ctx.executed_middlewares.clone();
535 if !executed.is_empty() {
536 if let Some(recovery) = self
537 .middleware_manager
538 .execute_on_error(
539 module_id,
540 pipe_ctx.inputs,
541 &e,
542 &pipe_ctx.context,
543 &executed,
544 )
545 .await
546 {
547 return Ok(recovery);
548 }
549 }
550 Err(e)
551 }
552 }
553 }
554
555 pub async fn validate(
570 &self,
571 module_id: &str,
572 inputs: &serde_json::Value,
573 ctx: Option<&Context<serde_json::Value>>,
574 ) -> Result<PreflightResult, ModuleError> {
575 Self::validate_module_id(module_id)?;
576 let context = ctx.cloned().unwrap_or_else(|| {
577 Context::<serde_json::Value>::new(Identity::new(
578 "@external".to_string(),
579 "external".to_string(),
580 vec![],
581 HashMap::new(),
582 ))
583 });
584 let mut pipe_ctx =
585 PipelineContext::new(module_id, inputs.clone(), context, self.strategy.name());
586 pipe_ctx.dry_run = true;
587 self.inject_resources(&mut pipe_ctx);
588
589 let mut checks: Vec<PreflightCheckResult> = Vec::new();
590 let trace_result = PipelineEngine::run(&self.strategy, &mut pipe_ctx).await;
591 match trace_result {
592 Ok((_output, trace)) => {
593 checks.extend(trace_to_checks(&trace));
594 }
595 Err(e) => {
596 checks.extend(trace_to_checks(&pipe_ctx.trace));
598 let check_name = match e.code {
599 ErrorCode::ModuleNotFound => "module_lookup",
600 ErrorCode::ACLDenied => "acl",
601 ErrorCode::SchemaValidationError | ErrorCode::GeneralInvalidInput => "schema",
602 ErrorCode::CallDepthExceeded | ErrorCode::CircularCall => "call_chain",
603 _ => "unknown",
604 };
605 checks.push(PreflightCheckResult {
606 check: check_name.to_string(),
607 passed: false,
608 error: Some(serde_json::json!({
609 "code": format!("{:?}", e.code),
610 "message": e.message,
611 })),
612 warnings: vec![],
613 });
614 }
615 }
616
617 let mut requires_approval = false;
619 if let Some(desc) = self.registry.get_definition(module_id) {
620 if desc
621 .annotations
622 .as_ref()
623 .is_some_and(|a| a.requires_approval)
624 {
625 requires_approval = true;
626 }
627 }
628
629 let valid = checks.iter().all(|c| c.passed);
630 Ok(PreflightResult {
631 valid,
632 checks,
633 requires_approval,
634 })
635 }
636
637 pub fn from_registry(
639 registry: impl Into<Arc<Registry>>,
640 config: impl Into<Arc<Config>>,
641 ) -> Self {
642 Self::new(registry, config)
643 }
644
645 pub fn stream<'a>(
665 &'a self,
666 module_id: &str,
667 inputs: Value,
668 ctx: Option<&Context<Value>>,
669 version_hint: Option<&str>,
670 ) -> Pin<Box<dyn Stream<Item = Result<Value, ModuleError>> + Send + 'a>> {
671 let module_id_owned = module_id.to_string();
673 let version_hint_owned = version_hint.map(str::to_string);
674 let initial_context = ctx.cloned();
675
676 Box::pin(async_stream::try_stream! {
677 Self::validate_module_id(&module_id_owned)?;
679
680 let mut setup = self
682 .prepare_stream(
683 &module_id_owned,
684 inputs,
685 initial_context,
686 version_hint_owned.as_deref(),
687 )
688 .await?;
689
690 let Some(mut inner) = setup.module.stream(setup.inputs.clone(), &setup.context) else {
695 Err(streaming_not_supported_error(&module_id_owned))?;
696 return;
698 };
699
700 let mut accumulated: Vec<Value> = Vec::new();
701 while let Some(chunk_result) = inner.next().await {
702 let chunk = chunk_result?;
703 accumulated.push(chunk.clone());
704 yield chunk;
705 }
706
707 let merged = deep_merge_chunks(&accumulated);
710 validate_against_schema(&merged, &setup.output_schema, "Output")?;
711 if let Some(ref mm) = setup.middleware_manager {
712 mm.execute_after(&module_id_owned, setup.inputs.clone(), merged, &setup.context)
713 .await?;
714 }
715 let _ = &mut setup; })
719 }
720
721 async fn prepare_stream(
732 &self,
733 module_id: &str,
734 inputs: Value,
735 ctx: Option<Context<Value>>,
736 version_hint: Option<&str>,
737 ) -> Result<StreamSetup, ModuleError> {
738 let context = ctx.unwrap_or_else(|| {
739 Context::<Value>::new(Identity::new(
740 "@external".to_string(),
741 "external".to_string(),
742 vec![],
743 HashMap::new(),
744 ))
745 });
746
747 let mut pipe_ctx = PipelineContext::new(module_id, inputs, context, self.strategy.name());
748 if let Some(hint) = version_hint {
749 pipe_ctx.version_hint = Some(hint.to_string());
750 }
751 self.inject_resources(&mut pipe_ctx);
752
753 let (_output, trace) =
757 PipelineEngine::run_until(&self.strategy, &mut pipe_ctx, "execute").await?;
758
759 if !trace.success {
763 let explanation = trace
764 .steps
765 .iter()
766 .find_map(|s| {
767 if s.result.action == "abort" {
768 s.result.explanation.clone()
769 } else {
770 None
771 }
772 })
773 .unwrap_or_else(|| "pre-stream pipeline aborted".to_string());
774 return Err(ModuleError::new(
775 ErrorCode::GeneralInternalError,
776 explanation,
777 ));
778 }
779
780 let module = pipe_ctx.module.clone().ok_or_else(|| {
781 ModuleError::new(
782 ErrorCode::ModuleNotFound,
783 format!("Module '{module_id}' was not resolved during pre-stream setup"),
784 )
785 })?;
786 let output_schema = module.output_schema();
787
788 Ok(StreamSetup {
789 module,
790 inputs: pipe_ctx.inputs,
791 context: pipe_ctx.context,
792 output_schema,
793 middleware_manager: pipe_ctx.middleware_manager.clone(),
794 })
795 }
796
797 pub fn strategy(&self) -> &ExecutionStrategy {
799 &self.strategy
800 }
801
802 pub fn describe_pipeline(&self) -> StrategyInfo {
810 self.strategy.info()
811 }
812
813 #[deprecated(
817 since = "0.20.0",
818 note = "Use the module-level `register_strategy` function directly."
819 )]
820 pub fn register_strategy(info: StrategyInfo) {
821 register_strategy(info);
822 }
823
824 #[deprecated(
828 since = "0.20.0",
829 note = "Use the module-level `list_strategies` function directly."
830 )]
831 pub fn list_strategies() -> Vec<StrategyInfo> {
832 list_strategies()
833 }
834
835 pub async fn call_with_trace(
840 &self,
841 module_id: &str,
842 inputs: Value,
843 ctx: Option<&Context<Value>>,
844 strategy: Option<&ExecutionStrategy>,
845 ) -> Result<(Value, PipelineTrace), ModuleError> {
846 let effective_strategy = strategy.unwrap_or(&self.strategy);
847
848 let context = match ctx {
849 Some(c) => c.clone(),
850 None => Context::<Value>::new(Identity::new(
851 "@external".to_string(),
852 "external".to_string(),
853 vec![],
854 HashMap::new(),
855 )),
856 };
857
858 let mut pipeline_ctx =
859 PipelineContext::new(module_id, inputs, context, effective_strategy.name());
860 self.inject_resources(&mut pipeline_ctx);
861
862 let (output, trace) = PipelineEngine::run(effective_strategy, &mut pipeline_ctx).await?;
863
864 Ok((output.unwrap_or(Value::Null), trace))
865 }
866
867 fn inject_resources(&self, ctx: &mut PipelineContext) {
870 ctx.registry = Some(Arc::clone(&self.registry));
871 ctx.config = Some(Arc::clone(&self.config));
872 ctx.acl = self.acl.as_ref().map(Arc::clone);
873 ctx.approval_handler = self.approval_handler.as_ref().map(Arc::clone);
874 ctx.middleware_manager = Some(Arc::clone(&self.middleware_manager));
875 }
876
877 pub fn use_before(&self, middleware: Box<dyn BeforeMiddleware>) -> Result<(), ModuleError> {
879 self.middleware_manager
880 .add(Box::new(BoxedBeforeMiddlewareAdapter(middleware)))
881 }
882
883 pub fn use_after(&self, middleware: Box<dyn AfterMiddleware>) -> Result<(), ModuleError> {
885 self.middleware_manager
886 .add(Box::new(BoxedAfterMiddlewareAdapter(middleware)))
887 }
888}
889
890struct BoxedBeforeMiddlewareAdapter(Box<dyn BeforeMiddleware>);
897
898impl std::fmt::Debug for BoxedBeforeMiddlewareAdapter {
899 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
900 f.debug_struct("BoxedBeforeMiddlewareAdapter")
901 .field("name", &self.0.name())
902 .finish()
903 }
904}
905
906#[async_trait::async_trait]
907impl Middleware for BoxedBeforeMiddlewareAdapter {
908 fn name(&self) -> &str {
909 self.0.name()
910 }
911
912 async fn before(
913 &self,
914 module_id: &str,
915 inputs: serde_json::Value,
916 ctx: &Context<serde_json::Value>,
917 ) -> Result<Option<serde_json::Value>, ModuleError> {
918 self.0.before(module_id, inputs, ctx).await
919 }
920
921 async fn after(
922 &self,
923 _module_id: &str,
924 _inputs: serde_json::Value,
925 _output: serde_json::Value,
926 _ctx: &Context<serde_json::Value>,
927 ) -> Result<Option<serde_json::Value>, ModuleError> {
928 Ok(None)
929 }
930
931 async fn on_error(
932 &self,
933 _module_id: &str,
934 _inputs: serde_json::Value,
935 _error: &ModuleError,
936 _ctx: &Context<serde_json::Value>,
937 ) -> Result<Option<serde_json::Value>, ModuleError> {
938 Ok(None)
939 }
940}
941
942struct BoxedAfterMiddlewareAdapter(Box<dyn AfterMiddleware>);
944
945impl std::fmt::Debug for BoxedAfterMiddlewareAdapter {
946 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
947 f.debug_struct("BoxedAfterMiddlewareAdapter")
948 .field("name", &self.0.name())
949 .finish()
950 }
951}
952
953#[async_trait::async_trait]
954impl Middleware for BoxedAfterMiddlewareAdapter {
955 fn name(&self) -> &str {
956 self.0.name()
957 }
958
959 async fn before(
960 &self,
961 _module_id: &str,
962 _inputs: serde_json::Value,
963 _ctx: &Context<serde_json::Value>,
964 ) -> Result<Option<serde_json::Value>, ModuleError> {
965 Ok(None)
966 }
967
968 async fn after(
969 &self,
970 module_id: &str,
971 inputs: serde_json::Value,
972 output: serde_json::Value,
973 ctx: &Context<serde_json::Value>,
974 ) -> Result<Option<serde_json::Value>, ModuleError> {
975 self.0.after(module_id, inputs, output, ctx).await
976 }
977
978 async fn on_error(
979 &self,
980 _module_id: &str,
981 _inputs: serde_json::Value,
982 _error: &ModuleError,
983 _ctx: &Context<serde_json::Value>,
984 ) -> Result<Option<serde_json::Value>, ModuleError> {
985 Ok(None)
986 }
987}
988
989#[cfg(test)]
990mod tests {
991 use super::*;
992 use crate::approval::{ApprovalHandler, ApprovalRequest, ApprovalResult};
993 use crate::config::Config;
994 use crate::context::Context;
995 use crate::errors::ErrorCode;
996 use crate::module::{Module, ModuleAnnotations};
997 use crate::registry::registry::{ModuleDescriptor, Registry};
998 use async_trait::async_trait;
999 use serde_json::json;
1000 use std::sync::Mutex;
1001
1002 struct MockModule {
1005 input_schema: Value,
1006 output_schema: Value,
1007 output: Value,
1008 }
1009
1010 impl MockModule {
1011 fn new(input_schema: Value, output_schema: Value, output: Value) -> Self {
1012 Self {
1013 input_schema,
1014 output_schema,
1015 output,
1016 }
1017 }
1018
1019 fn echo() -> Self {
1020 Self::new(json!({}), json!({}), json!({"ok": true}))
1021 }
1022 }
1023
1024 #[async_trait]
1025 impl Module for MockModule {
1026 fn input_schema(&self) -> Value {
1027 self.input_schema.clone()
1028 }
1029 fn output_schema(&self) -> Value {
1030 self.output_schema.clone()
1031 }
1032 fn description(&self) -> &'static str {
1033 "mock module"
1034 }
1035 async fn execute(
1036 &self,
1037 _inputs: Value,
1038 _ctx: &Context<Value>,
1039 ) -> Result<Value, ModuleError> {
1040 Ok(self.output.clone())
1041 }
1042 }
1043
1044 #[derive(Debug)]
1049 struct MockApprovalHandler {
1050 result: ApprovalResult,
1052 calls: Mutex<Vec<String>>,
1054 }
1055
1056 impl MockApprovalHandler {
1057 fn with_status(status: &str) -> Self {
1058 Self {
1059 result: ApprovalResult {
1060 status: status.to_string(),
1061 approved_by: None,
1062 reason: Some(format!("mock-{status}")),
1063 approval_id: None,
1064 metadata: None,
1065 },
1066 calls: Mutex::new(vec![]),
1067 }
1068 }
1069 }
1070
1071 #[async_trait]
1072 impl ApprovalHandler for MockApprovalHandler {
1073 async fn request_approval(
1074 &self,
1075 _request: &ApprovalRequest,
1076 ) -> Result<ApprovalResult, ModuleError> {
1077 self.calls.lock().unwrap().push("request".to_string());
1078 Ok(self.result.clone())
1079 }
1080
1081 async fn check_approval(&self, approval_id: &str) -> Result<ApprovalResult, ModuleError> {
1082 self.calls
1083 .lock()
1084 .unwrap()
1085 .push(format!("check:{approval_id}"));
1086 Ok(self.result.clone())
1087 }
1088 }
1089
1090 fn build_executor_with_module(module: MockModule, annotations: ModuleAnnotations) -> Executor {
1093 let registry = Registry::new();
1094 let descriptor = ModuleDescriptor {
1095 module_id: "test_mod".to_string(),
1096 name: None,
1097 description: module.description().to_string(),
1098 documentation: None,
1099 input_schema: module.input_schema(),
1100 output_schema: module.output_schema(),
1101 version: "1.0.0".to_string(),
1102 tags: vec![],
1103 annotations: Some(annotations),
1104 examples: vec![],
1105 metadata: std::collections::HashMap::new(),
1106 display: None,
1107 sunset_date: None,
1108 dependencies: vec![],
1109 enabled: true,
1110 };
1111 registry
1112 .register("test_mod", Box::new(module), descriptor)
1113 .unwrap();
1114 Executor::new(registry, Config::default())
1115 }
1116
1117 #[test]
1122 fn test_validate_against_schema_valid_input_passes() {
1123 let schema = json!({
1124 "type": "object",
1125 "properties": {
1126 "name": {"type": "string"}
1127 },
1128 "required": ["name"]
1129 });
1130 let value = json!({"name": "Alice"});
1131 assert!(validate_against_schema(&value, &schema, "Input").is_ok());
1132 }
1133
1134 #[test]
1135 fn test_validate_against_schema_invalid_input_returns_error_with_details() {
1136 let schema = json!({
1137 "type": "object",
1138 "properties": {
1139 "age": {"type": "integer"}
1140 },
1141 "required": ["age"]
1142 });
1143 let value = json!({"age": "not-a-number"});
1144 let err = validate_against_schema(&value, &schema, "Input").unwrap_err();
1145 assert_eq!(err.code, ErrorCode::SchemaValidationError);
1146 assert!(err.details.contains_key("errors"));
1147 }
1148
1149 #[test]
1150 fn test_validate_against_schema_null_schema_skips() {
1151 let value = json!({"anything": 123});
1152 assert!(validate_against_schema(&value, &Value::Null, "Input").is_ok());
1153 }
1154
1155 #[test]
1156 fn test_validate_against_schema_empty_object_schema_skips() {
1157 let value = json!({"anything": 123});
1158 assert!(validate_against_schema(&value, &json!({}), "Input").is_ok());
1159 }
1160
1161 #[test]
1166 fn test_redact_sensitive_basic_field() {
1167 let schema = json!({
1168 "properties": {
1169 "password": {"type": "string", "x-sensitive": true},
1170 "username": {"type": "string"}
1171 }
1172 });
1173 let data = json!({"password": "s3cret", "username": "alice"});
1174 let result = redact_sensitive(&data, &schema);
1175 assert_eq!(result["password"], REDACTED_VALUE);
1176 assert_eq!(result["username"], "alice");
1177 }
1178
1179 #[test]
1180 fn test_redact_sensitive_nested_object() {
1181 let schema = json!({
1182 "properties": {
1183 "credentials": {
1184 "type": "object",
1185 "properties": {
1186 "token": {"type": "string", "x-sensitive": true},
1187 "scope": {"type": "string"}
1188 }
1189 }
1190 }
1191 });
1192 let data = json!({"credentials": {"token": "abc123", "scope": "read"}});
1193 let result = redact_sensitive(&data, &schema);
1194 assert_eq!(result["credentials"]["token"], REDACTED_VALUE);
1195 assert_eq!(result["credentials"]["scope"], "read");
1196 }
1197
1198 #[test]
1199 fn test_redact_sensitive_array_items() {
1200 let schema = json!({
1201 "properties": {
1202 "tokens": {
1203 "type": "array",
1204 "items": {"type": "string", "x-sensitive": true}
1205 }
1206 }
1207 });
1208 let data = json!({"tokens": ["a", "b", "c"]});
1209 let result = redact_sensitive(&data, &schema);
1210 let arr = result["tokens"].as_array().unwrap();
1211 for item in arr {
1212 assert_eq!(item, REDACTED_VALUE);
1213 }
1214 }
1215
1216 #[test]
1217 fn test_redact_sensitive_secret_prefix_keys() {
1218 let schema = json!({});
1219 let data = json!({
1220 "_secret_api_key": "key123",
1221 "public_field": "visible"
1222 });
1223 let result = redact_sensitive(&data, &schema);
1224 assert_eq!(result["_secret_api_key"], REDACTED_VALUE);
1225 assert_eq!(result["public_field"], "visible");
1226 }
1227
1228 #[test]
1229 fn test_redact_sensitive_null_values_preserved() {
1230 let schema = json!({
1231 "properties": {
1232 "password": {"type": "string", "x-sensitive": true}
1233 }
1234 });
1235 let data = json!({"password": null});
1236 let result = redact_sensitive(&data, &schema);
1237 assert!(result["password"].is_null());
1238 }
1239
1240 #[test]
1241 fn test_redact_sensitive_no_schema_no_redaction() {
1242 let data = json!({"password": "s3cret"});
1243 let result = redact_sensitive(&data, &Value::Null);
1244 assert_eq!(result, data);
1245 }
1246
1247 #[tokio::test]
1252 async fn test_approval_token_stripped_from_inputs_and_check_called() {
1253 let handler = MockApprovalHandler::with_status("approved");
1254 let module = MockModule::echo();
1255 let annotations = ModuleAnnotations {
1256 requires_approval: true,
1257 ..Default::default()
1258 };
1259 let mut executor = build_executor_with_module(module, annotations);
1260 executor.set_approval_handler(Box::new(handler));
1261
1262 let inputs = json!({"_approval_token": "tok-123", "data": "hello"});
1263 let result = executor.call("test_mod", inputs, None, None).await;
1264 assert!(result.is_ok());
1265 }
1266
1267 #[tokio::test]
1268 async fn test_approval_no_token_calls_request_approval() {
1269 let handler = MockApprovalHandler::with_status("approved");
1270 let module = MockModule::echo();
1271 let annotations = ModuleAnnotations {
1272 requires_approval: true,
1273 ..Default::default()
1274 };
1275 let mut executor = build_executor_with_module(module, annotations);
1276 executor.set_approval_handler(Box::new(handler));
1277
1278 let inputs = json!({"data": "hello"});
1280 let result = executor.call("test_mod", inputs, None, None).await;
1281 assert!(result.is_ok());
1282 }
1283
1284 #[tokio::test]
1285 async fn test_validate_notes_requires_approval_without_gating() {
1286 let handler = MockApprovalHandler::with_status("timeout");
1289 let module = MockModule::echo();
1290 let annotations = ModuleAnnotations {
1291 requires_approval: true,
1292 ..Default::default()
1293 };
1294 let mut executor = build_executor_with_module(module, annotations);
1295 executor.set_approval_handler(Box::new(handler));
1296
1297 let result = executor
1298 .validate("test_mod", &json!({}), None)
1299 .await
1300 .unwrap();
1301 assert!(result.valid);
1302 assert!(result.requires_approval);
1303 }
1304}