Skip to main content

oxillama_runtime/
embedding.rs

1//! Embedding pooling modes and the `pool_hidden_states` kernel.
2//!
3//! An embedding model produces a per-token hidden state matrix of shape
4//! `[seq_len, hidden_size]` (stored row-major as a flat `Vec<f32>`). This
5//! module provides four standard strategies to collapse that matrix into a
6//! single `hidden_size`-dimensional vector suitable for similarity search,
7//! retrieval, and reranking.
8//!
9//! # Modes
10//!
11//! | Mode   | Description |
12//! |--------|-------------|
13//! | `Last` | Return the hidden state of the **last** token (default). Appropriate for causal / decoder-only models such as LLaMA. |
14//! | `Mean` | Elementwise arithmetic mean across **all** tokens. Standard choice for BERT-style models. |
15//! | `Max`  | Elementwise maximum across all tokens. Captures the "most activated" feature in the sequence. |
16//! | `Cls`  | Return the hidden state of the **first** token (CLS). Used by BERT and its variants. |
17//!
18//! # Usage
19//!
20//! ```ignore
21//! use oxillama_runtime::embedding::{pool_hidden_states, PoolingMode};
22//!
23//! // hidden is a flat [seq_len × hidden_size] matrix.
24//! let pooled = pool_hidden_states(&hidden, seq_len, hidden_size, PoolingMode::Mean)?;
25//! ```
26
27use serde::{Deserialize, Serialize};
28
29use crate::error::{RuntimeError, RuntimeResult};
30
31/// Strategy for collapsing a sequence of hidden states into a single vector.
32///
33/// The default mode is [`PoolingMode::Last`], which is appropriate for causal
34/// LLMs (the last token's hidden state encodes the full context).
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
36pub enum PoolingMode {
37    /// Return the hidden state of the **last** token in the sequence.
38    ///
39    /// Shape: `states[(seq_len - 1) * hidden_size .. seq_len * hidden_size]`.
40    /// This is the natural pooling choice for causal / decoder-only models.
41    #[default]
42    Last,
43
44    /// Elementwise arithmetic mean across **all** `seq_len` token hidden states.
45    ///
46    /// `result[j] = (1 / seq_len) * Σ_{i=0}^{seq_len-1} states[i * hidden_size + j]`
47    Mean,
48
49    /// Elementwise maximum across **all** `seq_len` token hidden states.
50    ///
51    /// `result[j] = max_{i=0..seq_len-1} states[i * hidden_size + j]`
52    Max,
53
54    /// Return the hidden state of the **first** token (CLS position).
55    ///
56    /// Shape: `states[0 .. hidden_size]`.
57    /// This is the standard pooling choice for BERT-style encoder models.
58    Cls,
59}
60
61/// Pool a sequence of per-token hidden states into a single vector.
62///
63/// # Arguments
64///
65/// * `states` — Flat `[seq_len × hidden_size]` matrix in row-major order.
66///   Total length must equal `seq_len * hidden_size`.
67/// * `seq_len` — Number of tokens in the sequence (rows). Must be ≥ 1.
68/// * `hidden_size` — Dimensionality of each token's hidden state (columns).
69/// * `mode` — Pooling strategy; see [`PoolingMode`].
70///
71/// # Returns
72///
73/// A `Vec<f32>` of length `hidden_size`.
74///
75/// # Errors
76///
77/// * [`RuntimeError::EmptySequence`] — when `seq_len == 0`.
78/// * [`RuntimeError::SamplingError`] — when `states.len() != seq_len * hidden_size`
79///   (indicates a mismatched buffer from the forward pass).
80pub fn pool_hidden_states(
81    states: &[f32],
82    seq_len: usize,
83    hidden_size: usize,
84    mode: PoolingMode,
85) -> RuntimeResult<Vec<f32>> {
86    if seq_len == 0 {
87        return Err(RuntimeError::EmptySequence);
88    }
89
90    let expected_len = seq_len * hidden_size;
91    if states.len() != expected_len {
92        return Err(RuntimeError::SamplingError {
93            message: format!(
94                "pool_hidden_states: states.len()={} != seq_len({}) * hidden_size({}) = {}",
95                states.len(),
96                seq_len,
97                hidden_size,
98                expected_len,
99            ),
100        });
101    }
102
103    if hidden_size == 0 {
104        return Ok(Vec::new());
105    }
106
107    match mode {
108        PoolingMode::Last => pool_last(states, seq_len, hidden_size),
109        PoolingMode::Mean => pool_mean(states, seq_len, hidden_size),
110        PoolingMode::Max => pool_max(states, seq_len, hidden_size),
111        PoolingMode::Cls => pool_cls(states, hidden_size),
112    }
113}
114
115// ── Per-mode kernels ──────────────────────────────────────────────────────────
116
117/// Return a copy of the last row: `states[(seq_len-1)*hidden_size..]`.
118fn pool_last(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
119    let start = (seq_len - 1) * hidden_size;
120    Ok(states[start..start + hidden_size].to_vec())
121}
122
123/// Elementwise arithmetic mean across all `seq_len` rows.
124fn pool_mean(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
125    let mut result = vec![0.0f32; hidden_size];
126    for row in 0..seq_len {
127        let offset = row * hidden_size;
128        for j in 0..hidden_size {
129            result[j] += states[offset + j];
130        }
131    }
132    let inv_n = 1.0 / seq_len as f32;
133    for v in &mut result {
134        *v *= inv_n;
135    }
136    Ok(result)
137}
138
139/// Elementwise maximum across all `seq_len` rows.
140fn pool_max(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
141    // Initialise with the first row so that we handle negative values correctly.
142    let mut result = states[0..hidden_size].to_vec();
143    for row in 1..seq_len {
144        let offset = row * hidden_size;
145        for j in 0..hidden_size {
146            if states[offset + j] > result[j] {
147                result[j] = states[offset + j];
148            }
149        }
150    }
151    Ok(result)
152}
153
154/// Return a copy of the first row (CLS token): `states[0..hidden_size]`.
155fn pool_cls(states: &[f32], hidden_size: usize) -> RuntimeResult<Vec<f32>> {
156    Ok(states[0..hidden_size].to_vec())
157}
158
159// ── Unit tests ────────────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    /// Build a flat `[seq_len × hidden_size]` matrix where `states[i][j] = (i+1) * (j+1)`.
166    fn make_states(seq_len: usize, hidden_size: usize) -> Vec<f32> {
167        let mut v = Vec::with_capacity(seq_len * hidden_size);
168        for i in 0..seq_len {
169            for j in 0..hidden_size {
170                v.push(((i + 1) * (j + 1)) as f32);
171            }
172        }
173        v
174    }
175
176    /// `PoolingMode::Last` must return exactly the last row of the matrix.
177    #[test]
178    fn pooling_last_matches_last_row() {
179        let seq_len = 4;
180        let hidden_size = 3;
181        let states = make_states(seq_len, hidden_size);
182        let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Last)
183            .expect("Last pooling must succeed");
184
185        // Last row (i=3): values are (3+1)*(j+1) = 4*(j+1)
186        let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
187        assert_eq!(pooled, expected, "Last pooling should return the last row");
188    }
189
190    /// `PoolingMode::Mean` must return the elementwise arithmetic mean.
191    #[test]
192    fn pooling_mean_is_arithmetic_mean() {
193        let seq_len = 3;
194        let hidden_size = 2;
195        // states[i][j] = (i+1)*(j+1)
196        // For j=0: values are 1*1=1, 2*1=2, 3*1=3 → mean = 2.0
197        // For j=1: values are 1*2=2, 2*2=4, 3*2=6 → mean = 4.0
198        let states = make_states(seq_len, hidden_size);
199        let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Mean)
200            .expect("Mean pooling must succeed");
201
202        assert_eq!(
203            pooled.len(),
204            hidden_size,
205            "output length must equal hidden_size"
206        );
207        assert!(
208            (pooled[0] - 2.0).abs() < 1e-5,
209            "Mean at j=0 should be 2.0, got {}",
210            pooled[0]
211        );
212        assert!(
213            (pooled[1] - 4.0).abs() < 1e-5,
214            "Mean at j=1 should be 4.0, got {}",
215            pooled[1]
216        );
217    }
218
219    /// `PoolingMode::Max` must return the elementwise maximum across rows.
220    #[test]
221    fn pooling_max_is_elementwise_max() {
222        let seq_len = 4;
223        let hidden_size = 3;
224        // states[i][j] = (i+1)*(j+1), so the max across i is at i=3 (last row).
225        // Max[j] = 4*(j+1).
226        let states = make_states(seq_len, hidden_size);
227        let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Max)
228            .expect("Max pooling must succeed");
229
230        let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
231        assert_eq!(
232            pooled, expected,
233            "Max pooling should return elementwise max"
234        );
235    }
236
237    /// `PoolingMode::Cls` must return exactly the first row of the matrix.
238    #[test]
239    fn pooling_cls_matches_first_row() {
240        let seq_len = 5;
241        let hidden_size = 4;
242        let states = make_states(seq_len, hidden_size);
243        let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Cls)
244            .expect("Cls pooling must succeed");
245
246        // First row (i=0): values are (0+1)*(j+1) = j+1
247        let expected: Vec<f32> = (0..hidden_size).map(|j| (j + 1) as f32).collect();
248        assert_eq!(pooled, expected, "Cls pooling should return the first row");
249    }
250
251    /// `pool_hidden_states` with `seq_len == 0` must return `EmptySequence`.
252    #[test]
253    fn pooling_empty_sequence_errors() {
254        let result = pool_hidden_states(&[], 0, 4, PoolingMode::Last);
255        assert!(
256            matches!(result, Err(RuntimeError::EmptySequence)),
257            "empty seq_len must produce EmptySequence error"
258        );
259    }
260
261    /// Mismatched buffer length must return a `SamplingError`.
262    #[test]
263    fn pooling_wrong_buffer_length_errors() {
264        let states = vec![0.0f32; 10]; // 10 ≠ 2*3=6
265        let result = pool_hidden_states(&states, 2, 3, PoolingMode::Mean);
266        assert!(
267            matches!(result, Err(RuntimeError::SamplingError { .. })),
268            "mismatched buffer must produce SamplingError"
269        );
270    }
271
272    /// All four modes must work correctly for a single-token sequence.
273    #[test]
274    fn pooling_single_token_sequence() {
275        let seq_len = 1;
276        let hidden_size = 3;
277        let states = vec![1.0f32, 2.0, 3.0];
278
279        for mode in [
280            PoolingMode::Last,
281            PoolingMode::Mean,
282            PoolingMode::Max,
283            PoolingMode::Cls,
284        ] {
285            let pooled = pool_hidden_states(&states, seq_len, hidden_size, mode)
286                .unwrap_or_else(|e| panic!("mode {mode:?} failed: {e}"));
287            assert_eq!(
288                pooled, states,
289                "single-token pooling with {mode:?} must return the only row unchanged"
290            );
291        }
292    }
293
294    /// Mean pooling of a matrix where all values in each column are the same
295    /// must return those column values.
296    #[test]
297    fn pooling_mean_constant_columns() {
298        // 3 rows, 2 columns. Col 0 all 5.0, col 1 all 7.0.
299        let states = vec![5.0f32, 7.0, 5.0, 7.0, 5.0, 7.0];
300        let pooled = pool_hidden_states(&states, 3, 2, PoolingMode::Mean)
301            .expect("mean pooling must succeed");
302        assert!(
303            (pooled[0] - 5.0).abs() < 1e-6,
304            "mean of constant column 0 must be 5.0"
305        );
306        assert!(
307            (pooled[1] - 7.0).abs() < 1e-6,
308            "mean of constant column 1 must be 7.0"
309        );
310    }
311}