converge_core/model_selection.rs
1// Copyright 2024-2025 Aprio One AB, Sweden
2// Author: Kenneth Pernyer, kenneth@aprio.one
3// SPDX-License-Identifier: LicenseRef-Proprietary
4// All rights reserved. This source code is proprietary and confidential.
5// Unauthorized copying, modification, or distribution is strictly prohibited.
6
7//! Model selection based on agent requirements.
8//!
9//! This module provides the **abstract interface** for model selection.
10//! Concrete implementations (with provider-specific metadata) are in `converge-provider`.
11//!
12//! # Architecture
13//!
14//! - **Core (this module)**: Abstract requirements and selection trait
15//! - **Provider crate**: Concrete selector with all provider metadata
16//!
17//! This separation ensures core remains provider-agnostic while allowing
18//! injection of provider-specific selection logic.
19
20use crate::llm::LlmError;
21
22/// Cost classification for model selection.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
24pub enum CostClass {
25 /// Very low cost (e.g., Haiku, GPT-3.5 Turbo, Gemini Flash)
26 VeryLow,
27 /// Low cost (e.g., Sonnet, GPT-4 Turbo)
28 Low,
29 /// Medium cost (e.g., Opus, GPT-4)
30 Medium,
31 /// High cost (e.g., Opus-4, GPT-4o)
32 High,
33 /// Very high cost (e.g., specialized models)
34 VeryHigh,
35}
36
37impl CostClass {
38 /// Returns the maximum cost class that satisfies this requirement.
39 ///
40 /// For example, `CostClass::Low` allows VeryLow and Low.
41 #[must_use]
42 pub fn allowed_classes(self) -> Vec<CostClass> {
43 match self {
44 Self::VeryLow => vec![Self::VeryLow],
45 Self::Low => vec![Self::VeryLow, Self::Low],
46 Self::Medium => vec![Self::VeryLow, Self::Low, Self::Medium],
47 Self::High => vec![Self::VeryLow, Self::Low, Self::Medium, Self::High],
48 Self::VeryHigh => vec![
49 Self::VeryLow,
50 Self::Low,
51 Self::Medium,
52 Self::High,
53 Self::VeryHigh,
54 ],
55 }
56 }
57}
58
59/// Data sovereignty requirements.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
61pub enum DataSovereignty {
62 /// No specific requirements (default).
63 Any,
64 /// Data must remain in EU/EEA.
65 EU,
66 /// Data must remain in Switzerland.
67 Switzerland,
68 /// Data must remain in China.
69 China,
70 /// Data must remain in US.
71 US,
72 /// Self-hosted or on-premises.
73 OnPremises,
74}
75
76/// Compliance and explainability requirements.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
78pub enum ComplianceLevel {
79 /// No specific compliance requirements (default).
80 None,
81 /// GDPR compliance required.
82 GDPR,
83 /// SOC 2 compliance required.
84 SOC2,
85 /// HIPAA compliance required.
86 HIPAA,
87 /// High explainability (audit trails, provenance).
88 HighExplainability,
89}
90
91/// Requirements for an agent's LLM usage.
92///
93/// Agents specify their requirements, and the model selector
94/// finds the best matching model.
95#[derive(Debug, Clone, PartialEq)]
96pub struct AgentRequirements {
97 /// Maximum acceptable cost class.
98 pub max_cost_class: CostClass,
99 /// Maximum acceptable latency in milliseconds.
100 pub max_latency_ms: u32,
101 /// Whether the agent requires advanced reasoning capabilities.
102 pub requires_reasoning: bool,
103 /// Whether the agent requires web search capabilities.
104 pub requires_web_search: bool,
105 /// Minimum quality threshold (0.0-1.0).
106 /// Higher values prefer more capable models.
107 pub min_quality: f64,
108 /// Data sovereignty requirements.
109 pub data_sovereignty: DataSovereignty,
110 /// Compliance and explainability requirements.
111 pub compliance: ComplianceLevel,
112 /// Whether the agent requires multi-language support.
113 pub requires_multilingual: bool,
114}
115
116impl AgentRequirements {
117 /// Creates requirements for a fast, cheap agent (many instances).
118 ///
119 /// Use case: High-volume agents that need quick, cost-effective responses.
120 #[must_use]
121 pub fn fast_cheap() -> Self {
122 Self {
123 max_cost_class: CostClass::VeryLow,
124 max_latency_ms: 2000,
125 requires_reasoning: false,
126 requires_web_search: false,
127 min_quality: 0.6,
128 data_sovereignty: DataSovereignty::Any,
129 compliance: ComplianceLevel::None,
130 requires_multilingual: false,
131 }
132 }
133
134 /// Creates requirements for a deep research agent.
135 ///
136 /// Use case: Agents that need thorough analysis and reasoning.
137 #[must_use]
138 pub fn deep_research() -> Self {
139 Self {
140 max_cost_class: CostClass::High,
141 max_latency_ms: 30000, // 30 seconds
142 requires_reasoning: true,
143 requires_web_search: true,
144 min_quality: 0.9,
145 data_sovereignty: DataSovereignty::Any,
146 compliance: ComplianceLevel::None,
147 requires_multilingual: false,
148 }
149 }
150
151 /// Creates requirements for a balanced agent.
152 ///
153 /// Use case: General-purpose agents with moderate requirements.
154 #[must_use]
155 pub fn balanced() -> Self {
156 Self {
157 max_cost_class: CostClass::Medium,
158 max_latency_ms: 5000,
159 requires_reasoning: false,
160 requires_web_search: false,
161 min_quality: 0.7,
162 data_sovereignty: DataSovereignty::Any,
163 compliance: ComplianceLevel::None,
164 requires_multilingual: false,
165 }
166 }
167
168 /// Creates custom requirements.
169 #[must_use]
170 pub fn new(
171 max_cost_class: CostClass,
172 max_latency_ms: u32,
173 requires_reasoning: bool,
174 ) -> Self {
175 Self {
176 max_cost_class,
177 max_latency_ms,
178 requires_reasoning,
179 requires_web_search: false,
180 min_quality: 0.7,
181 data_sovereignty: DataSovereignty::Any,
182 compliance: ComplianceLevel::None,
183 requires_multilingual: false,
184 }
185 }
186
187 /// Sets web search requirement.
188 #[must_use]
189 pub fn with_web_search(mut self, requires: bool) -> Self {
190 self.requires_web_search = requires;
191 self
192 }
193
194 /// Sets minimum quality threshold.
195 #[must_use]
196 pub fn with_min_quality(mut self, quality: f64) -> Self {
197 self.min_quality = quality.clamp(0.0, 1.0);
198 self
199 }
200
201 /// Sets data sovereignty requirement.
202 #[must_use]
203 pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
204 self.data_sovereignty = sovereignty;
205 self
206 }
207
208 /// Sets compliance requirement.
209 #[must_use]
210 pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
211 self.compliance = compliance;
212 self
213 }
214
215 /// Sets multilingual requirement.
216 #[must_use]
217 pub fn with_multilingual(mut self, requires: bool) -> Self {
218 self.requires_multilingual = requires;
219 self
220 }
221}
222
223// ModelMetadata and ModelSelector implementations are in converge-provider.
224// Core only provides the abstract interface (ModelSelectorTrait) and requirements.
225
226/// Trait for model selection based on agent requirements.
227///
228/// This trait allows injecting provider-specific model selection logic
229/// without coupling core to concrete providers.
230///
231/// Concrete implementations (with provider metadata) are in `converge-provider`.
232pub trait ModelSelectorTrait: Send + Sync {
233 /// Selects the best model for the given requirements.
234 ///
235 /// Returns `(provider, model)` if a suitable model is found.
236 ///
237 /// # Errors
238 ///
239 /// Returns error if no model satisfies the requirements.
240 fn select(
241 &self,
242 requirements: &AgentRequirements,
243 ) -> Result<(String, String), LlmError>;
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_fast_cheap_requirements() {
252 let reqs = AgentRequirements::fast_cheap();
253 assert_eq!(reqs.max_cost_class, CostClass::VeryLow);
254 assert_eq!(reqs.max_latency_ms, 2000);
255 assert!(!reqs.requires_reasoning);
256 }
257
258 #[test]
259 fn test_deep_research_requirements() {
260 let reqs = AgentRequirements::deep_research();
261 assert!(reqs.max_cost_class >= CostClass::High);
262 assert!(reqs.requires_reasoning);
263 assert!(reqs.requires_web_search);
264 }
265
266 #[test]
267 fn test_cost_class_allowed() {
268 assert_eq!(CostClass::VeryLow.allowed_classes().len(), 1);
269 assert_eq!(CostClass::Low.allowed_classes().len(), 2);
270 assert_eq!(CostClass::VeryHigh.allowed_classes().len(), 5);
271 }
272}
273