Skip to main content

llmsdk_provider/middleware/builtin/
default_embedding_settings.rs

1//! Fill missing [`EmbedOptions`] fields with provider-level defaults.
2//!
3//! Embedding analogue of [`super::default_settings`]. Surface is narrower:
4//! `headers` and `provider_options` are the only mergeable fields; `values`
5//! is caller-only.
6// Rust guideline compliant 2026-02-21
7
8use async_trait::async_trait;
9
10use crate::embedding_model::{EmbedOptions, EmbeddingModel};
11use crate::error::Result;
12use crate::middleware::embedding_model::EmbeddingModelMiddleware;
13use crate::shared::{Headers, ProviderOptions};
14
15/// Middleware applying a baseline [`EmbedOptions`] to every embed call.
16#[derive(Debug, Clone, Default)]
17pub struct DefaultEmbeddingSettingsMiddleware {
18    defaults: EmbedOptions,
19}
20
21impl DefaultEmbeddingSettingsMiddleware {
22    /// Build with the given default options.
23    #[must_use]
24    pub fn new(defaults: EmbedOptions) -> Self {
25        Self { defaults }
26    }
27}
28
29#[async_trait]
30impl EmbeddingModelMiddleware for DefaultEmbeddingSettingsMiddleware {
31    async fn transform_params(
32        &self,
33        params: EmbedOptions,
34        _inner: &dyn EmbeddingModel,
35    ) -> Result<EmbedOptions> {
36        Ok(EmbedOptions {
37            values: if params.values.is_empty() {
38                self.defaults.values.clone()
39            } else {
40                params.values
41            },
42            headers: merge_headers(self.defaults.headers.clone(), params.headers),
43            provider_options: merge_provider_options(
44                self.defaults.provider_options.clone(),
45                params.provider_options,
46            ),
47        })
48    }
49}
50
51fn merge_headers(default: Option<Headers>, caller: Option<Headers>) -> Option<Headers> {
52    match (default, caller) {
53        (None, c) => c,
54        (Some(d), None) => Some(d),
55        (Some(mut d), Some(c)) => {
56            d.extend(c);
57            Some(d)
58        }
59    }
60}
61
62fn merge_provider_options(
63    default: Option<ProviderOptions>,
64    caller: Option<ProviderOptions>,
65) -> Option<ProviderOptions> {
66    match (default, caller) {
67        (None, c) => c,
68        (Some(d), None) => Some(d),
69        (Some(mut d), Some(c)) => {
70            for (provider, caller_inner) in c {
71                let entry = d.entry(provider).or_default();
72                for (k, v) in caller_inner {
73                    entry.insert(k, v);
74                }
75            }
76            Some(d)
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use std::sync::{Arc, Mutex};
84
85    use super::*;
86    use crate::embedding_model::EmbedResult;
87    use crate::middleware::wrap_embedding_model;
88
89    #[derive(Debug, Default)]
90    struct Recorder(Mutex<Option<EmbedOptions>>);
91
92    #[async_trait]
93    impl EmbeddingModel for Recorder {
94        fn provider(&self) -> &'static str {
95            "rec"
96        }
97        fn model_id(&self) -> &'static str {
98            "rec"
99        }
100        async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
101            *self.0.lock().expect("mutex") = Some(options);
102            Ok(EmbedResult {
103                embeddings: vec![],
104                usage: None,
105                provider_metadata: None,
106                request: None,
107                response: None,
108            })
109        }
110    }
111
112    #[tokio::test]
113    async fn defaults_fill_missing_provider_options() {
114        let rec = Arc::new(Recorder::default());
115        let mut po = ProviderOptions::default();
116        po.insert(
117            "openai".into(),
118            serde_json::json!({"dimensions": 256})
119                .as_object()
120                .cloned()
121                .unwrap(),
122        );
123        let defaults = EmbedOptions {
124            provider_options: Some(po),
125            ..Default::default()
126        };
127        let wrapped = wrap_embedding_model(
128            Arc::clone(&rec) as Arc<dyn EmbeddingModel>,
129            [Arc::new(DefaultEmbeddingSettingsMiddleware::new(defaults))
130                as Arc<dyn EmbeddingModelMiddleware>],
131        );
132
133        wrapped
134            .do_embed(EmbedOptions {
135                values: vec!["x".into()],
136                ..Default::default()
137            })
138            .await
139            .expect("embed");
140
141        let captured = rec.0.lock().expect("mutex").clone().expect("params");
142        let po = captured.provider_options.expect("po set");
143        let openai = po.get("openai").expect("openai key");
144        assert_eq!(
145            openai.get("dimensions").and_then(serde_json::Value::as_i64),
146            Some(256)
147        );
148    }
149}