llmsdk_provider/middleware/builtin/
default_settings.rs1use async_trait::async_trait;
11
12use crate::error::Result;
13use crate::language_model::{CallOptions, LanguageModel};
14use crate::middleware::language_model::{CallKind, LanguageModelMiddleware};
15use crate::shared::{Headers, ProviderOptions};
16
17#[derive(Debug, Clone)]
22pub struct DefaultSettingsMiddleware {
23 defaults: CallOptions,
24}
25
26impl DefaultSettingsMiddleware {
27 #[must_use]
29 pub fn new(defaults: CallOptions) -> Self {
30 Self { defaults }
31 }
32}
33
34#[async_trait]
35impl LanguageModelMiddleware for DefaultSettingsMiddleware {
36 async fn transform_params(
37 &self,
38 _kind: CallKind,
39 params: CallOptions,
40 _inner: &dyn LanguageModel,
41 ) -> Result<CallOptions> {
42 Ok(merge_call_options(self.defaults.clone(), params))
43 }
44}
45
46fn merge_call_options(default: CallOptions, caller: CallOptions) -> CallOptions {
47 CallOptions {
48 prompt: if caller.prompt.is_empty() {
49 default.prompt
50 } else {
51 caller.prompt
52 },
53 max_output_tokens: caller.max_output_tokens.or(default.max_output_tokens),
54 temperature: caller.temperature.or(default.temperature),
55 stop_sequences: caller.stop_sequences.or(default.stop_sequences),
56 top_p: caller.top_p.or(default.top_p),
57 top_k: caller.top_k.or(default.top_k),
58 presence_penalty: caller.presence_penalty.or(default.presence_penalty),
59 frequency_penalty: caller.frequency_penalty.or(default.frequency_penalty),
60 response_format: caller.response_format.or(default.response_format),
61 seed: caller.seed.or(default.seed),
62 tools: caller.tools.or(default.tools),
63 tool_choice: caller.tool_choice.or(default.tool_choice),
64 include_raw_chunks: caller.include_raw_chunks.or(default.include_raw_chunks),
65 headers: merge_headers(default.headers, caller.headers),
66 reasoning: caller.reasoning.or(default.reasoning),
67 provider_options: merge_provider_options(default.provider_options, caller.provider_options),
68 }
69}
70
71fn merge_headers(default: Option<Headers>, caller: Option<Headers>) -> Option<Headers> {
72 match (default, caller) {
73 (None, c) => c,
74 (Some(d), None) => Some(d),
75 (Some(mut d), Some(c)) => {
76 d.extend(c);
77 Some(d)
78 }
79 }
80}
81
82fn merge_provider_options(
83 default: Option<ProviderOptions>,
84 caller: Option<ProviderOptions>,
85) -> Option<ProviderOptions> {
86 match (default, caller) {
87 (None, c) => c,
88 (Some(d), None) => Some(d),
89 (Some(mut d), Some(c)) => {
90 for (provider, caller_inner) in c {
91 let entry = d.entry(provider).or_default();
92 for (k, v) in caller_inner {
93 match entry.remove(&k) {
94 Some(base) => {
95 entry.insert(k, deep_merge_value(base, v));
102 }
103 None => {
104 entry.insert(k, v);
105 }
106 }
107 }
108 }
109 Some(d)
110 }
111 }
112}
113
114fn deep_merge_value(base: serde_json::Value, overrides: serde_json::Value) -> serde_json::Value {
119 use serde_json::Value;
120 match (base, overrides) {
121 (Value::Object(mut b), Value::Object(o)) => {
122 for (k, v) in o {
123 match b.remove(&k) {
124 Some(base_v) => {
125 b.insert(k, deep_merge_value(base_v, v));
126 }
127 None => {
128 b.insert(k, v);
129 }
130 }
131 }
132 Value::Object(b)
133 }
134 (_, overrides) => overrides,
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use std::sync::{Arc, Mutex};
141
142 use super::*;
143 use crate::language_model::{GenerateResult, Message, Prompt, StreamResult};
144 use crate::middleware::wrap_language_model;
145
146 #[derive(Debug, Default)]
147 struct Recorder(Mutex<Option<CallOptions>>);
148
149 #[async_trait]
150 impl LanguageModel for Recorder {
151 fn provider(&self) -> &'static str {
152 "rec"
153 }
154 fn model_id(&self) -> &'static str {
155 "rec"
156 }
157 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
158 *self.0.lock().expect("mutex") = Some(options);
159 Ok(GenerateResult {
160 content: vec![],
161 finish_reason: crate::language_model::FinishReason::new(
162 crate::language_model::FinishReasonKind::Stop,
163 ),
164 usage: crate::language_model::Usage::default(),
165 provider_metadata: None,
166 request: None,
167 response: None,
168 warnings: vec![],
169 })
170 }
171 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
172 unimplemented!()
173 }
174 }
175
176 fn user_prompt() -> Prompt {
177 vec![Message::System {
178 content: "sys".into(),
179 provider_options: None,
180 }]
181 }
182
183 #[tokio::test]
184 async fn caller_fills_missing_fields_from_defaults() {
185 let rec = Arc::new(Recorder::default());
186 let defaults = CallOptions {
187 temperature: Some(0.7),
188 max_output_tokens: Some(1024),
189 ..Default::default()
190 };
191 let wrapped = wrap_language_model(
192 Arc::clone(&rec) as Arc<dyn LanguageModel>,
193 [Arc::new(DefaultSettingsMiddleware::new(defaults))
194 as Arc<dyn LanguageModelMiddleware>],
195 );
196
197 wrapped
198 .do_generate(CallOptions {
199 prompt: user_prompt(),
200 temperature: Some(0.1),
201 ..Default::default()
202 })
203 .await
204 .expect("generate");
205
206 let captured = rec.0.lock().expect("mutex").clone().expect("params");
207 assert_eq!(captured.temperature, Some(0.1), "caller wins");
208 assert_eq!(captured.max_output_tokens, Some(1024), "default filled");
209 }
210
211 #[tokio::test]
212 async fn provider_options_merge_is_deep_recursive() {
213 let rec = Arc::new(Recorder::default());
221
222 let mut defaults_inner = serde_json::Map::new();
223 defaults_inner.insert(
224 "feature".into(),
225 serde_json::json!({ "enabled": true, "cache": true }),
226 );
227 let mut defaults_po = ProviderOptions::new();
228 defaults_po.insert("anthropic".into(), defaults_inner);
229
230 let defaults = CallOptions {
231 provider_options: Some(defaults_po),
232 ..Default::default()
233 };
234 let wrapped = wrap_language_model(
235 Arc::clone(&rec) as Arc<dyn LanguageModel>,
236 [Arc::new(DefaultSettingsMiddleware::new(defaults))
237 as Arc<dyn LanguageModelMiddleware>],
238 );
239
240 let mut caller_inner = serde_json::Map::new();
241 caller_inner.insert("feature".into(), serde_json::json!({ "enabled": false }));
242 let mut caller_po = ProviderOptions::new();
243 caller_po.insert("anthropic".into(), caller_inner);
244
245 wrapped
246 .do_generate(CallOptions {
247 prompt: user_prompt(),
248 provider_options: Some(caller_po),
249 ..Default::default()
250 })
251 .await
252 .expect("generate");
253
254 let captured = rec.0.lock().expect("mutex").clone().expect("params");
255 let merged = captured.provider_options.expect("provider_options merged");
256 let anthropic = merged.get("anthropic").expect("anthropic key present");
257 let feature = anthropic.get("feature").expect("feature key present");
258 assert_eq!(feature["enabled"], false, "caller override survives");
259 assert_eq!(
260 feature["cache"], true,
261 "sibling key from defaults must survive deep merge"
262 );
263 }
264}