use smol_str::SmolStr;
use std::path::Path;
use crate::error::{
CapExceededPayload, Error, FileIoPayload, FileOp, InvariantViolationPayload, OutOfRangePayload,
ParsePayload, Result, UnknownEnumValuePayload,
};
use super::pooling::PoolingStrategy;
#[derive(Debug)]
struct PoolingJsonParseError {
pos: usize,
msg: String,
}
impl std::fmt::Display for PoolingJsonParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "at byte {}: {}", self.pos, self.msg)
}
}
impl std::error::Error for PoolingJsonParseError {}
fn parse_err_at(pos: usize, msg: impl Into<String>) -> Error {
Error::Parse(ParsePayload::new(
"embeddings::pooling_from_st_config",
"1_Pooling/config.json",
PoolingJsonParseError {
pos,
msg: msg.into(),
},
))
}
const MAX_ST_POOLING_CONFIG_BYTES: u64 = 1 << 20;
const LEGACY_KEYS: &[(&str, &str)] = &[
("pooling_mode_cls_token", "cls"),
("pooling_mode_max_tokens", "max"),
("pooling_mode_mean_tokens", "mean"),
("pooling_mode_mean_sqrt_len_tokens", "mean_sqrt_len_tokens"),
("pooling_mode_weightedmean_tokens", "weightedmean"),
("pooling_mode_lasttoken", "lasttoken"),
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StPoolingConfig {
strategy: PoolingStrategy,
normalize: bool,
dimension: Option<usize>,
}
impl StPoolingConfig {
pub fn new(strategy: PoolingStrategy, normalize: bool, dimension: Option<usize>) -> Self {
Self {
strategy,
normalize,
dimension,
}
}
#[inline(always)]
pub fn strategy(&self) -> PoolingStrategy {
self.strategy
}
#[inline(always)]
pub fn normalize(&self) -> bool {
self.normalize
}
#[inline(always)]
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
}
#[derive(Debug)]
pub(crate) enum JVal<'a> {
Null,
Bool(bool),
Str(SmolStr),
Num(&'a str),
Array,
Object,
}
const KNOWN_KEYS: &[&str] = &[
"pooling_mode",
"pooling_mode_cls_token",
"pooling_mode_max_tokens",
"pooling_mode_mean_tokens",
"pooling_mode_mean_sqrt_len_tokens",
"pooling_mode_weightedmean_tokens",
"pooling_mode_lasttoken",
"include_prompt",
"word_embedding_dimension",
"embedding_dimension",
];
fn is_known_key(k: &str) -> bool {
KNOWN_KEYS.contains(&k)
}
const MAX_NESTING_DEPTH: usize = 128;
struct Scanner<'a> {
src: &'a [u8],
pos: usize,
depth: usize,
}
impl<'a> Scanner<'a> {
fn new(src: &'a [u8]) -> Self {
Self {
src,
pos: 0,
depth: 0,
}
}
fn enter(&mut self) -> Result<()> {
if self.depth >= MAX_NESTING_DEPTH {
return Err(self.err(format!(
"nested object/array depth exceeds the {MAX_NESTING_DEPTH}-level cap; \
refusing to recurse further (defends against stack-overflow on \
hostile pooling config input)"
)));
}
self.depth += 1;
Ok(())
}
fn leave(&mut self) {
self.depth = self.depth.saturating_sub(1);
}
fn err(&self, msg: impl Into<String>) -> Error {
parse_err_at(self.pos, msg)
}
fn peek(&self) -> Option<u8> {
self.src.get(self.pos).copied()
}
fn bump(&mut self) -> Option<u8> {
let b = self.peek()?;
self.pos += 1;
Some(b)
}
fn skip_ws(&mut self) {
self.pos += memspan::skip::skip_whitespace(&self.src[self.pos..]);
}
fn expect(&mut self, b: u8, ctx: &str) -> Result<()> {
self.skip_ws();
match self.peek() {
Some(c) if c == b => {
self.pos += 1;
Ok(())
}
Some(c) => Err(self.err(format!(
"expected {:?} {ctx} but found {:?}",
b as char, c as char
))),
None => Err(self.err(format!(
"expected {:?} {ctx} but reached end of input",
b as char
))),
}
}
fn parse_top_object(&mut self) -> Result<Vec<(SmolStr, JVal<'a>)>> {
self.skip_ws();
self.expect(b'{', "at start of pooling config")?;
self.enter()?;
let mut out: Vec<(SmolStr, JVal<'a>)> = Vec::with_capacity(KNOWN_KEYS.len());
self.skip_ws();
if self.peek() == Some(b'}') {
self.pos += 1;
self.leave();
return Ok(out);
}
loop {
self.skip_ws();
let key = self.parse_string("for object key")?;
self.expect(b':', &format!("after key {key:?}"))?;
self.skip_ws();
if is_known_key(&key) {
let val = self.parse_value(&format!("for value of key {key:?}"))?;
if let Some(idx) = out
.iter()
.position(|(k, _): &(SmolStr, JVal<'_>)| k == &key)
{
out.swap_remove(idx);
}
out.push((key, val));
} else {
self.skip_value(&format!("for value of key {key:?}"))?;
}
self.skip_ws();
match self.peek() {
Some(b',') => {
self.pos += 1;
self.skip_ws();
if self.peek() == Some(b'}') {
return Err(self.err("trailing comma before `}` is not valid JSON"));
}
}
Some(b'}') => {
self.pos += 1;
self.leave();
return Ok(out);
}
Some(c) => {
return Err(self.err(format!(
"expected `,` or `}}` in object but found {:?}",
c as char
)));
}
None => return Err(self.err("expected `,` or `}` in object but reached end of input")),
}
}
}
fn parse_value(&mut self, ctx: &str) -> Result<JVal<'a>> {
self.skip_ws();
match self.peek() {
Some(b'"') => Ok(JVal::Str(self.parse_string(ctx)?)),
Some(b'{') => {
self.skip_object()?;
Ok(JVal::Object)
}
Some(b'[') => {
self.skip_array()?;
Ok(JVal::Array)
}
Some(b't') | Some(b'f') => Ok(JVal::Bool(self.parse_bool(ctx)?)),
Some(b'n') => {
self.parse_keyword("null", ctx)?;
Ok(JVal::Null)
}
Some(c) if c == b'-' || c.is_ascii_digit() => Ok(JVal::Num(self.parse_number_slice(ctx)?)),
Some(c) => Err(self.err(format!(
"unexpected character {:?} while parsing value {ctx}",
c as char
))),
None => Err(self.err(format!("unexpected end of input while parsing value {ctx}"))),
}
}
fn skip_value(&mut self, ctx: &str) -> Result<()> {
self.skip_ws();
match self.peek() {
Some(b'"') => self.skip_string(ctx),
Some(b'{') => self.skip_object(),
Some(b'[') => self.skip_array(),
Some(b't') | Some(b'f') => {
self.parse_bool(ctx)?;
Ok(())
}
Some(b'n') => self.parse_keyword("null", ctx),
Some(c) if c == b'-' || c.is_ascii_digit() => {
self.parse_number_slice(ctx)?;
Ok(())
}
Some(c) => Err(self.err(format!(
"unexpected character {:?} while skipping value {ctx}",
c as char
))),
None => Err(self.err(format!(
"unexpected end of input while skipping value {ctx}"
))),
}
}
fn skip_string(&mut self, ctx: &str) -> Result<()> {
self.skip_ws();
if self.peek() != Some(b'"') {
return Err(self.err(format!("expected string {ctx}")));
}
let open_pos = self.pos;
self.pos += 1; loop {
let tail = &self.src[self.pos..];
let n = memspan::skip::skip_until(tail, *b"\"\\").ok_or_else(|| {
parse_err_at(
open_pos,
"unterminated string starting at this position".to_string(),
)
})?;
let chunk = &tail[..n];
if let Some(bad) = chunk.iter().position(|&b| b < 0x20) {
return Err(parse_err_at(
self.pos + bad,
format!(
"control character 0x{:02X} in string body {ctx} (RFC 8259 §7)",
chunk[bad]
),
));
}
self.pos += n;
let boundary = match self.bump() {
Some(b) => b,
None => unreachable!("skip_until found a needle so self.bump cannot be None"),
};
if boundary == b'"' {
return Ok(());
}
let esc = match self.bump() {
Some(b) => b,
None => {
return Err(parse_err_at(
open_pos,
"unterminated escape in string starting at this position".to_string(),
));
}
};
match esc {
b'"' | b'\\' | b'/' | b'b' | b'f' | b'n' | b'r' | b't' => {}
b'u' => {
let cp = self.parse_unicode_escape()?;
if (0xD800..=0xDBFF).contains(&cp) {
if self.bump() != Some(b'\\') || self.bump() != Some(b'u') {
return Err(self.err("expected low surrogate (`\\uDCxx`) after high surrogate"));
}
let low = self.parse_unicode_escape()?;
if !(0xDC00..=0xDFFF).contains(&low) {
return Err(self.err(format!("expected low surrogate but got U+{low:04X}")));
}
} else if (0xDC00..=0xDFFF).contains(&cp) {
return Err(self.err(format!("unpaired low surrogate U+{cp:04X}")));
} else if char::from_u32(cp).is_none() {
return Err(self.err(format!("invalid Unicode codepoint U+{cp:X}")));
}
}
other => {
return Err(self.err(format!(
"invalid escape sequence `\\{}` in string {ctx}",
other as char
)));
}
}
}
}
fn parse_string(&mut self, ctx: &str) -> Result<SmolStr> {
self.skip_ws();
if self.peek() != Some(b'"') {
return Err(self.err(format!("expected string {ctx}")));
}
let open_pos = self.pos;
self.pos += 1; let body_start = self.pos;
let tail = &self.src[body_start..];
let first = memspan::skip::skip_until(tail, *b"\"\\");
match first {
Some(n) if tail[n] == b'"' => {
let body = &tail[..n];
if let Some(bad) = body.iter().position(|&b| b < 0x20) {
return Err(parse_err_at(
body_start + bad,
format!(
"control character 0x{:02X} in string body {ctx} (RFC 8259 §7)",
body[bad]
),
));
}
let s = std::str::from_utf8(body)
.map_err(|e| self.err(format!("string is not valid UTF-8 {ctx}: {e}")))?;
let out = SmolStr::new(s);
self.pos = body_start + n + 1; Ok(out)
}
Some(_) => {
self.parse_string_with_escapes(ctx, open_pos)
}
None => Err(parse_err_at(
open_pos,
"unterminated string starting at this position".to_string(),
)),
}
}
fn parse_string_with_escapes(&mut self, ctx: &str, open_pos: usize) -> Result<SmolStr> {
let mut buf: Vec<u8> = Vec::new();
loop {
let tail = &self.src[self.pos..];
let n = memspan::skip::skip_until(tail, *b"\"\\").ok_or_else(|| {
parse_err_at(
open_pos,
"unterminated string starting at this position".to_string(),
)
})?;
let chunk = &tail[..n];
if let Some(bad) = chunk.iter().position(|&b| b < 0x20) {
return Err(parse_err_at(
self.pos + bad,
format!(
"control character 0x{:02X} in string body {ctx} (RFC 8259 §7)",
chunk[bad]
),
));
}
buf.extend_from_slice(chunk);
self.pos += n;
let boundary = match self.bump() {
Some(b) => b,
None => unreachable!("skip_until found a needle so self.bump cannot be None"),
};
if boundary == b'"' {
return std::str::from_utf8(&buf)
.map(SmolStr::new)
.map_err(|e| self.err(format!("string is not valid UTF-8 {ctx}: {e}")));
}
let esc = match self.bump() {
Some(b) => b,
None => {
return Err(parse_err_at(
open_pos,
"unterminated escape in string starting at this position".to_string(),
));
}
};
match esc {
b'"' => buf.push(b'"'),
b'\\' => buf.push(b'\\'),
b'/' => buf.push(b'/'),
b'b' => buf.push(0x08),
b'f' => buf.push(0x0C),
b'n' => buf.push(b'\n'),
b'r' => buf.push(b'\r'),
b't' => buf.push(b'\t'),
b'u' => {
let cp = self.parse_unicode_escape()?;
let c = if (0xD800..=0xDBFF).contains(&cp) {
if self.bump() != Some(b'\\') || self.bump() != Some(b'u') {
return Err(self.err("expected low surrogate (`\\uDCxx`) after high surrogate"));
}
let low = self.parse_unicode_escape()?;
if !(0xDC00..=0xDFFF).contains(&low) {
return Err(self.err(format!("expected low surrogate but got U+{low:04X}")));
}
let combined = 0x10000 + ((cp - 0xD800) << 10) + (low - 0xDC00);
char::from_u32(combined)
.ok_or_else(|| self.err(format!("invalid surrogate-pair codepoint U+{combined:X}")))?
} else if (0xDC00..=0xDFFF).contains(&cp) {
return Err(self.err(format!("unpaired low surrogate U+{cp:04X}")));
} else {
char::from_u32(cp)
.ok_or_else(|| self.err(format!("invalid Unicode codepoint U+{cp:X}")))?
};
let mut tmp = [0u8; 4];
buf.extend_from_slice(c.encode_utf8(&mut tmp).as_bytes());
}
other => {
return Err(self.err(format!(
"invalid escape sequence `\\{}` in string {ctx}",
other as char
)));
}
}
}
}
fn parse_unicode_escape(&mut self) -> Result<u32> {
let mut cp: u32 = 0;
for _ in 0..4 {
let b = self
.bump()
.ok_or_else(|| self.err("incomplete `\\uXXXX` escape"))?;
let nib = match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
b'A'..=b'F' => b - b'A' + 10,
other => {
return Err(self.err(format!(
"invalid hex digit {:?} in `\\uXXXX` escape",
other as char
)));
}
};
cp = (cp << 4) | u32::from(nib);
}
Ok(cp)
}
fn parse_bool(&mut self, ctx: &str) -> Result<bool> {
match self.peek() {
Some(b't') => {
self.parse_keyword("true", ctx)?;
Ok(true)
}
Some(b'f') => {
self.parse_keyword("false", ctx)?;
Ok(false)
}
Some(c) => Err(self.err(format!("expected bool {ctx} but found {:?}", c as char))),
None => Err(self.err(format!("expected bool {ctx} but reached end of input"))),
}
}
fn parse_keyword(&mut self, kw: &str, ctx: &str) -> Result<()> {
let bytes = kw.as_bytes();
if self.pos + bytes.len() > self.src.len() {
return Err(self.err(format!(
"expected keyword `{kw}` {ctx} but reached end of input"
)));
}
if &self.src[self.pos..self.pos + bytes.len()] != bytes {
return Err(self.err(format!("expected keyword `{kw}` {ctx}")));
}
self.pos += bytes.len();
Ok(())
}
fn parse_number_slice(&mut self, ctx: &str) -> Result<&'a str> {
let start = self.pos;
if self.peek() == Some(b'-') {
self.pos += 1;
}
match self.peek() {
Some(b'0') => {
self.pos += 1;
}
Some(c) if c.is_ascii_digit() => {
while matches!(self.peek(), Some(c) if c.is_ascii_digit()) {
self.pos += 1;
}
}
Some(c) => return Err(self.err(format!("invalid number {ctx}: unexpected {:?}", c as char))),
None => return Err(self.err(format!("invalid number {ctx}: reached end of input"))),
}
if self.peek() == Some(b'.') {
self.pos += 1;
let frac_start = self.pos;
while matches!(self.peek(), Some(c) if c.is_ascii_digit()) {
self.pos += 1;
}
if self.pos == frac_start {
return Err(self.err(format!("invalid number {ctx}: expected digit after `.`")));
}
}
if matches!(self.peek(), Some(b'e') | Some(b'E')) {
self.pos += 1;
if matches!(self.peek(), Some(b'+') | Some(b'-')) {
self.pos += 1;
}
let exp_start = self.pos;
while matches!(self.peek(), Some(c) if c.is_ascii_digit()) {
self.pos += 1;
}
if self.pos == exp_start {
return Err(self.err(format!(
"invalid number {ctx}: expected digit after exponent marker"
)));
}
}
Ok(
std::str::from_utf8(&self.src[start..self.pos])
.expect("number bytes are ASCII by construction"),
)
}
fn skip_object(&mut self) -> Result<()> {
self.expect(b'{', "at start of nested object")?;
self.enter()?;
self.skip_ws();
if self.peek() == Some(b'}') {
self.pos += 1;
self.leave();
return Ok(());
}
loop {
self.skip_ws();
self.skip_string("for nested object key")?;
self.expect(b':', "after nested object key")?;
self.skip_ws();
self.skip_value("inside nested object")?;
self.skip_ws();
match self.peek() {
Some(b',') => {
self.pos += 1;
self.skip_ws();
if self.peek() == Some(b'}') {
return Err(self.err("trailing comma before `}` is not valid JSON"));
}
}
Some(b'}') => {
self.pos += 1;
self.leave();
return Ok(());
}
Some(c) => {
return Err(self.err(format!(
"expected `,` or `}}` in nested object but found {:?}",
c as char
)));
}
None => {
return Err(self.err("expected `,` or `}` in nested object but reached end of input"));
}
}
}
}
fn skip_array(&mut self) -> Result<()> {
self.expect(b'[', "at start of array")?;
self.enter()?;
self.skip_ws();
if self.peek() == Some(b']') {
self.pos += 1;
self.leave();
return Ok(());
}
loop {
self.skip_ws();
self.skip_value("inside array")?;
self.skip_ws();
match self.peek() {
Some(b',') => {
self.pos += 1;
self.skip_ws();
if self.peek() == Some(b']') {
return Err(self.err("trailing comma before `]` is not valid JSON"));
}
}
Some(b']') => {
self.pos += 1;
self.leave();
return Ok(());
}
Some(c) => {
return Err(self.err(format!(
"expected `,` or `]` in array but found {:?}",
c as char
)));
}
None => return Err(self.err("expected `,` or `]` in array but reached end of input")),
}
}
}
}
pub(crate) fn parse_pooling_json(src: &str) -> Result<Vec<(SmolStr, JVal<'_>)>> {
let mut scanner = Scanner::new(src.as_bytes());
let out = scanner.parse_top_object()?;
scanner.skip_ws();
if scanner.pos != scanner.src.len() {
return Err(scanner.err("trailing data after top-level object"));
}
Ok(out)
}
fn find<'a, 'b>(cfg: &'a [(SmolStr, JVal<'b>)], key: &str) -> Option<&'a JVal<'b>> {
cfg.iter().rev().find(|(k, _)| k == key).map(|(_, v)| v)
}
fn resolve_strategy(cfg: &[(SmolStr, JVal<'_>)]) -> Result<PoolingStrategy> {
if let Some(JVal::Str(mode)) = find(cfg, "pooling_mode") {
return PoolingStrategy::from_mode(mode);
}
if let Some(JVal::Array) = find(cfg, "pooling_mode") {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"pooling config: `pooling_mode`",
"concatenated pooling mode (list) is not supported; only a single pooling mode is allowed",
)));
}
if let Some(v) = find(cfg, "pooling_mode") {
let descr = match v {
JVal::Null => "null".to_string(),
JVal::Bool(b) => format!("bool {b}"),
JVal::Num(n) => format!("number {n}"),
JVal::Object => "object".to_string(),
_ => "an unsupported JSON type".to_string(),
};
return Err(Error::OutOfRange(OutOfRangePayload::new(
"pooling config: `pooling_mode`",
"must be a string or list of strings (a malformed pooling mode is rejected; \
python `pool_by_config` raises `ValueError` for a non-string/non-list mode \
rather than silently falling back to a different strategy)",
descr,
)));
}
let truthy = |k: &str| matches!(find(cfg, k), Some(JVal::Bool(true)));
if truthy("pooling_mode_cls_token") {
return Ok(PoolingStrategy::Cls);
}
if truthy("pooling_mode_mean_tokens") {
return Ok(PoolingStrategy::Mean);
}
if truthy("pooling_mode_max_tokens") {
return Ok(PoolingStrategy::Max);
}
if truthy("pooling_mode_lasttoken") {
return Ok(PoolingStrategy::Last);
}
for (key, name) in LEGACY_KEYS {
if (*name == "weightedmean" || *name == "mean_sqrt_len_tokens") && truthy(key) {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"embeddings::PoolingStrategy (legacy `pooling_mode_*` flag)",
*name,
&["cls", "lasttoken", "max", "mean"],
)));
}
}
let has_legacy = LEGACY_KEYS
.iter()
.any(|(k, _)| cfg.iter().any(|(name, _)| name == *k));
if has_legacy {
return Ok(PoolingStrategy::Mean);
}
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"pooling config",
"declares no pooling mode (no `pooling_mode` and no legacy `pooling_mode_*` flags)",
)))
}
fn parse_dim_number(key: &str, raw: &str) -> Result<usize> {
let context: smol_str::SmolStr = smol_str::format_smolstr!("pooling config: `{key}`");
if raw.starts_with('-') {
return Err(parse_err_at(
0,
format!(
"`{key}` is present but not a non-negative integer (got {raw}); \
a malformed matryoshka dimension is rejected rather than \
silently skipping truncation (which would return a \
full-width embedding the model author did not request)"
),
));
}
if raw.contains('.') || raw.contains('e') || raw.contains('E') {
return Err(parse_err_at(
0,
format!(
"`{key}` is present but not a non-negative integer (got {raw}); \
a malformed matryoshka dimension is rejected rather than \
silently skipping truncation (which would return a \
full-width embedding the model author did not request)"
),
));
}
let v: u64 = raw.parse().map_err(|_| {
let _ = &context;
Error::OutOfRange(OutOfRangePayload::new(
"pooling config: matryoshka dimension",
"must fit in usize (a u64-overflowing integer literal cannot be a matryoshka dimension)",
smol_str::format_smolstr!("{key} = {raw}"),
))
})?;
let v = usize::try_from(v).map_err(|_| {
let _ = &context;
Error::OutOfRange(OutOfRangePayload::new(
"pooling config: matryoshka dimension",
"must fit in usize (a u64 value exceeding usize::MAX cannot be a matryoshka dimension)",
smol_str::format_smolstr!("{key} = {v}"),
))
})?;
if v == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"pooling config: matryoshka dimension",
"must be > 0 (a zero matryoshka dimension would produce an empty embedding; \
rejected rather than silently skipped)",
smol_str::format_smolstr!("{key} = 0"),
)));
}
Ok(v)
}
fn parse_pairs(cfg: &[(SmolStr, JVal<'_>)]) -> Result<StPoolingConfig> {
if let Some(JVal::Bool(false)) = find(cfg, "include_prompt") {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"pooling config: `include_prompt`",
"prompt-aware pooling (include_prompt=false) is not supported",
)));
}
let strategy = resolve_strategy(cfg)?;
let dim_entry = find(cfg, "word_embedding_dimension")
.map(|v| ("word_embedding_dimension", v))
.or_else(|| find(cfg, "embedding_dimension").map(|v| ("embedding_dimension", v)));
let dimension = match dim_entry {
None => None,
Some((key, JVal::Num(raw))) => Some(parse_dim_number(key, raw)?),
Some((key, v)) => {
let descr = match v {
JVal::Null => "null".to_string(),
JVal::Bool(b) => format!("bool {b}"),
JVal::Str(s) => format!("string {s:?}"),
JVal::Array => "array".to_string(),
JVal::Object => "object".to_string(),
JVal::Num(_) => "number".to_string(), };
return Err(Error::OutOfRange(OutOfRangePayload::new(
"pooling config: matryoshka dimension",
"must be a non-negative integer (a malformed matryoshka dimension is rejected \
rather than silently skipping truncation, which would return a full-width \
embedding the model author did not request)",
smol_str::format_smolstr!("{key} = {descr}"),
)));
}
};
Ok(StPoolingConfig::new(strategy, true, dimension))
}
pub fn pooling_from_st_config_str(json: &str) -> Result<StPoolingConfig> {
let pairs = parse_pooling_json(json)?;
parse_pairs(&pairs)
}
pub fn pooling_from_st_config_bytes(json: &[u8]) -> Result<StPoolingConfig> {
let s = std::str::from_utf8(json).map_err(|e| {
Error::Parse(ParsePayload::new(
"embeddings::pooling_from_st_config_bytes",
"1_Pooling/config.json",
e,
))
})?;
pooling_from_st_config_str(s)
}
pub fn pooling_from_st_config_path(model_dir: impl AsRef<Path>) -> Result<StPoolingConfig> {
use std::io::Read;
let path = model_dir.as_ref().join("1_Pooling").join("config.json");
#[cfg(unix)]
let file = {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK | libc::O_CLOEXEC)
.open(&path)
.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"cannot open pooling config",
FileOp::Open,
path.to_path_buf(),
e,
))
})?
};
#[cfg(not(unix))]
let file = std::fs::File::open(&path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"cannot open pooling config",
FileOp::Open,
path.to_path_buf(),
e,
))
})?;
let meta = file.metadata().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"cannot stat opened pooling config",
FileOp::Stat,
path.to_path_buf(),
e,
))
})?;
if !meta.is_file() {
return Err(Error::FileIo(FileIoPayload::new(
"pooling config: opened handle is not a regular file",
FileOp::Stat,
path,
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"not a regular file (FIFO/device/directory/symlink-to-special)",
),
)));
}
let mut bytes = Vec::new();
file
.take(MAX_ST_POOLING_CONFIG_BYTES + 1)
.read_to_end(&mut bytes)
.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"cannot read pooling config",
FileOp::Read,
path.to_path_buf(),
e,
))
})?;
if bytes.len() as u64 > MAX_ST_POOLING_CONFIG_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"embeddings::pooling_from_st_config_path",
"MAX_ST_POOLING_CONFIG_BYTES",
MAX_ST_POOLING_CONFIG_BYTES,
bytes.len() as u64,
)));
}
pooling_from_st_config_bytes(&bytes)
}
#[cfg(test)]
mod tests {
use super::{JVal, parse_pooling_json, pooling_from_st_config_str};
#[test]
fn parse_pooling_json_mean_only() {
let pairs = parse_pooling_json(r#"{"pooling_mode_mean_tokens": true}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "pooling_mode_mean_tokens");
assert!(matches!(pairs[0].1, JVal::Bool(true)));
}
#[test]
fn parse_pooling_json_modern_pooling_mode_string() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "mean"}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "pooling_mode");
match &pairs[0].1 {
JVal::Str(s) => assert_eq!(s, "mean"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_ignores_unknown_keys() {
let json =
r#"{"pooling_mode_mean_tokens": true, "future_field": 42, "nested": {"x": [1, 2, 3]}}"#;
let cfg = pooling_from_st_config_str(json).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Mean);
}
#[test]
fn parse_pooling_json_rejects_unterminated_string() {
let err = parse_pooling_json(r#"{"pooling_mode": "unfinished"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("unterminated string"),
"expected unterminated-string error, got: {msg}"
);
assert!(
msg.contains("byte 17"),
"expected byte-offset 17, got: {msg}"
);
}
#[test]
fn parse_pooling_json_rejects_trailing_comma_object() {
let err = parse_pooling_json(r#"{"pooling_mode_mean_tokens": true,}"#).unwrap_err();
assert!(
format!("{err}").contains("trailing comma"),
"expected trailing-comma error, got: {err}"
);
}
#[test]
fn parse_pooling_json_rejects_trailing_comma_array() {
let err = parse_pooling_json(r#"{"pooling_mode": ["mean",]}"#).unwrap_err();
assert!(
format!("{err}").contains("trailing comma"),
"expected trailing-comma error (in array), got: {err}"
);
}
#[test]
fn parse_pooling_json_rejects_unexpected_char_at_value() {
let err = parse_pooling_json(r#"{"pooling_mode": @}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("unexpected character") && msg.contains("'@'"),
"expected unexpected-character error citing '@', got: {msg}"
);
}
#[test]
fn parse_pooling_json_rejects_trailing_data() {
let err = parse_pooling_json(r#"{"pooling_mode": "mean"} junk"#).unwrap_err();
assert!(
format!("{err}").contains("trailing data"),
"expected trailing-data error, got: {err}"
);
}
#[test]
fn parse_pooling_json_handles_string_escapes() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "a\"b\\c\nd"}"#).unwrap();
match &pairs[0].1 {
JVal::Str(s) => assert_eq!(s, "a\"b\\c\nd"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_handles_unicode_escape() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "caf\u00E9"}"#).unwrap();
match &pairs[0].1 {
JVal::Str(s) => assert_eq!(s, "café"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_handles_utf16_surrogate_pair() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "\uD83D\uDE00"}"#).unwrap();
match &pairs[0].1 {
JVal::Str(s) => assert_eq!(s, "\u{1F600}"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_rejects_control_char_in_string() {
let src = "{\"pooling_mode\": \"a\x01b\"}";
let err = parse_pooling_json(src).unwrap_err();
assert!(
format!("{err}").contains("control character"),
"expected control-char rejection, got: {err}"
);
}
#[test]
fn parse_pooling_json_rejects_overly_deep_nesting() {
let deep_array = {
let opens = "[".repeat(super::MAX_NESTING_DEPTH + 16);
let closes = "]".repeat(super::MAX_NESTING_DEPTH + 16);
format!(r#"{{"pooling_mode": {opens}{closes}}}"#)
};
let err = parse_pooling_json(&deep_array).unwrap_err();
assert!(
format!("{err}").contains("depth exceeds"),
"deep-array nesting must yield depth-cap error, got: {err}"
);
let deep_object = {
let mut s = String::new();
for _ in 0..(super::MAX_NESTING_DEPTH + 16) {
s.push_str("{\"k\":");
}
s.push_str("true");
for _ in 0..(super::MAX_NESTING_DEPTH + 16) {
s.push('}');
}
format!(r#"{{"future_field": {s}}}"#)
};
let err = parse_pooling_json(&deep_object).unwrap_err();
assert!(
format!("{err}").contains("depth exceeds"),
"deep-object nesting must yield depth-cap error, got: {err}"
);
}
#[test]
fn parse_pooling_json_allows_shallow_nesting_within_cap() {
let json = r#"{"pooling_mode_mean_tokens": true, "future_field": {"a": [{"b": [1,2,3]}]}}"#;
let cfg = pooling_from_st_config_str(json).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Mean);
}
#[test]
fn parse_pooling_json_handles_many_unknown_keys_linearly() {
let mut json = String::from("{\"pooling_mode_mean_tokens\": true");
for i in 0..10_000 {
json.push_str(&format!(", \"unknown_field_{i:05}\": {i}"));
}
json.push('}');
let cfg = pooling_from_st_config_str(&json).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Mean);
}
#[test]
fn parse_pooling_json_empty_object_is_no_keys() {
let pairs = parse_pooling_json("{}").unwrap();
assert!(pairs.is_empty(), "expected no pairs, got {pairs:?}");
}
#[test]
fn parse_pooling_json_empty_object_with_leading_ws() {
let pairs = parse_pooling_json(" \t\n {}").unwrap();
assert!(pairs.is_empty(), "expected no pairs, got {pairs:?}");
}
#[test]
fn parse_pooling_json_duplicate_known_key_last_wins() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "mean", "pooling_mode": "cls"}"#).unwrap();
assert_eq!(pairs.len(), 1, "duplicate key must collapse to one pair");
match &pairs[0].1 {
JVal::Str(s) => assert_eq!(s, "cls", "last duplicate must win"),
other => panic!("expected Str(\"cls\"), got {other:?}"),
}
assert_eq!(
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "pooling_mode": "cls"}"#)
.unwrap()
.strategy(),
super::PoolingStrategy::Cls,
);
}
#[test]
fn parse_pooling_json_expect_colon_wrong_char() {
let err = parse_pooling_json(r#"{"k" "v"}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected ':'") && msg.contains("found '\"'"),
"expected colon-not-found-citing-quote error, got: {msg}"
);
assert!(msg.contains("byte 5"), "expected byte 5, got: {msg}");
}
#[test]
fn parse_pooling_json_object_unexpected_char_after_pair() {
let err = parse_pooling_json(r#"{"pooling_mode": "mean" @}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected `,` or `}`") && msg.contains("'@'"),
"expected comma-or-brace error citing '@', got: {msg}"
);
}
#[test]
fn parse_pooling_json_object_unterminated_after_pair() {
let err = parse_pooling_json(r#"{"pooling_mode": "mean""#).unwrap_err();
assert!(
format!("{err}").contains("reached end of input"),
"expected end-of-input error, got: {err}"
);
}
#[test]
fn parse_pooling_json_known_key_object_value() {
let pairs = parse_pooling_json(r#"{"pooling_mode": {"a": 1}}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert!(matches!(pairs[0].1, JVal::Object), "expected Object");
}
#[test]
fn parse_pooling_json_known_key_array_value() {
let pairs = parse_pooling_json(r#"{"pooling_mode": ["cls", "mean"]}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert!(matches!(pairs[0].1, JVal::Array), "expected Array");
}
#[test]
fn parse_pooling_json_known_key_null_value() {
let pairs = parse_pooling_json(r#"{"pooling_mode": null}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert!(matches!(pairs[0].1, JVal::Null), "expected Null");
}
#[test]
fn parse_pooling_json_known_key_value_eof() {
let err = parse_pooling_json(r#"{"pooling_mode":"#).unwrap_err();
assert!(
format!("{err}").contains("unexpected end of input while parsing value"),
"expected parse-value EOF error, got: {err}"
);
}
#[test]
fn parse_pooling_json_known_key_value_unexpected_char() {
let err = parse_pooling_json(r#"{"pooling_mode": %}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("unexpected character") && msg.contains("'%'"),
"expected unexpected-char-while-parsing-value error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_unknown_key_bool_value_dropped() {
let pairs = parse_pooling_json(r#"{"future_flag": true}"#).unwrap();
assert!(
pairs.is_empty(),
"unknown key must be dropped, got {pairs:?}"
);
}
#[test]
fn parse_pooling_json_unknown_key_null_value_dropped() {
let pairs = parse_pooling_json(r#"{"future_flag": null}"#).unwrap();
assert!(pairs.is_empty(), "unknown null value must be dropped");
}
#[test]
fn parse_pooling_json_unknown_key_number_value_dropped() {
let pairs = parse_pooling_json(r#"{"future_flag": 42}"#).unwrap();
assert!(pairs.is_empty(), "unknown number value must be dropped");
}
#[test]
fn parse_pooling_json_unknown_key_string_value_dropped() {
let pairs = parse_pooling_json(r#"{"future_flag": "hello"}"#).unwrap();
assert!(pairs.is_empty(), "unknown string value must be dropped");
}
#[test]
fn parse_pooling_json_unknown_key_value_unexpected_char() {
let err = parse_pooling_json(r#"{"future_flag": @}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("unexpected character") && msg.contains("skipping value") && msg.contains("'@'"),
"expected skip-value unexpected-char error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_unknown_key_value_eof() {
let err = parse_pooling_json(r#"{"future_flag":"#).unwrap_err();
assert!(
format!("{err}").contains("unexpected end of input while skipping value"),
"expected skip-value EOF error, got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_escapes_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": "a\"b\\c\nd"}"#).unwrap();
assert!(pairs.is_empty(), "escaped unknown string must be dropped");
}
#[test]
fn parse_pooling_json_unknown_key_string_surrogate_pair_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": "\uD83D\uDE00"}"#).unwrap();
assert!(
pairs.is_empty(),
"surrogate-pair unknown string must be dropped"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_unpaired_high_surrogate() {
let err = parse_pooling_json(r#"{"future_field": "\uD83Dx"}"#).unwrap_err();
assert!(
format!("{err}").contains("expected low surrogate"),
"expected unpaired-high-surrogate error, got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_unpaired_low_surrogate() {
let err = parse_pooling_json(r#"{"future_field": "\uDC00"}"#).unwrap_err();
assert!(
format!("{err}").contains("unpaired low surrogate"),
"expected unpaired-low-surrogate error, got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_invalid_escape() {
let err = parse_pooling_json(r#"{"future_field": "a\xb"}"#).unwrap_err();
assert!(
format!("{err}").contains("invalid escape sequence"),
"expected invalid-escape error, got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_control_char() {
let src = "{\"future_field\": \"a\x01b\"}";
let err = parse_pooling_json(src).unwrap_err();
assert!(
format!("{err}").contains("control character"),
"expected control-char rejection in skip_string, got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_unterminated() {
let err = parse_pooling_json(r#"{"future_field": "no close"#).unwrap_err();
assert!(
format!("{err}").contains("unterminated string"),
"expected unterminated-string error (skip_string), got: {err}"
);
}
#[test]
fn parse_pooling_json_nested_object_non_string_key() {
let err = parse_pooling_json(r#"{"future_field": {42: 1}}"#).unwrap_err();
assert!(
format!("{err}").contains("expected string"),
"expected non-string-nested-key error, got: {err}"
);
}
#[test]
fn parse_pooling_json_invalid_hex_in_unicode_escape() {
let err = parse_pooling_json(r#"{"pooling_mode": "\uXYZW"}"#).unwrap_err();
assert!(
format!("{err}").contains("invalid hex digit") && format!("{err}").contains("'X'"),
"expected invalid-hex-digit error citing 'X', got: {err}"
);
}
#[test]
fn parse_pooling_json_incomplete_unicode_escape() {
let err = parse_pooling_json(r#"{"pooling_mode": "\u00"}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("invalid hex digit") || msg.contains("incomplete"),
"expected incomplete/invalid \\u escape error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_keyword_truncated_eof() {
let err = parse_pooling_json(r#"{"pooling_mode": tru"#).unwrap_err();
assert!(
format!("{err}").contains("keyword `true`") && format!("{err}").contains("end of input"),
"expected truncated-keyword EOF error, got: {err}"
);
}
#[test]
fn parse_pooling_json_keyword_mismatch() {
let err = parse_pooling_json(r#"{"pooling_mode": trux}"#).unwrap_err();
assert!(
format!("{err}").contains("expected keyword `true`"),
"expected keyword-mismatch error, got: {err}"
);
}
#[test]
fn parse_pooling_json_false_keyword_parsed() {
let pairs = parse_pooling_json(r#"{"include_prompt": false}"#).unwrap();
assert_eq!(pairs.len(), 1);
assert!(
matches!(pairs[0].1, JVal::Bool(false)),
"expected Bool(false), got {:?}",
pairs[0].1
);
}
#[test]
fn parse_pooling_json_number_bad_first_char_after_minus() {
let err = parse_pooling_json(r#"{"future_field": -}"#).unwrap_err();
assert!(
format!("{err}").contains("invalid number") && format!("{err}").contains("unexpected"),
"expected invalid-number (bad char after -) error, got: {err}"
);
}
#[test]
fn parse_pooling_json_number_eof_after_minus() {
let err = parse_pooling_json(r#"{"future_field": -"#).unwrap_err();
assert!(
format!("{err}").contains("invalid number") && format!("{err}").contains("end of input"),
"expected invalid-number (EOF after -) error, got: {err}"
);
}
#[test]
fn parse_pooling_json_number_fraction_without_digit() {
let err = parse_pooling_json(r#"{"future_field": 1.}"#).unwrap_err();
assert!(
format!("{err}").contains("expected digit after `.`"),
"expected fraction-without-digit error, got: {err}"
);
}
#[test]
fn parse_pooling_json_number_exponent_without_digit() {
let err = parse_pooling_json(r#"{"future_field": 1e}"#).unwrap_err();
assert!(
format!("{err}").contains("expected digit after exponent"),
"expected exponent-without-digit error, got: {err}"
);
}
#[test]
fn parse_pooling_json_number_full_float_with_signed_exponent_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": -1.5e+10}"#).unwrap();
assert!(pairs.is_empty(), "unknown float value must be dropped");
let pairs = parse_pooling_json(r#"{"future_field": 1.5e-3}"#).unwrap();
assert!(pairs.is_empty(), "unknown float value must be dropped");
}
#[test]
fn parse_pooling_json_nested_empty_object_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": {}}"#).unwrap();
assert!(pairs.is_empty(), "nested empty object must be dropped");
}
#[test]
fn parse_pooling_json_nested_empty_array_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": []}"#).unwrap();
assert!(pairs.is_empty(), "nested empty array must be dropped");
}
#[test]
fn parse_pooling_json_nested_object_multi_pair_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": {"a": 1, "b": 2}}"#).unwrap();
assert!(pairs.is_empty(), "multi-pair nested object must be dropped");
}
#[test]
fn parse_pooling_json_nested_object_trailing_comma() {
let err = parse_pooling_json(r#"{"future_field": {"a": 1,}}"#).unwrap_err();
assert!(
format!("{err}").contains("trailing comma"),
"expected nested-object trailing-comma error, got: {err}"
);
}
#[test]
fn parse_pooling_json_nested_object_unexpected_char() {
let err = parse_pooling_json(r#"{"future_field": {"a": 1 @}}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected `,` or `}`") && msg.contains("nested object") && msg.contains("'@'"),
"expected nested-object comma-or-brace error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_nested_object_unterminated() {
let err = parse_pooling_json(r#"{"future_field": {"a": 1"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("nested object") && msg.contains("end of input"),
"expected nested-object EOF error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_nested_array_multi_value_dropped() {
let pairs = parse_pooling_json(r#"{"future_field": [1, 2, 3]}"#).unwrap();
assert!(pairs.is_empty(), "multi-value nested array must be dropped");
}
#[test]
fn parse_pooling_json_nested_array_unexpected_char() {
let err = parse_pooling_json(r#"{"future_field": [1 2]}"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected `,` or `]`") && msg.contains("array") && msg.contains("'2'"),
"expected array comma-or-bracket error, got: {msg}"
);
}
#[test]
fn parse_pooling_json_nested_array_unterminated() {
let err = parse_pooling_json(r#"{"future_field": [1"#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("array") && msg.contains("end of input"),
"expected array EOF error, got: {msg}"
);
}
#[test]
fn pooling_from_st_config_str_no_mode_invariant_violation() {
use crate::error::Error;
for json in ["{}", r#"{"include_prompt": true}"#] {
let err = pooling_from_st_config_str(json).unwrap_err();
match err {
Error::InvariantViolation(p) => {
assert_eq!(p.context(), "pooling config");
assert!(
p.requirement().contains("declares no pooling mode"),
"unexpected requirement: {}",
p.requirement()
);
}
other => panic!("expected InvariantViolation for {json:?}, got {other:?}"),
}
}
}
#[test]
fn pooling_from_st_config_str_present_nonstring_mode_out_of_range() {
use crate::error::Error;
for (json, descr) in [
(r#"{"pooling_mode": null}"#, "null"),
(r#"{"pooling_mode": true}"#, "bool true"),
(r#"{"pooling_mode": false}"#, "bool false"),
(r#"{"pooling_mode": 7}"#, "number 7"),
(r#"{"pooling_mode": {"a": 1}}"#, "object"),
] {
let err = pooling_from_st_config_str(json).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "pooling config: `pooling_mode`");
assert_eq!(p.value(), descr, "value descr mismatch for {json:?}");
}
other => panic!("expected OutOfRange for {json:?}, got {other:?}"),
}
}
}
#[test]
fn pooling_from_st_config_str_concatenated_mode_array_rejected() {
use crate::error::Error;
let err = pooling_from_st_config_str(r#"{"pooling_mode": ["cls", "mean"]}"#).unwrap_err();
match err {
Error::InvariantViolation(p) => {
assert_eq!(p.context(), "pooling config: `pooling_mode`");
assert!(
p.requirement().contains("concatenated pooling mode"),
"unexpected requirement: {}",
p.requirement()
);
}
other => panic!("expected InvariantViolation, got {other:?}"),
}
}
#[test]
fn pooling_from_st_config_str_unsupported_legacy_flag_unknown_enum() {
use crate::error::Error;
for (json, name) in [
(
r#"{"pooling_mode_weightedmean_tokens": true}"#,
"weightedmean",
),
(
r#"{"pooling_mode_mean_sqrt_len_tokens": true}"#,
"mean_sqrt_len_tokens",
),
] {
let err = pooling_from_st_config_str(json).unwrap_err();
match err {
Error::UnknownEnumValue(p) => {
assert_eq!(p.value(), name);
assert_eq!(p.supported(), &["cls", "lasttoken", "max", "mean"]);
}
other => panic!("expected UnknownEnumValue for {json:?}, got {other:?}"),
}
}
}
#[test]
fn pooling_from_st_config_str_unsupported_modern_mode_unknown_enum() {
use crate::error::Error;
let err = pooling_from_st_config_str(r#"{"pooling_mode": "weightedmean"}"#).unwrap_err();
match err {
Error::UnknownEnumValue(p) => {
assert_eq!(p.type_name(), "embeddings::PoolingStrategy");
assert_eq!(p.value(), "weightedmean");
}
other => panic!("expected UnknownEnumValue, got {other:?}"),
}
}
#[test]
fn pooling_from_st_config_str_legacy_priority_cls_over_mean() {
let cfg = pooling_from_st_config_str(
r#"{"pooling_mode_cls_token": true, "pooling_mode_mean_tokens": true, "pooling_mode_max_tokens": true, "pooling_mode_lasttoken": true}"#,
)
.unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Cls);
let cfg = pooling_from_st_config_str(
r#"{"pooling_mode_mean_tokens": true, "pooling_mode_max_tokens": true, "pooling_mode_lasttoken": true}"#,
)
.unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Mean);
let cfg = pooling_from_st_config_str(
r#"{"pooling_mode_max_tokens": true, "pooling_mode_lasttoken": true}"#,
)
.unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Max);
let cfg = pooling_from_st_config_str(r#"{"pooling_mode_lasttoken": true}"#).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Last);
}
#[test]
fn pooling_from_st_config_str_all_false_legacy_defaults_mean() {
let cfg = pooling_from_st_config_str(r#"{"pooling_mode_cls_token": false}"#).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Mean);
assert!(cfg.normalize(), "normalize is always true for ST configs");
assert_eq!(cfg.dimension(), None);
}
#[test]
fn pooling_from_st_config_str_include_prompt_false_rejected() {
use crate::error::Error;
let err = pooling_from_st_config_str(r#"{"pooling_mode": "mean", "include_prompt": false}"#)
.unwrap_err();
match err {
Error::InvariantViolation(p) => {
assert_eq!(p.context(), "pooling config: `include_prompt`");
assert!(
p.requirement().contains("include_prompt=false"),
"unexpected requirement: {}",
p.requirement()
);
}
other => panic!("expected InvariantViolation, got {other:?}"),
}
}
#[test]
fn pooling_from_st_config_str_valid_word_embedding_dimension() {
let cfg =
pooling_from_st_config_str(r#"{"pooling_mode": "cls", "word_embedding_dimension": 384}"#)
.unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Cls);
assert_eq!(cfg.dimension(), Some(384));
}
#[test]
fn pooling_from_st_config_str_word_dim_precedence_over_embedding_dim() {
let cfg = pooling_from_st_config_str(
r#"{"pooling_mode": "mean", "word_embedding_dimension": 128, "embedding_dimension": 256}"#,
)
.unwrap();
assert_eq!(cfg.dimension(), Some(128));
}
#[test]
fn pooling_from_st_config_str_embedding_dimension_alias() {
let cfg =
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "embedding_dimension": 64}"#).unwrap();
assert_eq!(cfg.dimension(), Some(64));
}
#[test]
fn pooling_from_st_config_str_zero_dimension_out_of_range() {
use crate::error::Error;
let err =
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "word_embedding_dimension": 0}"#)
.unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "pooling config: matryoshka dimension");
assert!(
p.requirement().contains("must be > 0"),
"unexpected requirement: {}",
p.requirement()
);
assert_eq!(p.value(), "word_embedding_dimension = 0");
}
other => panic!("expected OutOfRange for zero dim, got {other:?}"),
}
}
#[test]
fn pooling_from_st_config_str_u64_overflow_dimension_out_of_range() {
use crate::error::Error;
let err = pooling_from_st_config_str(
r#"{"pooling_mode": "mean", "embedding_dimension": 99999999999999999999999999}"#,
)
.unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "pooling config: matryoshka dimension");
assert!(
p.requirement().contains("u64-overflowing"),
"unexpected requirement: {}",
p.requirement()
);
}
other => panic!("expected OutOfRange for u64-overflow dim, got {other:?}"),
}
}
#[test]
fn pooling_from_st_config_str_negative_dimension_parse_error() {
use crate::error::Error;
let err =
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "word_embedding_dimension": -1}"#)
.unwrap_err();
assert!(
matches!(err, Error::Parse(_)),
"expected Parse for negative dim, got {err:?}"
);
assert!(
format!("{err}").contains("non-negative integer"),
"expected non-negative-integer message, got: {err}"
);
}
#[test]
fn pooling_from_st_config_str_fractional_dimension_parse_error() {
use crate::error::Error;
let err =
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "word_embedding_dimension": 1.5}"#)
.unwrap_err();
assert!(
matches!(err, Error::Parse(_)),
"expected Parse for fractional dim, got {err:?}"
);
}
#[test]
fn pooling_from_st_config_str_nonnumber_dimension_out_of_range() {
use crate::error::Error;
for (json, expect_descr) in [
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": "384"}"#,
r#"word_embedding_dimension = string "384""#,
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": true}"#,
"word_embedding_dimension = bool true",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": null}"#,
"word_embedding_dimension = null",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": [1, 2]}"#,
"word_embedding_dimension = array",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": {"a": 1}}"#,
"word_embedding_dimension = object",
),
] {
let err = pooling_from_st_config_str(json).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "pooling config: matryoshka dimension");
assert_eq!(p.value(), expect_descr, "descr mismatch for {json:?}");
}
other => panic!("expected OutOfRange for {json:?}, got {other:?}"),
}
}
}
#[test]
fn pooling_from_st_config_str_invalid_primary_dim_no_fallback_to_alias() {
use crate::error::Error;
let err = pooling_from_st_config_str(
r#"{"pooling_mode": "mean", "word_embedding_dimension": -1, "embedding_dimension": 384}"#,
)
.unwrap_err();
assert!(
matches!(err, Error::Parse(_) | Error::OutOfRange(_)),
"invalid primary dim must reject, not fall back, got {err:?}"
);
}
#[test]
fn pooling_from_st_config_bytes_invalid_utf8_parse_error() {
use super::pooling_from_st_config_bytes;
use crate::error::Error;
let err = pooling_from_st_config_bytes(&[b'{', 0xFF, 0xFE, b'}']).unwrap_err();
assert!(
matches!(err, Error::Parse(_)),
"expected Parse for invalid UTF-8 bytes, got {err:?}"
);
}
#[test]
fn pooling_from_st_config_bytes_valid_roundtrips_to_str() {
use super::pooling_from_st_config_bytes;
let cfg = pooling_from_st_config_bytes(br#"{"pooling_mode": "cls"}"#).unwrap();
assert_eq!(cfg.strategy(), super::PoolingStrategy::Cls);
assert!(cfg.normalize());
assert_eq!(cfg.dimension(), None);
}
#[test]
fn parse_pooling_json_expect_colon_reached_eof() {
let err = parse_pooling_json(r#"{"k""#).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected ':'") && msg.contains("reached end of input"),
"expected colon-EOF error from expect(), got: {msg}"
);
}
#[test]
fn parse_pooling_json_known_key_string_simple_escapes_decoded() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "\n\/\b\f\r\t"}"#).unwrap();
assert_eq!(pairs.len(), 1);
match &pairs[0].1 {
super::JVal::Str(s) => {
assert_eq!(
s.as_str(),
"\n/\u{08}\u{0C}\r\t",
"decoded escape bytes mismatch"
);
}
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_known_key_string_lowercase_hex_unicode_escape() {
let json = "{\"pooling_mode\": \"x\\u00abz\"}";
let pairs = parse_pooling_json(json).unwrap();
match &pairs[0].1 {
super::JVal::Str(s) => assert_eq!(s.as_str(), "x\u{00AB}z"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_known_key_string_raw_4byte_scalar_fast_path() {
let pairs = parse_pooling_json(r#"{"pooling_mode": "😀"}"#).unwrap();
match &pairs[0].1 {
super::JVal::Str(s) => assert_eq!(s.as_str(), "\u{1F600}"),
other => panic!("expected Str, got {other:?}"),
}
}
#[test]
fn parse_pooling_json_known_key_string_high_surrogate_not_followed_by_escape() {
let err = parse_pooling_json(r#"{"pooling_mode": "\uD83Dx"}"#).unwrap_err();
assert!(
format!("{err}").contains("expected low surrogate"),
"expected high-surrogate-without-low error (parse_string slow path), got: {err}"
);
}
#[test]
fn parse_pooling_json_known_key_string_high_surrogate_then_out_of_range_low() {
let json = "{\"pooling_mode\": \"\\uD83D\\u0041\"}";
let err = parse_pooling_json(json).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("expected low surrogate but got") && msg.contains("U+0041"),
"expected out-of-range-low-surrogate error citing U+0041, got: {msg}"
);
}
#[test]
fn parse_pooling_json_known_key_string_unpaired_low_surrogate() {
let err = parse_pooling_json(r#"{"pooling_mode": "\n\uDC00"}"#).unwrap_err();
assert!(
format!("{err}").contains("unpaired low surrogate"),
"expected unpaired-low-surrogate error (parse_string slow path), got: {err}"
);
}
#[test]
fn parse_pooling_json_known_key_string_invalid_escape() {
let err = parse_pooling_json(r#"{"pooling_mode": "a\xb"}"#).unwrap_err();
assert!(
format!("{err}").contains("invalid escape sequence"),
"expected invalid-escape error (parse_string slow path), got: {err}"
);
}
#[test]
fn parse_pooling_json_known_key_string_unterminated_escape_eof() {
let err = parse_pooling_json("{\"pooling_mode\": \"\\n\\").unwrap_err();
assert!(
format!("{err}").contains("unterminated escape"),
"expected unterminated-escape error (parse_string slow path), got: {err}"
);
}
#[test]
fn parse_pooling_json_unknown_key_string_unterminated_escape_eof() {
let err = parse_pooling_json("{\"future_field\": \"a\\").unwrap_err();
assert!(
format!("{err}").contains("unterminated escape"),
"expected unterminated-escape error (skip_string), got: {err}"
);
}
}