Skip to main content

mnem_sparse_providers/
config.rs

1// SPLADE, BGE-M3, sidecar-host proper nouns are well-known external
2// identifiers.
3#![allow(clippy::doc_markdown)]
4
5//! [`ProviderConfig`] + [`open`] factory. Mirrors the pattern in
6//! `mnem-rerank-providers` and `mnem-llm-providers`.
7
8use std::sync::Arc;
9
10use mnem_core::sparse::{SparseEncoder, SparseError};
11use serde::{Deserialize, Serialize};
12
13/// Tagged enum over the shipped backends. TOML shape for each backend:
14///
15/// Sidecar (Python HTTP server, zero native deps in mnem):
16/// ```toml
17/// [sparse]
18/// provider  = "sidecar"
19/// base_url  = "http://localhost:8791"
20/// model     = "opensearch-doc-v3-distill"
21/// vocab_id  = "bert-base-uncased@30522"
22/// ```
23///
24/// Native ONNX (in-process, requires the `onnx` cargo feature):
25/// ```toml
26/// [sparse]
27/// provider   = "onnx"
28/// model      = "opensearch-doc-v3-distill"   # or "opensearch-bi-v2-distill"
29/// # max_length = 512                         # optional; default = model ceiling
30/// ```
31#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "lowercase", tag = "provider")]
33pub enum ProviderConfig {
34    /// HTTP sidecar running the reference Python SPLADE/BGE-M3
35    /// implementation. See `benchmarks/adapters/splade-sidecar/` in
36    /// for a Docker image.
37    Sidecar(SidecarConfig),
38    /// Native in-process ONNX encoder. Requires building with the
39    /// `onnx` feature; otherwise [`open`] returns a
40    /// [`SparseError::Config`] explaining the mismatch so operators
41    /// can either rebuild or switch back to `sidecar`.
42    Onnx(OnnxConfig),
43}
44
45/// Sidecar HTTP config.
46#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
47pub struct SidecarConfig {
48    /// Base URL, e.g. `"http://localhost:8791"`. The adapter
49    /// POSTs to `<base_url>/encode` with `{text, model}` and
50    /// expects `{indices: [u32], values: [f32], vocab_id: str}`.
51    pub base_url: String,
52    /// Bare model id the sidecar exposes, e.g.
53    /// `"opensearch-doc-v3-distill"`. Final fq id is
54    /// `"sidecar:<model>"`.
55    pub model: String,
56    /// Vocabulary id the sidecar's model uses. Must match the
57    /// `vocab_id` stamped on stored `SparseEmbed`s for retrieval to
58    /// work (mnem-core::index::SparseInvertedIndex enforces this).
59    pub vocab_id: String,
60    /// Per-request timeout in seconds. Default 30.
61    #[serde(default = "default_timeout")]
62    pub timeout_secs: u64,
63}
64
65impl Default for SidecarConfig {
66    fn default() -> Self {
67        Self {
68            base_url: "http://localhost:8791".into(),
69            model: "opensearch-doc-v3-distill".into(),
70            vocab_id: "bert-base-uncased@30522".into(),
71            timeout_secs: default_timeout(),
72        }
73    }
74}
75
76const fn default_timeout() -> u64 {
77    30
78}
79
80/// Native in-process ONNX encoder config. Unlike [`SidecarConfig`]
81/// there is no network URL or vocab_id: the `model` string resolves
82/// directly to a compiled-in `onnx::ModelKind` variant (only available
83/// when the `onnx` cargo feature is enabled) and the
84/// encoder stamps the canonical `vocab_id` itself.
85///
86/// Kept visible even when the `onnx` feature is off, so that a
87/// deserialised `[sparse] provider = "onnx"` block round-trips
88/// through TOML and [`open`] can emit an actionable "rebuild with
89/// `--features onnx`" error at construction time.
90#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub struct OnnxConfig {
92    /// Model shortname. Known values map to `onnx::ModelKind` variants
93    /// (only available when the `onnx` cargo feature is enabled):
94    ///   - `"opensearch-doc-v3-distill"` (default; asymmetric, IDF query)
95    ///   - `"opensearch-bi-v2-distill"`  (symmetric, both sides run the net)
96    pub model: String,
97    /// Optional tokenizer `max_length` override. `None` defers to
98    /// the model's `default_max_length()` (DistilBERT = 512). Values
99    /// above the model's `positional_limit()` are clamped with a
100    /// stderr warning.
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub max_length: Option<usize>,
103}
104
105impl Default for OnnxConfig {
106    fn default() -> Self {
107        Self {
108            model: "opensearch-doc-v3-distill".into(),
109            max_length: None,
110        }
111    }
112}
113
114/// Construct a live [`SparseEncoder`] from a [`ProviderConfig`].
115///
116/// # Errors
117///
118/// - [`SparseError::Config`] if the sidecar URL is malformed.
119/// - [`SparseError::Config`] with an actionable remediation string
120///   when `provider = "onnx"` is used in a build without the `onnx`
121///   cargo feature.
122/// - [`SparseError::Config`] on an unknown onnx `model` string.
123pub fn open(cfg: &ProviderConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
124    match cfg {
125        ProviderConfig::Sidecar(c) => {
126            let enc = crate::sidecar::SidecarSparseEncoder::from_config(c)?;
127            Ok(Arc::new(enc))
128        }
129        ProviderConfig::Onnx(c) => open_onnx(c),
130    }
131}
132
133#[cfg(feature = "onnx")]
134fn open_onnx(c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
135    let kind = parse_onnx_model(&c.model)?;
136    let enc = crate::onnx::OnnxSparseEncoder::with_max_length(kind, c.max_length)?;
137    Ok(Arc::new(enc))
138}
139
140#[cfg(not(feature = "onnx"))]
141fn open_onnx(_c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
142    Err(SparseError::Config(
143        "sparse.provider = \"onnx\" but this binary was built without the `onnx` feature. \
144         Rebuild with `--features onnx` or set sparse.provider = \"sidecar\"."
145            .into(),
146    ))
147}
148
149#[cfg(feature = "onnx")]
150fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, SparseError> {
151    use crate::onnx::ModelKind;
152    match s {
153        "opensearch-doc-v3-distill" => Ok(ModelKind::OpensearchDocV3Distill),
154        "opensearch-bi-v2-distill" => Ok(ModelKind::OpensearchBiV2Distill),
155        other => Err(SparseError::Config(format!(
156            "unknown onnx sparse model `{other}`; known: \
157             opensearch-doc-v3-distill, opensearch-bi-v2-distill"
158        ))),
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn sidecar_config_toml_round_trip() {
168        let cfg = ProviderConfig::Sidecar(SidecarConfig::default());
169        let s = toml::to_string(&cfg).unwrap();
170        assert!(s.contains("provider = \"sidecar\""));
171        let back: ProviderConfig = toml::from_str(&s).unwrap();
172        assert_eq!(cfg, back);
173    }
174
175    #[test]
176    fn sidecar_default_has_sane_values() {
177        let c = SidecarConfig::default();
178        assert!(c.base_url.starts_with("http"));
179        assert!(!c.model.is_empty());
180        assert!(!c.vocab_id.is_empty());
181        assert_eq!(c.timeout_secs, 30);
182    }
183
184    #[test]
185    fn onnx_config_toml_round_trip() {
186        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
187        let s = toml::to_string(&cfg).unwrap();
188        assert!(
189            s.contains("provider = \"onnx\""),
190            "onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
191        );
192        let back: ProviderConfig = toml::from_str(&s).unwrap();
193        assert_eq!(cfg, back);
194    }
195
196    #[test]
197    fn onnx_config_default_skips_max_length() {
198        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
199        let s = toml::to_string(&cfg).unwrap();
200        assert!(
201            !s.contains("max_length"),
202            "default OnnxConfig should not emit max_length (let the encoder pick). Got:\n{s}"
203        );
204    }
205
206    #[test]
207    fn onnx_config_max_length_round_trip() {
208        let cfg = ProviderConfig::Onnx(OnnxConfig {
209            model: "opensearch-bi-v2-distill".into(),
210            max_length: Some(256),
211        });
212        let s = toml::to_string(&cfg).unwrap();
213        assert!(s.contains("max_length = 256"));
214        let back: ProviderConfig = toml::from_str(&s).unwrap();
215        assert_eq!(cfg, back);
216    }
217
218    #[cfg(not(feature = "onnx"))]
219    #[test]
220    fn open_onnx_without_feature_returns_actionable_error() {
221        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
222        let err = open(&cfg).unwrap_err();
223        let msg = format!("{err}");
224        assert!(
225            msg.contains("--features onnx"),
226            "err should point at the feature rebuild; got: {msg}"
227        );
228    }
229
230    #[cfg(feature = "onnx")]
231    #[test]
232    fn parse_onnx_model_rejects_unknown() {
233        let err = parse_onnx_model("made-up-model").unwrap_err();
234        assert!(format!("{err}").contains("unknown onnx sparse model"));
235    }
236}