Skip to main content

argyph_embed/
tokenize.rs

1use std::path::Path;
2
3use tokenizers::utils::padding::pad_encodings;
4use tokenizers::{PaddingParams, Tokenizer, TruncationDirection};
5
6use crate::error::{EmbedError, Result};
7
8pub struct BertTokenizer {
9    inner: Tokenizer,
10}
11
12pub struct TokenizedBatch {
13    pub input_ids: Vec<i64>,
14    pub attention_mask: Vec<i64>,
15    pub seq_len: usize,
16}
17
18impl BertTokenizer {
19    pub fn from_file(path: &Path) -> Result<Self> {
20        let inner = Tokenizer::from_file(path).map_err(|e| {
21            EmbedError::Config(format!(
22                "failed to load tokenizer from {}: {e}",
23                path.display()
24            ))
25        })?;
26        Ok(Self { inner })
27    }
28
29    pub fn encode_batch(&self, texts: &[String], max_len: usize) -> Result<TokenizedBatch> {
30        if texts.is_empty() {
31            return Err(EmbedError::EmptyInput);
32        }
33
34        let mut encodings = self
35            .inner
36            .encode_batch(texts.to_vec(), true)
37            .map_err(|e| EmbedError::Config(format!("tokenization failed: {e}")))?;
38
39        for enc in &mut encodings {
40            if enc.len() > max_len {
41                enc.truncate(max_len, 0, TruncationDirection::Right);
42            }
43        }
44
45        if !encodings.is_empty() {
46            let pad = PaddingParams::default();
47            pad_encodings(&mut encodings, &pad)
48                .map_err(|e| EmbedError::Config(format!("padding failed: {e}")))?;
49        }
50
51        let seq_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
52        let batch_size = texts.len();
53        let mut input_ids = Vec::with_capacity(batch_size * seq_len);
54        let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
55
56        for enc in &encodings {
57            let ids = enc.get_ids();
58            let mask = enc.get_attention_mask();
59            for i in 0..seq_len {
60                input_ids.push(*ids.get(i).unwrap_or(&0) as i64);
61                attention_mask.push(*mask.get(i).unwrap_or(&0) as i64);
62            }
63        }
64
65        Ok(TokenizedBatch {
66            input_ids,
67            attention_mask,
68            seq_len,
69        })
70    }
71
72    pub fn mean_pool(
73        last_hidden_state: &[f32],
74        attention_mask: &[i64],
75        batch_size: usize,
76        seq_len: usize,
77        hidden_size: usize,
78    ) -> Vec<Vec<f32>> {
79        let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
80
81        for (b, out_vec) in result.iter_mut().enumerate() {
82            let mut sum = vec![0.0f32; hidden_size];
83            let mut count = 0.0f32;
84
85            for s in 0..seq_len {
86                if attention_mask[b * seq_len + s] != 0 {
87                    let offset = (b * seq_len + s) * hidden_size;
88                    for (h, sum_val) in sum.iter_mut().take(hidden_size).enumerate() {
89                        *sum_val += last_hidden_state[offset + h];
90                    }
91                    count += 1.0;
92                }
93            }
94
95            if count > 0.0 {
96                for (h, out_val) in out_vec.iter_mut().enumerate() {
97                    *out_val = sum[h] / count;
98                }
99            }
100        }
101
102        for out_vec in &mut result {
103            let norm: f32 = out_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
104            if norm > 0.0 {
105                for val in out_vec.iter_mut() {
106                    *val /= norm;
107                }
108            }
109        }
110
111        result
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn mean_pool_l2_normalized() {
121        let last_hidden = vec![
122            1.0_f32, 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
123        ];
124        let mask = vec![1_i64, 0, 1, 1];
125        let result = BertTokenizer::mean_pool(&last_hidden, &mask, 1, 4, 3);
126
127        assert_eq!(result.len(), 1);
128        let norm: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
129        assert!(
130            (norm - 1.0).abs() < 0.001,
131            "expected L2 norm approx 1.0, got {norm}"
132        );
133    }
134
135    #[test]
136    fn mean_pool_multi_batch() {
137        let last_hidden = vec![
138            1.0_f32, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 3.0,
139            0.0, 0.0,
140        ];
141        let mask = vec![1_i64, 1, 1, 1, 0, 0, 1, 1, 1];
142        let result = BertTokenizer::mean_pool(&last_hidden, &mask, 2, 3, 3);
143
144        assert_eq!(result.len(), 2);
145        assert_eq!(result[0].len(), 3);
146        assert_eq!(result[1].len(), 3);
147
148        let norm0: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
149        let norm1: f32 = result[1].iter().map(|x| x * x).sum::<f32>().sqrt();
150        assert!((norm0 - 1.0).abs() < 0.001);
151        assert!((norm1 - 1.0).abs() < 0.001);
152    }
153
154    #[test]
155    fn mean_pool_empty_safe() {
156        let last_hidden = vec![];
157        let mask = vec![];
158        let result = BertTokenizer::mean_pool(&last_hidden, &mask, 0, 0, 0);
159        assert!(result.is_empty());
160    }
161}