use std::fmt;
use std::time::Duration;
use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::context::JobContext;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApprovalRequirement {
Never,
UnlessAutoApproved,
Always,
}
impl ApprovalRequirement {
pub fn is_required(&self) -> bool {
!matches!(self, Self::Never)
}
}
#[derive(Debug, Clone)]
pub enum ApprovalContext {
Autonomous {
allowed_tools: std::collections::HashSet<String>,
},
}
impl ApprovalContext {
pub fn autonomous() -> Self {
Self::Autonomous {
allowed_tools: std::collections::HashSet::new(),
}
}
pub fn autonomous_with_tools(tools: impl IntoIterator<Item = String>) -> Self {
Self::Autonomous {
allowed_tools: tools.into_iter().collect(),
}
}
pub fn is_blocked(&self, tool_name: &str, _requirement: ApprovalRequirement) -> bool {
match self {
Self::Autonomous { allowed_tools } => !allowed_tools.contains(tool_name),
}
}
pub fn is_blocked_or_default(
context: &Option<Self>,
tool_name: &str,
requirement: ApprovalRequirement,
) -> bool {
match context {
Some(ctx) => ctx.is_blocked(tool_name, requirement),
None => requirement.is_required(),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolRateLimitConfig {
pub requests_per_minute: u32,
pub requests_per_hour: u32,
}
impl ToolRateLimitConfig {
pub fn new(requests_per_minute: u32, requests_per_hour: u32) -> Self {
Self {
requests_per_minute,
requests_per_hour,
}
}
}
impl Default for ToolRateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 60,
requests_per_hour: 1000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum RiskLevel {
Low,
Medium,
High,
}
impl fmt::Display for RiskLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Low => f.write_str("low"),
Self::Medium => f.write_str("medium"),
Self::High => f.write_str("high"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ToolDomain {
Orchestrator,
Container,
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Timeout after {0:?}")]
Timeout(Duration),
#[error("Not authorized: {0}")]
NotAuthorized(String),
#[error("Rate limited, retry after {0:?}")]
RateLimited(Option<Duration>),
#[error("External service error: {0}")]
ExternalService(String),
#[error("Sandbox error: {0}")]
Sandbox(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
pub result: serde_json::Value,
pub cost: Option<Decimal>,
pub duration: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<String>,
}
impl ToolOutput {
pub fn success(result: serde_json::Value, duration: Duration) -> Self {
Self {
result,
cost: None,
duration,
raw: None,
}
}
pub fn text(text: impl Into<String>, duration: Duration) -> Self {
Self {
result: serde_json::Value::String(text.into()),
cost: None,
duration,
raw: None,
}
}
pub fn with_cost(mut self, cost: Decimal) -> Self {
self.cost = Some(cost);
self
}
pub fn with_raw(mut self, raw: impl Into<String>) -> Self {
self.raw = Some(raw.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ToolSchema {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
}
}
pub fn with_parameters(mut self, parameters: serde_json::Value) -> Self {
self.parameters = parameters;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ToolDiscoverySummary {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub always_required: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub conditional_requirements: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub notes: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub examples: Vec<serde_json::Value>,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(
&self,
params: serde_json::Value,
ctx: &JobContext,
) -> Result<ToolOutput, ToolError>;
fn estimated_cost(&self, _params: &serde_json::Value) -> Option<Decimal> {
None
}
fn estimated_duration(&self, _params: &serde_json::Value) -> Option<Duration> {
None
}
fn requires_sanitization(&self) -> bool {
true
}
fn risk_level_for(&self, _params: &serde_json::Value) -> RiskLevel {
RiskLevel::Low
}
fn requires_approval(&self, _params: &serde_json::Value) -> ApprovalRequirement {
ApprovalRequirement::Never
}
fn execution_timeout(&self) -> Duration {
Duration::from_secs(60)
}
fn domain(&self) -> ToolDomain {
ToolDomain::Orchestrator
}
fn sensitive_params(&self) -> &[&str] {
&[]
}
fn rate_limit_config(&self) -> Option<ToolRateLimitConfig> {
None
}
fn webhook_capability(&self) -> Option<crate::tools::wasm::WebhookCapability> {
None
}
fn discovery_schema(&self) -> serde_json::Value {
self.parameters_schema()
}
fn discovery_summary(&self) -> Option<ToolDiscoverySummary> {
None
}
fn schema(&self) -> ToolSchema {
let parameters = self.parameters_schema();
let has_discovery_hint =
self.discovery_summary().is_some() || self.discovery_schema() != parameters;
let description = if has_discovery_hint {
format!(
"{} (call tool_info(name: \"{}\", detail: \"summary\") for rules/examples or detail: \"schema\" for the full discovery schema)",
self.description(),
self.name()
)
} else {
self.description().to_string()
};
ToolSchema {
name: self.name().to_string(),
description,
parameters,
}
}
}
pub fn require_str<'a>(params: &'a serde_json::Value, name: &str) -> Result<&'a str, ToolError> {
params
.get(name)
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters(format!("missing '{}' parameter", name)))
}
pub fn require_param<'a>(
params: &'a serde_json::Value,
name: &str,
) -> Result<&'a serde_json::Value, ToolError> {
params
.get(name)
.ok_or_else(|| ToolError::InvalidParameters(format!("missing '{}' parameter", name)))
}
pub fn redact_params(params: &serde_json::Value, sensitive: &[&str]) -> serde_json::Value {
if sensitive.is_empty() {
return params.clone();
}
let mut redacted = params.clone();
if let Some(obj) = redacted.as_object_mut() {
for key in sensitive {
if obj.contains_key(*key) {
obj.insert(
(*key).to_string(),
serde_json::Value::String("[REDACTED]".into()),
);
}
}
}
redacted
}
const MAX_SCHEMA_DEPTH: usize = 16;
fn has_object_combinator_variants(schema: &serde_json::Value) -> bool {
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array())
&& variants.iter().any(|v| {
v.get("type").and_then(|t| t.as_str()) == Some("object")
|| v.get("properties").is_some()
})
{
return true;
}
}
false
}
pub fn validate_tool_schema(schema: &serde_json::Value, path: &str) -> Vec<String> {
validate_tool_schema_inner(schema, path, 0)
}
fn validate_tool_schema_inner(schema: &serde_json::Value, path: &str, depth: usize) -> Vec<String> {
let mut errors = Vec::new();
if depth > MAX_SCHEMA_DEPTH {
errors.push(format!(
"{path}: schema nesting exceeds maximum depth of {MAX_SCHEMA_DEPTH}"
));
return errors;
}
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(val) = schema.get(key)
&& !val.is_array()
{
errors.push(format!("{path}: \"{key}\" must be an array"));
}
}
let has_combinators = has_object_combinator_variants(schema);
match schema.get("type").and_then(|t| t.as_str()) {
Some("object") => {}
Some(other) => {
errors.push(format!("{path}: expected type \"object\", got \"{other}\""));
return errors; }
None => {
if !has_combinators {
errors.push(format!("{path}: missing \"type\": \"object\""));
return errors;
}
}
}
for key in ["allOf", "oneOf", "anyOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array()) {
for (i, variant) in variants.iter().enumerate() {
if variant.get("type").and_then(|t| t.as_str()) == Some("object")
|| variant.get("properties").is_some()
{
let variant_path = format!("{path}.{key}[{i}]");
errors.extend(validate_tool_schema_inner(
variant,
&variant_path,
depth + 1,
));
}
}
}
}
let properties = match schema.get("properties").and_then(|p| p.as_object()) {
Some(p) => p,
None => {
if !has_combinators {
errors.push(format!("{path}: missing or non-object \"properties\""));
return errors;
}
if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
let mut merged_keys = std::collections::HashSet::new();
if let Some(all_of) = schema.get("allOf").and_then(|a| a.as_array()) {
for variant in all_of {
if let Some(props) = variant.get("properties").and_then(|p| p.as_object()) {
merged_keys.extend(props.keys().cloned());
}
}
}
for key in ["oneOf", "anyOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array()) {
for variant in variants {
if let Some(props) =
variant.get("properties").and_then(|p| p.as_object())
{
merged_keys.extend(props.keys().cloned());
}
}
}
}
for req in required {
if let Some(key) = req.as_str()
&& !merged_keys.contains(key)
{
errors.push(format!(
"{path}: required key \"{key}\" not found in any combinator variant properties"
));
}
}
}
return errors;
}
};
if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
for req in required {
if let Some(key) = req.as_str()
&& !properties.contains_key(key)
{
errors.push(format!(
"{path}: required key \"{key}\" not found in properties"
));
}
}
}
for (key, prop) in properties {
let prop_path = format!("{path}.{key}");
if let Some(prop_type) = prop.get("type").and_then(|t| t.as_str()) {
match prop_type {
"object" => {
errors.extend(validate_tool_schema_inner(prop, &prop_path, depth + 1));
}
"array" => {
if let Some(items) = prop.get("items") {
if items.get("type").and_then(|t| t.as_str()) == Some("object") {
errors.extend(validate_tool_schema_inner(
items,
&format!("{prop_path}.items"),
depth + 1,
));
}
} else {
errors.push(format!("{prop_path}: array property missing \"items\""));
}
}
_ => {}
}
}
}
errors
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::credentials::TEST_REDACT_SECRET;
#[derive(Debug)]
pub struct EchoTool;
#[async_trait]
impl Tool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echoes back the input message. Useful for testing."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The message to echo back"
}
},
"required": ["message"]
})
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
let message = require_str(¶ms, "message")?;
Ok(ToolOutput::text(message, Duration::from_millis(1)))
}
fn requires_sanitization(&self) -> bool {
false }
}
#[tokio::test]
async fn test_echo_tool() {
let tool = EchoTool;
let ctx = JobContext::default();
let result = tool
.execute(serde_json::json!({"message": "hello"}), &ctx)
.await
.unwrap();
assert_eq!(result.result, serde_json::json!("hello"));
}
#[test]
fn test_tool_schema() {
let tool = EchoTool;
let schema = tool.schema();
assert_eq!(schema.name, "echo");
assert!(!schema.description.is_empty());
}
#[test]
fn test_execution_timeout_default() {
let tool = EchoTool;
assert_eq!(tool.execution_timeout(), Duration::from_secs(60));
}
#[test]
fn test_require_str_present() {
let params = serde_json::json!({"name": "alice"});
assert_eq!(require_str(¶ms, "name").unwrap(), "alice");
}
#[test]
fn test_require_str_missing() {
let params = serde_json::json!({});
let err = require_str(¶ms, "name").unwrap_err();
assert!(err.to_string().contains("missing 'name'"));
}
#[test]
fn test_require_str_wrong_type() {
let params = serde_json::json!({"name": 42});
let err = require_str(¶ms, "name").unwrap_err();
assert!(err.to_string().contains("missing 'name'"));
}
#[test]
fn test_require_param_present() {
let params = serde_json::json!({"data": [1, 2, 3]});
assert_eq!(
require_param(¶ms, "data").unwrap(),
&serde_json::json!([1, 2, 3])
);
}
#[test]
fn test_require_param_missing() {
let params = serde_json::json!({});
let err = require_param(¶ms, "data").unwrap_err();
assert!(err.to_string().contains("missing 'data'"));
}
#[test]
fn test_requires_approval_default() {
let tool = EchoTool;
assert_eq!(
tool.requires_approval(&serde_json::json!({"message": "hi"})),
ApprovalRequirement::Never
);
assert!(!ApprovalRequirement::Never.is_required());
assert!(ApprovalRequirement::UnlessAutoApproved.is_required());
assert!(ApprovalRequirement::Always.is_required());
}
#[test]
fn test_redact_params_replaces_sensitive_key() {
let params = serde_json::json!({"name": "openai_key", "value": TEST_REDACT_SECRET});
let redacted = redact_params(¶ms, &["value"]);
assert_eq!(redacted["name"], "openai_key");
assert_eq!(redacted["value"], "[REDACTED]");
assert_eq!(params["value"], TEST_REDACT_SECRET);
}
#[test]
fn test_redact_params_empty_sensitive_is_noop() {
let params = serde_json::json!({"name": "key", "value": "secret"});
let redacted = redact_params(¶ms, &[]);
assert_eq!(redacted, params);
}
#[test]
fn test_redact_params_missing_key_is_noop() {
let params = serde_json::json!({"name": "key"});
let redacted = redact_params(¶ms, &["value"]);
assert_eq!(redacted, params);
}
#[test]
fn test_redact_params_non_object_is_passthrough() {
let params = serde_json::json!("just a string");
let redacted = redact_params(¶ms, &["value"]);
assert_eq!(redacted, params);
}
#[test]
fn test_validate_schema_valid() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string", "description": "A name" }
},
"required": ["name"]
});
let errors = validate_tool_schema(&schema, "test");
assert!(errors.is_empty(), "unexpected errors: {errors:?}");
}
#[test]
fn test_validate_schema_missing_type() {
let schema = serde_json::json!({
"properties": {
"name": { "type": "string" }
}
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("missing \"type\": \"object\""));
}
#[test]
fn test_validate_schema_wrong_type() {
let schema = serde_json::json!({
"type": "string"
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("expected type \"object\""));
}
#[test]
fn test_validate_schema_required_not_in_properties() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name", "age"]
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("\"age\" not found in properties"));
}
#[test]
fn test_validate_schema_nested_object() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"key": { "type": "string" }
},
"required": ["key", "missing"]
}
}
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("test.config"));
assert!(errors[0].contains("\"missing\" not found"));
}
#[test]
fn test_validate_schema_array_missing_items() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"tags": { "type": "array", "description": "Tags" }
}
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("array property missing \"items\""));
}
#[test]
fn test_validate_schema_array_with_items_ok() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": { "type": "string" }
}
}
});
let errors = validate_tool_schema(&schema, "test");
assert!(errors.is_empty(), "unexpected errors: {errors:?}");
}
#[test]
fn test_validate_schema_freeform_property_allowed() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"data": { "description": "Any JSON value" }
},
"required": ["data"]
});
let errors = validate_tool_schema(&schema, "test");
assert!(
errors.is_empty(),
"freeform property should be allowed: {errors:?}"
);
}
#[test]
fn test_validate_schema_nested_array_items_object() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"headers": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name", "value"]
}
}
}
});
let errors = validate_tool_schema(&schema, "test");
assert!(errors.is_empty(), "unexpected errors: {errors:?}");
}
#[test]
fn test_validate_schema_nested_array_items_object_bad() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"headers": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name", "missing_field"]
}
}
}
});
let errors = validate_tool_schema(&schema, "test");
assert_eq!(errors.len(), 1);
assert!(errors[0].contains("headers.items"));
assert!(errors[0].contains("\"missing_field\""));
}
#[test]
fn test_validate_schema_depth_limit() {
let mut schema = serde_json::json!({
"type": "object",
"properties": {
"leaf": { "type": "string" }
}
});
for _ in 0..20 {
schema = serde_json::json!({
"type": "object",
"properties": {
"nested": schema
}
});
}
let errors = validate_tool_schema(&schema, "test");
assert!(
errors.iter().any(|e| e.contains("maximum depth")),
"expected depth limit error, got: {errors:?}"
);
}
#[test]
fn test_approval_context_autonomous_blocks_tools_not_in_scope() {
let ctx = ApprovalContext::autonomous();
assert!(ctx.is_blocked("shell", ApprovalRequirement::Never));
assert!(ctx.is_blocked("shell", ApprovalRequirement::UnlessAutoApproved));
assert!(ctx.is_blocked("shell", ApprovalRequirement::Always));
}
#[test]
fn test_approval_context_autonomous_with_tools_allows_registered_name() {
let ctx =
ApprovalContext::autonomous_with_tools(["shell".to_string(), "message".to_string()]);
assert!(!ctx.is_blocked("shell", ApprovalRequirement::Never));
assert!(!ctx.is_blocked("shell", ApprovalRequirement::Always));
assert!(!ctx.is_blocked("message", ApprovalRequirement::Always));
assert!(ctx.is_blocked("http", ApprovalRequirement::Always));
}
#[test]
fn test_approval_context_blocks_never_when_not_in_scope() {
let ctx = ApprovalContext::autonomous();
assert!(ctx.is_blocked("any_tool", ApprovalRequirement::Never));
}
#[test]
fn test_is_blocked_or_default_with_none_uses_legacy() {
assert!(!ApprovalContext::is_blocked_or_default(
&None,
"any",
ApprovalRequirement::Never
));
assert!(ApprovalContext::is_blocked_or_default(
&None,
"any",
ApprovalRequirement::UnlessAutoApproved
));
assert!(ApprovalContext::is_blocked_or_default(
&None,
"any",
ApprovalRequirement::Always
));
}
#[test]
fn test_is_blocked_or_default_with_some_delegates() {
let ctx = Some(ApprovalContext::autonomous_with_tools(
["shell".to_string()],
));
assert!(!ApprovalContext::is_blocked_or_default(
&ctx,
"shell",
ApprovalRequirement::Always
));
assert!(ApprovalContext::is_blocked_or_default(
&ctx,
"other",
ApprovalRequirement::Always
));
assert!(ApprovalContext::is_blocked_or_default(
&ctx,
"any",
ApprovalRequirement::UnlessAutoApproved
));
}
}