#![allow(clippy::doc_markdown)]
use std::sync::Arc;
use mnem_core::sparse::{SparseEncoder, SparseError};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase", tag = "provider")]
pub enum ProviderConfig {
Sidecar(SidecarConfig),
Onnx(OnnxConfig),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SidecarConfig {
pub base_url: String,
pub model: String,
pub vocab_id: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
impl Default for SidecarConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:8791".into(),
model: "opensearch-doc-v3-distill".into(),
vocab_id: "bert-base-uncased@30522".into(),
timeout_secs: default_timeout(),
}
}
}
const fn default_timeout() -> u64 {
30
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OnnxConfig {
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_length: Option<usize>,
}
impl Default for OnnxConfig {
fn default() -> Self {
Self {
model: "opensearch-doc-v3-distill".into(),
max_length: None,
}
}
}
pub fn open(cfg: &ProviderConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
match cfg {
ProviderConfig::Sidecar(c) => {
let enc = crate::sidecar::SidecarSparseEncoder::from_config(c)?;
Ok(Arc::new(enc))
}
ProviderConfig::Onnx(c) => open_onnx(c),
}
}
#[cfg(feature = "onnx")]
fn open_onnx(c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
let kind = parse_onnx_model(&c.model)?;
let enc = crate::onnx::OnnxSparseEncoder::with_max_length(kind, c.max_length)?;
Ok(Arc::new(enc))
}
#[cfg(not(feature = "onnx"))]
fn open_onnx(_c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
Err(SparseError::Config(
"sparse.provider = \"onnx\" but this binary was built without the `onnx` feature. \
Rebuild with `--features onnx` or set sparse.provider = \"sidecar\"."
.into(),
))
}
#[cfg(feature = "onnx")]
fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, SparseError> {
use crate::onnx::ModelKind;
match s {
"opensearch-doc-v3-distill" => Ok(ModelKind::OpensearchDocV3Distill),
"opensearch-bi-v2-distill" => Ok(ModelKind::OpensearchBiV2Distill),
other => Err(SparseError::Config(format!(
"unknown onnx sparse model `{other}`; known: \
opensearch-doc-v3-distill, opensearch-bi-v2-distill"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sidecar_config_toml_round_trip() {
let cfg = ProviderConfig::Sidecar(SidecarConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("provider = \"sidecar\""));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn sidecar_default_has_sane_values() {
let c = SidecarConfig::default();
assert!(c.base_url.starts_with("http"));
assert!(!c.model.is_empty());
assert!(!c.vocab_id.is_empty());
assert_eq!(c.timeout_secs, 30);
}
#[test]
fn onnx_config_toml_round_trip() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(
s.contains("provider = \"onnx\""),
"onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
);
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn onnx_config_default_skips_max_length() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(
!s.contains("max_length"),
"default OnnxConfig should not emit max_length (let the encoder pick). Got:\n{s}"
);
}
#[test]
fn onnx_config_max_length_round_trip() {
let cfg = ProviderConfig::Onnx(OnnxConfig {
model: "opensearch-bi-v2-distill".into(),
max_length: Some(256),
});
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("max_length = 256"));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[cfg(not(feature = "onnx"))]
#[test]
fn open_onnx_without_feature_returns_actionable_error() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let err = open(&cfg).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("--features onnx"),
"err should point at the feature rebuild; got: {msg}"
);
}
#[cfg(feature = "onnx")]
#[test]
fn parse_onnx_model_rejects_unknown() {
let err = parse_onnx_model("made-up-model").unwrap_err();
assert!(format!("{err}").contains("unknown onnx sparse model"));
}
}