use std::collections::HashMap;
use std::sync::Arc;
use symbi_runtime::integrations::schemapin::VerificationResult;
use symbi_runtime::integrations::{
DefaultToolInvocationEnforcer, EnforcementPolicy, InvocationContext,
InvocationEnforcementConfig, LocalKeyStore, McpClient, McpTool, MockMcpClient,
MockNativeSchemaPinClient, SecureMcpClient, ToolInvocationEnforcer, ToolInvocationError,
ToolProvider, VerificationStatus,
};
use symbi_runtime::types::AgentId;
fn create_test_tool_with_status(name: &str, status: VerificationStatus) -> McpTool {
McpTool {
name: name.to_string(),
description: format!("Test tool: {}", name),
schema: serde_json::json!({
"type": "object",
"properties": {
"input": {"type": "string"}
}
}),
provider: ToolProvider {
identifier: "test.example.com".to_string(),
name: "Test Provider".to_string(),
public_key_url: "https://test.example.com/pubkey".to_string(),
version: Some("1.0.0".to_string()),
},
verification_status: status,
metadata: None,
}
}
fn create_verified_tool(name: &str) -> McpTool {
create_test_tool_with_status(
name,
VerificationStatus::Verified {
result: Box::new(VerificationResult {
success: true,
message: "Test verification successful".to_string(),
schema_hash: Some("test_hash".to_string()),
public_key_url: Some("https://test.example.com/pubkey".to_string()),
signature: None,
metadata: None,
timestamp: Some("2024-01-01T00:00:00Z".to_string()),
}),
verified_at: "2024-01-01T00:00:00Z".to_string(),
},
)
}
fn create_failed_tool(name: &str) -> McpTool {
create_test_tool_with_status(
name,
VerificationStatus::Failed {
reason: "Test verification failed".to_string(),
failed_at: "2024-01-01T00:00:00Z".to_string(),
},
)
}
fn create_pending_tool(name: &str) -> McpTool {
create_test_tool_with_status(name, VerificationStatus::Pending)
}
fn create_skipped_tool(name: &str) -> McpTool {
create_test_tool_with_status(
name,
VerificationStatus::Skipped {
reason: "Test skipped".to_string(),
},
)
}
fn create_test_context(tool_name: &str) -> InvocationContext {
InvocationContext {
agent_id: AgentId::new(),
tool_name: tool_name.to_string(),
arguments: serde_json::json!({"input": "test"}),
timestamp: chrono::Utc::now(),
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_strict_mode_allows_verified_tools() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_verified_tool("verified_tool");
let context = create_test_context("verified_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::Allow
));
}
#[tokio::test]
async fn test_strict_mode_blocks_unverified_tools() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_pending_tool("pending_tool");
let context = create_test_context("pending_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::Block { .. }
));
}
#[tokio::test]
async fn test_strict_mode_blocks_failed_tools() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_failed_tool("failed_tool");
let context = create_test_context("failed_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::Block { .. }
));
}
#[tokio::test]
async fn test_permissive_mode_allows_with_warnings() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Permissive,
block_pending_verification: false,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_pending_tool("pending_tool");
let context = create_test_context("pending_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::AllowWithWarnings { .. }
));
}
#[tokio::test]
async fn test_development_mode_allows_skipped_tools() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Development,
allow_skipped_in_dev: true,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_skipped_tool("skipped_tool");
let context = create_test_context("skipped_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::AllowWithWarnings { .. }
));
}
#[tokio::test]
async fn test_disabled_mode_allows_all_tools() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Disabled,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_failed_tool("failed_tool");
let context = create_test_context("failed_tool");
let decision = enforcer
.check_invocation_allowed(&tool, &context)
.await
.unwrap();
assert!(matches!(
decision,
symbi_runtime::integrations::EnforcementDecision::Allow
));
}
#[tokio::test]
async fn test_execute_tool_blocks_unverified_in_strict_mode() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_pending_tool("pending_tool");
let context = create_test_context("pending_tool");
let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ToolInvocationError::InvocationBlocked { .. }
));
}
#[tokio::test]
async fn test_execute_tool_succeeds_with_verified_tool() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_verified_tool("verified_tool");
let context = create_test_context("verified_tool");
let result = enforcer
.execute_tool_with_enforcement(&tool, context)
.await
.unwrap();
assert!(result.success);
assert!(result.warnings.is_empty());
}
#[tokio::test]
async fn test_execute_tool_succeeds_with_warnings_in_permissive_mode() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Permissive,
block_pending_verification: false,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_pending_tool("pending_tool");
let context = create_test_context("pending_tool");
let result = enforcer
.execute_tool_with_enforcement(&tool, context)
.await
.unwrap();
assert!(result.success);
assert!(!result.warnings.is_empty());
}
#[tokio::test]
async fn test_mcp_client_integration_strict_mode() {
let client = MockMcpClient::new_success();
let tool = create_verified_tool("test_tool");
let _ = client.discover_tool(tool.clone()).await.unwrap();
let context = create_test_context("test_tool");
let result = client
.invoke_tool("test_tool", serde_json::json!({"input": "test"}), context)
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_mcp_client_integration_blocks_unverified() {
let client = MockMcpClient::new_failure();
let tool = create_pending_tool("test_tool");
let result = client.discover_tool(tool).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_secure_mcp_client_with_enforcer() {
let config = symbi_runtime::integrations::McpClientConfig::default();
let schema_pin = Arc::new(MockNativeSchemaPinClient::new_success());
let key_store = Arc::new(LocalKeyStore::new().unwrap());
let client = SecureMcpClient::new(config, schema_pin, key_store);
let tool = create_test_tool_with_status("test_tool", VerificationStatus::Pending);
let event = client.discover_tool(tool).await.unwrap();
assert!(event.tool.verification_status.is_verified());
let context = create_test_context("test_tool");
let result = client
.invoke_tool("test_tool", serde_json::json!({"input": "test"}), context)
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_enforcement_policy_configuration() {
let mut config = InvocationEnforcementConfig::default();
assert_eq!(config.policy, EnforcementPolicy::Strict);
config.policy = EnforcementPolicy::Permissive;
let enforcer = DefaultToolInvocationEnforcer::with_config(config.clone());
assert_eq!(
enforcer.get_enforcement_config().policy,
EnforcementPolicy::Permissive
);
let mut mutable_enforcer = DefaultToolInvocationEnforcer::new();
mutable_enforcer.update_enforcement_config(config);
assert_eq!(
mutable_enforcer.get_enforcement_config().policy,
EnforcementPolicy::Permissive
);
}
#[tokio::test]
async fn test_error_message_clarity() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Strict,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_failed_tool("failed_tool");
let context = create_test_context("failed_tool");
let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
assert!(result.is_err());
if let Err(ToolInvocationError::InvocationBlocked { tool_name, reason }) = result {
assert_eq!(tool_name, "failed_tool");
assert!(reason.contains("verification failed"));
} else {
panic!("Expected InvocationBlocked error");
}
}
#[tokio::test]
async fn test_warning_escalation() {
let config = InvocationEnforcementConfig {
policy: EnforcementPolicy::Permissive,
block_pending_verification: false,
max_warnings_before_escalation: 2,
..Default::default()
};
let enforcer = DefaultToolInvocationEnforcer::with_config(config);
let tool = create_pending_tool("pending_tool");
let context1 = create_test_context("pending_tool");
let result1 = enforcer
.execute_tool_with_enforcement(&tool, context1)
.await
.unwrap();
assert!(result1.success);
assert!(!result1.warnings.is_empty());
assert!(!result1.metadata.contains_key("escalated"));
let context2 = create_test_context("pending_tool");
let result2 = enforcer
.execute_tool_with_enforcement(&tool, context2)
.await
.unwrap();
assert!(result2.success);
assert!(!result2.warnings.is_empty());
assert!(result2.metadata.contains_key("escalated"));
}