1use crate::{error::ColbertError, model::ColBERT};
2use candle_core::Device;
3use hf_hub::{api::sync::Api, Repo, RepoType};
4use std::{convert::TryFrom, fs, path::PathBuf};
5
6#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
12pub struct ColbertBuilder {
13 repo_id: String,
14 query_prefix: Option<String>,
15 document_prefix: Option<String>,
16 mask_token: Option<String>,
17 attend_to_expansion_tokens: Option<bool>,
18 query_length: Option<usize>,
19 document_length: Option<usize>,
20 batch_size: Option<usize>,
21 device: Option<Device>,
22}
23
24#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
25impl ColbertBuilder {
26 pub(crate) fn new(repo_id: &str) -> Self {
28 Self {
29 repo_id: repo_id.to_string(),
30 query_prefix: None,
31 document_prefix: None,
32 mask_token: None,
33 attend_to_expansion_tokens: None,
34 query_length: None,
35 document_length: None,
36 batch_size: None,
37 device: None,
38 }
39 }
40
41 pub fn with_query_prefix(mut self, query_prefix: String) -> Self {
43 self.query_prefix = Some(query_prefix);
44 self
45 }
46
47 pub fn with_document_prefix(mut self, document_prefix: String) -> Self {
49 self.document_prefix = Some(document_prefix);
50 self
51 }
52
53 pub fn with_mask_token(mut self, mask_token: String) -> Self {
55 self.mask_token = Some(mask_token);
56 self
57 }
58
59 pub fn with_attend_to_expansion_tokens(mut self, attend: bool) -> Self {
61 self.attend_to_expansion_tokens = Some(attend);
62 self
63 }
64
65 pub fn with_query_length(mut self, query_length: usize) -> Self {
67 self.query_length = Some(query_length);
68 self
69 }
70
71 pub fn with_document_length(mut self, document_length: usize) -> Self {
73 self.document_length = Some(document_length);
74 self
75 }
76
77 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
79 self.batch_size = Some(batch_size);
80 self
81 }
82
83 pub fn with_device(mut self, device: Device) -> Self {
85 self.device = Some(device);
86 self
87 }
88}
89
90#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
91impl TryFrom<ColbertBuilder> for ColBERT {
92 type Error = ColbertError;
93
94 fn try_from(builder: ColbertBuilder) -> Result<Self, Self::Error> {
96 let device = builder.device.unwrap_or(Device::Cpu);
97
98 let local_path = PathBuf::from(&builder.repo_id);
99 let (
100 tokenizer_path,
101 weights_path,
102 config_path,
103 st_config_path,
104 dense_config_path,
105 dense_weights_path,
106 special_tokens_map_path,
107 ) = if local_path.is_dir() {
108 (
109 local_path.join("tokenizer.json"),
110 local_path.join("model.safetensors"),
111 local_path.join("config.json"),
112 local_path.join("config_sentence_transformers.json"),
113 local_path.join("1_Dense/config.json"),
114 local_path.join("1_Dense/model.safetensors"),
115 local_path.join("special_tokens_map.json"),
116 )
117 } else {
118 let api = Api::new()?;
119 let repo = api.repo(Repo::with_revision(
120 builder.repo_id.clone(),
121 RepoType::Model,
122 "main".to_string(),
123 ));
124 (
125 repo.get("tokenizer.json")?,
126 repo.get("model.safetensors")?,
127 repo.get("config.json")?,
128 repo.get("config_sentence_transformers.json")?,
129 repo.get("1_Dense/config.json")?,
130 repo.get("1_Dense/model.safetensors")?,
131 repo.get("special_tokens_map.json")?,
132 )
133 };
134
135 if local_path.is_dir() {
136 for path in [
137 &tokenizer_path,
138 &weights_path,
139 &config_path,
140 &st_config_path,
141 &dense_config_path,
142 &dense_weights_path,
143 &special_tokens_map_path,
144 ] {
145 if !path.exists() {
146 return Err(ColbertError::Io(std::io::Error::new(
147 std::io::ErrorKind::NotFound,
148 format!("File not found in local directory: {}", path.display()),
149 )));
150 }
151 }
152 }
153
154 let tokenizer_bytes = fs::read(tokenizer_path)?;
155 let weights_bytes = fs::read(weights_path)?;
156 let config_bytes = fs::read(config_path)?;
157 let st_config_bytes = fs::read(st_config_path)?;
158 let dense_config_bytes = fs::read(dense_config_path)?;
159 let dense_weights_bytes = fs::read(dense_weights_path)?;
160 let special_tokens_map_bytes = fs::read(special_tokens_map_path)?;
161
162 let st_config: serde_json::Value = serde_json::from_slice(&st_config_bytes)?;
163 let special_tokens_map: serde_json::Value =
164 serde_json::from_slice(&special_tokens_map_bytes)?;
165
166 let final_query_prefix = builder.query_prefix.unwrap_or_else(|| {
167 st_config["query_prefix"]
168 .as_str()
169 .unwrap_or("[Q]")
170 .to_string()
171 });
172 let final_document_prefix = builder.document_prefix.unwrap_or_else(|| {
173 st_config["document_prefix"]
174 .as_str()
175 .unwrap_or("[D]")
176 .to_string()
177 });
178
179 let mask_token = builder.mask_token.unwrap_or_else(|| {
180 special_tokens_map["mask_token"]
181 .as_str()
182 .unwrap_or("[MASK]")
183 .to_string()
184 });
185
186 let final_attend_to_expansion_tokens =
187 builder.attend_to_expansion_tokens.unwrap_or_else(|| {
188 st_config["attend_to_expansion_tokens"]
189 .as_bool()
190 .unwrap_or(false)
191 });
192 let final_query_length = builder
193 .query_length
194 .or_else(|| st_config["query_length"].as_u64().map(|v| v as usize));
195 let final_document_length = builder
196 .document_length
197 .or_else(|| st_config["document_length"].as_u64().map(|v| v as usize));
198
199 ColBERT::new(
200 weights_bytes,
201 dense_weights_bytes,
202 tokenizer_bytes,
203 config_bytes,
204 dense_config_bytes,
205 final_query_prefix,
206 final_document_prefix,
207 mask_token,
208 final_attend_to_expansion_tokens,
209 final_query_length,
210 final_document_length,
211 builder.batch_size,
212 &device,
213 )
214 }
215}