oxillama_runtime/embedding.rs
1//! Embedding pooling modes and the `pool_hidden_states` kernel.
2//!
3//! An embedding model produces a per-token hidden state matrix of shape
4//! `[seq_len, hidden_size]` (stored row-major as a flat `Vec<f32>`). This
5//! module provides four standard strategies to collapse that matrix into a
6//! single `hidden_size`-dimensional vector suitable for similarity search,
7//! retrieval, and reranking.
8//!
9//! # Modes
10//!
11//! | Mode | Description |
12//! |--------|-------------|
13//! | `Last` | Return the hidden state of the **last** token (default). Appropriate for causal / decoder-only models such as LLaMA. |
14//! | `Mean` | Elementwise arithmetic mean across **all** tokens. Standard choice for BERT-style models. |
15//! | `Max` | Elementwise maximum across all tokens. Captures the "most activated" feature in the sequence. |
16//! | `Cls` | Return the hidden state of the **first** token (CLS). Used by BERT and its variants. |
17//!
18//! # Usage
19//!
20//! ```ignore
21//! use oxillama_runtime::embedding::{pool_hidden_states, PoolingMode};
22//!
23//! // hidden is a flat [seq_len × hidden_size] matrix.
24//! let pooled = pool_hidden_states(&hidden, seq_len, hidden_size, PoolingMode::Mean)?;
25//! ```
26
27use serde::{Deserialize, Serialize};
28
29use crate::error::{RuntimeError, RuntimeResult};
30
31/// Strategy for collapsing a sequence of hidden states into a single vector.
32///
33/// The default mode is [`PoolingMode::Last`], which is appropriate for causal
34/// LLMs (the last token's hidden state encodes the full context).
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
36pub enum PoolingMode {
37 /// Return the hidden state of the **last** token in the sequence.
38 ///
39 /// Shape: `states[(seq_len - 1) * hidden_size .. seq_len * hidden_size]`.
40 /// This is the natural pooling choice for causal / decoder-only models.
41 #[default]
42 Last,
43
44 /// Elementwise arithmetic mean across **all** `seq_len` token hidden states.
45 ///
46 /// `result[j] = (1 / seq_len) * Σ_{i=0}^{seq_len-1} states[i * hidden_size + j]`
47 Mean,
48
49 /// Elementwise maximum across **all** `seq_len` token hidden states.
50 ///
51 /// `result[j] = max_{i=0..seq_len-1} states[i * hidden_size + j]`
52 Max,
53
54 /// Return the hidden state of the **first** token (CLS position).
55 ///
56 /// Shape: `states[0 .. hidden_size]`.
57 /// This is the standard pooling choice for BERT-style encoder models.
58 Cls,
59}
60
61/// Pool a sequence of per-token hidden states into a single vector.
62///
63/// # Arguments
64///
65/// * `states` — Flat `[seq_len × hidden_size]` matrix in row-major order.
66/// Total length must equal `seq_len * hidden_size`.
67/// * `seq_len` — Number of tokens in the sequence (rows). Must be ≥ 1.
68/// * `hidden_size` — Dimensionality of each token's hidden state (columns).
69/// * `mode` — Pooling strategy; see [`PoolingMode`].
70///
71/// # Returns
72///
73/// A `Vec<f32>` of length `hidden_size`.
74///
75/// # Errors
76///
77/// * [`RuntimeError::EmptySequence`] — when `seq_len == 0`.
78/// * [`RuntimeError::SamplingError`] — when `states.len() != seq_len * hidden_size`
79/// (indicates a mismatched buffer from the forward pass).
80pub fn pool_hidden_states(
81 states: &[f32],
82 seq_len: usize,
83 hidden_size: usize,
84 mode: PoolingMode,
85) -> RuntimeResult<Vec<f32>> {
86 if seq_len == 0 {
87 return Err(RuntimeError::EmptySequence);
88 }
89
90 let expected_len = seq_len * hidden_size;
91 if states.len() != expected_len {
92 return Err(RuntimeError::SamplingError {
93 message: format!(
94 "pool_hidden_states: states.len()={} != seq_len({}) * hidden_size({}) = {}",
95 states.len(),
96 seq_len,
97 hidden_size,
98 expected_len,
99 ),
100 });
101 }
102
103 if hidden_size == 0 {
104 return Ok(Vec::new());
105 }
106
107 match mode {
108 PoolingMode::Last => pool_last(states, seq_len, hidden_size),
109 PoolingMode::Mean => pool_mean(states, seq_len, hidden_size),
110 PoolingMode::Max => pool_max(states, seq_len, hidden_size),
111 PoolingMode::Cls => pool_cls(states, hidden_size),
112 }
113}
114
115// ── Per-mode kernels ──────────────────────────────────────────────────────────
116
117/// Return a copy of the last row: `states[(seq_len-1)*hidden_size..]`.
118fn pool_last(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
119 let start = (seq_len - 1) * hidden_size;
120 Ok(states[start..start + hidden_size].to_vec())
121}
122
123/// Elementwise arithmetic mean across all `seq_len` rows.
124fn pool_mean(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
125 let mut result = vec![0.0f32; hidden_size];
126 for row in 0..seq_len {
127 let offset = row * hidden_size;
128 for j in 0..hidden_size {
129 result[j] += states[offset + j];
130 }
131 }
132 let inv_n = 1.0 / seq_len as f32;
133 for v in &mut result {
134 *v *= inv_n;
135 }
136 Ok(result)
137}
138
139/// Elementwise maximum across all `seq_len` rows.
140fn pool_max(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
141 // Initialise with the first row so that we handle negative values correctly.
142 let mut result = states[0..hidden_size].to_vec();
143 for row in 1..seq_len {
144 let offset = row * hidden_size;
145 for j in 0..hidden_size {
146 if states[offset + j] > result[j] {
147 result[j] = states[offset + j];
148 }
149 }
150 }
151 Ok(result)
152}
153
154/// Return a copy of the first row (CLS token): `states[0..hidden_size]`.
155fn pool_cls(states: &[f32], hidden_size: usize) -> RuntimeResult<Vec<f32>> {
156 Ok(states[0..hidden_size].to_vec())
157}
158
159// ── Unit tests ────────────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 /// Build a flat `[seq_len × hidden_size]` matrix where `states[i][j] = (i+1) * (j+1)`.
166 fn make_states(seq_len: usize, hidden_size: usize) -> Vec<f32> {
167 let mut v = Vec::with_capacity(seq_len * hidden_size);
168 for i in 0..seq_len {
169 for j in 0..hidden_size {
170 v.push(((i + 1) * (j + 1)) as f32);
171 }
172 }
173 v
174 }
175
176 /// `PoolingMode::Last` must return exactly the last row of the matrix.
177 #[test]
178 fn pooling_last_matches_last_row() {
179 let seq_len = 4;
180 let hidden_size = 3;
181 let states = make_states(seq_len, hidden_size);
182 let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Last)
183 .expect("Last pooling must succeed");
184
185 // Last row (i=3): values are (3+1)*(j+1) = 4*(j+1)
186 let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
187 assert_eq!(pooled, expected, "Last pooling should return the last row");
188 }
189
190 /// `PoolingMode::Mean` must return the elementwise arithmetic mean.
191 #[test]
192 fn pooling_mean_is_arithmetic_mean() {
193 let seq_len = 3;
194 let hidden_size = 2;
195 // states[i][j] = (i+1)*(j+1)
196 // For j=0: values are 1*1=1, 2*1=2, 3*1=3 → mean = 2.0
197 // For j=1: values are 1*2=2, 2*2=4, 3*2=6 → mean = 4.0
198 let states = make_states(seq_len, hidden_size);
199 let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Mean)
200 .expect("Mean pooling must succeed");
201
202 assert_eq!(
203 pooled.len(),
204 hidden_size,
205 "output length must equal hidden_size"
206 );
207 assert!(
208 (pooled[0] - 2.0).abs() < 1e-5,
209 "Mean at j=0 should be 2.0, got {}",
210 pooled[0]
211 );
212 assert!(
213 (pooled[1] - 4.0).abs() < 1e-5,
214 "Mean at j=1 should be 4.0, got {}",
215 pooled[1]
216 );
217 }
218
219 /// `PoolingMode::Max` must return the elementwise maximum across rows.
220 #[test]
221 fn pooling_max_is_elementwise_max() {
222 let seq_len = 4;
223 let hidden_size = 3;
224 // states[i][j] = (i+1)*(j+1), so the max across i is at i=3 (last row).
225 // Max[j] = 4*(j+1).
226 let states = make_states(seq_len, hidden_size);
227 let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Max)
228 .expect("Max pooling must succeed");
229
230 let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
231 assert_eq!(
232 pooled, expected,
233 "Max pooling should return elementwise max"
234 );
235 }
236
237 /// `PoolingMode::Cls` must return exactly the first row of the matrix.
238 #[test]
239 fn pooling_cls_matches_first_row() {
240 let seq_len = 5;
241 let hidden_size = 4;
242 let states = make_states(seq_len, hidden_size);
243 let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Cls)
244 .expect("Cls pooling must succeed");
245
246 // First row (i=0): values are (0+1)*(j+1) = j+1
247 let expected: Vec<f32> = (0..hidden_size).map(|j| (j + 1) as f32).collect();
248 assert_eq!(pooled, expected, "Cls pooling should return the first row");
249 }
250
251 /// `pool_hidden_states` with `seq_len == 0` must return `EmptySequence`.
252 #[test]
253 fn pooling_empty_sequence_errors() {
254 let result = pool_hidden_states(&[], 0, 4, PoolingMode::Last);
255 assert!(
256 matches!(result, Err(RuntimeError::EmptySequence)),
257 "empty seq_len must produce EmptySequence error"
258 );
259 }
260
261 /// Mismatched buffer length must return a `SamplingError`.
262 #[test]
263 fn pooling_wrong_buffer_length_errors() {
264 let states = vec![0.0f32; 10]; // 10 ≠ 2*3=6
265 let result = pool_hidden_states(&states, 2, 3, PoolingMode::Mean);
266 assert!(
267 matches!(result, Err(RuntimeError::SamplingError { .. })),
268 "mismatched buffer must produce SamplingError"
269 );
270 }
271
272 /// All four modes must work correctly for a single-token sequence.
273 #[test]
274 fn pooling_single_token_sequence() {
275 let seq_len = 1;
276 let hidden_size = 3;
277 let states = vec![1.0f32, 2.0, 3.0];
278
279 for mode in [
280 PoolingMode::Last,
281 PoolingMode::Mean,
282 PoolingMode::Max,
283 PoolingMode::Cls,
284 ] {
285 let pooled = pool_hidden_states(&states, seq_len, hidden_size, mode)
286 .unwrap_or_else(|e| panic!("mode {mode:?} failed: {e}"));
287 assert_eq!(
288 pooled, states,
289 "single-token pooling with {mode:?} must return the only row unchanged"
290 );
291 }
292 }
293
294 /// Mean pooling of a matrix where all values in each column are the same
295 /// must return those column values.
296 #[test]
297 fn pooling_mean_constant_columns() {
298 // 3 rows, 2 columns. Col 0 all 5.0, col 1 all 7.0.
299 let states = vec![5.0f32, 7.0, 5.0, 7.0, 5.0, 7.0];
300 let pooled = pool_hidden_states(&states, 3, 2, PoolingMode::Mean)
301 .expect("mean pooling must succeed");
302 assert!(
303 (pooled[0] - 5.0).abs() < 1e-6,
304 "mean of constant column 0 must be 5.0"
305 );
306 assert!(
307 (pooled[1] - 7.0).abs() < 1e-6,
308 "mean of constant column 1 must be 7.0"
309 );
310 }
311}