Skip to main content

ares/llm/
capabilities.rs

1//! Model Capabilities for DIR-43
2//!
3//! This module provides capability detection and matching for LLM models,
4//! enabling intelligent model selection based on task requirements.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ares_server::llm::{ModelCapabilities, CapabilityRequirements, ProviderRegistry};
10//!
11//! // Define what capabilities the task needs
12//! let requirements = CapabilityRequirements::builder()
13//!     .requires_tools()
14//!     .min_context_window(32_000)
15//!     .build();
16//!
17//! // Find a matching model
18//! let model = registry.find_model(&requirements)?;
19//! ```
20
21use serde::{Deserialize, Serialize};
22use std::collections::HashSet;
23
24/// Capabilities that an LLM model may support.
25///
26/// These are used for intelligent model selection based on task requirements.
27#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
28pub struct ModelCapabilities {
29    /// Whether the model supports tool/function calling
30    #[serde(default)]
31    pub supports_tools: bool,
32
33    /// Whether the model supports vision/image inputs
34    #[serde(default)]
35    pub supports_vision: bool,
36
37    /// Whether the model supports audio inputs
38    #[serde(default)]
39    pub supports_audio: bool,
40
41    /// Whether the model supports structured output (JSON mode)
42    #[serde(default)]
43    pub supports_json_mode: bool,
44
45    /// Whether the model supports streaming responses
46    #[serde(default = "default_true")]
47    pub supports_streaming: bool,
48
49    /// Whether the model supports system prompts
50    #[serde(default = "default_true")]
51    pub supports_system_prompt: bool,
52
53    /// Maximum context window size in tokens
54    #[serde(default = "default_context_window")]
55    pub context_window: u32,
56
57    /// Maximum output tokens the model can generate
58    #[serde(default = "default_max_output")]
59    pub max_output_tokens: u32,
60
61    /// Whether the model has reasoning/chain-of-thought capabilities
62    #[serde(default)]
63    pub supports_reasoning: bool,
64
65    /// Whether the model supports code execution
66    #[serde(default)]
67    pub supports_code_execution: bool,
68
69    /// Cost tier: "free", "low", "medium", "high", "premium"
70    #[serde(default = "default_cost_tier")]
71    pub cost_tier: String,
72
73    /// Speed tier: "slow", "medium", "fast", "realtime"
74    #[serde(default = "default_speed_tier")]
75    pub speed_tier: String,
76
77    /// Quality tier: "basic", "standard", "high", "premium"
78    #[serde(default = "default_quality_tier")]
79    pub quality_tier: String,
80
81    /// Supported languages (empty = all languages)
82    #[serde(default)]
83    pub languages: HashSet<String>,
84
85    /// Model family (e.g., "gpt-4", "claude-3", "llama-3")
86    #[serde(default)]
87    pub family: Option<String>,
88
89    /// Whether this model is suitable for production use
90    #[serde(default = "default_true")]
91    pub production_ready: bool,
92
93    /// Whether this model runs locally (no API calls)
94    #[serde(default)]
95    pub is_local: bool,
96
97    /// Custom capability tags for extension
98    #[serde(default)]
99    pub tags: HashSet<String>,
100}
101
102fn default_true() -> bool {
103    true
104}
105
106fn default_context_window() -> u32 {
107    4096
108}
109
110fn default_max_output() -> u32 {
111    4096
112}
113
114fn default_cost_tier() -> String {
115    "medium".to_string()
116}
117
118fn default_speed_tier() -> String {
119    "medium".to_string()
120}
121
122fn default_quality_tier() -> String {
123    "standard".to_string()
124}
125
126impl ModelCapabilities {
127    /// Create capabilities for a known model by name.
128    ///
129    /// This provides sensible defaults for popular models.
130    pub fn for_model(model_name: &str) -> Self {
131        let model_lower = model_name.to_lowercase();
132
133        // Claude models
134        if model_lower.contains("claude-3-5-sonnet") || model_lower.contains("claude-sonnet-4") {
135            return Self {
136                supports_tools: true,
137                supports_vision: true,
138                supports_json_mode: true,
139                supports_streaming: true,
140                supports_system_prompt: true,
141                context_window: 200_000,
142                max_output_tokens: 8192,
143                supports_reasoning: true,
144                cost_tier: "high".to_string(),
145                speed_tier: "fast".to_string(),
146                quality_tier: "premium".to_string(),
147                family: Some("claude-3".to_string()),
148                production_ready: true,
149                ..Default::default()
150            };
151        }
152
153        if model_lower.contains("claude-3-opus") || model_lower.contains("claude-opus") {
154            return Self {
155                supports_tools: true,
156                supports_vision: true,
157                supports_json_mode: true,
158                supports_streaming: true,
159                supports_system_prompt: true,
160                context_window: 200_000,
161                max_output_tokens: 4096,
162                supports_reasoning: true,
163                cost_tier: "premium".to_string(),
164                speed_tier: "slow".to_string(),
165                quality_tier: "premium".to_string(),
166                family: Some("claude-3".to_string()),
167                production_ready: true,
168                ..Default::default()
169            };
170        }
171
172        if model_lower.contains("claude-3-haiku") || model_lower.contains("claude-haiku") {
173            return Self {
174                supports_tools: true,
175                supports_vision: true,
176                supports_json_mode: true,
177                supports_streaming: true,
178                supports_system_prompt: true,
179                context_window: 200_000,
180                max_output_tokens: 4096,
181                supports_reasoning: false,
182                cost_tier: "low".to_string(),
183                speed_tier: "realtime".to_string(),
184                quality_tier: "standard".to_string(),
185                family: Some("claude-3".to_string()),
186                production_ready: true,
187                ..Default::default()
188            };
189        }
190
191        // GPT models
192        if model_lower.contains("gpt-4o") {
193            return Self {
194                supports_tools: true,
195                supports_vision: true,
196                supports_audio: true,
197                supports_json_mode: true,
198                supports_streaming: true,
199                supports_system_prompt: true,
200                context_window: 128_000,
201                max_output_tokens: 16384,
202                supports_reasoning: true,
203                cost_tier: "high".to_string(),
204                speed_tier: "fast".to_string(),
205                quality_tier: "premium".to_string(),
206                family: Some("gpt-4".to_string()),
207                production_ready: true,
208                ..Default::default()
209            };
210        }
211
212        if model_lower.contains("gpt-4-turbo") || model_lower.contains("gpt-4-1106") {
213            return Self {
214                supports_tools: true,
215                supports_vision: true,
216                supports_json_mode: true,
217                supports_streaming: true,
218                supports_system_prompt: true,
219                context_window: 128_000,
220                max_output_tokens: 4096,
221                supports_reasoning: true,
222                cost_tier: "high".to_string(),
223                speed_tier: "medium".to_string(),
224                quality_tier: "premium".to_string(),
225                family: Some("gpt-4".to_string()),
226                production_ready: true,
227                ..Default::default()
228            };
229        }
230
231        if model_lower.contains("gpt-4") && !model_lower.contains("gpt-4o") {
232            return Self {
233                supports_tools: true,
234                supports_vision: false,
235                supports_json_mode: true,
236                supports_streaming: true,
237                supports_system_prompt: true,
238                context_window: 8192,
239                max_output_tokens: 4096,
240                supports_reasoning: true,
241                cost_tier: "high".to_string(),
242                speed_tier: "slow".to_string(),
243                quality_tier: "premium".to_string(),
244                family: Some("gpt-4".to_string()),
245                production_ready: true,
246                ..Default::default()
247            };
248        }
249
250        if model_lower.contains("gpt-3.5") {
251            return Self {
252                supports_tools: true,
253                supports_vision: false,
254                supports_json_mode: true,
255                supports_streaming: true,
256                supports_system_prompt: true,
257                context_window: 16385,
258                max_output_tokens: 4096,
259                supports_reasoning: false,
260                cost_tier: "low".to_string(),
261                speed_tier: "fast".to_string(),
262                quality_tier: "standard".to_string(),
263                family: Some("gpt-3.5".to_string()),
264                production_ready: true,
265                ..Default::default()
266            };
267        }
268
269        // Llama models
270        if model_lower.contains("llama-3.3") || model_lower.contains("llama-3.1") {
271            let context = if model_lower.contains("70b") {
272                128_000
273            } else {
274                131_072
275            };
276            return Self {
277                supports_tools: true,
278                supports_vision: false,
279                supports_json_mode: true,
280                supports_streaming: true,
281                supports_system_prompt: true,
282                context_window: context,
283                max_output_tokens: 4096,
284                supports_reasoning: model_lower.contains("70b"),
285                cost_tier: "free".to_string(),
286                speed_tier: "medium".to_string(),
287                quality_tier: if model_lower.contains("70b") {
288                    "high".to_string()
289                } else {
290                    "standard".to_string()
291                },
292                family: Some("llama-3".to_string()),
293                production_ready: true,
294                is_local: true,
295                ..Default::default()
296            };
297        }
298
299        // Mistral models
300        if model_lower.contains("ministral") || model_lower.contains("mistral") {
301            return Self {
302                supports_tools: true,
303                supports_vision: false,
304                supports_json_mode: true,
305                supports_streaming: true,
306                supports_system_prompt: true,
307                context_window: 32_000,
308                max_output_tokens: 4096,
309                supports_reasoning: false,
310                cost_tier: "low".to_string(),
311                speed_tier: "fast".to_string(),
312                quality_tier: "standard".to_string(),
313                family: Some("mistral".to_string()),
314                production_ready: true,
315                is_local: true,
316                ..Default::default()
317            };
318        }
319
320        // Qwen models
321        if model_lower.contains("qwen") {
322            let has_vl = model_lower.contains("-vl");
323            return Self {
324                supports_tools: true,
325                supports_vision: has_vl,
326                supports_json_mode: true,
327                supports_streaming: true,
328                supports_system_prompt: true,
329                context_window: 128_000,
330                max_output_tokens: 8192,
331                supports_reasoning: model_lower.contains("qwq") || model_lower.contains("235b"),
332                cost_tier: "free".to_string(),
333                speed_tier: "medium".to_string(),
334                quality_tier: "high".to_string(),
335                family: Some("qwen".to_string()),
336                production_ready: true,
337                is_local: true,
338                ..Default::default()
339            };
340        }
341
342        // DeepSeek models
343        if model_lower.contains("deepseek") {
344            return Self {
345                supports_tools: true,
346                supports_vision: false,
347                supports_json_mode: true,
348                supports_streaming: true,
349                supports_system_prompt: true,
350                context_window: 128_000,
351                max_output_tokens: 8192,
352                supports_reasoning: model_lower.contains("v3") || model_lower.contains("r1"),
353                cost_tier: "low".to_string(),
354                speed_tier: "medium".to_string(),
355                quality_tier: "high".to_string(),
356                family: Some("deepseek".to_string()),
357                production_ready: true,
358                ..Default::default()
359            };
360        }
361
362        // Default capabilities for unknown models
363        Self::default()
364    }
365
366    /// Check if this model satisfies the given requirements.
367    pub fn satisfies(&self, requirements: &CapabilityRequirements) -> bool {
368        // Check boolean requirements
369        if requirements.requires_tools && !self.supports_tools {
370            return false;
371        }
372        if requirements.requires_vision && !self.supports_vision {
373            return false;
374        }
375        if requirements.requires_audio && !self.supports_audio {
376            return false;
377        }
378        if requirements.requires_json_mode && !self.supports_json_mode {
379            return false;
380        }
381        if requirements.requires_streaming && !self.supports_streaming {
382            return false;
383        }
384        if requirements.requires_reasoning && !self.supports_reasoning {
385            return false;
386        }
387        if requirements.requires_code_execution && !self.supports_code_execution {
388            return false;
389        }
390        if requirements.requires_local && !self.is_local {
391            return false;
392        }
393        if requirements.requires_production_ready && !self.production_ready {
394            return false;
395        }
396
397        // Check numeric requirements
398        if let Some(min_context) = requirements.min_context_window {
399            if self.context_window < min_context {
400                return false;
401            }
402        }
403        if let Some(min_output) = requirements.min_output_tokens {
404            if self.max_output_tokens < min_output {
405                return false;
406            }
407        }
408
409        // Check tier requirements
410        if let Some(ref max_cost) = requirements.max_cost_tier {
411            if !tier_satisfies(&self.cost_tier, max_cost) {
412                return false;
413            }
414        }
415        if let Some(ref min_speed) = requirements.min_speed_tier {
416            if !tier_satisfies(min_speed, &self.speed_tier) {
417                return false;
418            }
419        }
420        if let Some(ref min_quality) = requirements.min_quality_tier {
421            if !tier_satisfies(min_quality, &self.quality_tier) {
422                return false;
423            }
424        }
425
426        // Check required tags
427        for tag in &requirements.required_tags {
428            if !self.tags.contains(tag) {
429                return false;
430            }
431        }
432
433        // Check excluded families
434        if let Some(ref family) = self.family {
435            if requirements.excluded_families.contains(family) {
436                return false;
437            }
438        }
439
440        true
441    }
442
443    /// Calculate a score for how well this model matches requirements.
444    ///
445    /// Higher score = better match. Used for ranking when multiple models satisfy requirements.
446    pub fn score(&self, requirements: &CapabilityRequirements) -> u32 {
447        let mut score = 0u32;
448
449        // Bonus for exceeding minimum requirements
450        if let Some(min_context) = requirements.min_context_window {
451            score += (self.context_window.saturating_sub(min_context)) / 1000;
452        }
453
454        // Bonus for speed when not explicitly required
455        score += match self.speed_tier.as_str() {
456            "realtime" => 40,
457            "fast" => 30,
458            "medium" => 20,
459            "slow" => 10,
460            _ => 0,
461        };
462
463        // Bonus for quality
464        score += match self.quality_tier.as_str() {
465            "premium" => 40,
466            "high" => 30,
467            "standard" => 20,
468            "basic" => 10,
469            _ => 0,
470        };
471
472        // Penalty for cost (prefer cheaper when quality is equal)
473        score += match self.cost_tier.as_str() {
474            "free" => 50,
475            "low" => 40,
476            "medium" => 30,
477            "high" => 20,
478            "premium" => 10,
479            _ => 0,
480        };
481
482        // Bonus for local models (no network latency/cost)
483        if self.is_local {
484            score += 20;
485        }
486
487        // Bonus for having more capabilities than required
488        if self.supports_tools && !requirements.requires_tools {
489            score += 5;
490        }
491        if self.supports_reasoning && !requirements.requires_reasoning {
492            score += 10;
493        }
494
495        score
496    }
497}
498
499/// Requirements for model capability matching.
500#[derive(Debug, Clone, Default, Serialize, Deserialize)]
501pub struct CapabilityRequirements {
502    /// Require tool/function calling support
503    #[serde(default)]
504    pub requires_tools: bool,
505
506    /// Require vision/image input support
507    #[serde(default)]
508    pub requires_vision: bool,
509
510    /// Require audio input support
511    #[serde(default)]
512    pub requires_audio: bool,
513
514    /// Require JSON mode/structured output support
515    #[serde(default)]
516    pub requires_json_mode: bool,
517
518    /// Require streaming response support
519    #[serde(default)]
520    pub requires_streaming: bool,
521
522    /// Require reasoning/chain-of-thought capabilities
523    #[serde(default)]
524    pub requires_reasoning: bool,
525
526    /// Require code execution support
527    #[serde(default)]
528    pub requires_code_execution: bool,
529
530    /// Require the model to run locally
531    #[serde(default)]
532    pub requires_local: bool,
533
534    /// Require production-ready models only
535    #[serde(default)]
536    pub requires_production_ready: bool,
537
538    /// Minimum context window size
539    pub min_context_window: Option<u32>,
540
541    /// Minimum output token capacity
542    pub min_output_tokens: Option<u32>,
543
544    /// Maximum cost tier (e.g., "medium" means free/low/medium are OK)
545    pub max_cost_tier: Option<String>,
546
547    /// Minimum speed tier
548    pub min_speed_tier: Option<String>,
549
550    /// Minimum quality tier
551    pub min_quality_tier: Option<String>,
552
553    /// Required custom tags
554    #[serde(default)]
555    pub required_tags: HashSet<String>,
556
557    /// Model families to exclude (e.g., exclude GPT for privacy reasons)
558    #[serde(default)]
559    pub excluded_families: HashSet<String>,
560}
561
562impl CapabilityRequirements {
563    /// Create a new builder for capability requirements.
564    pub fn builder() -> CapabilityRequirementsBuilder {
565        CapabilityRequirementsBuilder::default()
566    }
567
568    /// Create requirements for tool-calling agents.
569    pub fn for_agent() -> Self {
570        Self {
571            requires_tools: true,
572            requires_production_ready: true,
573            min_quality_tier: Some("standard".to_string()),
574            ..Default::default()
575        }
576    }
577
578    /// Create requirements for chat/conversation.
579    pub fn for_chat() -> Self {
580        Self {
581            requires_streaming: true,
582            requires_production_ready: true,
583            ..Default::default()
584        }
585    }
586
587    /// Create requirements for code generation.
588    pub fn for_coding() -> Self {
589        Self {
590            requires_tools: true,
591            requires_reasoning: true,
592            min_context_window: Some(32_000),
593            min_quality_tier: Some("high".to_string()),
594            ..Default::default()
595        }
596    }
597
598    /// Create requirements for vision tasks.
599    pub fn for_vision() -> Self {
600        Self {
601            requires_vision: true,
602            requires_production_ready: true,
603            ..Default::default()
604        }
605    }
606
607    /// Create requirements for local-only inference.
608    pub fn for_local() -> Self {
609        Self {
610            requires_local: true,
611            max_cost_tier: Some("free".to_string()),
612            ..Default::default()
613        }
614    }
615}
616
617/// Builder for CapabilityRequirements.
618#[derive(Debug, Default)]
619pub struct CapabilityRequirementsBuilder {
620    inner: CapabilityRequirements,
621}
622
623impl CapabilityRequirementsBuilder {
624    /// Require tool/function calling support.
625    pub fn requires_tools(mut self) -> Self {
626        self.inner.requires_tools = true;
627        self
628    }
629
630    /// Require vision/image input support.
631    pub fn requires_vision(mut self) -> Self {
632        self.inner.requires_vision = true;
633        self
634    }
635
636    /// Require audio input support.
637    pub fn requires_audio(mut self) -> Self {
638        self.inner.requires_audio = true;
639        self
640    }
641
642    /// Require JSON mode support.
643    pub fn requires_json_mode(mut self) -> Self {
644        self.inner.requires_json_mode = true;
645        self
646    }
647
648    /// Require streaming support.
649    pub fn requires_streaming(mut self) -> Self {
650        self.inner.requires_streaming = true;
651        self
652    }
653
654    /// Require reasoning capabilities.
655    pub fn requires_reasoning(mut self) -> Self {
656        self.inner.requires_reasoning = true;
657        self
658    }
659
660    /// Require code execution support.
661    pub fn requires_code_execution(mut self) -> Self {
662        self.inner.requires_code_execution = true;
663        self
664    }
665
666    /// Require local-only models.
667    pub fn requires_local(mut self) -> Self {
668        self.inner.requires_local = true;
669        self
670    }
671
672    /// Require production-ready models.
673    pub fn requires_production_ready(mut self) -> Self {
674        self.inner.requires_production_ready = true;
675        self
676    }
677
678    /// Set minimum context window size.
679    pub fn min_context_window(mut self, tokens: u32) -> Self {
680        self.inner.min_context_window = Some(tokens);
681        self
682    }
683
684    /// Set minimum output token capacity.
685    pub fn min_output_tokens(mut self, tokens: u32) -> Self {
686        self.inner.min_output_tokens = Some(tokens);
687        self
688    }
689
690    /// Set maximum cost tier.
691    pub fn max_cost_tier(mut self, tier: impl Into<String>) -> Self {
692        self.inner.max_cost_tier = Some(tier.into());
693        self
694    }
695
696    /// Set minimum speed tier.
697    pub fn min_speed_tier(mut self, tier: impl Into<String>) -> Self {
698        self.inner.min_speed_tier = Some(tier.into());
699        self
700    }
701
702    /// Set minimum quality tier.
703    pub fn min_quality_tier(mut self, tier: impl Into<String>) -> Self {
704        self.inner.min_quality_tier = Some(tier.into());
705        self
706    }
707
708    /// Add a required tag.
709    pub fn require_tag(mut self, tag: impl Into<String>) -> Self {
710        self.inner.required_tags.insert(tag.into());
711        self
712    }
713
714    /// Exclude a model family.
715    pub fn exclude_family(mut self, family: impl Into<String>) -> Self {
716        self.inner.excluded_families.insert(family.into());
717        self
718    }
719
720    /// Build the requirements.
721    pub fn build(self) -> CapabilityRequirements {
722        self.inner
723    }
724}
725
726/// Check if tier `a` satisfies requirement for tier `b`.
727///
728/// Tier ordering: free < low < medium < high < premium
729fn tier_satisfies(requirement: &str, actual: &str) -> bool {
730    let tier_order = |t: &str| match t.to_lowercase().as_str() {
731        "free" | "realtime" | "basic" => 0,
732        "low" | "fast" | "standard" => 1,
733        "medium" => 2,
734        "high" | "slow" => 3,
735        "premium" => 4,
736        _ => 2, // Default to medium
737    };
738
739    tier_order(actual) >= tier_order(requirement)
740}
741
742/// Model with its capabilities for registry storage.
743#[derive(Debug, Clone)]
744pub struct ModelWithCapabilities {
745    /// Model configuration name
746    pub name: String,
747    /// Provider name
748    pub provider: String,
749    /// Model identifier for the provider
750    pub model_id: String,
751    /// Model capabilities
752    pub capabilities: ModelCapabilities,
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758
759    #[test]
760    fn test_claude_capabilities() {
761        let caps = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
762        assert!(caps.supports_tools);
763        assert!(caps.supports_vision);
764        assert_eq!(caps.context_window, 200_000);
765        assert_eq!(caps.quality_tier, "premium");
766    }
767
768    #[test]
769    fn test_gpt4o_capabilities() {
770        let caps = ModelCapabilities::for_model("gpt-4o-2024-08-06");
771        assert!(caps.supports_tools);
772        assert!(caps.supports_vision);
773        assert!(caps.supports_audio);
774        assert_eq!(caps.context_window, 128_000);
775    }
776
777    #[test]
778    fn test_llama_capabilities() {
779        let caps = ModelCapabilities::for_model("llama-3.3-70b-instruct");
780        assert!(caps.supports_tools);
781        assert!(!caps.supports_vision);
782        assert!(caps.is_local);
783        assert_eq!(caps.cost_tier, "free");
784    }
785
786    #[test]
787    fn test_requirements_builder() {
788        let reqs = CapabilityRequirements::builder()
789            .requires_tools()
790            .requires_vision()
791            .min_context_window(100_000)
792            .max_cost_tier("high")
793            .build();
794
795        assert!(reqs.requires_tools);
796        assert!(reqs.requires_vision);
797        assert_eq!(reqs.min_context_window, Some(100_000));
798        assert_eq!(reqs.max_cost_tier, Some("high".to_string()));
799    }
800
801    #[test]
802    fn test_capability_matching() {
803        let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
804        let gpt35 = ModelCapabilities::for_model("gpt-3.5-turbo");
805
806        let vision_reqs = CapabilityRequirements::builder().requires_vision().build();
807
808        assert!(claude.satisfies(&vision_reqs));
809        assert!(!gpt35.satisfies(&vision_reqs));
810    }
811
812    #[test]
813    fn test_context_window_matching() {
814        let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
815        let gpt4 = ModelCapabilities::for_model("gpt-4");
816
817        let long_context_reqs = CapabilityRequirements::builder()
818            .min_context_window(100_000)
819            .build();
820
821        assert!(claude.satisfies(&long_context_reqs));
822        assert!(!gpt4.satisfies(&long_context_reqs)); // gpt-4 base has 8k context
823    }
824
825    #[test]
826    fn test_local_model_matching() {
827        let llama = ModelCapabilities::for_model("llama-3.3-70b-instruct");
828        let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
829
830        let local_reqs = CapabilityRequirements::for_local();
831
832        assert!(llama.satisfies(&local_reqs));
833        assert!(!claude.satisfies(&local_reqs));
834    }
835
836    #[test]
837    fn test_scoring() {
838        let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
839        let haiku = ModelCapabilities::for_model("claude-3-haiku-20240307");
840        let llama = ModelCapabilities::for_model("llama-3.3-70b-instruct");
841
842        let basic_reqs = CapabilityRequirements::builder().requires_tools().build();
843
844        // All satisfy the basic requirements
845        assert!(claude.satisfies(&basic_reqs));
846        assert!(haiku.satisfies(&basic_reqs));
847        assert!(llama.satisfies(&basic_reqs));
848
849        // Llama should score higher due to being free + local
850        let claude_score = claude.score(&basic_reqs);
851        let llama_score = llama.score(&basic_reqs);
852        assert!(
853            llama_score > claude_score,
854            "Llama (free, local) should score higher than Claude (high cost)"
855        );
856
857        // Haiku and Claude scores depend on the scoring algorithm's weight factors
858        let haiku_score = haiku.score(&basic_reqs);
859        // Both should produce valid non-zero scores
860        assert!(haiku_score > 0, "Haiku should have a positive score");
861        assert!(claude_score > 0, "Claude should have a positive score");
862    }
863
864    #[test]
865    fn test_preset_requirements() {
866        let agent_reqs = CapabilityRequirements::for_agent();
867        assert!(agent_reqs.requires_tools);
868        assert!(agent_reqs.requires_production_ready);
869
870        let coding_reqs = CapabilityRequirements::for_coding();
871        assert!(coding_reqs.requires_tools);
872        assert!(coding_reqs.requires_reasoning);
873        assert_eq!(coding_reqs.min_context_window, Some(32_000));
874
875        let vision_reqs = CapabilityRequirements::for_vision();
876        assert!(vision_reqs.requires_vision);
877    }
878
879    #[test]
880    fn test_tier_comparison() {
881        assert!(tier_satisfies("low", "medium")); // medium >= low
882        assert!(tier_satisfies("medium", "high")); // high >= medium
883        assert!(!tier_satisfies("high", "low")); // low < high
884        assert!(tier_satisfies("standard", "premium")); // premium >= standard
885    }
886}