pylate_rs/
builder.rs

1use crate::{error::ColbertError, model::ColBERT};
2use candle_core::Device;
3use hf_hub::{api::sync::Api, Repo, RepoType};
4use std::{convert::TryFrom, fs, path::PathBuf};
5
6/// A builder for configuring and creating a `ColBERT` model from the Hugging Face Hub.
7///
8/// This struct provides an interface to set various configuration options
9/// before downloading the model files and initializing the `ColBERT` instance.
10/// This is only available when the `hf-hub` feature is enabled.
11#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
12pub struct ColbertBuilder {
13    repo_id: String,
14    query_prefix: Option<String>,
15    document_prefix: Option<String>,
16    mask_token: Option<String>,
17    attend_to_expansion_tokens: Option<bool>,
18    query_length: Option<usize>,
19    document_length: Option<usize>,
20    batch_size: Option<usize>,
21    device: Option<Device>,
22}
23
24#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
25impl ColbertBuilder {
26    /// Creates a new `ColbertBuilder`.
27    pub(crate) fn new(repo_id: &str) -> Self {
28        Self {
29            repo_id: repo_id.to_string(),
30            query_prefix: None,
31            document_prefix: None,
32            mask_token: None,
33            attend_to_expansion_tokens: None,
34            query_length: None,
35            document_length: None,
36            batch_size: None,
37            device: None,
38        }
39    }
40
41    /// Sets the query prefix token. Overrides the value from the config file.
42    pub fn with_query_prefix(mut self, query_prefix: String) -> Self {
43        self.query_prefix = Some(query_prefix);
44        self
45    }
46
47    /// Sets the document prefix token. Overrides the value from the config file.
48    pub fn with_document_prefix(mut self, document_prefix: String) -> Self {
49        self.document_prefix = Some(document_prefix);
50        self
51    }
52
53    /// Sets the mask token. Overrides the value from the `special_tokens_map.json` file.
54    pub fn with_mask_token(mut self, mask_token: String) -> Self {
55        self.mask_token = Some(mask_token);
56        self
57    }
58
59    /// Sets whether to attend to expansion tokens. Overrides the value from the config file.
60    pub fn with_attend_to_expansion_tokens(mut self, attend: bool) -> Self {
61        self.attend_to_expansion_tokens = Some(attend);
62        self
63    }
64
65    /// Sets the maximum query length. Overrides the value from the config file.
66    pub fn with_query_length(mut self, query_length: usize) -> Self {
67        self.query_length = Some(query_length);
68        self
69    }
70
71    /// Sets the maximum document length. Overrides the value from the config file.
72    pub fn with_document_length(mut self, document_length: usize) -> Self {
73        self.document_length = Some(document_length);
74        self
75    }
76
77    /// Sets the batch size for encoding. Defaults to 32.
78    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
79        self.batch_size = Some(batch_size);
80        self
81    }
82
83    /// Sets the device to run the model on.
84    pub fn with_device(mut self, device: Device) -> Self {
85        self.device = Some(device);
86        self
87    }
88}
89
90#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
91impl TryFrom<ColbertBuilder> for ColBERT {
92    type Error = ColbertError;
93
94    /// Builds the `ColBERT` model by downloading files from the hub and initializing the model.
95    fn try_from(builder: ColbertBuilder) -> Result<Self, Self::Error> {
96        let device = builder.device.unwrap_or(Device::Cpu);
97
98        let local_path = PathBuf::from(&builder.repo_id);
99        let (
100            tokenizer_path,
101            weights_path,
102            config_path,
103            st_config_path,
104            dense_config_path,
105            dense_weights_path,
106            special_tokens_map_path,
107        ) = if local_path.is_dir() {
108            (
109                local_path.join("tokenizer.json"),
110                local_path.join("model.safetensors"),
111                local_path.join("config.json"),
112                local_path.join("config_sentence_transformers.json"),
113                local_path.join("1_Dense/config.json"),
114                local_path.join("1_Dense/model.safetensors"),
115                local_path.join("special_tokens_map.json"),
116            )
117        } else {
118            let api = Api::new()?;
119            let repo = api.repo(Repo::with_revision(
120                builder.repo_id.clone(),
121                RepoType::Model,
122                "main".to_string(),
123            ));
124            (
125                repo.get("tokenizer.json")?,
126                repo.get("model.safetensors")?,
127                repo.get("config.json")?,
128                repo.get("config_sentence_transformers.json")?,
129                repo.get("1_Dense/config.json")?,
130                repo.get("1_Dense/model.safetensors")?,
131                repo.get("special_tokens_map.json")?,
132            )
133        };
134
135        if local_path.is_dir() {
136            for path in [
137                &tokenizer_path,
138                &weights_path,
139                &config_path,
140                &st_config_path,
141                &dense_config_path,
142                &dense_weights_path,
143                &special_tokens_map_path,
144            ] {
145                if !path.exists() {
146                    return Err(ColbertError::Io(std::io::Error::new(
147                        std::io::ErrorKind::NotFound,
148                        format!("File not found in local directory: {}", path.display()),
149                    )));
150                }
151            }
152        }
153
154        let tokenizer_bytes = fs::read(tokenizer_path)?;
155        let weights_bytes = fs::read(weights_path)?;
156        let config_bytes = fs::read(config_path)?;
157        let st_config_bytes = fs::read(st_config_path)?;
158        let dense_config_bytes = fs::read(dense_config_path)?;
159        let dense_weights_bytes = fs::read(dense_weights_path)?;
160        let special_tokens_map_bytes = fs::read(special_tokens_map_path)?;
161
162        let st_config: serde_json::Value = serde_json::from_slice(&st_config_bytes)?;
163        let special_tokens_map: serde_json::Value =
164            serde_json::from_slice(&special_tokens_map_bytes)?;
165
166        let final_query_prefix = builder.query_prefix.unwrap_or_else(|| {
167            st_config["query_prefix"]
168                .as_str()
169                .unwrap_or("[Q]")
170                .to_string()
171        });
172        let final_document_prefix = builder.document_prefix.unwrap_or_else(|| {
173            st_config["document_prefix"]
174                .as_str()
175                .unwrap_or("[D]")
176                .to_string()
177        });
178
179        let mask_token = builder.mask_token.unwrap_or_else(|| {
180            special_tokens_map["mask_token"]
181                .as_str()
182                .unwrap_or("[MASK]")
183                .to_string()
184        });
185
186        let final_attend_to_expansion_tokens =
187            builder.attend_to_expansion_tokens.unwrap_or_else(|| {
188                st_config["attend_to_expansion_tokens"]
189                    .as_bool()
190                    .unwrap_or(false)
191            });
192        let final_query_length = builder
193            .query_length
194            .or_else(|| st_config["query_length"].as_u64().map(|v| v as usize));
195        let final_document_length = builder
196            .document_length
197            .or_else(|| st_config["document_length"].as_u64().map(|v| v as usize));
198
199        ColBERT::new(
200            weights_bytes,
201            dense_weights_bytes,
202            tokenizer_bytes,
203            config_bytes,
204            dense_config_bytes,
205            final_query_prefix,
206            final_document_prefix,
207            mask_token,
208            final_attend_to_expansion_tokens,
209            final_query_length,
210            final_document_length,
211            builder.batch_size,
212            &device,
213        )
214    }
215}