Skip to main content

docbert_pylate/
builder.rs

1use std::{
2    convert::TryFrom,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use candle_core::Device;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use serde::Deserialize;
10
11use crate::{error::ColbertError, model::ColBERT};
12
13/// SentenceTransformers `pylate.models.Dense.Dense` module type marker.
14const PYLATE_DENSE_TYPE: &str = "pylate.models.Dense.Dense";
15
16/// Raw bytes for one Dense projection layer in the SentenceTransformers
17/// pipeline, in the order they appear in `modules.json`.
18pub struct DenseModuleData {
19    /// Contents of `<path>/config.json` (in_features, out_features,
20    /// activation_function, bias, optional use_residual).
21    pub config_bytes: Vec<u8>,
22    /// Contents of `<path>/model.safetensors` (always contains
23    /// `linear.weight`; also contains `residual.weight` when the module's
24    /// `use_residual` is true).
25    pub weights_bytes: Vec<u8>,
26}
27
28#[derive(Deserialize)]
29struct ModuleEntry {
30    path: String,
31    #[serde(rename = "type")]
32    module_type: String,
33}
34
35/// A builder for configuring and creating a `ColBERT` model from the Hugging Face Hub.
36///
37/// This struct provides an interface to set various configuration options
38/// before downloading the model files and initializing the `ColBERT` instance.
39pub struct ColbertBuilder {
40    repo_id: String,
41    query_prefix: Option<String>,
42    document_prefix: Option<String>,
43    query_prompt: Option<String>,
44    document_prompt: Option<String>,
45    mask_token: Option<String>,
46    do_query_expansion: Option<bool>,
47    attend_to_expansion_tokens: Option<bool>,
48    query_length: Option<usize>,
49    document_length: Option<usize>,
50    batch_size: Option<usize>,
51    device: Option<Device>,
52}
53
54impl ColbertBuilder {
55    /// Creates a new `ColbertBuilder`.
56    pub(crate) fn new(repo_id: &str) -> Self {
57        Self {
58            repo_id: repo_id.to_string(),
59            query_prefix: None,
60            document_prefix: None,
61            query_prompt: None,
62            document_prompt: None,
63            mask_token: None,
64            do_query_expansion: None,
65            attend_to_expansion_tokens: None,
66            query_length: None,
67            document_length: None,
68            batch_size: None,
69            device: None,
70        }
71    }
72
73    /// Sets the query prefix token. Overrides the value from the config file.
74    pub fn with_query_prefix(mut self, query_prefix: String) -> Self {
75        self.query_prefix = Some(query_prefix);
76        self
77    }
78
79    /// Sets the document prefix token. Overrides the value from the config file.
80    pub fn with_document_prefix(mut self, document_prefix: String) -> Self {
81        self.document_prefix = Some(document_prefix);
82        self
83    }
84
85    /// Sets the SentenceTransformers-style query prompt (e.g. `"search_query: "`).
86    /// Overrides the `prompts.query` field from `config_sentence_transformers.json`.
87    pub fn with_query_prompt(mut self, query_prompt: String) -> Self {
88        self.query_prompt = Some(query_prompt);
89        self
90    }
91
92    /// Sets the SentenceTransformers-style document prompt (e.g. `"search_document: "`).
93    /// Overrides the `prompts.document` field from `config_sentence_transformers.json`.
94    pub fn with_document_prompt(mut self, document_prompt: String) -> Self {
95        self.document_prompt = Some(document_prompt);
96        self
97    }
98
99    /// Sets the mask token. Overrides the value from the `special_tokens_map.json` file.
100    pub fn with_mask_token(mut self, mask_token: String) -> Self {
101        self.mask_token = Some(mask_token);
102        self
103    }
104
105    /// Sets whether to perform query expansion. Overrides the value from the config file.
106    pub fn with_do_query_expansion(mut self, do_expansion: bool) -> Self {
107        self.do_query_expansion = Some(do_expansion);
108        self
109    }
110
111    /// Sets whether to attend to expansion tokens. Overrides the value from the config file.
112    pub fn with_attend_to_expansion_tokens(mut self, attend: bool) -> Self {
113        self.attend_to_expansion_tokens = Some(attend);
114        self
115    }
116
117    /// Sets the maximum query length. Overrides the value from the config file.
118    pub fn with_query_length(mut self, query_length: usize) -> Self {
119        self.query_length = Some(query_length);
120        self
121    }
122
123    /// Sets the maximum document length. Overrides the value from the config file.
124    pub fn with_document_length(mut self, document_length: usize) -> Self {
125        self.document_length = Some(document_length);
126        self
127    }
128
129    /// Sets the batch size for encoding. Defaults to 32.
130    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
131        self.batch_size = Some(batch_size);
132        self
133    }
134
135    /// Sets the device to run the model on.
136    pub fn with_device(mut self, device: Device) -> Self {
137        self.device = Some(device);
138        self
139    }
140}
141
142/// Parses a `modules.json` payload and returns the relative paths of every
143/// `pylate.models.Dense.Dense` module, in declaration order.
144///
145/// Errors when the file isn't a JSON array of objects with `path` and `type`
146/// fields, or when no Dense modules are listed.
147pub(crate) fn discover_dense_module_paths(
148    modules_json: &[u8],
149) -> Result<Vec<String>, ColbertError> {
150    let entries: Vec<ModuleEntry> = serde_json::from_slice(modules_json)?;
151    let dense_paths: Vec<String> = entries
152        .into_iter()
153        .filter(|e| e.module_type == PYLATE_DENSE_TYPE)
154        .map(|e| e.path)
155        .collect();
156    if dense_paths.is_empty() {
157        return Err(ColbertError::Operation(
158            "modules.json declares no pylate.models.Dense.Dense modules".into(),
159        ));
160    }
161    Ok(dense_paths)
162}
163
164/// Bag of bytes the builder hands to [`ColBERT::new`].
165struct LoadedAssets {
166    tokenizer: Vec<u8>,
167    weights: Vec<u8>,
168    config: Vec<u8>,
169    st_config: Vec<u8>,
170    special_tokens_map: Vec<u8>,
171    dense_modules: Vec<DenseModuleData>,
172}
173
174impl TryFrom<ColbertBuilder> for ColBERT {
175    type Error = ColbertError;
176
177    /// Builds the `ColBERT` model by downloading files from the hub and initializing the model.
178    fn try_from(builder: ColbertBuilder) -> Result<Self, Self::Error> {
179        let device = builder.device.unwrap_or(Device::Cpu);
180
181        let local_path = PathBuf::from(&builder.repo_id);
182        let assets = if local_path.is_dir() {
183            load_local_assets(&local_path)?
184        } else {
185            load_hub_assets(&builder.repo_id)?
186        };
187
188        let st_config: serde_json::Value =
189            serde_json::from_slice(&assets.st_config)?;
190        let special_tokens_map: serde_json::Value =
191            serde_json::from_slice(&assets.special_tokens_map)?;
192
193        let final_query_prefix = builder.query_prefix.unwrap_or_else(|| {
194            st_config["query_prefix"]
195                .as_str()
196                .unwrap_or("[Q]")
197                .to_string()
198        });
199        let final_document_prefix =
200            builder.document_prefix.unwrap_or_else(|| {
201                st_config["document_prefix"]
202                    .as_str()
203                    .unwrap_or("[D]")
204                    .to_string()
205            });
206
207        let final_query_prompt = builder.query_prompt.unwrap_or_else(|| {
208            st_config["prompts"]["query"]
209                .as_str()
210                .unwrap_or("")
211                .to_string()
212        });
213        let final_document_prompt =
214            builder.document_prompt.unwrap_or_else(|| {
215                st_config["prompts"]["document"]
216                    .as_str()
217                    .unwrap_or("")
218                    .to_string()
219            });
220
221        let mask_token = builder.mask_token.unwrap_or_else(|| {
222            special_tokens_map["mask_token"]
223                .as_str()
224                .unwrap_or("[MASK]")
225                .to_string()
226        });
227
228        let final_do_query_expansion =
229            builder.do_query_expansion.unwrap_or_else(|| {
230                st_config["do_query_expansion"].as_bool().unwrap_or(true)
231            });
232
233        let final_attend_to_expansion_tokens =
234            builder.attend_to_expansion_tokens.unwrap_or_else(|| {
235                st_config["attend_to_expansion_tokens"]
236                    .as_bool()
237                    .unwrap_or(false)
238            });
239        let final_query_length = builder
240            .query_length
241            .or_else(|| st_config["query_length"].as_u64().map(|v| v as usize));
242        let final_document_length = builder.document_length.or_else(|| {
243            st_config["document_length"].as_u64().map(|v| v as usize)
244        });
245
246        ColBERT::new(
247            assets.weights,
248            assets.dense_modules,
249            assets.tokenizer,
250            assets.config,
251            final_query_prefix,
252            final_document_prefix,
253            final_query_prompt,
254            final_document_prompt,
255            mask_token,
256            final_do_query_expansion,
257            final_attend_to_expansion_tokens,
258            final_query_length,
259            final_document_length,
260            builder.batch_size,
261            &device,
262        )
263    }
264}
265
266/// Reads every required asset from a local model directory.
267fn load_local_assets(local_path: &Path) -> Result<LoadedAssets, ColbertError> {
268    let modules_path = local_path.join("modules.json");
269    if !modules_path.exists() {
270        return Err(ColbertError::Io(std::io::Error::new(
271            std::io::ErrorKind::NotFound,
272            format!(
273                "modules.json not found in local model directory: {}",
274                modules_path.display()
275            ),
276        )));
277    }
278    let modules_bytes = fs::read(&modules_path)?;
279    let dense_paths = discover_dense_module_paths(&modules_bytes)?;
280
281    let tokenizer_path = local_path.join("tokenizer.json");
282    let weights_path = local_path.join("model.safetensors");
283    let config_path = local_path.join("config.json");
284    let st_config_path = local_path.join("config_sentence_transformers.json");
285    let special_tokens_map_path = local_path.join("special_tokens_map.json");
286    for path in [
287        &tokenizer_path,
288        &weights_path,
289        &config_path,
290        &st_config_path,
291        &special_tokens_map_path,
292    ] {
293        if !path.exists() {
294            return Err(ColbertError::Io(std::io::Error::new(
295                std::io::ErrorKind::NotFound,
296                format!(
297                    "File not found in local directory: {}",
298                    path.display()
299                ),
300            )));
301        }
302    }
303
304    let mut dense_modules = Vec::with_capacity(dense_paths.len());
305    for rel_path in dense_paths {
306        let dense_dir = local_path.join(&rel_path);
307        let cfg_path = dense_dir.join("config.json");
308        let dense_weights_path = dense_dir.join("model.safetensors");
309        for path in [&cfg_path, &dense_weights_path] {
310            if !path.exists() {
311                return Err(ColbertError::Io(std::io::Error::new(
312                    std::io::ErrorKind::NotFound,
313                    format!("Dense module file not found: {}", path.display()),
314                )));
315            }
316        }
317        dense_modules.push(DenseModuleData {
318            config_bytes: fs::read(cfg_path)?,
319            weights_bytes: fs::read(dense_weights_path)?,
320        });
321    }
322
323    Ok(LoadedAssets {
324        tokenizer: fs::read(tokenizer_path)?,
325        weights: fs::read(weights_path)?,
326        config: fs::read(config_path)?,
327        st_config: fs::read(st_config_path)?,
328        special_tokens_map: fs::read(special_tokens_map_path)?,
329        dense_modules,
330    })
331}
332
333/// Downloads every required asset from the Hugging Face Hub.
334fn load_hub_assets(repo_id: &str) -> Result<LoadedAssets, ColbertError> {
335    let api = Api::new()?;
336    let repo = api.repo(Repo::with_revision(
337        repo_id.to_string(),
338        RepoType::Model,
339        "main".to_string(),
340    ));
341
342    let modules_path = repo.get("modules.json")?;
343    let modules_bytes = fs::read(&modules_path)?;
344    let dense_paths = discover_dense_module_paths(&modules_bytes)?;
345
346    let mut dense_modules = Vec::with_capacity(dense_paths.len());
347    for rel_path in dense_paths {
348        let cfg_remote = format!("{rel_path}/config.json");
349        let weights_remote = format!("{rel_path}/model.safetensors");
350        let cfg_path = repo.get(&cfg_remote)?;
351        let weights_path = repo.get(&weights_remote)?;
352        dense_modules.push(DenseModuleData {
353            config_bytes: fs::read(cfg_path)?,
354            weights_bytes: fs::read(weights_path)?,
355        });
356    }
357
358    Ok(LoadedAssets {
359        tokenizer: fs::read(repo.get("tokenizer.json")?)?,
360        weights: fs::read(repo.get("model.safetensors")?)?,
361        config: fs::read(repo.get("config.json")?)?,
362        st_config: fs::read(repo.get("config_sentence_transformers.json")?)?,
363        special_tokens_map: fs::read(repo.get("special_tokens_map.json")?)?,
364        dense_modules,
365    })
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn discovers_dense_modules_in_declaration_order() {
374        let modules_json = br#"[
375            {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
376            {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"},
377            {"idx":2,"name":"2","path":"2_Dense","type":"pylate.models.Dense.Dense"},
378            {"idx":3,"name":"3","path":"3_Dense","type":"pylate.models.Dense.Dense"}
379        ]"#;
380        let paths = discover_dense_module_paths(modules_json).unwrap();
381        assert_eq!(paths, vec!["1_Dense", "2_Dense", "3_Dense"]);
382    }
383
384    #[test]
385    fn discovers_single_dense_module_when_only_one_listed() {
386        let modules_json = br#"[
387            {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
388            {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"}
389        ]"#;
390        let paths = discover_dense_module_paths(modules_json).unwrap();
391        assert_eq!(paths, vec!["1_Dense"]);
392    }
393
394    #[test]
395    fn errors_when_modules_json_has_no_dense_modules() {
396        let modules_json = br#"[
397            {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"}
398        ]"#;
399        let err = discover_dense_module_paths(modules_json).unwrap_err();
400        assert!(matches!(err, ColbertError::Operation(_)));
401    }
402
403    #[test]
404    fn ignores_non_dense_module_entries() {
405        let modules_json = br#"[
406            {"idx":0,"name":"0","path":"","type":"sentence_transformers.models.Transformer"},
407            {"idx":1,"name":"1","path":"1_Dense","type":"pylate.models.Dense.Dense"},
408            {"idx":2,"name":"pool","path":"2_Pooling","type":"sentence_transformers.models.Pooling"},
409            {"idx":3,"name":"3","path":"3_Dense","type":"pylate.models.Dense.Dense"}
410        ]"#;
411        let paths = discover_dense_module_paths(modules_json).unwrap();
412        assert_eq!(paths, vec!["1_Dense", "3_Dense"]);
413    }
414}