use super::types::{InputType, OutputType};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum ReasoningStrategy {
#[default]
Direct,
ReAct {
max_iterations: usize,
},
ChainOfThought,
TreeOfThought {
branching_factor: usize,
},
Custom(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentCapabilities {
pub tags: HashSet<String>,
pub input_types: HashSet<InputType>,
pub output_types: HashSet<OutputType>,
pub max_context_length: Option<usize>,
pub reasoning_strategies: Vec<ReasoningStrategy>,
pub supports_streaming: bool,
pub supports_conversation: bool,
pub supports_tools: bool,
pub supports_coordination: bool,
pub custom: HashMap<String, serde_json::Value>,
}
impl AgentCapabilities {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> AgentCapabilitiesBuilder {
AgentCapabilitiesBuilder::default()
}
pub fn has_tag(&self, tag: &str) -> bool {
self.tags.contains(tag)
}
pub fn supports_input(&self, input_type: &InputType) -> bool {
self.input_types.contains(input_type)
}
pub fn supports_output(&self, output_type: &OutputType) -> bool {
self.output_types.contains(output_type)
}
pub fn matches(&self, requirements: &AgentRequirements) -> bool {
if !requirements
.required_tags
.iter()
.all(|t| self.tags.contains(t))
{
return false;
}
if !requirements
.input_types
.iter()
.all(|t| self.input_types.contains(t))
{
return false;
}
if !requirements
.output_types
.iter()
.all(|t| self.output_types.contains(t))
{
return false;
}
if requirements.requires_streaming && !self.supports_streaming {
return false;
}
if requirements.requires_tools && !self.supports_tools {
return false;
}
if requirements.requires_conversation && !self.supports_conversation {
return false;
}
if requirements.requires_coordination && !self.supports_coordination {
return false;
}
true
}
pub fn match_score(&self, requirements: &AgentRequirements) -> f64 {
if !self.matches(requirements) {
return 0.0;
}
let mut score = 0.0;
let mut weight = 0.0;
weight += 1.0;
if !requirements.required_tags.is_empty() {
let matched = requirements
.required_tags
.iter()
.filter(|t| self.tags.contains(*t))
.count();
score += matched as f64 / requirements.required_tags.len() as f64;
} else {
score += 1.0;
}
if !requirements.preferred_tags.is_empty() {
weight += 0.5;
let matched = requirements
.preferred_tags
.iter()
.filter(|t| self.tags.contains(*t))
.count();
score += 0.5 * (matched as f64 / requirements.preferred_tags.len() as f64);
}
if self.supports_streaming {
score += 0.1;
weight += 0.1;
}
if self.supports_tools {
score += 0.1;
weight += 0.1;
}
score / weight
}
}
#[derive(Debug, Default)]
pub struct AgentCapabilitiesBuilder {
capabilities: AgentCapabilities,
}
impl AgentCapabilitiesBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.capabilities.tags.insert(tag.into());
self
}
pub fn with_tag(self, tag: impl Into<String>) -> Self {
self.tag(tag)
}
pub fn tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
for tag in tags {
self.capabilities.tags.insert(tag.into());
}
self
}
pub fn input_type(mut self, input_type: InputType) -> Self {
self.capabilities.input_types.insert(input_type);
self
}
pub fn with_input_type(self, input_type: InputType) -> Self {
self.input_type(input_type)
}
pub fn output_type(mut self, output_type: OutputType) -> Self {
self.capabilities.output_types.insert(output_type);
self
}
pub fn with_output_type(self, output_type: OutputType) -> Self {
self.output_type(output_type)
}
pub fn max_context_length(mut self, length: usize) -> Self {
self.capabilities.max_context_length = Some(length);
self
}
pub fn reasoning_strategy(mut self, strategy: ReasoningStrategy) -> Self {
self.capabilities.reasoning_strategies.push(strategy);
self
}
pub fn with_reasoning_strategy(self, strategy: ReasoningStrategy) -> Self {
self.reasoning_strategy(strategy)
}
pub fn supports_streaming(mut self, supports: bool) -> Self {
self.capabilities.supports_streaming = supports;
self
}
pub fn supports_conversation(mut self, supports: bool) -> Self {
self.capabilities.supports_conversation = supports;
self
}
pub fn supports_tools(mut self, supports: bool) -> Self {
self.capabilities.supports_tools = supports;
self
}
pub fn supports_coordination(mut self, supports: bool) -> Self {
self.capabilities.supports_coordination = supports;
self
}
pub fn custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.capabilities.custom.insert(key.into(), value);
self
}
pub fn build(self) -> AgentCapabilities {
self.capabilities
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentRequirements {
pub required_tags: HashSet<String>,
pub preferred_tags: HashSet<String>,
pub input_types: HashSet<InputType>,
pub output_types: HashSet<OutputType>,
pub requires_streaming: bool,
pub requires_tools: bool,
pub requires_conversation: bool,
pub requires_coordination: bool,
}
impl AgentRequirements {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> AgentRequirementsBuilder {
AgentRequirementsBuilder::default()
}
pub fn matches(&self, capabilities: &AgentCapabilities) -> bool {
for tag in &self.required_tags {
if !capabilities.tags.contains(tag) {
return false;
}
}
for input_type in &self.input_types {
if !capabilities.input_types.contains(input_type) {
return false;
}
}
for output_type in &self.output_types {
if !capabilities.output_types.contains(output_type) {
return false;
}
}
if self.requires_streaming && !capabilities.supports_streaming {
return false;
}
if self.requires_tools && !capabilities.supports_tools {
return false;
}
if self.requires_conversation && !capabilities.supports_conversation {
return false;
}
if self.requires_coordination && !capabilities.supports_coordination {
return false;
}
true
}
pub fn score(&self, capabilities: &AgentCapabilities) -> f32 {
if !self.matches(capabilities) {
return 0.0;
}
let mut score = 1.0;
let preferred_count = self
.preferred_tags
.iter()
.filter(|tag| capabilities.tags.contains(*tag))
.count();
if !self.preferred_tags.is_empty() {
score += (preferred_count as f32) / (self.preferred_tags.len() as f32);
}
score
}
}
#[derive(Debug, Default)]
pub struct AgentRequirementsBuilder {
requirements: AgentRequirements,
}
impl AgentRequirementsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn require_tag(mut self, tag: impl Into<String>) -> Self {
self.requirements.required_tags.insert(tag.into());
self
}
pub fn prefer_tag(mut self, tag: impl Into<String>) -> Self {
self.requirements.preferred_tags.insert(tag.into());
self
}
pub fn require_input(mut self, input_type: InputType) -> Self {
self.requirements.input_types.insert(input_type);
self
}
pub fn require_output(mut self, output_type: OutputType) -> Self {
self.requirements.output_types.insert(output_type);
self
}
pub fn require_streaming(mut self) -> Self {
self.requirements.requires_streaming = true;
self
}
pub fn require_tools(mut self) -> Self {
self.requirements.requires_tools = true;
self
}
pub fn require_conversation(mut self) -> Self {
self.requirements.requires_conversation = true;
self
}
pub fn require_coordination(mut self) -> Self {
self.requirements.requires_coordination = true;
self
}
pub fn build(self) -> AgentRequirements {
self.requirements
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capabilities_builder() {
let caps = AgentCapabilities::builder()
.tag("llm")
.tag("coding")
.input_type(InputType::Text)
.output_type(OutputType::Text)
.supports_streaming(true)
.supports_tools(true)
.build();
assert!(caps.has_tag("llm"));
assert!(caps.has_tag("coding"));
assert!(caps.supports_input(&InputType::Text));
assert!(caps.supports_streaming);
assert!(caps.supports_tools);
}
#[test]
fn test_capabilities_matching() {
let caps = AgentCapabilities::builder()
.tag("llm")
.tag("coding")
.input_type(InputType::Text)
.output_type(OutputType::Text)
.supports_tools(true)
.build();
let requirements = AgentRequirements::builder()
.require_tag("llm")
.require_input(InputType::Text)
.require_tools()
.build();
assert!(caps.matches(&requirements));
}
#[test]
fn test_capabilities_mismatch() {
let caps = AgentCapabilities::builder()
.tag("llm")
.input_type(InputType::Text)
.build();
let requirements = AgentRequirements::builder()
.require_tag("coding") .build();
assert!(!caps.matches(&requirements));
}
#[test]
fn test_match_score() {
let caps = AgentCapabilities::builder()
.tag("llm")
.tag("coding")
.tag("research")
.supports_streaming(true)
.supports_tools(true)
.build();
let requirements = AgentRequirements::builder()
.require_tag("llm")
.prefer_tag("coding")
.prefer_tag("research")
.build();
let score = caps.match_score(&requirements);
assert!(score > 0.8);
}
}