Skip to main content

ripvec_core/
tokenize.rs

1//! `HuggingFace` tokenizer wrapper.
2//!
3//! Downloads and caches the tokenizer.json from a `HuggingFace` model
4//! repository using hf-hub, then loads it for fast encoding.
5
6use hf_hub::api::sync::Api;
7use tokenizers::Tokenizer;
8
9/// Load a tokenizer from a `HuggingFace` model repository.
10///
11/// Downloads `tokenizer.json` on first call; subsequent calls use the cache.
12///
13/// # Errors
14///
15/// Returns an error if the tokenizer file cannot be downloaded or parsed.
16pub fn load_tokenizer(model_repo: &str) -> crate::Result<Tokenizer> {
17    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
18    let repo = api.model(model_repo.to_string());
19    let tokenizer_path = repo
20        .get("tokenizer.json")
21        .map_err(|e| crate::Error::Download(e.to_string()))?;
22    Tokenizer::from_file(tokenizer_path).map_err(|e| crate::Error::Tokenization(e.to_string()))
23}
24
25/// Tokenize a query string for embedding, truncating to `model_max_tokens`.
26///
27/// Returns an [`crate::backend::Encoding`] with `input_ids`, `attention_mask`,
28/// and `token_type_ids` cast to `i64`, ready for ONNX inference.
29///
30/// # Errors
31///
32/// Returns an error if the tokenizer fails to encode the text.
33pub fn tokenize_query(
34    text: &str,
35    tokenizer: &tokenizers::Tokenizer,
36    model_max_tokens: usize,
37) -> crate::Result<crate::backend::Encoding> {
38    let encoding = tokenizer
39        .encode(text, true)
40        .map_err(|e| crate::Error::Tokenization(e.to_string()))?;
41
42    let len = encoding.get_ids().len().min(model_max_tokens);
43    Ok(crate::backend::Encoding {
44        input_ids: encoding.get_ids()[..len]
45            .iter()
46            .map(|&x| i64::from(x))
47            .collect(),
48        attention_mask: encoding.get_attention_mask()[..len]
49            .iter()
50            .map(|&x| i64::from(x))
51            .collect(),
52        token_type_ids: encoding.get_type_ids()[..len]
53            .iter()
54            .map(|&x| i64::from(x))
55            .collect(),
56    })
57}