mofa_kernel/agent/
capabilities.rs1use super::types::{InputType, OutputType};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11pub enum ReasoningStrategy {
12 #[default]
14 Direct,
15 ReAct {
17 max_iterations: usize,
19 },
20 ChainOfThought,
22 TreeOfThought {
24 branching_factor: usize,
26 },
27 Custom(String),
29}
30
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
35pub struct AgentCapabilities {
36 pub tags: HashSet<String>,
38 pub input_types: HashSet<InputType>,
40 pub output_types: HashSet<OutputType>,
42 pub max_context_length: Option<usize>,
44 pub reasoning_strategies: Vec<ReasoningStrategy>,
46 pub supports_streaming: bool,
48 pub supports_conversation: bool,
50 pub supports_tools: bool,
52 pub supports_coordination: bool,
54 pub custom: HashMap<String, serde_json::Value>,
56}
57
58impl AgentCapabilities {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn builder() -> AgentCapabilitiesBuilder {
66 AgentCapabilitiesBuilder::default()
67 }
68
69 pub fn has_tag(&self, tag: &str) -> bool {
71 self.tags.contains(tag)
72 }
73
74 pub fn supports_input(&self, input_type: &InputType) -> bool {
76 self.input_types.contains(input_type)
77 }
78
79 pub fn supports_output(&self, output_type: &OutputType) -> bool {
81 self.output_types.contains(output_type)
82 }
83
84 pub fn matches(&self, requirements: &AgentRequirements) -> bool {
86 if !requirements
88 .required_tags
89 .iter()
90 .all(|t| self.tags.contains(t))
91 {
92 return false;
93 }
94
95 if !requirements
97 .input_types
98 .iter()
99 .all(|t| self.input_types.contains(t))
100 {
101 return false;
102 }
103
104 if !requirements
106 .output_types
107 .iter()
108 .all(|t| self.output_types.contains(t))
109 {
110 return false;
111 }
112
113 if requirements.requires_streaming && !self.supports_streaming {
115 return false;
116 }
117 if requirements.requires_tools && !self.supports_tools {
118 return false;
119 }
120 if requirements.requires_conversation && !self.supports_conversation {
121 return false;
122 }
123 if requirements.requires_coordination && !self.supports_coordination {
124 return false;
125 }
126
127 true
128 }
129
130 pub fn match_score(&self, requirements: &AgentRequirements) -> f64 {
132 if !self.matches(requirements) {
133 return 0.0;
134 }
135
136 let mut score = 0.0;
137 let mut weight = 0.0;
138
139 weight += 1.0;
141 if !requirements.required_tags.is_empty() {
142 let matched = requirements
143 .required_tags
144 .iter()
145 .filter(|t| self.tags.contains(*t))
146 .count();
147 score += matched as f64 / requirements.required_tags.len() as f64;
148 } else {
149 score += 1.0;
150 }
151
152 if !requirements.preferred_tags.is_empty() {
154 weight += 0.5;
155 let matched = requirements
156 .preferred_tags
157 .iter()
158 .filter(|t| self.tags.contains(*t))
159 .count();
160 score += 0.5 * (matched as f64 / requirements.preferred_tags.len() as f64);
161 }
162
163 if self.supports_streaming {
165 score += 0.1;
166 weight += 0.1;
167 }
168 if self.supports_tools {
169 score += 0.1;
170 weight += 0.1;
171 }
172
173 score / weight
174 }
175}
176
177#[derive(Debug, Default)]
179pub struct AgentCapabilitiesBuilder {
180 capabilities: AgentCapabilities,
181}
182
183impl AgentCapabilitiesBuilder {
184 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn tag(mut self, tag: impl Into<String>) -> Self {
191 self.capabilities.tags.insert(tag.into());
192 self
193 }
194
195 pub fn with_tag(self, tag: impl Into<String>) -> Self {
197 self.tag(tag)
198 }
199
200 pub fn tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
202 for tag in tags {
203 self.capabilities.tags.insert(tag.into());
204 }
205 self
206 }
207
208 pub fn input_type(mut self, input_type: InputType) -> Self {
210 self.capabilities.input_types.insert(input_type);
211 self
212 }
213
214 pub fn with_input_type(self, input_type: InputType) -> Self {
216 self.input_type(input_type)
217 }
218
219 pub fn output_type(mut self, output_type: OutputType) -> Self {
221 self.capabilities.output_types.insert(output_type);
222 self
223 }
224
225 pub fn with_output_type(self, output_type: OutputType) -> Self {
227 self.output_type(output_type)
228 }
229
230 pub fn max_context_length(mut self, length: usize) -> Self {
232 self.capabilities.max_context_length = Some(length);
233 self
234 }
235
236 pub fn reasoning_strategy(mut self, strategy: ReasoningStrategy) -> Self {
238 self.capabilities.reasoning_strategies.push(strategy);
239 self
240 }
241
242 pub fn with_reasoning_strategy(self, strategy: ReasoningStrategy) -> Self {
244 self.reasoning_strategy(strategy)
245 }
246
247 pub fn supports_streaming(mut self, supports: bool) -> Self {
249 self.capabilities.supports_streaming = supports;
250 self
251 }
252
253 pub fn supports_conversation(mut self, supports: bool) -> Self {
255 self.capabilities.supports_conversation = supports;
256 self
257 }
258
259 pub fn supports_tools(mut self, supports: bool) -> Self {
261 self.capabilities.supports_tools = supports;
262 self
263 }
264
265 pub fn supports_coordination(mut self, supports: bool) -> Self {
267 self.capabilities.supports_coordination = supports;
268 self
269 }
270
271 pub fn custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
273 self.capabilities.custom.insert(key.into(), value);
274 self
275 }
276
277 pub fn build(self) -> AgentCapabilities {
279 self.capabilities
280 }
281}
282
283#[derive(Debug, Clone, Default, Serialize, Deserialize)]
287pub struct AgentRequirements {
288 pub required_tags: HashSet<String>,
290 pub preferred_tags: HashSet<String>,
292 pub input_types: HashSet<InputType>,
294 pub output_types: HashSet<OutputType>,
296 pub requires_streaming: bool,
298 pub requires_tools: bool,
300 pub requires_conversation: bool,
302 pub requires_coordination: bool,
304}
305
306impl AgentRequirements {
307 pub fn new() -> Self {
309 Self::default()
310 }
311
312 pub fn builder() -> AgentRequirementsBuilder {
314 AgentRequirementsBuilder::default()
315 }
316
317 pub fn matches(&self, capabilities: &AgentCapabilities) -> bool {
319 for tag in &self.required_tags {
321 if !capabilities.tags.contains(tag) {
322 return false;
323 }
324 }
325
326 for input_type in &self.input_types {
328 if !capabilities.input_types.contains(input_type) {
329 return false;
330 }
331 }
332
333 for output_type in &self.output_types {
335 if !capabilities.output_types.contains(output_type) {
336 return false;
337 }
338 }
339
340 if self.requires_streaming && !capabilities.supports_streaming {
342 return false;
343 }
344
345 if self.requires_tools && !capabilities.supports_tools {
347 return false;
348 }
349
350 if self.requires_conversation && !capabilities.supports_conversation {
352 return false;
353 }
354
355 if self.requires_coordination && !capabilities.supports_coordination {
357 return false;
358 }
359
360 true
361 }
362
363 pub fn score(&self, capabilities: &AgentCapabilities) -> f32 {
365 if !self.matches(capabilities) {
366 return 0.0;
367 }
368
369 let mut score = 1.0;
370
371 let preferred_count = self
373 .preferred_tags
374 .iter()
375 .filter(|tag| capabilities.tags.contains(*tag))
376 .count();
377
378 if !self.preferred_tags.is_empty() {
379 score += (preferred_count as f32) / (self.preferred_tags.len() as f32);
380 }
381
382 score
383 }
384}
385
386#[derive(Debug, Default)]
388pub struct AgentRequirementsBuilder {
389 requirements: AgentRequirements,
390}
391
392impl AgentRequirementsBuilder {
393 pub fn new() -> Self {
395 Self::default()
396 }
397
398 pub fn require_tag(mut self, tag: impl Into<String>) -> Self {
400 self.requirements.required_tags.insert(tag.into());
401 self
402 }
403
404 pub fn prefer_tag(mut self, tag: impl Into<String>) -> Self {
406 self.requirements.preferred_tags.insert(tag.into());
407 self
408 }
409
410 pub fn require_input(mut self, input_type: InputType) -> Self {
412 self.requirements.input_types.insert(input_type);
413 self
414 }
415
416 pub fn require_output(mut self, output_type: OutputType) -> Self {
418 self.requirements.output_types.insert(output_type);
419 self
420 }
421
422 pub fn require_streaming(mut self) -> Self {
424 self.requirements.requires_streaming = true;
425 self
426 }
427
428 pub fn require_tools(mut self) -> Self {
430 self.requirements.requires_tools = true;
431 self
432 }
433
434 pub fn require_conversation(mut self) -> Self {
436 self.requirements.requires_conversation = true;
437 self
438 }
439
440 pub fn require_coordination(mut self) -> Self {
442 self.requirements.requires_coordination = true;
443 self
444 }
445
446 pub fn build(self) -> AgentRequirements {
448 self.requirements
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_capabilities_builder() {
458 let caps = AgentCapabilities::builder()
459 .tag("llm")
460 .tag("coding")
461 .input_type(InputType::Text)
462 .output_type(OutputType::Text)
463 .supports_streaming(true)
464 .supports_tools(true)
465 .build();
466
467 assert!(caps.has_tag("llm"));
468 assert!(caps.has_tag("coding"));
469 assert!(caps.supports_input(&InputType::Text));
470 assert!(caps.supports_streaming);
471 assert!(caps.supports_tools);
472 }
473
474 #[test]
475 fn test_capabilities_matching() {
476 let caps = AgentCapabilities::builder()
477 .tag("llm")
478 .tag("coding")
479 .input_type(InputType::Text)
480 .output_type(OutputType::Text)
481 .supports_tools(true)
482 .build();
483
484 let requirements = AgentRequirements::builder()
485 .require_tag("llm")
486 .require_input(InputType::Text)
487 .require_tools()
488 .build();
489
490 assert!(caps.matches(&requirements));
491 }
492
493 #[test]
494 fn test_capabilities_mismatch() {
495 let caps = AgentCapabilities::builder()
496 .tag("llm")
497 .input_type(InputType::Text)
498 .build();
499
500 let requirements = AgentRequirements::builder()
501 .require_tag("coding") .build();
503
504 assert!(!caps.matches(&requirements));
505 }
506
507 #[test]
508 fn test_match_score() {
509 let caps = AgentCapabilities::builder()
510 .tag("llm")
511 .tag("coding")
512 .tag("research")
513 .supports_streaming(true)
514 .supports_tools(true)
515 .build();
516
517 let requirements = AgentRequirements::builder()
518 .require_tag("llm")
519 .prefer_tag("coding")
520 .prefer_tag("research")
521 .build();
522
523 let score = caps.match_score(&requirements);
524 assert!(score > 0.8);
525 }
526}