reflex/embedding/
utils.rs1use std::io;
2use std::path::Path;
3use tokenizers::Tokenizer;
4
5pub fn load_tokenizer(model_path: &Path) -> io::Result<Tokenizer> {
7 let tokenizer_path = if model_path
8 .file_name()
9 .is_some_and(|name| name == std::ffi::OsStr::new("tokenizer.json"))
10 {
11 model_path.to_path_buf()
12 } else if model_path.is_dir() {
13 model_path.join("tokenizer.json")
14 } else {
15 model_path
16 .parent()
17 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Model path has no parent"))?
18 .join("tokenizer.json")
19 };
20
21 Tokenizer::from_file(&tokenizer_path).map_err(io::Error::other)
22}
23
24pub fn load_tokenizer_with_truncation(model_path: &Path, max_len: usize) -> io::Result<Tokenizer> {
29 use tokenizers::TruncationParams;
30
31 let mut tokenizer = load_tokenizer(model_path)?;
32
33 let truncation = TruncationParams {
34 max_length: max_len,
35 ..Default::default()
36 };
37
38 tokenizer
39 .with_truncation(Some(truncation))
40 .map_err(|e| io::Error::other(format!("Failed to configure truncation: {}", e)))?;
41
42 Ok(tokenizer)
43}