1use std::{
2 convert::TryFrom,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use candle_core::Device;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use serde::Deserialize;
10
11use crate::{error::ColbertError, model::ColBERT};
12
13const PYLATE_DENSE_TYPE: &str = "pylate.models.Dense.Dense";
15
16pub struct DenseModuleData {
19 pub config_bytes: Vec<u8>,
22 pub weights_bytes: Vec<u8>,
26}
27
28#[derive(Deserialize)]
29struct ModuleEntry {
30 path: String,
31 #[serde(rename = "type")]
32 module_type: String,
33}
34
35pub struct ColbertBuilder {
40 repo_id: String,
41 query_prefix: Option<String>,
42 document_prefix: Option<String>,
43 query_prompt: Option<String>,
44 document_prompt: Option<String>,
45 mask_token: Option<String>,
46 do_query_expansion: Option<bool>,
47 attend_to_expansion_tokens: Option<bool>,
48 query_length: Option<usize>,
49 document_length: Option<usize>,
50 batch_size: Option<usize>,
51 device: Option<Device>,
52}
53
54impl ColbertBuilder {
55 pub(crate) fn new(repo_id: &str) -> Self {
57 Self {
58 repo_id: repo_id.to_string(),
59 query_prefix: None,
60 document_prefix: None,
61 query_prompt: None,
62 document_prompt: None,
63 mask_token: None,
64 do_query_expansion: None,
65 attend_to_expansion_tokens: None,
66 query_length: None,
67 document_length: None,
68 batch_size: None,
69 device: None,
70 }
71 }
72
73 pub fn with_query_prefix(mut self, query_prefix: String) -> Self {
75 self.query_prefix = Some(query_prefix);
76 self
77 }
78
79 pub fn with_document_prefix(mut self, document_prefix: String) -> Self {
81 self.document_prefix = Some(document_prefix);
82 self
83 }
84
85 pub fn with_query_prompt(mut self, query_prompt: String) -> Self {
88 self.query_prompt = Some(query_prompt);
89 self
90 }
91
92 pub fn with_document_prompt(mut self, document_prompt: String) -> Self {
95 self.document_prompt = Some(document_prompt);
96 self
97 }
98
99 pub fn with_mask_token(mut self, mask_token: String) -> Self {
101 self.mask_token = Some(mask_token);
102 self
103 }
104
105 pub fn with_do_query_expansion(mut self, do_expansion: bool) -> Self {
107 self.do_query_expansion = Some(do_expansion);
108 self
109 }
110
111 pub fn with_attend_to_expansion_tokens(mut self, attend: bool) -> Self {
113 self.attend_to_expansion_tokens = Some(attend);
114 self
115 }
116
117 pub fn with_query_length(mut self, query_length: usize) -> Self {
119 self.query_length = Some(query_length);
120 self
121 }
122
123 pub fn with_document_length(mut self, document_length: usize) -> Self {
125 self.document_length = Some(document_length);
126 self
127 }
128
129 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
131 self.batch_size = Some(batch_size);
132 self
133 }
134
135 pub fn with_device(mut self, device: Device) -> Self {
137 self.device = Some(device);
138 self
139 }
140}
141
142pub(crate) fn discover_dense_module_paths(
148 modules_json: &[u8],
149) -> Result<Vec<String>, ColbertError> {
150 let entries: Vec<ModuleEntry> = serde_json::from_slice(modules_json)?;
151 let dense_paths: Vec<String> = entries
152 .into_iter()
153 .filter(|e| e.module_type == PYLATE_DENSE_TYPE)
154 .map(|e| e.path)
155 .collect();
156 if dense_paths.is_empty() {
157 return Err(ColbertError::Operation(
158 "modules.json declares no pylate.models.Dense.Dense modules".into(),
159 ));
160 }
161 Ok(dense_paths)
162}
163
164struct LoadedAssets {
166 tokenizer: Vec<u8>,
167 weights: Vec<u8>,
168 config: Vec<u8>,
169 st_config: Vec<u8>,
170 special_tokens_map: Vec<u8>,
171 dense_modules: Vec<DenseModuleData>,
172}
173
174impl TryFrom<ColbertBuilder> for ColBERT {
175 type Error = ColbertError;
176
177 fn try_from(builder: ColbertBuilder) -> Result<Self, Self::Error> {
179 let device = builder.device.unwrap_or(Device::Cpu);
180
181 let local_path = PathBuf::from(&builder.repo_id);
182 let assets = if local_path.is_dir() {
183 load_local_assets(&local_path)?
184 } else {
185 load_hub_assets(&builder.repo_id)?
186 };
187
188 let st_config: serde_json::Value =
189 serde_json::from_slice(&assets.st_config)?;
190 let special_tokens_map: serde_json::Value =
191 serde_json::from_slice(&assets.special_tokens_map)?;
192
193 let final_query_prefix = builder.query_prefix.unwrap_or_else(|| {
194 st_config["query_prefix"]
195 .as_str()
196 .unwrap_or("[Q]")
197 .to_string()
198 });
199 let final_document_prefix =
200 builder.document_prefix.unwrap_or_else(|| {
201 st_config["document_prefix"]
202 .as_str()
203 .unwrap_or("[D]")
204 .to_string()
205 });
206
207 let final_query_prompt = builder.query_prompt.unwrap_or_else(|| {
208 st_config["prompts"]["query"]
209 .as_str()
210 .unwrap_or("")
211 .to_string()
212 });
213 let final_document_prompt =
214 builder.document_prompt.unwrap_or_else(|| {
215 st_config["prompts"]["document"]
216 .as_str()
217 .unwrap_or("")
218 .to_string()
219 });
220
221 let mask_token = builder.mask_token.unwrap_or_else(|| {
222 special_tokens_map["mask_token"]
223 .as_str()
224 .unwrap_or("[MASK]")
225 .to_string()
226 });
227
228 let final_do_query_expansion =
229 builder.do_query_expansion.unwrap_or_else(|| {
230 st_config["do_query_expansion"].as_bool().unwrap_or(true)
231 });
232
233 let final_attend_to_expansion_tokens =
234 builder.attend_to_expansion_tokens.unwrap_or_else(|| {
235 st_config["attend_to_expansion_tokens"]
236 .as_bool()
237 .unwrap_or(false)
238 });
239 let final_query_length = builder
240 .query_length
241 .or_else(|| st_config["query_length"].as_u64().map(|v| v as usize));
242 let final_document_length = builder.document_length.or_else(|| {
243 st_config["document_length"].as_u64().map(|v| v as usize)
244 });
245
246 ColBERT::new(
247 assets.weights,
248 assets.dense_modules,
249 assets.tokenizer,
250 assets.config,
251 final_query_prefix,
252 final_document_prefix,
253 final_query_prompt,
254 final_document_prompt,
255 mask_token,
256 final_do_query_expansion,
257 final_attend_to_expansion_tokens,
258 final_query_length,
259 final_document_length,
260 builder.batch_size,
261 &device,
262 )
263 }
264}
265
266fn load_local_assets(local_path: &Path) -> Result<LoadedAssets, ColbertError> {
268 let modules_path = local_path.join("modules.json");
269 if !modules_path.exists() {
270 return Err(ColbertError::Io(std::io::Error::new(
271 std::io::ErrorKind::NotFound,
272 format!(
273 "modules.json not found in local model directory: {}",
274 modules_path.display()
275 ),
276 )));
277 }
278 let modules_bytes = fs::read(&modules_path)?;
279 let dense_paths = discover_dense_module_paths(&modules_bytes)?;
280
281 let tokenizer_path = local_path.join("tokenizer.json");
282 let weights_path = local_path.join("model.safetensors");
283 let config_path = local_path.join("config.json");
284 let st_config_path = local_path.join("config_sentence_transformers.json");
285 let special_tokens_map_path = local_path.join("special_tokens_map.json");
286 for path in [
287 &tokenizer_path,
288 &weights_path,
289 &config_path,
290 &st_config_path,
291 &special_tokens_map_path,
292 ] {
293 if !path.exists() {
294 return Err(ColbertError::Io(std::io::Error::new(
295 std::io::ErrorKind::NotFound,
296 format!(
297 "File not found in local directory: {}",
298 path.display()
299 ),
300 )));
301 }
302 }
303
304 let mut dense_modules = Vec::with_capacity(dense_paths.len());
305 for rel_path in dense_paths {
306 let dense_dir = local_path.join(&rel_path);
307 let cfg_path = dense_dir.join("config.json");
308 let dense_weights_path = dense_dir.join("model.safetensors");
309 for path in [&cfg_path, &dense_weights_path] {
310 if !path.exists() {
311 return Err(ColbertError::Io(std::io::Error::new(
312 std::io::ErrorKind::NotFound,
313 format!("Dense module file not found: {}", path.display()),
314 )));
315 }
316 }
317 dense_modules.push(DenseModuleData {
318 config_bytes: fs::read(cfg_path)?,
319 weights_bytes: fs::read(dense_weights_path)?,
320 });
321 }
322
323 Ok(LoadedAssets {
324 tokenizer: fs::read(tokenizer_path)?,
325 weights: fs::read(weights_path)?,
326 config: fs::read(config_path)?,
327 st_config: fs::read(st_config_path)?,
328 special_tokens_map: fs::read(special_tokens_map_path)?,
329 dense_modules,
330 })
331}
332
333fn load_hub_assets(repo_id: &str) -> Result<LoadedAssets, ColbertError> {
335 let api = Api::new()?;
336 let repo = api.repo(Repo::with_revision(
337 repo_id.to_string(),
338 RepoType::Model,
339 "main".to_string(),
340 ));
341
342 let modules_path = repo.get("modules.json")?;
343 let modules_bytes = fs::read(&modules_path)?;
344 let dense_paths = discover_dense_module_paths(&modules_bytes)?;
345
346 let mut dense_modules = Vec::with_capacity(dense_paths.len());
347 for rel_path in dense_paths {
348 let cfg_remote = format!("{rel_path}/config.json");
349 let weights_remote = format!("{rel_path}/model.safetensors");
350 let cfg_path = repo.get(&cfg_remote)?;
351 let weights_path = repo.get(&weights_remote)?;
352 dense_modules.push(DenseModuleData {
353 config_bytes: fs::read(cfg_path)?,
354 weights_bytes: fs::read(weights_path)?,
355 });
356 }
357
358 Ok(LoadedAssets {
359 tokenizer: fs::read(repo.get("tokenizer.json")?)?,
360 weights: fs::read(repo.get("model.safetensors")?)?,
361 config: fs::read(repo.get("config.json")?)?,
362 st_config: fs::read(repo.get("config_sentence_transformers.json")?)?,
363 special_tokens_map: fs::read(repo.get("special_tokens_map.json")?)?,
364 dense_modules,
365 })
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn discovers_dense_modules_in_declaration_order() {
374 let modules_json = br#"[
375 {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
376 {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"},
377 {"idx":2,"name":"2","path":"2_Dense","type":"pylate.models.Dense.Dense"},
378 {"idx":3,"name":"3","path":"3_Dense","type":"pylate.models.Dense.Dense"}
379 ]"#;
380 let paths = discover_dense_module_paths(modules_json).unwrap();
381 assert_eq!(paths, vec!["1_Dense", "2_Dense", "3_Dense"]);
382 }
383
384 #[test]
385 fn discovers_single_dense_module_when_only_one_listed() {
386 let modules_json = br#"[
387 {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
388 {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"}
389 ]"#;
390 let paths = discover_dense_module_paths(modules_json).unwrap();
391 assert_eq!(paths, vec!["1_Dense"]);
392 }
393
394 #[test]
395 fn errors_when_modules_json_has_no_dense_modules() {
396 let modules_json = br#"[
397 {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"}
398 ]"#;
399 let err = discover_dense_module_paths(modules_json).unwrap_err();
400 assert!(matches!(err, ColbertError::Operation(_)));
401 }
402
403 #[test]
404 fn ignores_non_dense_module_entries() {
405 let modules_json = br#"[
406 {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
407 {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"},
408 {"idx":2,"name":"pool","path":"2_Pooling","type":"sentence_transformers.models.Pooling"},
409 {"idx":3,"name":"3","path":"3_Dense","type":"pylate.models.Dense.Dense"}
410 ]"#;
411 let paths = discover_dense_module_paths(modules_json).unwrap();
412 assert_eq!(paths, vec!["1_Dense", "3_Dense"]);
413 }
414}