1use std::path::Path;
2
3use tokenizers::utils::padding::pad_encodings;
4use tokenizers::{PaddingParams, Tokenizer, TruncationDirection};
5
6use crate::error::{EmbedError, Result};
7
8pub struct BertTokenizer {
9 inner: Tokenizer,
10}
11
12pub struct TokenizedBatch {
13 pub input_ids: Vec<i64>,
14 pub attention_mask: Vec<i64>,
15 pub seq_len: usize,
16}
17
18impl BertTokenizer {
19 pub fn from_file(path: &Path) -> Result<Self> {
20 let inner = Tokenizer::from_file(path).map_err(|e| {
21 EmbedError::Config(format!(
22 "failed to load tokenizer from {}: {e}",
23 path.display()
24 ))
25 })?;
26 Ok(Self { inner })
27 }
28
29 pub fn encode_batch(&self, texts: &[String], max_len: usize) -> Result<TokenizedBatch> {
30 if texts.is_empty() {
31 return Err(EmbedError::EmptyInput);
32 }
33
34 let mut encodings = self
35 .inner
36 .encode_batch(texts.to_vec(), true)
37 .map_err(|e| EmbedError::Config(format!("tokenization failed: {e}")))?;
38
39 for enc in &mut encodings {
40 if enc.len() > max_len {
41 enc.truncate(max_len, 0, TruncationDirection::Right);
42 }
43 }
44
45 if !encodings.is_empty() {
46 let pad = PaddingParams::default();
47 pad_encodings(&mut encodings, &pad)
48 .map_err(|e| EmbedError::Config(format!("padding failed: {e}")))?;
49 }
50
51 let seq_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
52 let batch_size = texts.len();
53 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
54 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
55
56 for enc in &encodings {
57 let ids = enc.get_ids();
58 let mask = enc.get_attention_mask();
59 for i in 0..seq_len {
60 input_ids.push(*ids.get(i).unwrap_or(&0) as i64);
61 attention_mask.push(*mask.get(i).unwrap_or(&0) as i64);
62 }
63 }
64
65 Ok(TokenizedBatch {
66 input_ids,
67 attention_mask,
68 seq_len,
69 })
70 }
71
72 pub fn mean_pool(
73 last_hidden_state: &[f32],
74 attention_mask: &[i64],
75 batch_size: usize,
76 seq_len: usize,
77 hidden_size: usize,
78 ) -> Vec<Vec<f32>> {
79 let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
80
81 for (b, out_vec) in result.iter_mut().enumerate() {
82 let mut sum = vec![0.0f32; hidden_size];
83 let mut count = 0.0f32;
84
85 for s in 0..seq_len {
86 if attention_mask[b * seq_len + s] != 0 {
87 let offset = (b * seq_len + s) * hidden_size;
88 for (h, sum_val) in sum.iter_mut().take(hidden_size).enumerate() {
89 *sum_val += last_hidden_state[offset + h];
90 }
91 count += 1.0;
92 }
93 }
94
95 if count > 0.0 {
96 for (h, out_val) in out_vec.iter_mut().enumerate() {
97 *out_val = sum[h] / count;
98 }
99 }
100 }
101
102 for out_vec in &mut result {
103 let norm: f32 = out_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
104 if norm > 0.0 {
105 for val in out_vec.iter_mut() {
106 *val /= norm;
107 }
108 }
109 }
110
111 result
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn mean_pool_l2_normalized() {
121 let last_hidden = vec![
122 1.0_f32, 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
123 ];
124 let mask = vec![1_i64, 0, 1, 1];
125 let result = BertTokenizer::mean_pool(&last_hidden, &mask, 1, 4, 3);
126
127 assert_eq!(result.len(), 1);
128 let norm: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
129 assert!(
130 (norm - 1.0).abs() < 0.001,
131 "expected L2 norm approx 1.0, got {norm}"
132 );
133 }
134
135 #[test]
136 fn mean_pool_multi_batch() {
137 let last_hidden = vec![
138 1.0_f32, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 3.0,
139 0.0, 0.0,
140 ];
141 let mask = vec![1_i64, 1, 1, 1, 0, 0, 1, 1, 1];
142 let result = BertTokenizer::mean_pool(&last_hidden, &mask, 2, 3, 3);
143
144 assert_eq!(result.len(), 2);
145 assert_eq!(result[0].len(), 3);
146 assert_eq!(result[1].len(), 3);
147
148 let norm0: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
149 let norm1: f32 = result[1].iter().map(|x| x * x).sum::<f32>().sqrt();
150 assert!((norm0 - 1.0).abs() < 0.001);
151 assert!((norm1 - 1.0).abs() < 0.001);
152 }
153
154 #[test]
155 fn mean_pool_empty_safe() {
156 let last_hidden = vec![];
157 let mask = vec![];
158 let result = BertTokenizer::mean_pool(&last_hidden, &mask, 0, 0, 0);
159 assert!(result.is_empty());
160 }
161}