use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolContract {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub output_schema: serde_json::Value,
pub require_proofguard: bool,
pub min_confidence: f64,
pub is_unsafe: bool,
pub required_capabilities: Vec<String>,
pub timeout_ms: u64,
}
impl Default for ToolContract {
fn default() -> Self {
Self {
name: String::new(),
description: String::new(),
input_schema: serde_json::json!({ "type": "object" }),
output_schema: serde_json::json!({ "type": "object" }),
require_proofguard: false,
min_confidence: 0.0,
is_unsafe: false,
required_capabilities: Vec::new(),
timeout_ms: 30000,
}
}
}
impl ToolContract {
pub fn builder(name: impl Into<String>) -> ToolContractBuilder {
ToolContractBuilder::new(name)
}
pub fn requires_capability(&self, cap: &str) -> bool {
self.required_capabilities.iter().any(|c| c == cap)
}
}
pub struct ToolContractBuilder {
contract: ToolContract,
}
impl ToolContractBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
contract: ToolContract {
name: name.into(),
..Default::default()
},
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.contract.description = desc.into();
self
}
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
self.contract.input_schema = schema;
self
}
pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
self.contract.output_schema = schema;
self
}
pub fn require_proofguard(mut self, required: bool) -> Self {
self.contract.require_proofguard = required;
self
}
pub fn min_confidence(mut self, threshold: f64) -> Self {
self.contract.min_confidence = threshold.clamp(0.0, 1.0);
self
}
pub fn unsafe_tool(mut self) -> Self {
self.contract.is_unsafe = true;
self
}
pub fn require_capability(mut self, cap: impl Into<String>) -> Self {
self.contract.required_capabilities.push(cap.into());
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.contract.timeout_ms = ms;
self
}
pub fn build(self) -> ToolContract {
self.contract
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
pub valid: bool,
pub errors: Vec<ValidationError>,
pub validated_data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationError {
pub path: String,
pub message: String,
pub code: String,
}
pub struct ContractValidator {
contracts: HashMap<String, ToolContract>,
strict_mode: bool,
}
impl ContractValidator {
pub fn new() -> Self {
Self {
contracts: HashMap::new(),
strict_mode: false,
}
}
pub fn strict(mut self) -> Self {
self.strict_mode = true;
self
}
pub fn register(&mut self, contract: ToolContract) {
self.contracts.insert(contract.name.clone(), contract);
}
pub fn get_contract(&self, name: &str) -> Option<&ToolContract> {
self.contracts.get(name)
}
pub fn validate_input(
&self,
tool_name: &str,
input: &serde_json::Value,
) -> Result<ValidationResult> {
let contract = match self.contracts.get(tool_name) {
Some(c) => c,
None => {
if self.strict_mode {
return Err(Error::Mcp(format!(
"No contract registered for tool: {}",
tool_name
)));
}
return Ok(ValidationResult {
valid: true,
errors: Vec::new(),
validated_data: Some(input.clone()),
});
}
};
self.validate_against_schema(input, &contract.input_schema)
}
pub fn validate_output(
&self,
tool_name: &str,
output: &serde_json::Value,
) -> Result<ValidationResult> {
let contract = match self.contracts.get(tool_name) {
Some(c) => c,
None => {
if self.strict_mode {
return Err(Error::Mcp(format!(
"No contract registered for tool: {}",
tool_name
)));
}
return Ok(ValidationResult {
valid: true,
errors: Vec::new(),
validated_data: Some(output.clone()),
});
}
};
self.validate_against_schema(output, &contract.output_schema)
}
pub fn requires_proofguard(&self, tool_name: &str) -> bool {
self.contracts
.get(tool_name)
.map(|c| c.require_proofguard)
.unwrap_or(false)
}
pub fn is_unsafe_tool(&self, tool_name: &str) -> bool {
self.contracts
.get(tool_name)
.map(|c| c.is_unsafe)
.unwrap_or(true) }
pub fn registered_tools(&self) -> Vec<&str> {
self.contracts.keys().map(|s| s.as_str()).collect()
}
fn validate_against_schema(
&self,
data: &serde_json::Value,
schema: &serde_json::Value,
) -> Result<ValidationResult> {
let compiled = match jsonschema::JSONSchema::compile(schema) {
Ok(v) => v,
Err(e) => {
return Ok(ValidationResult {
valid: false,
errors: vec![ValidationError {
path: "".to_string(),
message: format!("Invalid schema: {}", e),
code: "INVALID_SCHEMA".to_string(),
}],
validated_data: None,
});
}
};
let validation_result = compiled.validate(data);
match validation_result {
Ok(_) => Ok(ValidationResult {
valid: true,
errors: Vec::new(),
validated_data: Some(data.clone()),
}),
Err(errors) => {
let validation_errors: Vec<ValidationError> = errors
.map(|e| ValidationError {
path: e.instance_path.to_string(),
message: e.to_string(),
code: "VALIDATION_ERROR".to_string(),
})
.collect();
Ok(ValidationResult {
valid: false,
errors: validation_errors,
validated_data: None,
})
}
}
}
}
impl Default for ContractValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ToolVerificationStatus {
Verified,
PartiallyVerified,
Unverified,
}
pub fn get_tool_verification_status(
validator: &ContractValidator,
tool_name: &str,
) -> ToolVerificationStatus {
match validator.get_contract(tool_name) {
Some(contract) => {
if contract.require_proofguard {
ToolVerificationStatus::Verified
} else {
ToolVerificationStatus::PartiallyVerified
}
}
None => ToolVerificationStatus::Unverified,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_contract_builder() {
let contract = ToolContract::builder("search")
.description("Search for information")
.input_schema(json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
}))
.output_schema(json!({
"type": "object",
"properties": {
"results": { "type": "array" }
}
}))
.require_proofguard(true)
.min_confidence(0.8)
.require_capability("network")
.timeout_ms(60000)
.build();
assert_eq!(contract.name, "search");
assert!(contract.require_proofguard);
assert!((contract.min_confidence - 0.8).abs() < f64::EPSILON);
assert!(contract.requires_capability("network"));
}
#[test]
fn test_input_validation() {
let mut validator = ContractValidator::new();
validator.register(
ToolContract::builder("test_tool")
.input_schema(json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
}))
.build(),
);
let valid_input = json!({ "query": "test" });
let result = validator.validate_input("test_tool", &valid_input).unwrap();
assert!(result.valid);
let invalid_input = json!({ "other": "value" });
let result = validator
.validate_input("test_tool", &invalid_input)
.unwrap();
assert!(!result.valid);
assert!(!result.errors.is_empty());
}
#[test]
fn test_strict_mode() {
let validator = ContractValidator::new().strict();
let result = validator.validate_input("unknown_tool", &json!({}));
assert!(result.is_err());
}
#[test]
fn test_verification_status() {
let mut validator = ContractValidator::new();
validator.register(
ToolContract::builder("verified_tool")
.require_proofguard(true)
.build(),
);
validator.register(
ToolContract::builder("partial_tool")
.require_proofguard(false)
.build(),
);
assert_eq!(
get_tool_verification_status(&validator, "verified_tool"),
ToolVerificationStatus::Verified
);
assert_eq!(
get_tool_verification_status(&validator, "partial_tool"),
ToolVerificationStatus::PartiallyVerified
);
assert_eq!(
get_tool_verification_status(&validator, "unknown"),
ToolVerificationStatus::Unverified
);
}
}