Skip to main content

llmsdk_provider/middleware/builtin/
default_settings.rs

1//! Fill missing [`CallOptions`] fields with provider-level defaults.
2//!
3//! Mirrors `@ai-sdk/ai/src/middleware/default-settings-middleware.ts`. Caller
4//! values always win — defaults only apply to fields the caller left `None` /
5//! unspecified. `prompt`, `tools` (when present), `tool_choice`, `headers` and
6//! `provider_options` are *merged*: caller wins per-key, defaults supply
7//! the rest.
8// Rust guideline compliant 2026-02-21
9
10use async_trait::async_trait;
11
12use crate::error::Result;
13use crate::language_model::{CallOptions, LanguageModel};
14use crate::middleware::language_model::{CallKind, LanguageModelMiddleware};
15use crate::shared::{Headers, ProviderOptions};
16
17/// Middleware applying a baseline [`CallOptions`] to every call.
18///
19/// Construct once with the defaults you want and attach via
20/// [`crate::wrap_language_model`].
21#[derive(Debug, Clone)]
22pub struct DefaultSettingsMiddleware {
23    defaults: CallOptions,
24}
25
26impl DefaultSettingsMiddleware {
27    /// Build with the given default options.
28    #[must_use]
29    pub fn new(defaults: CallOptions) -> Self {
30        Self { defaults }
31    }
32}
33
34#[async_trait]
35impl LanguageModelMiddleware for DefaultSettingsMiddleware {
36    async fn transform_params(
37        &self,
38        _kind: CallKind,
39        params: CallOptions,
40        _inner: &dyn LanguageModel,
41    ) -> Result<CallOptions> {
42        Ok(merge_call_options(self.defaults.clone(), params))
43    }
44}
45
46fn merge_call_options(default: CallOptions, caller: CallOptions) -> CallOptions {
47    CallOptions {
48        prompt: if caller.prompt.is_empty() {
49            default.prompt
50        } else {
51            caller.prompt
52        },
53        max_output_tokens: caller.max_output_tokens.or(default.max_output_tokens),
54        temperature: caller.temperature.or(default.temperature),
55        stop_sequences: caller.stop_sequences.or(default.stop_sequences),
56        top_p: caller.top_p.or(default.top_p),
57        top_k: caller.top_k.or(default.top_k),
58        presence_penalty: caller.presence_penalty.or(default.presence_penalty),
59        frequency_penalty: caller.frequency_penalty.or(default.frequency_penalty),
60        response_format: caller.response_format.or(default.response_format),
61        seed: caller.seed.or(default.seed),
62        tools: caller.tools.or(default.tools),
63        tool_choice: caller.tool_choice.or(default.tool_choice),
64        include_raw_chunks: caller.include_raw_chunks.or(default.include_raw_chunks),
65        headers: merge_headers(default.headers, caller.headers),
66        reasoning: caller.reasoning.or(default.reasoning),
67        provider_options: merge_provider_options(default.provider_options, caller.provider_options),
68    }
69}
70
71fn merge_headers(default: Option<Headers>, caller: Option<Headers>) -> Option<Headers> {
72    match (default, caller) {
73        (None, c) => c,
74        (Some(d), None) => Some(d),
75        (Some(mut d), Some(c)) => {
76            d.extend(c);
77            Some(d)
78        }
79    }
80}
81
82fn merge_provider_options(
83    default: Option<ProviderOptions>,
84    caller: Option<ProviderOptions>,
85) -> Option<ProviderOptions> {
86    match (default, caller) {
87        (None, c) => c,
88        (Some(d), None) => Some(d),
89        (Some(mut d), Some(c)) => {
90            for (provider, caller_inner) in c {
91                let entry = d.entry(provider).or_default();
92                for (k, v) in caller_inner {
93                    match entry.remove(&k) {
94                        Some(base) => {
95                            // Mirror upstream `mergeObjects` deep recursion
96                            // (`packages/ai/src/util/merge-objects.ts:14-84`):
97                            // when both sides are JSON objects (not arrays,
98                            // not dates), merge recursively so per-feature
99                            // overrides do not clobber sibling keys the
100                            // caller did not mention.
101                            entry.insert(k, deep_merge_value(base, v));
102                        }
103                        None => {
104                            entry.insert(k, v);
105                        }
106                    }
107                }
108            }
109            Some(d)
110        }
111    }
112}
113
114/// Deep merge two JSON values mirroring upstream `mergeObjects`.
115///
116/// When both sides are JSON objects, recurse per-key. Otherwise the
117/// `overrides` value wins. Arrays / scalars / nulls do not merge.
118fn deep_merge_value(base: serde_json::Value, overrides: serde_json::Value) -> serde_json::Value {
119    use serde_json::Value;
120    match (base, overrides) {
121        (Value::Object(mut b), Value::Object(o)) => {
122            for (k, v) in o {
123                match b.remove(&k) {
124                    Some(base_v) => {
125                        b.insert(k, deep_merge_value(base_v, v));
126                    }
127                    None => {
128                        b.insert(k, v);
129                    }
130                }
131            }
132            Value::Object(b)
133        }
134        (_, overrides) => overrides,
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use std::sync::{Arc, Mutex};
141
142    use super::*;
143    use crate::language_model::{GenerateResult, Message, Prompt, StreamResult};
144    use crate::middleware::wrap_language_model;
145
146    #[derive(Debug, Default)]
147    struct Recorder(Mutex<Option<CallOptions>>);
148
149    #[async_trait]
150    impl LanguageModel for Recorder {
151        fn provider(&self) -> &'static str {
152            "rec"
153        }
154        fn model_id(&self) -> &'static str {
155            "rec"
156        }
157        async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
158            *self.0.lock().expect("mutex") = Some(options);
159            Ok(GenerateResult {
160                content: vec![],
161                finish_reason: crate::language_model::FinishReason::new(
162                    crate::language_model::FinishReasonKind::Stop,
163                ),
164                usage: crate::language_model::Usage::default(),
165                provider_metadata: None,
166                request: None,
167                response: None,
168                warnings: vec![],
169            })
170        }
171        async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
172            unimplemented!()
173        }
174    }
175
176    fn user_prompt() -> Prompt {
177        vec![Message::System {
178            content: "sys".into(),
179            provider_options: None,
180        }]
181    }
182
183    #[tokio::test]
184    async fn caller_fills_missing_fields_from_defaults() {
185        let rec = Arc::new(Recorder::default());
186        let defaults = CallOptions {
187            temperature: Some(0.7),
188            max_output_tokens: Some(1024),
189            ..Default::default()
190        };
191        let wrapped = wrap_language_model(
192            Arc::clone(&rec) as Arc<dyn LanguageModel>,
193            [Arc::new(DefaultSettingsMiddleware::new(defaults))
194                as Arc<dyn LanguageModelMiddleware>],
195        );
196
197        wrapped
198            .do_generate(CallOptions {
199                prompt: user_prompt(),
200                temperature: Some(0.1),
201                ..Default::default()
202            })
203            .await
204            .expect("generate");
205
206        let captured = rec.0.lock().expect("mutex").clone().expect("params");
207        assert_eq!(captured.temperature, Some(0.1), "caller wins");
208        assert_eq!(captured.max_output_tokens, Some(1024), "default filled");
209    }
210
211    #[tokio::test]
212    async fn provider_options_merge_is_deep_recursive() {
213        // Mirrors upstream `mergeObjects` semantics tested in
214        // `packages/ai/src/util/merge-objects.test.ts`: when the caller
215        // overrides a nested key, sibling keys at the *same nested level*
216        // must survive from the defaults — a shallow per-key insert would
217        // wipe them out. Catches the prior bug where Rust did per-key
218        // insert on the inner Map but did not recurse into the JSON
219        // value payload.
220        let rec = Arc::new(Recorder::default());
221
222        let mut defaults_inner = serde_json::Map::new();
223        defaults_inner.insert(
224            "feature".into(),
225            serde_json::json!({ "enabled": true, "cache": true }),
226        );
227        let mut defaults_po = ProviderOptions::new();
228        defaults_po.insert("anthropic".into(), defaults_inner);
229
230        let defaults = CallOptions {
231            provider_options: Some(defaults_po),
232            ..Default::default()
233        };
234        let wrapped = wrap_language_model(
235            Arc::clone(&rec) as Arc<dyn LanguageModel>,
236            [Arc::new(DefaultSettingsMiddleware::new(defaults))
237                as Arc<dyn LanguageModelMiddleware>],
238        );
239
240        let mut caller_inner = serde_json::Map::new();
241        caller_inner.insert("feature".into(), serde_json::json!({ "enabled": false }));
242        let mut caller_po = ProviderOptions::new();
243        caller_po.insert("anthropic".into(), caller_inner);
244
245        wrapped
246            .do_generate(CallOptions {
247                prompt: user_prompt(),
248                provider_options: Some(caller_po),
249                ..Default::default()
250            })
251            .await
252            .expect("generate");
253
254        let captured = rec.0.lock().expect("mutex").clone().expect("params");
255        let merged = captured.provider_options.expect("provider_options merged");
256        let anthropic = merged.get("anthropic").expect("anthropic key present");
257        let feature = anthropic.get("feature").expect("feature key present");
258        assert_eq!(feature["enabled"], false, "caller override survives");
259        assert_eq!(
260            feature["cache"], true,
261            "sibling key from defaults must survive deep merge"
262        );
263    }
264}