llmsdk_provider/middleware/
provider.rs1use std::sync::Arc;
14
15use crate::error::Result;
16use crate::image_model::ImageModel;
17use crate::language_model::LanguageModel;
18use crate::provider::{DynEmbeddingModel, DynImageModel, DynLanguageModel, Provider};
19
20use super::image_model::{ImageModelMiddleware, wrap_image_model};
21use super::language_model::{LanguageModelMiddleware, wrap_language_model};
22
23#[derive(Default, Clone)]
34pub struct ProviderMiddlewareSet {
35 pub language: Vec<Arc<dyn LanguageModelMiddleware>>,
38 pub image: Vec<Arc<dyn ImageModelMiddleware>>,
41}
42
43impl std::fmt::Debug for ProviderMiddlewareSet {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("ProviderMiddlewareSet")
46 .field("language", &self.language.len())
47 .field("image", &self.image.len())
48 .finish()
49 }
50}
51
52pub fn wrap_provider(inner: Arc<dyn Provider>, set: ProviderMiddlewareSet) -> Arc<dyn Provider> {
63 Arc::new(WrappedProvider { inner, set })
64}
65
66struct WrappedProvider {
67 inner: Arc<dyn Provider>,
68 set: ProviderMiddlewareSet,
69}
70
71impl std::fmt::Debug for WrappedProvider {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("WrappedProvider")
74 .field("inner", &self.inner)
75 .field("middleware", &self.set)
76 .finish()
77 }
78}
79
80impl Provider for WrappedProvider {
81 fn language_model(&self, model_id: &str) -> Result<DynLanguageModel> {
82 let dyn_model = self.inner.language_model(model_id)?;
83 if self.set.language.is_empty() {
84 return Ok(dyn_model);
85 }
86 let arc: Arc<dyn LanguageModel> = dyn_model.into_inner();
87 let wrapped = wrap_language_model(arc, self.set.language.iter().cloned());
88 Ok(DynLanguageModel::from_arc(wrapped))
89 }
90
91 fn embedding_model(&self, model_id: &str) -> Result<DynEmbeddingModel> {
92 self.inner.embedding_model(model_id)
97 }
98
99 fn image_model(&self, model_id: &str) -> Result<DynImageModel> {
100 let dyn_model = self.inner.image_model(model_id)?;
101 if self.set.image.is_empty() {
102 return Ok(dyn_model);
103 }
104 let arc: Arc<dyn ImageModel> = dyn_model.into_inner();
105 let wrapped = wrap_image_model(arc, self.set.image.iter().cloned());
106 Ok(DynImageModel::from_arc(wrapped))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Mutex;
113 use std::sync::atomic::{AtomicUsize, Ordering};
114
115 use async_trait::async_trait;
116
117 use super::*;
118 use crate::embedding_model::{EmbedOptions, EmbedResult, EmbeddingModel};
119 use crate::language_model::{
120 CallOptions, FinishReason, FinishReasonKind, GenerateResult, StreamResult, Usage,
121 };
122
123 #[derive(Debug, Default)]
124 struct StubLang;
125
126 #[async_trait]
127 impl LanguageModel for StubLang {
128 fn provider(&self) -> &'static str {
129 "stub"
130 }
131 fn model_id(&self) -> &'static str {
132 "stub-lm"
133 }
134 async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
135 Ok(GenerateResult {
136 content: vec![],
137 finish_reason: FinishReason::new(FinishReasonKind::Stop),
138 usage: Usage::default(),
139 provider_metadata: None,
140 request: None,
141 response: None,
142 warnings: vec![],
143 })
144 }
145 async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
146 Ok(StreamResult {
147 stream: Box::pin(futures::stream::iter(vec![])),
148 request: None,
149 response: None,
150 })
151 }
152 }
153
154 #[derive(Debug, Default)]
155 struct StubEmbed;
156
157 #[async_trait]
158 impl EmbeddingModel for StubEmbed {
159 fn provider(&self) -> &'static str {
160 "stub"
161 }
162 fn model_id(&self) -> &'static str {
163 "stub-em"
164 }
165 async fn do_embed(&self, _opts: EmbedOptions) -> Result<EmbedResult> {
166 Ok(EmbedResult {
167 embeddings: vec![],
168 usage: None,
169 provider_metadata: None,
170 request: None,
171 response: None,
172 })
173 }
174 }
175
176 #[derive(Debug, Default)]
177 struct StubProvider;
178
179 impl Provider for StubProvider {
180 fn language_model(&self, _model_id: &str) -> Result<DynLanguageModel> {
181 Ok(DynLanguageModel::new(StubLang))
182 }
183 fn embedding_model(&self, _model_id: &str) -> Result<DynEmbeddingModel> {
184 Ok(DynEmbeddingModel::new(StubEmbed))
185 }
186 }
187
188 #[derive(Debug, Default)]
190 struct Counter {
191 lang_calls: AtomicUsize,
192 embed_calls: AtomicUsize,
193 last_temp: Mutex<Option<f32>>,
194 }
195
196 #[derive(Debug)]
197 struct CountingLang(Arc<Counter>);
198
199 #[async_trait]
200 impl LanguageModelMiddleware for CountingLang {
201 async fn transform_params(
202 &self,
203 _kind: super::super::language_model::CallKind,
204 mut params: CallOptions,
205 _inner: &dyn LanguageModel,
206 ) -> Result<CallOptions> {
207 self.0.lang_calls.fetch_add(1, Ordering::SeqCst);
208 params.temperature = Some(0.5);
209 *self.0.last_temp.lock().expect("mutex") = params.temperature;
210 Ok(params)
211 }
212 }
213
214 #[tokio::test]
215 async fn wraps_language_surface_only_embedding_passes_through() {
216 let counter = Arc::new(Counter::default());
220 let set = ProviderMiddlewareSet {
221 language: vec![Arc::new(CountingLang(Arc::clone(&counter)))],
222 image: vec![],
223 };
224 let wrapped = wrap_provider(Arc::new(StubProvider), set);
225
226 let lm = wrapped.language_model("anything").expect("language");
227 lm.do_generate(CallOptions::default())
228 .await
229 .expect("generate");
230 assert_eq!(counter.lang_calls.load(Ordering::SeqCst), 1);
231 assert_eq!(*counter.last_temp.lock().expect("mutex"), Some(0.5));
232
233 let em = wrapped.embedding_model("anything").expect("embed");
236 em.do_embed(EmbedOptions::default()).await.expect("embed");
237 assert_eq!(counter.embed_calls.load(Ordering::SeqCst), 0);
239 }
240
241 #[tokio::test]
242 async fn unsupported_surface_propagates_inner_error() {
243 let set = ProviderMiddlewareSet::default();
244 let wrapped = wrap_provider(Arc::new(StubProvider), set);
245 let err = wrapped.image_model("x").expect_err("inner unsupported");
246 assert!(err.is_unsupported());
247 }
248}