use std::collections::HashSet;
use async_trait::async_trait;
use kaish_types::{ExecResult, ParamSchema, ToolArgs, ToolSchema, Value};
use crate::ctx::ToolCtx;
use crate::issue::{IssueCode, Severity, ValidationIssue};
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn schema(&self) -> ToolSchema;
async fn execute(&self, args: ToolArgs, ctx: &mut dyn ToolCtx) -> ExecResult;
fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
validate_against_schema(args, &self.schema())
}
}
pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
let mut issues = Vec::new();
let positional_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| p.positional).collect();
let flag_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| !p.positional).collect();
for (slot, param) in positional_params.iter().enumerate() {
if !param.required {
continue;
}
let has_positional = args.positional.len() > slot;
let has_named = args.named.contains_key(¶m.name);
if !has_positional && !has_named {
let code = IssueCode::MissingRequiredArg;
issues.push(ValidationIssue {
severity: code.default_severity(),
code,
message: format!("required parameter '{}' not provided", param.name),
span: None,
suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
});
}
}
for param in &flag_params {
if !param.required {
continue;
}
let has_named = args.named.contains_key(¶m.name);
let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
if !has_named && !has_flag {
let code = IssueCode::MissingRequiredArg;
issues.push(ValidationIssue {
severity: code.default_severity(),
code,
message: format!("required parameter '{}' not provided", param.name),
span: None,
suggestion: Some(format!("add --{} <value>", param.name)),
});
}
}
let known_flags: HashSet<&str> = flag_params
.iter()
.filter(|p| p.param_type == "bool")
.flat_map(|p| {
std::iter::once(p.name.as_str())
.chain(p.aliases.iter().map(|a| a.as_str()))
})
.collect();
for flag in &args.flags {
let flag_name = flag.trim_start_matches('-');
if is_global_output_flag(flag_name) {
continue;
}
if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
let matches_alias = flag_params.iter().any(|p| p.matches_flag(flag));
if !matches_alias {
issues.push(ValidationIssue {
severity: Severity::Warning,
code: IssueCode::UnknownFlag,
message: format!("unknown flag '{}'", flag),
span: None,
suggestion: None,
});
}
}
}
for (key, value) in &args.named {
if let Some(param) = schema.params.iter().find(|p| &p.name == key)
&& let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
issues.push(issue);
}
}
for (slot, value) in args.positional.iter().enumerate() {
if let Some(param) = positional_params.get(slot)
&& let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
issues.push(issue);
}
}
issues
}
pub fn is_global_output_flag(name: &str) -> bool {
name == "json"
}
fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
let compatible = match expected_type {
"any" => true,
"string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
"float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
"bool" => matches!(value, Value::Bool(_) | Value::String(_)),
"array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
if compatible {
None
} else {
let code = IssueCode::InvalidArgType;
Some(ValidationIssue {
severity: code.default_severity(),
code,
message: format!(
"argument '{}' has type {:?}, expected {}",
name, value, expected_type
),
span: None,
suggestion: None,
})
}
}
#[cfg(test)]
mod validate_tests {
use super::*;
use kaish_types::{ParamSchema, ToolSchema};
fn schema_with_positionals_after_flags() -> ToolSchema {
ToolSchema::new("demo", "demo")
.param(
ParamSchema::new("verbose", "bool")
.with_default(Some(Value::Bool(false)))
.with_aliases(["v"]),
)
.param(ParamSchema::new("lines", "int").with_aliases(["n"]))
.param(
ParamSchema::new("path", "string")
.with_required(true)
.positional(),
)
}
#[test]
fn required_positional_satisfied_when_positional_sits_after_flags() {
let schema = schema_with_positionals_after_flags();
let mut args = ToolArgs::new();
args.positional.push(Value::String("foo.txt".into()));
let issues = validate_against_schema(&args, &schema);
assert!(
!issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
"required positional should be satisfied by positional[0]; got {:?}",
issues
);
}
#[test]
fn required_positional_missing_when_no_positional_given() {
let schema = schema_with_positionals_after_flags();
let mut args = ToolArgs::new();
args.flags.insert("verbose".into());
let issues = validate_against_schema(&args, &schema);
assert!(
issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
"missing required positional should error; got {:?}",
issues
);
}
#[test]
fn positional_type_check_targets_positional_slot_not_struct_index() {
let mut schema = ToolSchema::new("demo", "demo");
schema = schema
.param(ParamSchema::new("verbose", "bool").with_default(Some(Value::Bool(false))))
.param(
ParamSchema::new("count", "int")
.with_required(true)
.positional(),
)
.param(
ParamSchema::new("name", "string")
.with_required(true)
.positional(),
);
let mut args = ToolArgs::new();
args.positional.push(Value::Int(5));
args.positional.push(Value::String("widget".into()));
let issues = validate_against_schema(&args, &schema);
assert!(
!issues.iter().any(|i| matches!(i.code, IssueCode::InvalidArgType)),
"int->int and string->string slots should validate clean; got {:?}",
issues
);
}
#[test]
fn required_flag_still_errors_when_missing() {
let schema = ToolSchema::new("demo", "demo").param(
ParamSchema::new("output", "string")
.with_required(true)
.with_aliases(["o"]),
);
let args = ToolArgs::new();
let issues = validate_against_schema(&args, &schema);
assert!(
issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
"required flag should error when missing; got {:?}",
issues
);
}
}