llmsdk_provider/middleware/builtin/
default_embedding_settings.rs1use async_trait::async_trait;
9
10use crate::embedding_model::{EmbedOptions, EmbeddingModel};
11use crate::error::Result;
12use crate::middleware::embedding_model::EmbeddingModelMiddleware;
13use crate::shared::{Headers, ProviderOptions};
14
15#[derive(Debug, Clone, Default)]
17pub struct DefaultEmbeddingSettingsMiddleware {
18 defaults: EmbedOptions,
19}
20
21impl DefaultEmbeddingSettingsMiddleware {
22 #[must_use]
24 pub fn new(defaults: EmbedOptions) -> Self {
25 Self { defaults }
26 }
27}
28
29#[async_trait]
30impl EmbeddingModelMiddleware for DefaultEmbeddingSettingsMiddleware {
31 async fn transform_params(
32 &self,
33 params: EmbedOptions,
34 _inner: &dyn EmbeddingModel,
35 ) -> Result<EmbedOptions> {
36 Ok(EmbedOptions {
37 values: if params.values.is_empty() {
38 self.defaults.values.clone()
39 } else {
40 params.values
41 },
42 headers: merge_headers(self.defaults.headers.clone(), params.headers),
43 provider_options: merge_provider_options(
44 self.defaults.provider_options.clone(),
45 params.provider_options,
46 ),
47 })
48 }
49}
50
51fn merge_headers(default: Option<Headers>, caller: Option<Headers>) -> Option<Headers> {
52 match (default, caller) {
53 (None, c) => c,
54 (Some(d), None) => Some(d),
55 (Some(mut d), Some(c)) => {
56 d.extend(c);
57 Some(d)
58 }
59 }
60}
61
62fn merge_provider_options(
63 default: Option<ProviderOptions>,
64 caller: Option<ProviderOptions>,
65) -> Option<ProviderOptions> {
66 match (default, caller) {
67 (None, c) => c,
68 (Some(d), None) => Some(d),
69 (Some(mut d), Some(c)) => {
70 for (provider, caller_inner) in c {
71 let entry = d.entry(provider).or_default();
72 for (k, v) in caller_inner {
73 entry.insert(k, v);
74 }
75 }
76 Some(d)
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use std::sync::{Arc, Mutex};
84
85 use super::*;
86 use crate::embedding_model::EmbedResult;
87 use crate::middleware::wrap_embedding_model;
88
89 #[derive(Debug, Default)]
90 struct Recorder(Mutex<Option<EmbedOptions>>);
91
92 #[async_trait]
93 impl EmbeddingModel for Recorder {
94 fn provider(&self) -> &'static str {
95 "rec"
96 }
97 fn model_id(&self) -> &'static str {
98 "rec"
99 }
100 async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
101 *self.0.lock().expect("mutex") = Some(options);
102 Ok(EmbedResult {
103 embeddings: vec![],
104 usage: None,
105 provider_metadata: None,
106 request: None,
107 response: None,
108 })
109 }
110 }
111
112 #[tokio::test]
113 async fn defaults_fill_missing_provider_options() {
114 let rec = Arc::new(Recorder::default());
115 let mut po = ProviderOptions::default();
116 po.insert(
117 "openai".into(),
118 serde_json::json!({"dimensions": 256})
119 .as_object()
120 .cloned()
121 .unwrap(),
122 );
123 let defaults = EmbedOptions {
124 provider_options: Some(po),
125 ..Default::default()
126 };
127 let wrapped = wrap_embedding_model(
128 Arc::clone(&rec) as Arc<dyn EmbeddingModel>,
129 [Arc::new(DefaultEmbeddingSettingsMiddleware::new(defaults))
130 as Arc<dyn EmbeddingModelMiddleware>],
131 );
132
133 wrapped
134 .do_embed(EmbedOptions {
135 values: vec!["x".into()],
136 ..Default::default()
137 })
138 .await
139 .expect("embed");
140
141 let captured = rec.0.lock().expect("mutex").clone().expect("params");
142 let po = captured.provider_options.expect("po set");
143 let openai = po.get("openai").expect("openai key");
144 assert_eq!(
145 openai.get("dimensions").and_then(serde_json::Value::as_i64),
146 Some(256)
147 );
148 }
149}