locus_sdk/application/
routing_config.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::domain::ai::{AiTask, EmbedRequest, ScoreAvecRequest};
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
8pub struct ProviderModelProfile {
9 pub semantic_model: Option<String>,
10 pub avec_embedding_model: Option<String>,
11 pub avec_scoring_model: Option<String>,
12}
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct AiRoutingConfig {
16 pub default_provider_id: Option<String>,
17 pub providers: HashMap<String, ProviderModelProfile>,
18}
19
20impl AiRoutingConfig {
21 pub fn profile_for(&self, provider_id: &str) -> Option<&ProviderModelProfile> {
22 self.providers.get(provider_id)
23 }
24
25 pub fn provider_for(&self, requested_provider_id: Option<&str>) -> Option<String> {
26 requested_provider_id
27 .filter(|value| !value.trim().is_empty())
28 .map(ToString::to_string)
29 .or_else(|| self.default_provider_id.clone())
30 }
31
32 pub fn model_for(&self, provider_id: Option<&str>, task: AiTask) -> Option<&str> {
33 let provider_id = self.provider_for(provider_id)?;
34 let profile = self.providers.get(&provider_id)?;
35
36 match task {
37 AiTask::SemanticEmbedding => profile.semantic_model.as_deref(),
38 AiTask::AvecEmbedding => profile.avec_embedding_model.as_deref(),
39 AiTask::AvecScoring => profile.avec_scoring_model.as_deref(),
40 }
41 }
42
43 pub fn apply_to_embed_request(&self, request: &EmbedRequest) -> EmbedRequest {
44 let provider_id = request
45 .provider_id
46 .as_deref()
47 .map(ToString::to_string)
48 .or_else(|| self.default_provider_id.clone());
49 let model = request.model.clone().or_else(|| {
50 self.model_for(provider_id.as_deref(), request.task)
51 .map(ToString::to_string)
52 });
53
54 EmbedRequest {
55 text: request.text.clone(),
56 task: request.task,
57 provider_id,
58 model,
59 policy: request.policy,
60 }
61 }
62
63 pub fn apply_to_score_request(&self, request: &ScoreAvecRequest) -> ScoreAvecRequest {
64 let provider_id = request
65 .provider_id
66 .as_deref()
67 .map(ToString::to_string)
68 .or_else(|| self.default_provider_id.clone());
69 let model = request.model.clone().or_else(|| {
70 self.model_for(provider_id.as_deref(), AiTask::AvecScoring)
71 .map(ToString::to_string)
72 });
73
74 ScoreAvecRequest {
75 text: request.text.clone(),
76 provider_id,
77 model,
78 policy: request.policy,
79 }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::{AiRoutingConfig, ProviderModelProfile};
86 use crate::domain::ai::{AiTask, EmbedRequest, ProviderPolicy, ScoreAvecRequest};
87
88 fn fixture_config() -> AiRoutingConfig {
89 let mut config = AiRoutingConfig {
90 default_provider_id: Some("genai".to_string()),
91 providers: std::collections::HashMap::new(),
92 };
93
94 config.providers.insert(
95 "genai".to_string(),
96 ProviderModelProfile {
97 semantic_model: Some("text-embedding-3-small".to_string()),
98 avec_embedding_model: Some("text-embedding-3-large".to_string()),
99 avec_scoring_model: Some("gpt-4o-mini".to_string()),
100 },
101 );
102
103 config
104 }
105
106 #[test]
107 fn apply_to_embed_request_fills_default_provider_and_model() {
108 let config = fixture_config();
109 let request = EmbedRequest {
110 text: "hello".to_string(),
111 task: AiTask::SemanticEmbedding,
112 provider_id: None,
113 model: None,
114 policy: ProviderPolicy::Auto,
115 };
116
117 let resolved = config.apply_to_embed_request(&request);
118
119 assert_eq!(resolved.provider_id.as_deref(), Some("genai"));
120 assert_eq!(
121 resolved.model.as_deref(),
122 Some("text-embedding-3-small")
123 );
124 }
125
126 #[test]
127 fn apply_to_embed_request_keeps_explicit_model() {
128 let config = fixture_config();
129 let request = EmbedRequest {
130 text: "hello".to_string(),
131 task: AiTask::AvecEmbedding,
132 provider_id: Some("genai".to_string()),
133 model: Some("my-custom-model".to_string()),
134 policy: ProviderPolicy::Preferred,
135 };
136
137 let resolved = config.apply_to_embed_request(&request);
138
139 assert_eq!(resolved.model.as_deref(), Some("my-custom-model"));
140 }
141
142 #[test]
143 fn apply_to_score_request_resolves_scoring_model() {
144 let config = fixture_config();
145 let request = ScoreAvecRequest {
146 text: "score this".to_string(),
147 provider_id: None,
148 model: None,
149 policy: ProviderPolicy::Auto,
150 };
151
152 let resolved = config.apply_to_score_request(&request);
153
154 assert_eq!(resolved.provider_id.as_deref(), Some("genai"));
155 assert_eq!(resolved.model.as_deref(), Some("gpt-4o-mini"));
156 }
157}