1use serde::{Deserialize, Serialize};
22use std::collections::HashSet;
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
28pub struct ModelCapabilities {
29 #[serde(default)]
31 pub supports_tools: bool,
32
33 #[serde(default)]
35 pub supports_vision: bool,
36
37 #[serde(default)]
39 pub supports_audio: bool,
40
41 #[serde(default)]
43 pub supports_json_mode: bool,
44
45 #[serde(default = "default_true")]
47 pub supports_streaming: bool,
48
49 #[serde(default = "default_true")]
51 pub supports_system_prompt: bool,
52
53 #[serde(default = "default_context_window")]
55 pub context_window: u32,
56
57 #[serde(default = "default_max_output")]
59 pub max_output_tokens: u32,
60
61 #[serde(default)]
63 pub supports_reasoning: bool,
64
65 #[serde(default)]
67 pub supports_code_execution: bool,
68
69 #[serde(default = "default_cost_tier")]
71 pub cost_tier: String,
72
73 #[serde(default = "default_speed_tier")]
75 pub speed_tier: String,
76
77 #[serde(default = "default_quality_tier")]
79 pub quality_tier: String,
80
81 #[serde(default)]
83 pub languages: HashSet<String>,
84
85 #[serde(default)]
87 pub family: Option<String>,
88
89 #[serde(default = "default_true")]
91 pub production_ready: bool,
92
93 #[serde(default)]
95 pub is_local: bool,
96
97 #[serde(default)]
99 pub tags: HashSet<String>,
100}
101
102fn default_true() -> bool {
103 true
104}
105
106fn default_context_window() -> u32 {
107 4096
108}
109
110fn default_max_output() -> u32 {
111 4096
112}
113
114fn default_cost_tier() -> String {
115 "medium".to_string()
116}
117
118fn default_speed_tier() -> String {
119 "medium".to_string()
120}
121
122fn default_quality_tier() -> String {
123 "standard".to_string()
124}
125
126impl ModelCapabilities {
127 pub fn for_model(model_name: &str) -> Self {
131 let model_lower = model_name.to_lowercase();
132
133 if model_lower.contains("claude-3-5-sonnet") || model_lower.contains("claude-sonnet-4") {
135 return Self {
136 supports_tools: true,
137 supports_vision: true,
138 supports_json_mode: true,
139 supports_streaming: true,
140 supports_system_prompt: true,
141 context_window: 200_000,
142 max_output_tokens: 8192,
143 supports_reasoning: true,
144 cost_tier: "high".to_string(),
145 speed_tier: "fast".to_string(),
146 quality_tier: "premium".to_string(),
147 family: Some("claude-3".to_string()),
148 production_ready: true,
149 ..Default::default()
150 };
151 }
152
153 if model_lower.contains("claude-3-opus") || model_lower.contains("claude-opus") {
154 return Self {
155 supports_tools: true,
156 supports_vision: true,
157 supports_json_mode: true,
158 supports_streaming: true,
159 supports_system_prompt: true,
160 context_window: 200_000,
161 max_output_tokens: 4096,
162 supports_reasoning: true,
163 cost_tier: "premium".to_string(),
164 speed_tier: "slow".to_string(),
165 quality_tier: "premium".to_string(),
166 family: Some("claude-3".to_string()),
167 production_ready: true,
168 ..Default::default()
169 };
170 }
171
172 if model_lower.contains("claude-3-haiku") || model_lower.contains("claude-haiku") {
173 return Self {
174 supports_tools: true,
175 supports_vision: true,
176 supports_json_mode: true,
177 supports_streaming: true,
178 supports_system_prompt: true,
179 context_window: 200_000,
180 max_output_tokens: 4096,
181 supports_reasoning: false,
182 cost_tier: "low".to_string(),
183 speed_tier: "realtime".to_string(),
184 quality_tier: "standard".to_string(),
185 family: Some("claude-3".to_string()),
186 production_ready: true,
187 ..Default::default()
188 };
189 }
190
191 if model_lower.contains("gpt-4o") {
193 return Self {
194 supports_tools: true,
195 supports_vision: true,
196 supports_audio: true,
197 supports_json_mode: true,
198 supports_streaming: true,
199 supports_system_prompt: true,
200 context_window: 128_000,
201 max_output_tokens: 16384,
202 supports_reasoning: true,
203 cost_tier: "high".to_string(),
204 speed_tier: "fast".to_string(),
205 quality_tier: "premium".to_string(),
206 family: Some("gpt-4".to_string()),
207 production_ready: true,
208 ..Default::default()
209 };
210 }
211
212 if model_lower.contains("gpt-4-turbo") || model_lower.contains("gpt-4-1106") {
213 return Self {
214 supports_tools: true,
215 supports_vision: true,
216 supports_json_mode: true,
217 supports_streaming: true,
218 supports_system_prompt: true,
219 context_window: 128_000,
220 max_output_tokens: 4096,
221 supports_reasoning: true,
222 cost_tier: "high".to_string(),
223 speed_tier: "medium".to_string(),
224 quality_tier: "premium".to_string(),
225 family: Some("gpt-4".to_string()),
226 production_ready: true,
227 ..Default::default()
228 };
229 }
230
231 if model_lower.contains("gpt-4") && !model_lower.contains("gpt-4o") {
232 return Self {
233 supports_tools: true,
234 supports_vision: false,
235 supports_json_mode: true,
236 supports_streaming: true,
237 supports_system_prompt: true,
238 context_window: 8192,
239 max_output_tokens: 4096,
240 supports_reasoning: true,
241 cost_tier: "high".to_string(),
242 speed_tier: "slow".to_string(),
243 quality_tier: "premium".to_string(),
244 family: Some("gpt-4".to_string()),
245 production_ready: true,
246 ..Default::default()
247 };
248 }
249
250 if model_lower.contains("gpt-3.5") {
251 return Self {
252 supports_tools: true,
253 supports_vision: false,
254 supports_json_mode: true,
255 supports_streaming: true,
256 supports_system_prompt: true,
257 context_window: 16385,
258 max_output_tokens: 4096,
259 supports_reasoning: false,
260 cost_tier: "low".to_string(),
261 speed_tier: "fast".to_string(),
262 quality_tier: "standard".to_string(),
263 family: Some("gpt-3.5".to_string()),
264 production_ready: true,
265 ..Default::default()
266 };
267 }
268
269 if model_lower.contains("llama-3.3") || model_lower.contains("llama-3.1") {
271 let context = if model_lower.contains("70b") {
272 128_000
273 } else {
274 131_072
275 };
276 return Self {
277 supports_tools: true,
278 supports_vision: false,
279 supports_json_mode: true,
280 supports_streaming: true,
281 supports_system_prompt: true,
282 context_window: context,
283 max_output_tokens: 4096,
284 supports_reasoning: model_lower.contains("70b"),
285 cost_tier: "free".to_string(),
286 speed_tier: "medium".to_string(),
287 quality_tier: if model_lower.contains("70b") {
288 "high".to_string()
289 } else {
290 "standard".to_string()
291 },
292 family: Some("llama-3".to_string()),
293 production_ready: true,
294 is_local: true,
295 ..Default::default()
296 };
297 }
298
299 if model_lower.contains("ministral") || model_lower.contains("mistral") {
301 return Self {
302 supports_tools: true,
303 supports_vision: false,
304 supports_json_mode: true,
305 supports_streaming: true,
306 supports_system_prompt: true,
307 context_window: 32_000,
308 max_output_tokens: 4096,
309 supports_reasoning: false,
310 cost_tier: "low".to_string(),
311 speed_tier: "fast".to_string(),
312 quality_tier: "standard".to_string(),
313 family: Some("mistral".to_string()),
314 production_ready: true,
315 is_local: true,
316 ..Default::default()
317 };
318 }
319
320 if model_lower.contains("qwen") {
322 let has_vl = model_lower.contains("-vl");
323 return Self {
324 supports_tools: true,
325 supports_vision: has_vl,
326 supports_json_mode: true,
327 supports_streaming: true,
328 supports_system_prompt: true,
329 context_window: 128_000,
330 max_output_tokens: 8192,
331 supports_reasoning: model_lower.contains("qwq") || model_lower.contains("235b"),
332 cost_tier: "free".to_string(),
333 speed_tier: "medium".to_string(),
334 quality_tier: "high".to_string(),
335 family: Some("qwen".to_string()),
336 production_ready: true,
337 is_local: true,
338 ..Default::default()
339 };
340 }
341
342 if model_lower.contains("deepseek") {
344 return Self {
345 supports_tools: true,
346 supports_vision: false,
347 supports_json_mode: true,
348 supports_streaming: true,
349 supports_system_prompt: true,
350 context_window: 128_000,
351 max_output_tokens: 8192,
352 supports_reasoning: model_lower.contains("v3") || model_lower.contains("r1"),
353 cost_tier: "low".to_string(),
354 speed_tier: "medium".to_string(),
355 quality_tier: "high".to_string(),
356 family: Some("deepseek".to_string()),
357 production_ready: true,
358 ..Default::default()
359 };
360 }
361
362 Self::default()
364 }
365
366 pub fn satisfies(&self, requirements: &CapabilityRequirements) -> bool {
368 if requirements.requires_tools && !self.supports_tools {
370 return false;
371 }
372 if requirements.requires_vision && !self.supports_vision {
373 return false;
374 }
375 if requirements.requires_audio && !self.supports_audio {
376 return false;
377 }
378 if requirements.requires_json_mode && !self.supports_json_mode {
379 return false;
380 }
381 if requirements.requires_streaming && !self.supports_streaming {
382 return false;
383 }
384 if requirements.requires_reasoning && !self.supports_reasoning {
385 return false;
386 }
387 if requirements.requires_code_execution && !self.supports_code_execution {
388 return false;
389 }
390 if requirements.requires_local && !self.is_local {
391 return false;
392 }
393 if requirements.requires_production_ready && !self.production_ready {
394 return false;
395 }
396
397 if let Some(min_context) = requirements.min_context_window {
399 if self.context_window < min_context {
400 return false;
401 }
402 }
403 if let Some(min_output) = requirements.min_output_tokens {
404 if self.max_output_tokens < min_output {
405 return false;
406 }
407 }
408
409 if let Some(ref max_cost) = requirements.max_cost_tier {
411 if !tier_satisfies(&self.cost_tier, max_cost) {
412 return false;
413 }
414 }
415 if let Some(ref min_speed) = requirements.min_speed_tier {
416 if !tier_satisfies(min_speed, &self.speed_tier) {
417 return false;
418 }
419 }
420 if let Some(ref min_quality) = requirements.min_quality_tier {
421 if !tier_satisfies(min_quality, &self.quality_tier) {
422 return false;
423 }
424 }
425
426 for tag in &requirements.required_tags {
428 if !self.tags.contains(tag) {
429 return false;
430 }
431 }
432
433 if let Some(ref family) = self.family {
435 if requirements.excluded_families.contains(family) {
436 return false;
437 }
438 }
439
440 true
441 }
442
443 pub fn score(&self, requirements: &CapabilityRequirements) -> u32 {
447 let mut score = 0u32;
448
449 if let Some(min_context) = requirements.min_context_window {
451 score += (self.context_window.saturating_sub(min_context)) / 1000;
452 }
453
454 score += match self.speed_tier.as_str() {
456 "realtime" => 40,
457 "fast" => 30,
458 "medium" => 20,
459 "slow" => 10,
460 _ => 0,
461 };
462
463 score += match self.quality_tier.as_str() {
465 "premium" => 40,
466 "high" => 30,
467 "standard" => 20,
468 "basic" => 10,
469 _ => 0,
470 };
471
472 score += match self.cost_tier.as_str() {
474 "free" => 50,
475 "low" => 40,
476 "medium" => 30,
477 "high" => 20,
478 "premium" => 10,
479 _ => 0,
480 };
481
482 if self.is_local {
484 score += 20;
485 }
486
487 if self.supports_tools && !requirements.requires_tools {
489 score += 5;
490 }
491 if self.supports_reasoning && !requirements.requires_reasoning {
492 score += 10;
493 }
494
495 score
496 }
497}
498
499#[derive(Debug, Clone, Default, Serialize, Deserialize)]
501pub struct CapabilityRequirements {
502 #[serde(default)]
504 pub requires_tools: bool,
505
506 #[serde(default)]
508 pub requires_vision: bool,
509
510 #[serde(default)]
512 pub requires_audio: bool,
513
514 #[serde(default)]
516 pub requires_json_mode: bool,
517
518 #[serde(default)]
520 pub requires_streaming: bool,
521
522 #[serde(default)]
524 pub requires_reasoning: bool,
525
526 #[serde(default)]
528 pub requires_code_execution: bool,
529
530 #[serde(default)]
532 pub requires_local: bool,
533
534 #[serde(default)]
536 pub requires_production_ready: bool,
537
538 pub min_context_window: Option<u32>,
540
541 pub min_output_tokens: Option<u32>,
543
544 pub max_cost_tier: Option<String>,
546
547 pub min_speed_tier: Option<String>,
549
550 pub min_quality_tier: Option<String>,
552
553 #[serde(default)]
555 pub required_tags: HashSet<String>,
556
557 #[serde(default)]
559 pub excluded_families: HashSet<String>,
560}
561
562impl CapabilityRequirements {
563 pub fn builder() -> CapabilityRequirementsBuilder {
565 CapabilityRequirementsBuilder::default()
566 }
567
568 pub fn for_agent() -> Self {
570 Self {
571 requires_tools: true,
572 requires_production_ready: true,
573 min_quality_tier: Some("standard".to_string()),
574 ..Default::default()
575 }
576 }
577
578 pub fn for_chat() -> Self {
580 Self {
581 requires_streaming: true,
582 requires_production_ready: true,
583 ..Default::default()
584 }
585 }
586
587 pub fn for_coding() -> Self {
589 Self {
590 requires_tools: true,
591 requires_reasoning: true,
592 min_context_window: Some(32_000),
593 min_quality_tier: Some("high".to_string()),
594 ..Default::default()
595 }
596 }
597
598 pub fn for_vision() -> Self {
600 Self {
601 requires_vision: true,
602 requires_production_ready: true,
603 ..Default::default()
604 }
605 }
606
607 pub fn for_local() -> Self {
609 Self {
610 requires_local: true,
611 max_cost_tier: Some("free".to_string()),
612 ..Default::default()
613 }
614 }
615}
616
617#[derive(Debug, Default)]
619pub struct CapabilityRequirementsBuilder {
620 inner: CapabilityRequirements,
621}
622
623impl CapabilityRequirementsBuilder {
624 pub fn requires_tools(mut self) -> Self {
626 self.inner.requires_tools = true;
627 self
628 }
629
630 pub fn requires_vision(mut self) -> Self {
632 self.inner.requires_vision = true;
633 self
634 }
635
636 pub fn requires_audio(mut self) -> Self {
638 self.inner.requires_audio = true;
639 self
640 }
641
642 pub fn requires_json_mode(mut self) -> Self {
644 self.inner.requires_json_mode = true;
645 self
646 }
647
648 pub fn requires_streaming(mut self) -> Self {
650 self.inner.requires_streaming = true;
651 self
652 }
653
654 pub fn requires_reasoning(mut self) -> Self {
656 self.inner.requires_reasoning = true;
657 self
658 }
659
660 pub fn requires_code_execution(mut self) -> Self {
662 self.inner.requires_code_execution = true;
663 self
664 }
665
666 pub fn requires_local(mut self) -> Self {
668 self.inner.requires_local = true;
669 self
670 }
671
672 pub fn requires_production_ready(mut self) -> Self {
674 self.inner.requires_production_ready = true;
675 self
676 }
677
678 pub fn min_context_window(mut self, tokens: u32) -> Self {
680 self.inner.min_context_window = Some(tokens);
681 self
682 }
683
684 pub fn min_output_tokens(mut self, tokens: u32) -> Self {
686 self.inner.min_output_tokens = Some(tokens);
687 self
688 }
689
690 pub fn max_cost_tier(mut self, tier: impl Into<String>) -> Self {
692 self.inner.max_cost_tier = Some(tier.into());
693 self
694 }
695
696 pub fn min_speed_tier(mut self, tier: impl Into<String>) -> Self {
698 self.inner.min_speed_tier = Some(tier.into());
699 self
700 }
701
702 pub fn min_quality_tier(mut self, tier: impl Into<String>) -> Self {
704 self.inner.min_quality_tier = Some(tier.into());
705 self
706 }
707
708 pub fn require_tag(mut self, tag: impl Into<String>) -> Self {
710 self.inner.required_tags.insert(tag.into());
711 self
712 }
713
714 pub fn exclude_family(mut self, family: impl Into<String>) -> Self {
716 self.inner.excluded_families.insert(family.into());
717 self
718 }
719
720 pub fn build(self) -> CapabilityRequirements {
722 self.inner
723 }
724}
725
726fn tier_satisfies(requirement: &str, actual: &str) -> bool {
730 let tier_order = |t: &str| match t.to_lowercase().as_str() {
731 "free" | "realtime" | "basic" => 0,
732 "low" | "fast" | "standard" => 1,
733 "medium" => 2,
734 "high" | "slow" => 3,
735 "premium" => 4,
736 _ => 2, };
738
739 tier_order(actual) >= tier_order(requirement)
740}
741
742#[derive(Debug, Clone)]
744pub struct ModelWithCapabilities {
745 pub name: String,
747 pub provider: String,
749 pub model_id: String,
751 pub capabilities: ModelCapabilities,
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758
759 #[test]
760 fn test_claude_capabilities() {
761 let caps = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
762 assert!(caps.supports_tools);
763 assert!(caps.supports_vision);
764 assert_eq!(caps.context_window, 200_000);
765 assert_eq!(caps.quality_tier, "premium");
766 }
767
768 #[test]
769 fn test_gpt4o_capabilities() {
770 let caps = ModelCapabilities::for_model("gpt-4o-2024-08-06");
771 assert!(caps.supports_tools);
772 assert!(caps.supports_vision);
773 assert!(caps.supports_audio);
774 assert_eq!(caps.context_window, 128_000);
775 }
776
777 #[test]
778 fn test_llama_capabilities() {
779 let caps = ModelCapabilities::for_model("llama-3.3-70b-instruct");
780 assert!(caps.supports_tools);
781 assert!(!caps.supports_vision);
782 assert!(caps.is_local);
783 assert_eq!(caps.cost_tier, "free");
784 }
785
786 #[test]
787 fn test_requirements_builder() {
788 let reqs = CapabilityRequirements::builder()
789 .requires_tools()
790 .requires_vision()
791 .min_context_window(100_000)
792 .max_cost_tier("high")
793 .build();
794
795 assert!(reqs.requires_tools);
796 assert!(reqs.requires_vision);
797 assert_eq!(reqs.min_context_window, Some(100_000));
798 assert_eq!(reqs.max_cost_tier, Some("high".to_string()));
799 }
800
801 #[test]
802 fn test_capability_matching() {
803 let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
804 let gpt35 = ModelCapabilities::for_model("gpt-3.5-turbo");
805
806 let vision_reqs = CapabilityRequirements::builder().requires_vision().build();
807
808 assert!(claude.satisfies(&vision_reqs));
809 assert!(!gpt35.satisfies(&vision_reqs));
810 }
811
812 #[test]
813 fn test_context_window_matching() {
814 let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
815 let gpt4 = ModelCapabilities::for_model("gpt-4");
816
817 let long_context_reqs = CapabilityRequirements::builder()
818 .min_context_window(100_000)
819 .build();
820
821 assert!(claude.satisfies(&long_context_reqs));
822 assert!(!gpt4.satisfies(&long_context_reqs)); }
824
825 #[test]
826 fn test_local_model_matching() {
827 let llama = ModelCapabilities::for_model("llama-3.3-70b-instruct");
828 let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
829
830 let local_reqs = CapabilityRequirements::for_local();
831
832 assert!(llama.satisfies(&local_reqs));
833 assert!(!claude.satisfies(&local_reqs));
834 }
835
836 #[test]
837 fn test_scoring() {
838 let claude = ModelCapabilities::for_model("claude-3-5-sonnet-20241022");
839 let haiku = ModelCapabilities::for_model("claude-3-haiku-20240307");
840 let llama = ModelCapabilities::for_model("llama-3.3-70b-instruct");
841
842 let basic_reqs = CapabilityRequirements::builder().requires_tools().build();
843
844 assert!(claude.satisfies(&basic_reqs));
846 assert!(haiku.satisfies(&basic_reqs));
847 assert!(llama.satisfies(&basic_reqs));
848
849 let claude_score = claude.score(&basic_reqs);
851 let llama_score = llama.score(&basic_reqs);
852 assert!(
853 llama_score > claude_score,
854 "Llama (free, local) should score higher than Claude (high cost)"
855 );
856
857 let haiku_score = haiku.score(&basic_reqs);
859 assert!(haiku_score > 0, "Haiku should have a positive score");
861 assert!(claude_score > 0, "Claude should have a positive score");
862 }
863
864 #[test]
865 fn test_preset_requirements() {
866 let agent_reqs = CapabilityRequirements::for_agent();
867 assert!(agent_reqs.requires_tools);
868 assert!(agent_reqs.requires_production_ready);
869
870 let coding_reqs = CapabilityRequirements::for_coding();
871 assert!(coding_reqs.requires_tools);
872 assert!(coding_reqs.requires_reasoning);
873 assert_eq!(coding_reqs.min_context_window, Some(32_000));
874
875 let vision_reqs = CapabilityRequirements::for_vision();
876 assert!(vision_reqs.requires_vision);
877 }
878
879 #[test]
880 fn test_tier_comparison() {
881 assert!(tier_satisfies("low", "medium")); assert!(tier_satisfies("medium", "high")); assert!(!tier_satisfies("high", "low")); assert!(tier_satisfies("standard", "premium")); }
886}