reflex/embedding/
utils.rs

1use std::io;
2use std::path::Path;
3use tokenizers::Tokenizer;
4
5/// Loads a tokenizer from a model directory or explicit tokenizer.json path.
6pub fn load_tokenizer(model_path: &Path) -> io::Result<Tokenizer> {
7    let tokenizer_path = if model_path
8        .file_name()
9        .is_some_and(|name| name == std::ffi::OsStr::new("tokenizer.json"))
10    {
11        model_path.to_path_buf()
12    } else if model_path.is_dir() {
13        model_path.join("tokenizer.json")
14    } else {
15        model_path
16            .parent()
17            .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Model path has no parent"))?
18            .join("tokenizer.json")
19    };
20
21    Tokenizer::from_file(&tokenizer_path).map_err(io::Error::other)
22}
23
24/// Loads a tokenizer with truncation enabled for a maximum sequence length.
25///
26/// This is important for cross-encoder models that have a fixed maximum sequence length.
27/// Inputs exceeding `max_len` will be truncated to fit.
28pub fn load_tokenizer_with_truncation(model_path: &Path, max_len: usize) -> io::Result<Tokenizer> {
29    use tokenizers::TruncationParams;
30
31    let mut tokenizer = load_tokenizer(model_path)?;
32
33    let truncation = TruncationParams {
34        max_length: max_len,
35        ..Default::default()
36    };
37
38    tokenizer
39        .with_truncation(Some(truncation))
40        .map_err(|e| io::Error::other(format!("Failed to configure truncation: {}", e)))?;
41
42    Ok(tokenizer)
43}