converge_core/
model_selection.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// Author: Kenneth Pernyer, kenneth@aprio.one
3// SPDX-License-Identifier: LicenseRef-Proprietary
4// All rights reserved. This source code is proprietary and confidential.
5// Unauthorized copying, modification, or distribution is strictly prohibited.
6
7//! Model selection based on agent requirements.
8//!
9//! This module provides the **abstract interface** for model selection.
10//! Concrete implementations (with provider-specific metadata) are in `converge-provider`.
11//!
12//! # Architecture
13//!
14//! - **Core (this module)**: Abstract requirements and selection trait
15//! - **Provider crate**: Concrete selector with all provider metadata
16//!
17//! This separation ensures core remains provider-agnostic while allowing
18//! injection of provider-specific selection logic.
19
20use crate::llm::LlmError;
21
22/// Cost classification for model selection.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
24pub enum CostClass {
25    /// Very low cost (e.g., Haiku, GPT-3.5 Turbo, Gemini Flash)
26    VeryLow,
27    /// Low cost (e.g., Sonnet, GPT-4 Turbo)
28    Low,
29    /// Medium cost (e.g., Opus, GPT-4)
30    Medium,
31    /// High cost (e.g., Opus-4, GPT-4o)
32    High,
33    /// Very high cost (e.g., specialized models)
34    VeryHigh,
35}
36
37impl CostClass {
38    /// Returns the maximum cost class that satisfies this requirement.
39    ///
40    /// For example, `CostClass::Low` allows VeryLow and Low.
41    #[must_use]
42    pub fn allowed_classes(self) -> Vec<CostClass> {
43        match self {
44            Self::VeryLow => vec![Self::VeryLow],
45            Self::Low => vec![Self::VeryLow, Self::Low],
46            Self::Medium => vec![Self::VeryLow, Self::Low, Self::Medium],
47            Self::High => vec![Self::VeryLow, Self::Low, Self::Medium, Self::High],
48            Self::VeryHigh => vec![
49                Self::VeryLow,
50                Self::Low,
51                Self::Medium,
52                Self::High,
53                Self::VeryHigh,
54            ],
55        }
56    }
57}
58
59/// Data sovereignty requirements.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
61pub enum DataSovereignty {
62    /// No specific requirements (default).
63    Any,
64    /// Data must remain in EU/EEA.
65    EU,
66    /// Data must remain in Switzerland.
67    Switzerland,
68    /// Data must remain in China.
69    China,
70    /// Data must remain in US.
71    US,
72    /// Self-hosted or on-premises.
73    OnPremises,
74}
75
76/// Compliance and explainability requirements.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
78pub enum ComplianceLevel {
79    /// No specific compliance requirements (default).
80    None,
81    /// GDPR compliance required.
82    GDPR,
83    /// SOC 2 compliance required.
84    SOC2,
85    /// HIPAA compliance required.
86    HIPAA,
87    /// High explainability (audit trails, provenance).
88    HighExplainability,
89}
90
91/// Requirements for an agent's LLM usage.
92///
93/// Agents specify their requirements, and the model selector
94/// finds the best matching model.
95#[derive(Debug, Clone, PartialEq)]
96pub struct AgentRequirements {
97    /// Maximum acceptable cost class.
98    pub max_cost_class: CostClass,
99    /// Maximum acceptable latency in milliseconds.
100    pub max_latency_ms: u32,
101    /// Whether the agent requires advanced reasoning capabilities.
102    pub requires_reasoning: bool,
103    /// Whether the agent requires web search capabilities.
104    pub requires_web_search: bool,
105    /// Minimum quality threshold (0.0-1.0).
106    /// Higher values prefer more capable models.
107    pub min_quality: f64,
108    /// Data sovereignty requirements.
109    pub data_sovereignty: DataSovereignty,
110    /// Compliance and explainability requirements.
111    pub compliance: ComplianceLevel,
112    /// Whether the agent requires multi-language support.
113    pub requires_multilingual: bool,
114}
115
116impl AgentRequirements {
117    /// Creates requirements for a fast, cheap agent (many instances).
118    ///
119    /// Use case: High-volume agents that need quick, cost-effective responses.
120    #[must_use]
121    pub fn fast_cheap() -> Self {
122        Self {
123            max_cost_class: CostClass::VeryLow,
124            max_latency_ms: 2000,
125            requires_reasoning: false,
126            requires_web_search: false,
127            min_quality: 0.6,
128            data_sovereignty: DataSovereignty::Any,
129            compliance: ComplianceLevel::None,
130            requires_multilingual: false,
131        }
132    }
133
134    /// Creates requirements for a deep research agent.
135    ///
136    /// Use case: Agents that need thorough analysis and reasoning.
137    #[must_use]
138    pub fn deep_research() -> Self {
139        Self {
140            max_cost_class: CostClass::High,
141            max_latency_ms: 30000, // 30 seconds
142            requires_reasoning: true,
143            requires_web_search: true,
144            min_quality: 0.9,
145            data_sovereignty: DataSovereignty::Any,
146            compliance: ComplianceLevel::None,
147            requires_multilingual: false,
148        }
149    }
150
151    /// Creates requirements for a balanced agent.
152    ///
153    /// Use case: General-purpose agents with moderate requirements.
154    #[must_use]
155    pub fn balanced() -> Self {
156        Self {
157            max_cost_class: CostClass::Medium,
158            max_latency_ms: 5000,
159            requires_reasoning: false,
160            requires_web_search: false,
161            min_quality: 0.7,
162            data_sovereignty: DataSovereignty::Any,
163            compliance: ComplianceLevel::None,
164            requires_multilingual: false,
165        }
166    }
167
168    /// Creates custom requirements.
169    #[must_use]
170    pub fn new(
171        max_cost_class: CostClass,
172        max_latency_ms: u32,
173        requires_reasoning: bool,
174    ) -> Self {
175        Self {
176            max_cost_class,
177            max_latency_ms,
178            requires_reasoning,
179            requires_web_search: false,
180            min_quality: 0.7,
181            data_sovereignty: DataSovereignty::Any,
182            compliance: ComplianceLevel::None,
183            requires_multilingual: false,
184        }
185    }
186
187    /// Sets web search requirement.
188    #[must_use]
189    pub fn with_web_search(mut self, requires: bool) -> Self {
190        self.requires_web_search = requires;
191        self
192    }
193
194    /// Sets minimum quality threshold.
195    #[must_use]
196    pub fn with_min_quality(mut self, quality: f64) -> Self {
197        self.min_quality = quality.clamp(0.0, 1.0);
198        self
199    }
200
201    /// Sets data sovereignty requirement.
202    #[must_use]
203    pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
204        self.data_sovereignty = sovereignty;
205        self
206    }
207
208    /// Sets compliance requirement.
209    #[must_use]
210    pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
211        self.compliance = compliance;
212        self
213    }
214
215    /// Sets multilingual requirement.
216    #[must_use]
217    pub fn with_multilingual(mut self, requires: bool) -> Self {
218        self.requires_multilingual = requires;
219        self
220    }
221}
222
223// ModelMetadata and ModelSelector implementations are in converge-provider.
224// Core only provides the abstract interface (ModelSelectorTrait) and requirements.
225
226/// Trait for model selection based on agent requirements.
227///
228/// This trait allows injecting provider-specific model selection logic
229/// without coupling core to concrete providers.
230///
231/// Concrete implementations (with provider metadata) are in `converge-provider`.
232pub trait ModelSelectorTrait: Send + Sync {
233    /// Selects the best model for the given requirements.
234    ///
235    /// Returns `(provider, model)` if a suitable model is found.
236    ///
237    /// # Errors
238    ///
239    /// Returns error if no model satisfies the requirements.
240    fn select(
241        &self,
242        requirements: &AgentRequirements,
243    ) -> Result<(String, String), LlmError>;
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_fast_cheap_requirements() {
252        let reqs = AgentRequirements::fast_cheap();
253        assert_eq!(reqs.max_cost_class, CostClass::VeryLow);
254        assert_eq!(reqs.max_latency_ms, 2000);
255        assert!(!reqs.requires_reasoning);
256    }
257
258    #[test]
259    fn test_deep_research_requirements() {
260        let reqs = AgentRequirements::deep_research();
261        assert!(reqs.max_cost_class >= CostClass::High);
262        assert!(reqs.requires_reasoning);
263        assert!(reqs.requires_web_search);
264    }
265
266    #[test]
267    fn test_cost_class_allowed() {
268        assert_eq!(CostClass::VeryLow.allowed_classes().len(), 1);
269        assert_eq!(CostClass::Low.allowed_classes().len(), 2);
270        assert_eq!(CostClass::VeryHigh.allowed_classes().len(), 5);
271    }
272}
273