Skip to main content

locus_sdk/application/
routing_config.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::domain::ai::{AiTask, EmbedRequest, ScoreAvecRequest};
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
8pub struct ProviderModelProfile {
9    pub semantic_model: Option<String>,
10    pub avec_embedding_model: Option<String>,
11    pub avec_scoring_model: Option<String>,
12}
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct AiRoutingConfig {
16    pub default_provider_id: Option<String>,
17    pub providers: HashMap<String, ProviderModelProfile>,
18}
19
20impl AiRoutingConfig {
21    pub fn profile_for(&self, provider_id: &str) -> Option<&ProviderModelProfile> {
22        self.providers.get(provider_id)
23    }
24
25    pub fn provider_for(&self, requested_provider_id: Option<&str>) -> Option<String> {
26        requested_provider_id
27            .filter(|value| !value.trim().is_empty())
28            .map(ToString::to_string)
29            .or_else(|| self.default_provider_id.clone())
30    }
31
32    pub fn model_for(&self, provider_id: Option<&str>, task: AiTask) -> Option<&str> {
33        let provider_id = self.provider_for(provider_id)?;
34        let profile = self.providers.get(&provider_id)?;
35
36        match task {
37            AiTask::SemanticEmbedding => profile.semantic_model.as_deref(),
38            AiTask::AvecEmbedding => profile.avec_embedding_model.as_deref(),
39            AiTask::AvecScoring => profile.avec_scoring_model.as_deref(),
40        }
41    }
42
43    pub fn apply_to_embed_request(&self, request: &EmbedRequest) -> EmbedRequest {
44        let provider_id = request
45            .provider_id
46            .as_deref()
47            .map(ToString::to_string)
48            .or_else(|| self.default_provider_id.clone());
49        let model = request.model.clone().or_else(|| {
50            self.model_for(provider_id.as_deref(), request.task)
51                .map(ToString::to_string)
52        });
53
54        EmbedRequest {
55            text: request.text.clone(),
56            task: request.task,
57            provider_id,
58            model,
59            policy: request.policy,
60        }
61    }
62
63    pub fn apply_to_score_request(&self, request: &ScoreAvecRequest) -> ScoreAvecRequest {
64        let provider_id = request
65            .provider_id
66            .as_deref()
67            .map(ToString::to_string)
68            .or_else(|| self.default_provider_id.clone());
69        let model = request.model.clone().or_else(|| {
70            self.model_for(provider_id.as_deref(), AiTask::AvecScoring)
71                .map(ToString::to_string)
72        });
73
74        ScoreAvecRequest {
75            text: request.text.clone(),
76            provider_id,
77            model,
78            policy: request.policy,
79        }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::{AiRoutingConfig, ProviderModelProfile};
86    use crate::domain::ai::{AiTask, EmbedRequest, ProviderPolicy, ScoreAvecRequest};
87
88    fn fixture_config() -> AiRoutingConfig {
89        let mut config = AiRoutingConfig {
90            default_provider_id: Some("genai".to_string()),
91            providers: std::collections::HashMap::new(),
92        };
93
94        config.providers.insert(
95            "genai".to_string(),
96            ProviderModelProfile {
97                semantic_model: Some("text-embedding-3-small".to_string()),
98                avec_embedding_model: Some("text-embedding-3-large".to_string()),
99                avec_scoring_model: Some("gpt-4o-mini".to_string()),
100            },
101        );
102
103        config
104    }
105
106    #[test]
107    fn apply_to_embed_request_fills_default_provider_and_model() {
108        let config = fixture_config();
109        let request = EmbedRequest {
110            text: "hello".to_string(),
111            task: AiTask::SemanticEmbedding,
112            provider_id: None,
113            model: None,
114            policy: ProviderPolicy::Auto,
115        };
116
117        let resolved = config.apply_to_embed_request(&request);
118
119        assert_eq!(resolved.provider_id.as_deref(), Some("genai"));
120        assert_eq!(
121            resolved.model.as_deref(),
122            Some("text-embedding-3-small")
123        );
124    }
125
126    #[test]
127    fn apply_to_embed_request_keeps_explicit_model() {
128        let config = fixture_config();
129        let request = EmbedRequest {
130            text: "hello".to_string(),
131            task: AiTask::AvecEmbedding,
132            provider_id: Some("genai".to_string()),
133            model: Some("my-custom-model".to_string()),
134            policy: ProviderPolicy::Preferred,
135        };
136
137        let resolved = config.apply_to_embed_request(&request);
138
139        assert_eq!(resolved.model.as_deref(), Some("my-custom-model"));
140    }
141
142    #[test]
143    fn apply_to_score_request_resolves_scoring_model() {
144        let config = fixture_config();
145        let request = ScoreAvecRequest {
146            text: "score this".to_string(),
147            provider_id: None,
148            model: None,
149            policy: ProviderPolicy::Auto,
150        };
151
152        let resolved = config.apply_to_score_request(&request);
153
154        assert_eq!(resolved.provider_id.as_deref(), Some("genai"));
155        assert_eq!(resolved.model.as_deref(), Some("gpt-4o-mini"));
156    }
157}