Skip to main content

ck_models/
lib.rs

1use anyhow::{Result, anyhow};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelConfig {
8    pub name: String,
9    pub provider: String,
10    pub dimensions: usize,
11    pub max_tokens: usize,
12    pub description: String,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelRegistry {
17    pub models: HashMap<String, ModelConfig>,
18    pub default_model: String,
19}
20
21impl Default for ModelRegistry {
22    fn default() -> Self {
23        let mut models = HashMap::new();
24
25        models.insert(
26            "bge-small".to_string(),
27            ModelConfig {
28                name: "BAAI/bge-small-en-v1.5".to_string(),
29                provider: "fastembed".to_string(),
30                dimensions: 384,
31                max_tokens: 512,
32                description: "Small, fast English embedding model".to_string(),
33            },
34        );
35
36        models.insert(
37            "minilm".to_string(),
38            ModelConfig {
39                name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
40                provider: "fastembed".to_string(),
41                dimensions: 384,
42                max_tokens: 256,
43                description: "Lightweight English embedding model".to_string(),
44            },
45        );
46
47        // Add enhanced models
48        models.insert(
49            "nomic-v1.5".to_string(),
50            ModelConfig {
51                name: "nomic-embed-text-v1.5".to_string(),
52                provider: "fastembed".to_string(),
53                dimensions: 768,
54                max_tokens: 8192,
55                description: "High-quality English embedding model with large context window"
56                    .to_string(),
57            },
58        );
59
60        models.insert(
61            "jina-code".to_string(),
62            ModelConfig {
63                name: "jina-embeddings-v2-base-code".to_string(),
64                provider: "fastembed".to_string(),
65                dimensions: 768,
66                max_tokens: 8192,
67                description: "Code-specific embedding model optimized for programming tasks"
68                    .to_string(),
69            },
70        );
71
72        models.insert(
73            "mxbai-xsmall".to_string(),
74            ModelConfig {
75                name: "mixedbread-ai/mxbai-embed-xsmall-v1".to_string(),
76                provider: "mixedbread".to_string(),
77                dimensions: 384,
78                max_tokens: 4096,
79                description: "Mixedbread xsmall embedding model (4k context, 384 dims) optimized for local semantic search".to_string(),
80            },
81        );
82
83        Self {
84            models,
85            default_model: "bge-small".to_string(), // Keep BGE as default for backward compatibility
86        }
87    }
88}
89
90impl ModelRegistry {
91    fn format_available_models(&self) -> String {
92        self.models.keys().cloned().collect::<Vec<_>>().join(", ")
93    }
94
95    fn resolve_alias_or_name(&self, key: &str) -> Option<(String, &ModelConfig)> {
96        if let Some(config) = self.models.get(key) {
97            return Some((key.to_string(), config));
98        }
99
100        self.models
101            .iter()
102            .find(|(_, config)| config.name == key)
103            .map(|(alias, config)| (alias.clone(), config))
104    }
105
106    pub fn resolve(&self, requested: Option<&str>) -> Result<(String, ModelConfig)> {
107        match requested {
108            Some(name) => {
109                let (alias, config) = self.resolve_alias_or_name(name).ok_or_else(|| {
110                    anyhow!(
111                        "Unknown model '{}'. Available models: {}",
112                        name,
113                        self.format_available_models()
114                    )
115                })?;
116                Ok((alias, config.clone()))
117            }
118            None => {
119                let alias = self.default_model.clone();
120                let config = self
121                    .get_default_model()
122                    .cloned()
123                    .ok_or_else(|| anyhow!("No default model configured in registry"))?;
124                Ok((alias, config))
125            }
126        }
127    }
128
129    pub fn aliases(&self) -> Vec<String> {
130        let mut keys = self.models.keys().cloned().collect::<Vec<_>>();
131        keys.sort();
132        keys
133    }
134
135    pub fn load(path: &Path) -> Result<Self> {
136        if path.exists() {
137            let data = std::fs::read_to_string(path)?;
138            Ok(serde_json::from_str(&data)?)
139        } else {
140            Ok(Self::default())
141        }
142    }
143
144    pub fn save(&self, path: &Path) -> Result<()> {
145        let data = serde_json::to_string_pretty(self)?;
146        std::fs::write(path, data)?;
147        Ok(())
148    }
149
150    pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
151        self.models.get(name)
152    }
153
154    pub fn get_default_model(&self) -> Option<&ModelConfig> {
155        self.models.get(&self.default_model)
156    }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct RerankModelConfig {
161    pub name: String,
162    pub provider: String,
163    pub description: String,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct RerankModelRegistry {
168    pub models: HashMap<String, RerankModelConfig>,
169    pub default_model: String,
170}
171
172impl Default for RerankModelRegistry {
173    fn default() -> Self {
174        let mut models = HashMap::new();
175
176        models.insert(
177            "jina".to_string(),
178            RerankModelConfig {
179                name: "jina-reranker-v1-turbo-en".to_string(),
180                provider: "fastembed".to_string(),
181                description:
182                    "Jina Turbo reranker (default) tuned for English code + text relevance"
183                        .to_string(),
184            },
185        );
186
187        models.insert(
188            "bge".to_string(),
189            RerankModelConfig {
190                name: "BAAI/bge-reranker-base".to_string(),
191                provider: "fastembed".to_string(),
192                description: "BGE reranker base model for multilingual use cases".to_string(),
193            },
194        );
195
196        models.insert(
197            "mxbai".to_string(),
198            RerankModelConfig {
199                name: "mixedbread-ai/mxbai-rerank-xsmall-v1".to_string(),
200                provider: "mixedbread".to_string(),
201                description: "Mixedbread xsmall reranker (quantized) optimized for local inference"
202                    .to_string(),
203            },
204        );
205
206        Self {
207            models,
208            default_model: "jina".to_string(),
209        }
210    }
211}
212
213impl RerankModelRegistry {
214    fn format_available_models(&self) -> String {
215        self.models.keys().cloned().collect::<Vec<_>>().join(", ")
216    }
217
218    fn resolve_alias_or_name(&self, key: &str) -> Option<(String, &RerankModelConfig)> {
219        if let Some(config) = self.models.get(key) {
220            return Some((key.to_string(), config));
221        }
222
223        self.models
224            .iter()
225            .find(|(_, config)| config.name == key)
226            .map(|(alias, config)| (alias.clone(), config))
227    }
228
229    pub fn resolve(&self, requested: Option<&str>) -> Result<(String, RerankModelConfig)> {
230        match requested {
231            Some(name) => {
232                let (alias, config) = self.resolve_alias_or_name(name).ok_or_else(|| {
233                    anyhow!(
234                        "Unknown rerank model '{}'. Available models: {}",
235                        name,
236                        self.format_available_models()
237                    )
238                })?;
239                Ok((alias, config.clone()))
240            }
241            None => {
242                let alias = self.default_model.clone();
243                let config = self
244                    .models
245                    .get(&self.default_model)
246                    .cloned()
247                    .ok_or_else(|| anyhow!("No default reranking model configured"))?;
248                Ok((alias, config))
249            }
250        }
251    }
252
253    pub fn aliases(&self) -> Vec<String> {
254        let mut keys = self.models.keys().cloned().collect::<Vec<_>>();
255        keys.sort();
256        keys
257    }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ProjectConfig {
262    pub model: String,
263    pub chunk_size: usize,
264    pub chunk_overlap: usize,
265    pub index_backend: String,
266}
267
268impl Default for ProjectConfig {
269    fn default() -> Self {
270        Self {
271            model: "bge-small".to_string(),
272            chunk_size: 512,
273            chunk_overlap: 128,
274            index_backend: "hnsw".to_string(),
275        }
276    }
277}
278
279impl ProjectConfig {
280    pub fn load(path: &Path) -> Result<Self> {
281        if path.exists() {
282            let data = std::fs::read_to_string(path)?;
283            Ok(serde_json::from_str(&data)?)
284        } else {
285            Ok(Self::default())
286        }
287    }
288
289    pub fn save(&self, path: &Path) -> Result<()> {
290        let data = serde_json::to_string_pretty(self)?;
291        std::fs::write(path, data)?;
292        Ok(())
293    }
294}