use serde::{Deserialize, Serialize};
use crate::error::{RuntimeError, RuntimeResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PoolingMode {
#[default]
Last,
Mean,
Max,
Cls,
}
pub fn pool_hidden_states(
states: &[f32],
seq_len: usize,
hidden_size: usize,
mode: PoolingMode,
) -> RuntimeResult<Vec<f32>> {
if seq_len == 0 {
return Err(RuntimeError::EmptySequence);
}
let expected_len = seq_len * hidden_size;
if states.len() != expected_len {
return Err(RuntimeError::SamplingError {
message: format!(
"pool_hidden_states: states.len()={} != seq_len({}) * hidden_size({}) = {}",
states.len(),
seq_len,
hidden_size,
expected_len,
),
});
}
if hidden_size == 0 {
return Ok(Vec::new());
}
match mode {
PoolingMode::Last => pool_last(states, seq_len, hidden_size),
PoolingMode::Mean => pool_mean(states, seq_len, hidden_size),
PoolingMode::Max => pool_max(states, seq_len, hidden_size),
PoolingMode::Cls => pool_cls(states, hidden_size),
}
}
fn pool_last(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
let start = (seq_len - 1) * hidden_size;
Ok(states[start..start + hidden_size].to_vec())
}
fn pool_mean(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
let mut result = vec![0.0f32; hidden_size];
for row in 0..seq_len {
let offset = row * hidden_size;
for j in 0..hidden_size {
result[j] += states[offset + j];
}
}
let inv_n = 1.0 / seq_len as f32;
for v in &mut result {
*v *= inv_n;
}
Ok(result)
}
fn pool_max(states: &[f32], seq_len: usize, hidden_size: usize) -> RuntimeResult<Vec<f32>> {
let mut result = states[0..hidden_size].to_vec();
for row in 1..seq_len {
let offset = row * hidden_size;
for j in 0..hidden_size {
if states[offset + j] > result[j] {
result[j] = states[offset + j];
}
}
}
Ok(result)
}
fn pool_cls(states: &[f32], hidden_size: usize) -> RuntimeResult<Vec<f32>> {
Ok(states[0..hidden_size].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_states(seq_len: usize, hidden_size: usize) -> Vec<f32> {
let mut v = Vec::with_capacity(seq_len * hidden_size);
for i in 0..seq_len {
for j in 0..hidden_size {
v.push(((i + 1) * (j + 1)) as f32);
}
}
v
}
#[test]
fn pooling_last_matches_last_row() {
let seq_len = 4;
let hidden_size = 3;
let states = make_states(seq_len, hidden_size);
let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Last)
.expect("Last pooling must succeed");
let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
assert_eq!(pooled, expected, "Last pooling should return the last row");
}
#[test]
fn pooling_mean_is_arithmetic_mean() {
let seq_len = 3;
let hidden_size = 2;
let states = make_states(seq_len, hidden_size);
let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Mean)
.expect("Mean pooling must succeed");
assert_eq!(
pooled.len(),
hidden_size,
"output length must equal hidden_size"
);
assert!(
(pooled[0] - 2.0).abs() < 1e-5,
"Mean at j=0 should be 2.0, got {}",
pooled[0]
);
assert!(
(pooled[1] - 4.0).abs() < 1e-5,
"Mean at j=1 should be 4.0, got {}",
pooled[1]
);
}
#[test]
fn pooling_max_is_elementwise_max() {
let seq_len = 4;
let hidden_size = 3;
let states = make_states(seq_len, hidden_size);
let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Max)
.expect("Max pooling must succeed");
let expected: Vec<f32> = (0..hidden_size).map(|j| 4.0 * (j + 1) as f32).collect();
assert_eq!(
pooled, expected,
"Max pooling should return elementwise max"
);
}
#[test]
fn pooling_cls_matches_first_row() {
let seq_len = 5;
let hidden_size = 4;
let states = make_states(seq_len, hidden_size);
let pooled = pool_hidden_states(&states, seq_len, hidden_size, PoolingMode::Cls)
.expect("Cls pooling must succeed");
let expected: Vec<f32> = (0..hidden_size).map(|j| (j + 1) as f32).collect();
assert_eq!(pooled, expected, "Cls pooling should return the first row");
}
#[test]
fn pooling_empty_sequence_errors() {
let result = pool_hidden_states(&[], 0, 4, PoolingMode::Last);
assert!(
matches!(result, Err(RuntimeError::EmptySequence)),
"empty seq_len must produce EmptySequence error"
);
}
#[test]
fn pooling_wrong_buffer_length_errors() {
let states = vec![0.0f32; 10]; let result = pool_hidden_states(&states, 2, 3, PoolingMode::Mean);
assert!(
matches!(result, Err(RuntimeError::SamplingError { .. })),
"mismatched buffer must produce SamplingError"
);
}
#[test]
fn pooling_single_token_sequence() {
let seq_len = 1;
let hidden_size = 3;
let states = vec![1.0f32, 2.0, 3.0];
for mode in [
PoolingMode::Last,
PoolingMode::Mean,
PoolingMode::Max,
PoolingMode::Cls,
] {
let pooled = pool_hidden_states(&states, seq_len, hidden_size, mode)
.unwrap_or_else(|e| panic!("mode {mode:?} failed: {e}"));
assert_eq!(
pooled, states,
"single-token pooling with {mode:?} must return the only row unchanged"
);
}
}
#[test]
fn pooling_mean_constant_columns() {
let states = vec![5.0f32, 7.0, 5.0, 7.0, 5.0, 7.0];
let pooled = pool_hidden_states(&states, 3, 2, PoolingMode::Mean)
.expect("mean pooling must succeed");
assert!(
(pooled[0] - 5.0).abs() < 1e-6,
"mean of constant column 0 must be 5.0"
);
assert!(
(pooled[1] - 7.0).abs() < 1e-6,
"mean of constant column 1 must be 7.0"
);
}
}