mnem_sparse_providers/
config.rs1#![allow(clippy::doc_markdown)]
4
5use std::sync::Arc;
9
10use mnem_core::sparse::{SparseEncoder, SparseError};
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "lowercase", tag = "provider")]
33pub enum ProviderConfig {
34 Sidecar(SidecarConfig),
38 Onnx(OnnxConfig),
43}
44
45#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
47pub struct SidecarConfig {
48 pub base_url: String,
52 pub model: String,
56 pub vocab_id: String,
60 #[serde(default = "default_timeout")]
62 pub timeout_secs: u64,
63}
64
65impl Default for SidecarConfig {
66 fn default() -> Self {
67 Self {
68 base_url: "http://localhost:8791".into(),
69 model: "opensearch-doc-v3-distill".into(),
70 vocab_id: "bert-base-uncased@30522".into(),
71 timeout_secs: default_timeout(),
72 }
73 }
74}
75
76const fn default_timeout() -> u64 {
77 30
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub struct OnnxConfig {
92 pub model: String,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub max_length: Option<usize>,
103}
104
105impl Default for OnnxConfig {
106 fn default() -> Self {
107 Self {
108 model: "opensearch-doc-v3-distill".into(),
109 max_length: None,
110 }
111 }
112}
113
114pub fn open(cfg: &ProviderConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
124 match cfg {
125 ProviderConfig::Sidecar(c) => {
126 let enc = crate::sidecar::SidecarSparseEncoder::from_config(c)?;
127 Ok(Arc::new(enc))
128 }
129 ProviderConfig::Onnx(c) => open_onnx(c),
130 }
131}
132
133#[cfg(feature = "onnx")]
134fn open_onnx(c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
135 let kind = parse_onnx_model(&c.model)?;
136 let enc = crate::onnx::OnnxSparseEncoder::with_max_length(kind, c.max_length)?;
137 Ok(Arc::new(enc))
138}
139
140#[cfg(not(feature = "onnx"))]
141fn open_onnx(_c: &OnnxConfig) -> Result<Arc<dyn SparseEncoder>, SparseError> {
142 Err(SparseError::Config(
143 "sparse.provider = \"onnx\" but this binary was built without the `onnx` feature. \
144 Rebuild with `--features onnx` or set sparse.provider = \"sidecar\"."
145 .into(),
146 ))
147}
148
149#[cfg(feature = "onnx")]
150fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, SparseError> {
151 use crate::onnx::ModelKind;
152 match s {
153 "opensearch-doc-v3-distill" => Ok(ModelKind::OpensearchDocV3Distill),
154 "opensearch-bi-v2-distill" => Ok(ModelKind::OpensearchBiV2Distill),
155 other => Err(SparseError::Config(format!(
156 "unknown onnx sparse model `{other}`; known: \
157 opensearch-doc-v3-distill, opensearch-bi-v2-distill"
158 ))),
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn sidecar_config_toml_round_trip() {
168 let cfg = ProviderConfig::Sidecar(SidecarConfig::default());
169 let s = toml::to_string(&cfg).unwrap();
170 assert!(s.contains("provider = \"sidecar\""));
171 let back: ProviderConfig = toml::from_str(&s).unwrap();
172 assert_eq!(cfg, back);
173 }
174
175 #[test]
176 fn sidecar_default_has_sane_values() {
177 let c = SidecarConfig::default();
178 assert!(c.base_url.starts_with("http"));
179 assert!(!c.model.is_empty());
180 assert!(!c.vocab_id.is_empty());
181 assert_eq!(c.timeout_secs, 30);
182 }
183
184 #[test]
185 fn onnx_config_toml_round_trip() {
186 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
187 let s = toml::to_string(&cfg).unwrap();
188 assert!(
189 s.contains("provider = \"onnx\""),
190 "onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
191 );
192 let back: ProviderConfig = toml::from_str(&s).unwrap();
193 assert_eq!(cfg, back);
194 }
195
196 #[test]
197 fn onnx_config_default_skips_max_length() {
198 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
199 let s = toml::to_string(&cfg).unwrap();
200 assert!(
201 !s.contains("max_length"),
202 "default OnnxConfig should not emit max_length (let the encoder pick). Got:\n{s}"
203 );
204 }
205
206 #[test]
207 fn onnx_config_max_length_round_trip() {
208 let cfg = ProviderConfig::Onnx(OnnxConfig {
209 model: "opensearch-bi-v2-distill".into(),
210 max_length: Some(256),
211 });
212 let s = toml::to_string(&cfg).unwrap();
213 assert!(s.contains("max_length = 256"));
214 let back: ProviderConfig = toml::from_str(&s).unwrap();
215 assert_eq!(cfg, back);
216 }
217
218 #[cfg(not(feature = "onnx"))]
219 #[test]
220 fn open_onnx_without_feature_returns_actionable_error() {
221 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
222 let err = open(&cfg).unwrap_err();
223 let msg = format!("{err}");
224 assert!(
225 msg.contains("--features onnx"),
226 "err should point at the feature rebuild; got: {msg}"
227 );
228 }
229
230 #[cfg(feature = "onnx")]
231 #[test]
232 fn parse_onnx_model_rejects_unknown() {
233 let err = parse_onnx_model("made-up-model").unwrap_err();
234 assert!(format!("{err}").contains("unknown onnx sparse model"));
235 }
236}