Skip to main content

converge_provider/
model_selection.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// Author: Kenneth Pernyer, kenneth@aprio.one
3// SPDX-License-Identifier: MIT
4// See LICENSE file in the project root for full license information.
5
6//! Model selection implementation with provider-specific metadata.
7//!
8//! This module provides concrete implementations of model selection
9//! with hardcoded knowledge of all providers. The abstract interface
10//! (`ModelSelectorTrait`, `AgentRequirements`) is in converge-core.
11
12use converge_traits::{
13    AgentRequirements, ComplianceLevel, CostClass, DataSovereignty, LlmError, ModelSelectorTrait,
14};
15
16/// Breakdown of fitness score components.
17#[derive(Debug, Clone, PartialEq)]
18pub struct FitnessBreakdown {
19    /// Cost efficiency score (0.0-1.0, higher = cheaper).
20    /// VeryLow=1.0, Low=0.8, Medium=0.6, High=0.4, VeryHigh=0.2
21    pub cost_score: f64,
22    /// Latency efficiency score (0.0-1.0, higher = faster).
23    /// Calculated as: 1.0 - (`typical_latency` / `max_allowed_latency`)
24    pub latency_score: f64,
25    /// Quality score (0.0-1.0, model's quality rating).
26    pub quality_score: f64,
27    /// Total weighted score: 40% cost + 30% latency + 30% quality.
28    pub total: f64,
29}
30
31impl std::fmt::Display for FitnessBreakdown {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(
34            f,
35            "{:.3} = 40%×cost({:.2}) + 30%×latency({:.2}) + 30%×quality({:.2})",
36            self.total, self.cost_score, self.latency_score, self.quality_score
37        )
38    }
39}
40
41/// Result of model selection with detailed information.
42#[derive(Debug, Clone)]
43pub struct SelectionResult {
44    /// The selected model's metadata.
45    pub selected: ModelMetadata,
46    /// Fitness breakdown for the selected model.
47    pub fitness: FitnessBreakdown,
48    /// All candidates that were considered, with their fitness scores.
49    /// Sorted by fitness (best first).
50    pub candidates: Vec<(ModelMetadata, FitnessBreakdown)>,
51    /// Models that were rejected and why.
52    pub rejected: Vec<(ModelMetadata, RejectionReason)>,
53}
54
55/// Reason why a model was rejected during selection.
56#[derive(Debug, Clone, PartialEq)]
57pub enum RejectionReason {
58    /// Provider not available (no API key).
59    ProviderUnavailable,
60    /// Cost class exceeds budget.
61    CostTooHigh {
62        model_cost: CostClass,
63        max_allowed: CostClass,
64    },
65    /// Latency exceeds limit.
66    LatencyTooHigh {
67        model_latency_ms: u32,
68        max_allowed_ms: u32,
69    },
70    /// Quality below threshold.
71    QualityTooLow {
72        model_quality: f64,
73        min_required: f64,
74    },
75    /// Reasoning required but not supported.
76    ReasoningRequired,
77    /// Web search required but not supported.
78    WebSearchRequired,
79    /// Data sovereignty mismatch.
80    DataSovereigntyMismatch {
81        required: DataSovereignty,
82        model_has: DataSovereignty,
83    },
84    /// Compliance level mismatch.
85    ComplianceMismatch {
86        required: ComplianceLevel,
87        model_has: ComplianceLevel,
88    },
89    /// Multilingual required but not supported.
90    MultilingualRequired,
91}
92
93impl std::fmt::Display for RejectionReason {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::ProviderUnavailable => write!(f, "provider unavailable (no API key)"),
97            Self::CostTooHigh {
98                model_cost,
99                max_allowed,
100            } => {
101                write!(f, "cost {model_cost:?} exceeds max {max_allowed:?}")
102            }
103            Self::LatencyTooHigh {
104                model_latency_ms,
105                max_allowed_ms,
106            } => {
107                write!(
108                    f,
109                    "latency {model_latency_ms}ms exceeds max {max_allowed_ms}ms"
110                )
111            }
112            Self::QualityTooLow {
113                model_quality,
114                min_required,
115            } => {
116                write!(f, "quality {model_quality:.2} below min {min_required:.2}")
117            }
118            Self::ReasoningRequired => write!(f, "reasoning required but not supported"),
119            Self::WebSearchRequired => write!(f, "web search required but not supported"),
120            Self::DataSovereigntyMismatch {
121                required,
122                model_has,
123            } => {
124                write!(f, "data sovereignty {model_has:?} != required {required:?}")
125            }
126            Self::ComplianceMismatch {
127                required,
128                model_has,
129            } => {
130                write!(f, "compliance {model_has:?} != required {required:?}")
131            }
132            Self::MultilingualRequired => write!(f, "multilingual required but not supported"),
133        }
134    }
135}
136
137/// Model metadata for selection.
138#[derive(Debug, Clone, PartialEq)]
139#[allow(clippy::struct_excessive_bools)]
140pub struct ModelMetadata {
141    /// Provider name (e.g., "anthropic", "openai").
142    pub provider: String,
143    /// Model identifier (e.g., "claude-haiku-4-5-20251001").
144    pub model: String,
145    /// Cost class of this model.
146    pub cost_class: CostClass,
147    /// Typical latency in milliseconds.
148    pub typical_latency_ms: u32,
149    /// Quality score (0.0-1.0).
150    pub quality: f64,
151    /// Whether this model has strong reasoning capabilities.
152    pub has_reasoning: bool,
153    /// Whether this model supports web search.
154    pub supports_web_search: bool,
155    /// Data sovereignty region.
156    pub data_sovereignty: DataSovereignty,
157    /// Compliance level.
158    pub compliance: ComplianceLevel,
159    /// Whether this model supports multiple languages.
160    pub supports_multilingual: bool,
161    // === New fields for enhanced selection ===
162    /// Context window size in tokens.
163    pub context_tokens: usize,
164    /// Whether this model supports tool/function calling.
165    pub supports_tool_use: bool,
166    /// Whether this model supports vision/images.
167    pub supports_vision: bool,
168    /// Whether this model supports structured output (JSON mode).
169    pub supports_structured_output: bool,
170    /// Whether this model is specialized for code.
171    pub supports_code: bool,
172    /// Provider's country (ISO code, e.g., "US", "FR", "CN").
173    pub country: String,
174    /// Provider's region (e.g., "US", "EU", "CN", "LOCAL").
175    pub region: String,
176}
177
178impl ModelMetadata {
179    /// Creates new model metadata.
180    #[must_use]
181    pub fn new(
182        provider: impl Into<String>,
183        model: impl Into<String>,
184        cost_class: CostClass,
185        typical_latency_ms: u32,
186        quality: f64,
187    ) -> Self {
188        Self {
189            provider: provider.into(),
190            model: model.into(),
191            cost_class,
192            typical_latency_ms,
193            quality: quality.clamp(0.0, 1.0),
194            has_reasoning: false,
195            supports_web_search: false,
196            data_sovereignty: DataSovereignty::Any,
197            compliance: ComplianceLevel::None,
198            supports_multilingual: false,
199            // New fields with defaults
200            context_tokens: 8192,
201            supports_tool_use: false,
202            supports_vision: false,
203            supports_structured_output: false,
204            supports_code: false,
205            country: "US".to_string(),
206            region: "US".to_string(),
207        }
208    }
209
210    /// Sets reasoning capability.
211    #[must_use]
212    pub fn with_reasoning(mut self, has: bool) -> Self {
213        self.has_reasoning = has;
214        self
215    }
216
217    /// Sets web search support.
218    #[must_use]
219    pub fn with_web_search(mut self, supports: bool) -> Self {
220        self.supports_web_search = supports;
221        self
222    }
223
224    /// Sets data sovereignty.
225    #[must_use]
226    pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
227        self.data_sovereignty = sovereignty;
228        self
229    }
230
231    /// Sets compliance level.
232    #[must_use]
233    pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
234        self.compliance = compliance;
235        self
236    }
237
238    /// Sets multilingual support.
239    #[must_use]
240    pub fn with_multilingual(mut self, supports: bool) -> Self {
241        self.supports_multilingual = supports;
242        self
243    }
244
245    /// Sets context window size.
246    #[must_use]
247    pub fn with_context_tokens(mut self, tokens: usize) -> Self {
248        self.context_tokens = tokens;
249        self
250    }
251
252    /// Sets tool/function calling support.
253    #[must_use]
254    pub fn with_tool_use(mut self, supports: bool) -> Self {
255        self.supports_tool_use = supports;
256        self
257    }
258
259    /// Sets vision support.
260    #[must_use]
261    pub fn with_vision(mut self, supports: bool) -> Self {
262        self.supports_vision = supports;
263        self
264    }
265
266    /// Sets structured output support.
267    #[must_use]
268    pub fn with_structured_output(mut self, supports: bool) -> Self {
269        self.supports_structured_output = supports;
270        self
271    }
272
273    /// Sets code capability.
274    #[must_use]
275    pub fn with_code(mut self, supports: bool) -> Self {
276        self.supports_code = supports;
277        self
278    }
279
280    /// Sets provider location (country and region).
281    #[must_use]
282    pub fn with_location(mut self, country: impl Into<String>, region: impl Into<String>) -> Self {
283        self.country = country.into();
284        self.region = region.into();
285        self
286    }
287
288    /// Checks if this model satisfies the given requirements.
289    #[must_use]
290    pub fn satisfies(&self, requirements: &AgentRequirements) -> bool {
291        // Cost check
292        if !requirements
293            .max_cost_class
294            .allowed_classes()
295            .contains(&self.cost_class)
296        {
297            return false;
298        }
299
300        // Latency check
301        if self.typical_latency_ms > requirements.max_latency_ms {
302            return false;
303        }
304
305        // Reasoning check
306        if requirements.requires_reasoning && !self.has_reasoning {
307            return false;
308        }
309
310        // Web search check
311        if requirements.requires_web_search && !self.supports_web_search {
312            return false;
313        }
314
315        // Quality check
316        if self.quality < requirements.min_quality {
317            return false;
318        }
319
320        // Data sovereignty check
321        if requirements.data_sovereignty != DataSovereignty::Any
322            && self.data_sovereignty != requirements.data_sovereignty
323        {
324            return false;
325        }
326
327        // Compliance check
328        if requirements.compliance != ComplianceLevel::None
329            && self.compliance != requirements.compliance
330        {
331            return false;
332        }
333
334        // Multilingual check
335        if requirements.requires_multilingual && !self.supports_multilingual {
336            return false;
337        }
338
339        true
340    }
341
342    /// Calculates a fitness score for matching requirements.
343    ///
344    /// Higher score = better match. Considers:
345    /// - Cost efficiency (lower cost within allowed range)
346    /// - Latency efficiency (faster within allowed range)
347    /// - Quality (higher is better)
348    #[must_use]
349    pub fn fitness_score(&self, requirements: &AgentRequirements) -> f64 {
350        if !self.satisfies(requirements) {
351            return 0.0;
352        }
353
354        // Cost efficiency: prefer lower cost (inverted, normalized)
355        let cost_score = match self.cost_class {
356            CostClass::VeryLow => 1.0,
357            CostClass::Low => 0.8,
358            CostClass::Medium => 0.6,
359            CostClass::High => 0.4,
360            CostClass::VeryHigh => 0.2,
361        };
362
363        // Latency efficiency: prefer faster (inverted, normalized)
364        let latency_ratio =
365            f64::from(self.typical_latency_ms) / f64::from(requirements.max_latency_ms);
366        let latency_score = 1.0 - latency_ratio.min(1.0);
367
368        // Quality score (already 0.0-1.0)
369        let quality_score = self.quality;
370
371        // Weighted combination
372        // Cost: 40%, Latency: 30%, Quality: 30%
373        0.4 * cost_score + 0.3 * latency_score + 0.3 * quality_score
374    }
375
376    /// Calculates a detailed fitness breakdown for matching requirements.
377    ///
378    /// Returns `None` if the model doesn't satisfy requirements.
379    #[must_use]
380    pub fn fitness_breakdown(&self, requirements: &AgentRequirements) -> Option<FitnessBreakdown> {
381        if !self.satisfies(requirements) {
382            return None;
383        }
384
385        let cost_score = match self.cost_class {
386            CostClass::VeryLow => 1.0,
387            CostClass::Low => 0.8,
388            CostClass::Medium => 0.6,
389            CostClass::High => 0.4,
390            CostClass::VeryHigh => 0.2,
391        };
392
393        let latency_ratio =
394            f64::from(self.typical_latency_ms) / f64::from(requirements.max_latency_ms);
395        let latency_score = 1.0 - latency_ratio.min(1.0);
396
397        let quality_score = self.quality;
398
399        let total = 0.4 * cost_score + 0.3 * latency_score + 0.3 * quality_score;
400
401        Some(FitnessBreakdown {
402            cost_score,
403            latency_score,
404            quality_score,
405            total,
406        })
407    }
408
409    /// Determines why this model was rejected for the given requirements.
410    ///
411    /// Returns `None` if the model satisfies all requirements.
412    #[must_use]
413    pub fn rejection_reason(&self, requirements: &AgentRequirements) -> Option<RejectionReason> {
414        // Cost check
415        if !requirements
416            .max_cost_class
417            .allowed_classes()
418            .contains(&self.cost_class)
419        {
420            return Some(RejectionReason::CostTooHigh {
421                model_cost: self.cost_class,
422                max_allowed: requirements.max_cost_class,
423            });
424        }
425
426        // Latency check
427        if self.typical_latency_ms > requirements.max_latency_ms {
428            return Some(RejectionReason::LatencyTooHigh {
429                model_latency_ms: self.typical_latency_ms,
430                max_allowed_ms: requirements.max_latency_ms,
431            });
432        }
433
434        // Reasoning check
435        if requirements.requires_reasoning && !self.has_reasoning {
436            return Some(RejectionReason::ReasoningRequired);
437        }
438
439        // Web search check
440        if requirements.requires_web_search && !self.supports_web_search {
441            return Some(RejectionReason::WebSearchRequired);
442        }
443
444        // Quality check
445        if self.quality < requirements.min_quality {
446            return Some(RejectionReason::QualityTooLow {
447                model_quality: self.quality,
448                min_required: requirements.min_quality,
449            });
450        }
451
452        // Data sovereignty check
453        if requirements.data_sovereignty != DataSovereignty::Any
454            && self.data_sovereignty != requirements.data_sovereignty
455        {
456            return Some(RejectionReason::DataSovereigntyMismatch {
457                required: requirements.data_sovereignty,
458                model_has: self.data_sovereignty,
459            });
460        }
461
462        // Compliance check
463        if requirements.compliance != ComplianceLevel::None
464            && self.compliance != requirements.compliance
465        {
466            return Some(RejectionReason::ComplianceMismatch {
467                required: requirements.compliance,
468                model_has: self.compliance,
469            });
470        }
471
472        // Multilingual check
473        if requirements.requires_multilingual && !self.supports_multilingual {
474            return Some(RejectionReason::MultilingualRequired);
475        }
476
477        None
478    }
479}
480
481/// Model selector that matches requirements to models.
482#[derive(Debug, Clone)]
483pub struct ModelSelector {
484    /// Available models with metadata.
485    models: Vec<ModelMetadata>,
486}
487
488impl ModelSelector {
489    /// Creates a new model selector with default models.
490    #[must_use]
491    pub fn new() -> Self {
492        Self::default()
493    }
494
495    /// Creates an empty selector (add models manually).
496    #[must_use]
497    pub fn empty() -> Self {
498        Self { models: Vec::new() }
499    }
500
501    /// Adds a model to the selector.
502    #[must_use]
503    pub fn with_model(mut self, metadata: ModelMetadata) -> Self {
504        self.models.push(metadata);
505        self
506    }
507
508    /// Lists all models that satisfy the requirements.
509    #[must_use]
510    pub fn list_satisfying(&self, requirements: &AgentRequirements) -> Vec<&ModelMetadata> {
511        self.models
512            .iter()
513            .filter(|m| m.satisfies(requirements))
514            .collect()
515    }
516}
517
518impl ModelSelectorTrait for ModelSelector {
519    fn select(&self, requirements: &AgentRequirements) -> Result<(String, String), LlmError> {
520        let mut candidates: Vec<(&ModelMetadata, f64)> = self
521            .models
522            .iter()
523            .filter_map(|m| {
524                if m.satisfies(requirements) {
525                    Some((m, m.fitness_score(requirements)))
526                } else {
527                    None
528                }
529            })
530            .collect();
531
532        if candidates.is_empty() {
533            return Err(LlmError::provider(format!(
534                "No model found satisfying requirements: cost <= {:?}, latency <= {}ms, reasoning = {}, web_search = {}, quality >= {:.2}, data_sovereignty = {:?}, compliance = {:?}, multilingual = {}",
535                requirements.max_cost_class,
536                requirements.max_latency_ms,
537                requirements.requires_reasoning,
538                requirements.requires_web_search,
539                requirements.min_quality,
540                requirements.data_sovereignty,
541                requirements.compliance,
542                requirements.requires_multilingual
543            )));
544        }
545
546        // Sort by fitness score (descending)
547        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
548
549        // Return best match
550        let best = candidates[0].0;
551        Ok((best.provider.clone(), best.model.clone()))
552    }
553}
554
555impl Default for ModelSelector {
556    #[allow(clippy::too_many_lines)] // Default model catalog is comprehensive by design
557    fn default() -> Self {
558        // Default models with realistic metadata
559        Self {
560            models: vec![
561                // Anthropic
562                #[cfg(feature = "anthropic")]
563                ModelMetadata::new(
564                    "anthropic",
565                    "claude-haiku-4-5-20251001",
566                    CostClass::VeryLow,
567                    1200,
568                    0.78,
569                )
570                .with_tool_use(true)
571                .with_vision(true)
572                .with_context_tokens(200_000),
573                #[cfg(feature = "anthropic")]
574                ModelMetadata::new(
575                    "anthropic",
576                    "claude-sonnet-4-6",
577                    CostClass::Low,
578                    2500,
579                    0.93,
580                )
581                .with_reasoning(true)
582                .with_tool_use(true)
583                .with_vision(true)
584                .with_structured_output(true)
585                .with_code(true)
586                .with_context_tokens(200_000),
587                #[cfg(feature = "anthropic")]
588                ModelMetadata::new(
589                    "anthropic",
590                    "claude-opus-4-6",
591                    CostClass::High,
592                    7000,
593                    0.97,
594                )
595                .with_reasoning(true)
596                .with_tool_use(true)
597                .with_vision(true)
598                .with_structured_output(true)
599                .with_code(true)
600                .with_context_tokens(200_000),
601                // OpenAI
602                #[cfg(feature = "openai")]
603                ModelMetadata::new("openai", "gpt-3.5-turbo", CostClass::VeryLow, 1200, 0.70),
604                #[cfg(feature = "openai")]
605                ModelMetadata::new("openai", "gpt-4", CostClass::Medium, 5000, 0.90)
606                    .with_reasoning(true),
607                #[cfg(feature = "openai")]
608                ModelMetadata::new("openai", "gpt-4-turbo", CostClass::Medium, 4000, 0.92)
609                    .with_reasoning(true),
610                #[cfg(feature = "openai")]
611                ModelMetadata::new("openai", "gpt-5.4-mini", CostClass::Low, 2500, 0.95)
612                    .with_reasoning(true)
613                    .with_web_search(true)
614                    .with_multilingual(true)
615                    .with_context_tokens(1_050_000)
616                    .with_tool_use(true)
617                    .with_vision(true)
618                    .with_structured_output(true)
619                    .with_code(true),
620                #[cfg(feature = "openai")]
621                ModelMetadata::new("openai", "gpt-5.4", CostClass::High, 5500, 0.99)
622                    .with_reasoning(true)
623                    .with_web_search(true)
624                    .with_multilingual(true)
625                    .with_context_tokens(1_050_000)
626                    .with_tool_use(true)
627                    .with_vision(true)
628                    .with_structured_output(true)
629                    .with_code(true),
630                #[cfg(feature = "openai")]
631                ModelMetadata::new("openai", "gpt-5.4-pro", CostClass::VeryHigh, 11000, 1.00)
632                    .with_reasoning(true)
633                    .with_web_search(true)
634                    .with_multilingual(true)
635                    .with_context_tokens(1_050_000)
636                    .with_tool_use(true)
637                    .with_vision(true)
638                    .with_code(true),
639                // Google Gemini
640                #[cfg(feature = "gemini")]
641                ModelMetadata::new("gemini", "gemini-pro", CostClass::Low, 2000, 0.80)
642                    .with_tool_use(true)
643                    .with_structured_output(true)
644                    .with_context_tokens(32000),
645                #[cfg(feature = "gemini")]
646                ModelMetadata::new("gemini", "gemini-1.5-flash", CostClass::VeryLow, 800, 0.78)
647                    .with_tool_use(true)
648                    .with_vision(true)
649                    .with_structured_output(true)
650                    .with_multilingual(true)
651                    .with_context_tokens(1_000_000),
652                #[cfg(feature = "gemini")]
653                ModelMetadata::new("gemini", "gemini-1.5-pro", CostClass::Medium, 3000, 0.88)
654                    .with_tool_use(true)
655                    .with_vision(true)
656                    .with_structured_output(true)
657                    .with_code(true)
658                    .with_reasoning(true)
659                    .with_multilingual(true)
660                    .with_context_tokens(2_000_000),
661                #[cfg(feature = "gemini")]
662                ModelMetadata::new("gemini", "gemini-2.0-flash", CostClass::VeryLow, 700, 0.82)
663                    .with_tool_use(true)
664                    .with_vision(true)
665                    .with_structured_output(true)
666                    .with_code(true)
667                    .with_reasoning(true)
668                    .with_multilingual(true)
669                    .with_context_tokens(1_000_000),
670                #[cfg(feature = "gemini")]
671                ModelMetadata::new("gemini", "gemini-2.5-flash", CostClass::VeryLow, 800, 0.82)
672                    .with_tool_use(true)
673                    .with_vision(true)
674                    .with_structured_output(true)
675                    .with_code(true)
676                    .with_reasoning(true)
677                    .with_multilingual(true)
678                    .with_context_tokens(1_000_000),
679                #[cfg(feature = "gemini")]
680                ModelMetadata::new("gemini", "gemini-3-flash-preview", CostClass::VeryLow, 900, 0.90)
681                    .with_tool_use(true)
682                    .with_vision(true)
683                    .with_structured_output(true)
684                    .with_code(true)
685                    .with_reasoning(true)
686                    .with_multilingual(true)
687                    .with_context_tokens(1_050_000),
688                #[cfg(feature = "gemini")]
689                ModelMetadata::new("gemini", "gemini-3-pro", CostClass::Medium, 2500, 0.96)
690                    .with_tool_use(true)
691                    .with_vision(true)
692                    .with_structured_output(true)
693                    .with_code(true)
694                    .with_reasoning(true)
695                    .with_multilingual(true)
696                    .with_context_tokens(2_000_000),
697                // Perplexity (web search)
698                #[cfg(feature = "perplexity")]
699                ModelMetadata::new(
700                    "perplexity",
701                    "pplx-70b-online",
702                    CostClass::Medium,
703                    4000,
704                    0.90,
705                )
706                .with_reasoning(true)
707                .with_web_search(true),
708                #[cfg(feature = "perplexity")]
709                ModelMetadata::new("perplexity", "pplx-7b-online", CostClass::Low, 2500, 0.75)
710                    .with_web_search(true),
711                // Qwen
712                #[cfg(feature = "qwen")]
713                ModelMetadata::new("qwen", "qwen-turbo", CostClass::VeryLow, 1500, 0.70),
714                #[cfg(feature = "qwen")]
715                ModelMetadata::new("qwen", "qwen-plus", CostClass::Low, 2500, 0.80),
716                // OpenRouter (examples - actual models depend on routing)
717                #[cfg(feature = "openai")]
718                ModelMetadata::new(
719                    "openrouter",
720                    "anthropic/claude-haiku-4-5-20251001",
721                    CostClass::VeryLow,
722                    1200,
723                    0.78,
724                ),
725                #[cfg(feature = "openai")]
726                ModelMetadata::new("openrouter", "openai/gpt-4", CostClass::Medium, 5000, 0.90)
727                    .with_reasoning(true),
728                // MinMax
729                #[cfg(feature = "minmax")]
730                ModelMetadata::new("minmax", "abab5.5-chat", CostClass::Low, 2000, 0.75),
731                // Grok
732                #[cfg(feature = "grok")]
733                ModelMetadata::new("grok", "grok-beta", CostClass::Medium, 3000, 0.80),
734                // Mistral
735                #[cfg(feature = "mistral")]
736                ModelMetadata::new(
737                    "mistral",
738                    "mistral-large-latest",
739                    CostClass::Low,
740                    3000,
741                    0.85,
742                )
743                .with_reasoning(true)
744                .with_multilingual(true),
745                #[cfg(feature = "mistral")]
746                ModelMetadata::new(
747                    "mistral",
748                    "mistral-medium-latest",
749                    CostClass::Medium,
750                    4000,
751                    0.88,
752                )
753                .with_reasoning(true)
754                .with_multilingual(true),
755                // DeepSeek
756                #[cfg(feature = "deepseek")]
757                ModelMetadata::new("deepseek", "deepseek-chat", CostClass::VeryLow, 1500, 0.75)
758                    .with_reasoning(true),
759                #[cfg(feature = "deepseek")]
760                ModelMetadata::new("deepseek", "deepseek-r1", CostClass::Low, 3000, 0.85)
761                    .with_reasoning(true),
762                // Baidu ERNIE (China)
763                #[cfg(feature = "baidu")]
764                ModelMetadata::new("baidu", "ernie-bot", CostClass::Low, 2500, 0.80)
765                    .with_data_sovereignty(DataSovereignty::China)
766                    .with_multilingual(true),
767                #[cfg(feature = "baidu")]
768                ModelMetadata::new("baidu", "ernie-bot-turbo", CostClass::VeryLow, 1500, 0.75)
769                    .with_data_sovereignty(DataSovereignty::China)
770                    .with_multilingual(true),
771                // Zhipu GLM (China)
772                #[cfg(feature = "zhipu")]
773                ModelMetadata::new("zhipu", "glm-4", CostClass::Low, 2500, 0.82)
774                    .with_data_sovereignty(DataSovereignty::China)
775                    .with_multilingual(true),
776                #[cfg(feature = "zhipu")]
777                ModelMetadata::new("zhipu", "glm-4.5", CostClass::Medium, 3000, 0.88)
778                    .with_data_sovereignty(DataSovereignty::China)
779                    .with_reasoning(true)
780                    .with_multilingual(true),
781                // Kimi (Moonshot AI)
782                #[cfg(feature = "kimi")]
783                ModelMetadata::new("kimi", "moonshot-v1-8k", CostClass::Low, 2000, 0.80)
784                    .with_multilingual(true),
785                #[cfg(feature = "kimi")]
786                ModelMetadata::new("kimi", "moonshot-v1-32k", CostClass::Medium, 3000, 0.85)
787                    .with_reasoning(true)
788                    .with_multilingual(true),
789                // Apertus (Switzerland, EU digital sovereignty)
790                #[cfg(feature = "apertus")]
791                ModelMetadata::new("apertus", "apertus-v1", CostClass::Medium, 4000, 0.85)
792                    .with_data_sovereignty(DataSovereignty::Switzerland)
793                    .with_compliance(ComplianceLevel::GDPR)
794                    .with_multilingual(true),
795            ],
796        }
797    }
798}
799
800/// Checks if a provider is available (has API key set).
801///
802/// Returns `true` if the environment variable for the provider is set.
803#[must_use]
804pub fn is_provider_available(provider: &str) -> bool {
805    match provider {
806        #[cfg(feature = "anthropic")]
807        "anthropic" => std::env::var("ANTHROPIC_API_KEY").is_ok(),
808        #[cfg(feature = "openai")]
809        "openai" => std::env::var("OPENAI_API_KEY").is_ok(),
810        #[cfg(feature = "gemini")]
811        "gemini" => std::env::var("GEMINI_API_KEY").is_ok(),
812        #[cfg(feature = "perplexity")]
813        "perplexity" => std::env::var("PERPLEXITY_API_KEY").is_ok(),
814        #[cfg(feature = "openai")]
815        "openrouter" => std::env::var("OPENROUTER_API_KEY").is_ok(),
816        #[cfg(feature = "qwen")]
817        "qwen" => std::env::var("QWEN_API_KEY").is_ok(),
818        #[cfg(feature = "minmax")]
819        "minmax" => std::env::var("MINMAX_API_KEY").is_ok(),
820        #[cfg(feature = "grok")]
821        "grok" => std::env::var("GROK_API_KEY").is_ok(),
822        #[cfg(feature = "mistral")]
823        "mistral" => std::env::var("MISTRAL_API_KEY").is_ok(),
824        #[cfg(feature = "deepseek")]
825        "deepseek" => std::env::var("DEEPSEEK_API_KEY").is_ok(),
826        #[cfg(feature = "baidu")]
827        "baidu" => {
828            std::env::var("BAIDU_API_KEY").is_ok() && std::env::var("BAIDU_SECRET_KEY").is_ok()
829        }
830        #[cfg(feature = "zhipu")]
831        "zhipu" => std::env::var("ZHIPU_API_KEY").is_ok(),
832        #[cfg(feature = "kimi")]
833        "kimi" => std::env::var("KIMI_API_KEY").is_ok(),
834        #[cfg(feature = "apertus")]
835        "apertus" => std::env::var("APERTUS_API_KEY").is_ok(),
836        // Search providers
837        #[cfg(feature = "brave")]
838        "brave" => std::env::var("BRAVE_API_KEY").is_ok(),
839        _ => false,
840    }
841}
842
843/// Checks if Brave Search is available.
844#[must_use]
845pub fn is_brave_available() -> bool {
846    #[cfg(feature = "brave")]
847    {
848        is_provider_available("brave")
849    }
850    #[cfg(not(feature = "brave"))]
851    {
852        false
853    }
854}
855
856/// Runtime provider registry that tracks available providers and allows
857/// dynamic metadata updates.
858///
859/// This registry:
860/// 1. Filters models by available providers (based on API keys)
861/// 2. Allows dynamic updates to metadata (pricing, latency, etc.)
862/// 3. Maintains requirements-based selection logic
863#[derive(Debug, Clone)]
864pub struct ProviderRegistry {
865    /// Base selector with all models (static metadata).
866    base_selector: ModelSelector,
867    /// Available providers (checked at runtime).
868    available_providers: std::collections::HashSet<String>,
869    /// Dynamic metadata overrides (updates to pricing, latency, etc.).
870    metadata_overrides: std::collections::HashMap<(String, String), ModelMetadata>,
871}
872
873impl ProviderRegistry {
874    /// Creates a new registry that checks available providers from environment.
875    ///
876    /// Only providers with API keys set will be considered for selection.
877    #[must_use]
878    pub fn from_env() -> Self {
879        let base_selector = ModelSelector::new();
880
881        // Check all known providers (LLMs and search)
882        let known_providers = vec![
883            // LLM providers
884            "anthropic",
885            "openai",
886            "gemini",
887            "perplexity",
888            "openrouter",
889            "qwen",
890            "minmax",
891            "grok",
892            "mistral",
893            "deepseek",
894            "baidu",
895            "zhipu",
896            "kimi",
897            "apertus",
898            // Search providers
899            "brave",
900        ];
901
902        let available_providers: std::collections::HashSet<String> = known_providers
903            .into_iter()
904            .filter(|p| is_provider_available(p))
905            .map(std::string::ToString::to_string)
906            .collect();
907
908        Self {
909            base_selector,
910            available_providers,
911            metadata_overrides: std::collections::HashMap::new(),
912        }
913    }
914
915    /// Creates a registry with explicit provider availability.
916    ///
917    /// Use this when you want to control which providers are available
918    /// programmatically (e.g., from a config file or user input).
919    #[must_use]
920    pub fn with_providers(providers: &[&str]) -> Self {
921        let base_selector = ModelSelector::new();
922        let available_providers: std::collections::HashSet<String> = providers
923            .iter()
924            .map(std::string::ToString::to_string)
925            .collect();
926
927        Self {
928            base_selector,
929            available_providers,
930            metadata_overrides: std::collections::HashMap::new(),
931        }
932    }
933
934    /// Updates metadata for a specific model (e.g., pricing, latency).
935    ///
936    /// This allows dynamic updates to model characteristics without
937    /// rebuilding the entire registry.
938    pub fn update_metadata(
939        &mut self,
940        provider: impl Into<String>,
941        model: impl Into<String>,
942        metadata: ModelMetadata,
943    ) {
944        self.metadata_overrides
945            .insert((provider.into(), model.into()), metadata);
946    }
947
948    /// Lists all available models that satisfy the requirements.
949    #[must_use]
950    pub fn list_available(&self, requirements: &AgentRequirements) -> Vec<&ModelMetadata> {
951        self.base_selector
952            .list_satisfying(requirements)
953            .into_iter()
954            .filter(|m| self.available_providers.contains(&m.provider))
955            .collect()
956    }
957
958    /// Gets the list of available providers.
959    #[must_use]
960    pub fn available_providers(&self) -> Vec<&str> {
961        self.available_providers
962            .iter()
963            .map(std::string::String::as_str)
964            .collect()
965    }
966
967    /// Checks if a provider is available.
968    #[must_use]
969    pub fn is_available(&self, provider: &str) -> bool {
970        self.available_providers.contains(provider)
971    }
972
973    /// Selects the best model with detailed information about the selection process.
974    ///
975    /// Returns a `SelectionResult` containing:
976    /// - The selected model and its fitness breakdown
977    /// - All candidates that were considered (sorted by fitness)
978    /// - Models that were rejected and why
979    ///
980    /// # Errors
981    ///
982    /// Returns error if no model satisfies the requirements.
983    pub fn select_with_details(
984        &self,
985        requirements: &AgentRequirements,
986    ) -> Result<SelectionResult, LlmError> {
987        let mut candidates: Vec<(ModelMetadata, FitnessBreakdown)> = Vec::new();
988        let mut rejected: Vec<(ModelMetadata, RejectionReason)> = Vec::new();
989
990        // Process all models in the base selector
991        for model in &self.base_selector.models {
992            // Check provider availability first
993            if !self.available_providers.contains(&model.provider) {
994                rejected.push((model.clone(), RejectionReason::ProviderUnavailable));
995                continue;
996            }
997
998            // Use override if available
999            let metadata = self
1000                .metadata_overrides
1001                .get(&(model.provider.clone(), model.model.clone()))
1002                .unwrap_or(model);
1003
1004            // Check if model satisfies requirements
1005            if let Some(breakdown) = metadata.fitness_breakdown(requirements) {
1006                candidates.push((metadata.clone(), breakdown));
1007            } else if let Some(reason) = metadata.rejection_reason(requirements) {
1008                rejected.push((metadata.clone(), reason));
1009            }
1010        }
1011
1012        if candidates.is_empty() {
1013            let available = self
1014                .available_providers
1015                .iter()
1016                .map(std::string::String::as_str)
1017                .collect::<Vec<_>>()
1018                .join(", ");
1019            return Err(LlmError::provider(format!(
1020                "No available model found satisfying requirements. Available providers: [{}]",
1021                if available.is_empty() {
1022                    "none (set API keys)".to_string()
1023                } else {
1024                    available
1025                }
1026            )));
1027        }
1028
1029        // Sort by fitness score (descending)
1030        candidates.sort_by(|a, b| {
1031            b.1.total
1032                .partial_cmp(&a.1.total)
1033                .unwrap_or(std::cmp::Ordering::Equal)
1034        });
1035
1036        // Extract the best
1037        let (selected, fitness) = candidates[0].clone();
1038
1039        Ok(SelectionResult {
1040            selected,
1041            fitness,
1042            candidates,
1043            rejected,
1044        })
1045    }
1046}
1047
1048impl ModelSelectorTrait for ProviderRegistry {
1049    fn select(&self, requirements: &AgentRequirements) -> Result<(String, String), LlmError> {
1050        // Get all models that satisfy requirements
1051        let all_candidates = self.base_selector.list_satisfying(requirements);
1052
1053        // Filter by available providers and apply overrides
1054        let mut candidates: Vec<(&ModelMetadata, f64)> = all_candidates
1055            .iter()
1056            .filter(|m| self.available_providers.contains(&m.provider))
1057            .map(|m| {
1058                // Use override if available, otherwise use base metadata
1059                let metadata = self
1060                    .metadata_overrides
1061                    .get(&(m.provider.clone(), m.model.clone()))
1062                    .unwrap_or(m);
1063                (metadata, metadata.fitness_score(requirements))
1064            })
1065            .collect();
1066
1067        if candidates.is_empty() {
1068            let available = self
1069                .available_providers
1070                .iter()
1071                .map(std::string::String::as_str)
1072                .collect::<Vec<_>>()
1073                .join(", ");
1074            return Err(LlmError::provider(format!(
1075                "No available model found satisfying requirements. Available providers: [{}]",
1076                if available.is_empty() {
1077                    "none (set API keys)".to_string()
1078                } else {
1079                    available
1080                }
1081            )));
1082        }
1083
1084        // Sort by fitness score (descending)
1085        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1086
1087        // Return best match
1088        let best = candidates[0].0;
1089        Ok((best.provider.clone(), best.model.clone()))
1090    }
1091}
1092
1093impl Default for ProviderRegistry {
1094    fn default() -> Self {
1095        Self::from_env()
1096    }
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102    use converge_traits::CostClass;
1103
1104    #[test]
1105    fn test_provider_availability_check() {
1106        // This test depends on environment, so we just check the function exists
1107        let _ = is_provider_available("anthropic");
1108    }
1109
1110    #[test]
1111    fn test_registry_with_explicit_providers() {
1112        let registry = ProviderRegistry::with_providers(&["anthropic", "openai"]);
1113        assert!(registry.is_available("anthropic"));
1114        assert!(registry.is_available("openai"));
1115        assert!(!registry.is_available("gemini"));
1116    }
1117
1118    #[test]
1119    fn test_metadata_override() {
1120        let mut registry = ProviderRegistry::with_providers(&["anthropic"]);
1121
1122        // Override latency for a model
1123        let updated = ModelMetadata::new(
1124            "anthropic",
1125            "claude-haiku-4-5-20251001",
1126            CostClass::VeryLow,
1127            1000, // Updated latency
1128            0.78,
1129        );
1130        registry.update_metadata("anthropic", "claude-haiku-4-5-20251001", updated);
1131
1132        let reqs = AgentRequirements::fast_cheap();
1133        let result = registry.select(&reqs);
1134        assert!(result.is_ok());
1135    }
1136
1137    #[test]
1138    fn test_model_selection() {
1139        let selector = ModelSelector::new();
1140        let reqs = AgentRequirements::fast_cheap();
1141
1142        let (provider, model) = selector.select(&reqs).unwrap();
1143        // Should select a VeryLow cost, fast model
1144        assert!(
1145            provider == "anthropic"
1146                || provider == "openai"
1147                || provider == "gemini"
1148                || provider == "qwen"
1149        );
1150        assert!(
1151            model.contains("haiku")
1152                || model.contains("flash")
1153                || model.contains("turbo")
1154                || model.contains("qwen")
1155        );
1156    }
1157
1158    #[test]
1159    fn test_selection_requires_reasoning_and_web_search() {
1160        let selector = ModelSelector::empty()
1161            .with_model(ModelMetadata::new(
1162                "alpha",
1163                "basic",
1164                CostClass::Low,
1165                1200,
1166                0.85,
1167            ))
1168            .with_model(
1169                ModelMetadata::new("beta", "reasoning-only", CostClass::Low, 1400, 0.88)
1170                    .with_reasoning(true),
1171            )
1172            .with_model(
1173                ModelMetadata::new("gamma", "reasoning-search", CostClass::Low, 1500, 0.87)
1174                    .with_reasoning(true)
1175                    .with_web_search(true),
1176            );
1177
1178        let reqs = AgentRequirements::new(CostClass::Low, 5000, true).with_web_search(true);
1179        let (provider, model) = selector.select(&reqs).unwrap();
1180        assert_eq!(provider, "gamma");
1181        assert_eq!(model, "reasoning-search");
1182    }
1183
1184    #[test]
1185    fn test_selection_respects_data_sovereignty_and_compliance() {
1186        let selector = ModelSelector::empty()
1187            .with_model(
1188                ModelMetadata::new("us", "us-model", CostClass::Low, 1500, 0.85)
1189                    .with_data_sovereignty(DataSovereignty::US),
1190            )
1191            .with_model(
1192                ModelMetadata::new("eu", "eu-gdpr", CostClass::Low, 1800, 0.86)
1193                    .with_data_sovereignty(DataSovereignty::EU)
1194                    .with_compliance(ComplianceLevel::GDPR),
1195            );
1196
1197        let reqs = AgentRequirements::balanced()
1198            .with_data_sovereignty(DataSovereignty::EU)
1199            .with_compliance(ComplianceLevel::GDPR);
1200        let (provider, model) = selector.select(&reqs).unwrap();
1201        assert_eq!(provider, "eu");
1202        assert_eq!(model, "eu-gdpr");
1203    }
1204
1205    #[test]
1206    fn test_selection_requires_multilingual() {
1207        let selector = ModelSelector::empty()
1208            .with_model(
1209                ModelMetadata::new("mono", "fast", CostClass::VeryLow, 800, 0.80)
1210                    .with_multilingual(false),
1211            )
1212            .with_model(
1213                ModelMetadata::new("multi", "polyglot", CostClass::Low, 1200, 0.82)
1214                    .with_multilingual(true),
1215            );
1216
1217        let reqs = AgentRequirements::new(CostClass::Low, 2000, false).with_multilingual(true);
1218        let (provider, model) = selector.select(&reqs).unwrap();
1219        assert_eq!(provider, "multi");
1220        assert_eq!(model, "polyglot");
1221    }
1222}