use async_trait::async_trait;
use std::collections::HashSet;
use crate::interpreter::{ExecResult, OutputFormat};
use crate::validator::{IssueCode, Severity, ValidationIssue};
pub use kaish_types::{ParamSchema, ToolArgs, ToolSchema};
use super::context::ExecContext;
use crate::ast::Value;
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn schema(&self) -> ToolSchema;
async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
validate_against_schema(args, &self.schema())
}
}
pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
let mut issues = Vec::new();
for (i, param) in schema.params.iter().enumerate() {
if !param.required {
continue;
}
let has_named = args.named.contains_key(¶m.name);
let has_positional = args.positional.len() > i;
let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
if !has_named && !has_positional && !has_flag {
let code = IssueCode::MissingRequiredArg;
issues.push(ValidationIssue {
severity: code.default_severity(),
code,
message: format!("required parameter '{}' not provided", param.name),
span: None,
suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
});
}
}
let known_flags: HashSet<&str> = schema
.params
.iter()
.filter(|p| p.param_type == "bool")
.flat_map(|p| {
std::iter::once(p.name.as_str())
.chain(p.aliases.iter().map(|a| a.as_str()))
})
.collect();
for flag in &args.flags {
let flag_name = flag.trim_start_matches('-');
if is_global_output_flag(flag_name) {
continue;
}
if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
if !matches_alias {
issues.push(ValidationIssue {
severity: Severity::Warning,
code: IssueCode::UnknownFlag,
message: format!("unknown flag '{}'", flag),
span: None,
suggestion: None,
});
}
}
}
for (key, value) in &args.named {
if let Some(param) = schema.params.iter().find(|p| &p.name == key)
&& let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
issues.push(issue);
}
}
for (i, value) in args.positional.iter().enumerate() {
if let Some(param) = schema.params.get(i)
&& let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
issues.push(issue);
}
}
issues
}
const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
("json", OutputFormat::Json),
];
pub fn is_global_output_flag(name: &str) -> bool {
GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
}
pub fn extract_output_format(
args: &mut ToolArgs,
schema: Option<&ToolSchema>,
) -> Option<OutputFormat> {
let _schema = schema?;
for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
if args.flags.remove(*flag_name) {
return Some(*format);
}
}
None
}
fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
let compatible = match expected_type {
"any" => true,
"string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
"float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
"bool" => matches!(value, Value::Bool(_) | Value::String(_)),
"array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
if compatible {
None
} else {
let code = IssueCode::InvalidArgType;
Some(ValidationIssue {
severity: code.default_severity(),
code,
message: format!(
"argument '{}' has type {:?}, expected {}",
name, value, expected_type
),
span: None,
suggestion: None,
})
}
}