1use super::support::*;
2
3#[derive(Debug)]
4struct SharedProviderComponent<T> {
5 inner: Arc<Mutex<T>>,
6}
7
8impl<T> Clone for SharedProviderComponent<T> {
9 fn clone(&self) -> Self {
10 Self {
11 inner: Arc::clone(&self.inner),
12 }
13 }
14}
15
16impl<T> SharedProviderComponent<T> {
17 fn new(inner: Arc<Mutex<T>>) -> Self {
18 Self { inner }
19 }
20}
21
22impl<T> ProviderState for SharedProviderComponent<T>
23where
24 T: ProviderState + Clone + Send + Sync + std::fmt::Debug + 'static,
25{
26 fn kind(&self) -> &'static str {
27 self.inner.lock().expect("provider state lock").kind()
28 }
29
30 fn options(&self) -> ProviderOptions {
31 self.inner.lock().expect("provider state lock").options()
32 }
33
34 fn set_options(&mut self, options: ProviderOptions) {
35 self.inner
36 .lock()
37 .expect("provider state lock")
38 .set_options(options);
39 }
40
41 fn serialize_config(&self) -> serde_json::Value {
42 self.inner
43 .lock()
44 .expect("provider state lock")
45 .serialize_config()
46 }
47
48 fn clone_boxed(&self) -> Box<dyn ProviderState> {
49 Box::new(self.clone())
50 }
51}
52
53#[async_trait]
54impl<T> ProviderTransport for SharedProviderComponent<T>
55where
56 T: ProviderTransport + Clone + Send + Sync + std::fmt::Debug + 'static,
57{
58 async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
59 let mut provider = self.inner.lock().expect("provider transport lock").clone();
60 let result = provider.complete(request).await;
61 *self.inner.lock().expect("provider transport lock") = provider;
62 result
63 }
64
65 fn requires_streaming(&self) -> bool {
66 self.inner
67 .lock()
68 .expect("provider transport lock")
69 .requires_streaming()
70 }
71
72 fn clone_boxed(&self) -> Box<dyn ProviderTransport> {
73 Box::new(self.clone())
74 }
75}
76
77#[derive(Debug)]
79pub struct ProviderComponents {
80 pub state: Box<dyn ProviderState>,
81 pub transport: Box<dyn ProviderTransport>,
82 pub model_policy: Arc<dyn ProviderModelPolicy>,
83 pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
84 pub rate_limiter: Arc<ProviderRateLimiter>,
85}
86
87impl ProviderComponents {
88 pub fn new(
89 state: Box<dyn ProviderState>,
90 transport: Box<dyn ProviderTransport>,
91 model_policy: Arc<dyn ProviderModelPolicy>,
92 ) -> Self {
93 let options = state.options();
94 Self {
95 state,
96 transport,
97 model_policy,
98 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
99 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
100 }
101 }
102
103 pub fn shared<T>(provider: T, model_policy: Arc<dyn ProviderModelPolicy>) -> Self
104 where
105 T: ProviderState + ProviderTransport + Clone + Send + Sync + std::fmt::Debug + 'static,
106 {
107 let inner = Arc::new(Mutex::new(provider));
108 let options = inner.lock().expect("provider state lock").options();
109 Self {
110 state: Box::new(SharedProviderComponent::new(Arc::clone(&inner))),
111 transport: Box::new(SharedProviderComponent::new(inner)),
112 model_policy,
113 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
114 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
115 }
116 }
117
118 pub fn map_transport(
119 mut self,
120 map: impl FnOnce(Box<dyn ProviderTransport>) -> Box<dyn ProviderTransport>,
121 ) -> Self {
122 self.transport = map(self.transport);
123 self
124 }
125
126 pub fn with_failure_classifier(
127 mut self,
128 classifier: Arc<dyn ProviderFailureClassifier>,
129 ) -> Self {
130 self.failure_classifier = classifier;
131 self
132 }
133}
134
135impl Clone for ProviderComponents {
136 fn clone(&self) -> Self {
137 Self {
138 state: self.state.clone_boxed(),
139 transport: self.transport.clone_boxed(),
140 model_policy: Arc::clone(&self.model_policy),
141 failure_classifier: Arc::clone(&self.failure_classifier),
142 rate_limiter: Arc::clone(&self.rate_limiter),
143 }
144 }
145}
146
147pub struct ProviderHandle {
151 components: ProviderComponents,
152}
153
154impl ProviderHandle {
155 pub fn new(components: ProviderComponents) -> Self {
156 Self { components }
157 }
158
159 pub fn components(&self) -> &ProviderComponents {
160 &self.components
161 }
162
163 pub fn components_mut(&mut self) -> &mut ProviderComponents {
164 &mut self.components
165 }
166
167 pub fn kind(&self) -> &'static str {
168 self.components.state.kind()
169 }
170
171 pub fn default_model(&self) -> &str {
172 self.components.model_policy.default_model()
173 }
174
175 pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
176 self.components.model_policy.supported_variants(model)
177 }
178
179 pub fn default_model_variant(&self, model: &str) -> Option<&'static str> {
180 self.components.model_policy.default_model_variant(model)
181 }
182
183 pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
184 let variants = self.supported_variants(model);
185 if variants.is_empty() {
186 return Err(format!(
187 "Model `{}` on {} does not expose configurable variants.",
188 model,
189 self.kind()
190 ));
191 }
192 if variants.contains(&variant) {
193 return Ok(());
194 }
195 Err(format!(
196 "Unsupported variant `{}` for `{}` on {}. Available: {}",
197 variant,
198 model,
199 self.kind(),
200 variants.join(", ")
201 ))
202 }
203
204 pub fn request_variant_config(
205 &self,
206 model: &str,
207 variant: &str,
208 ) -> Option<VariantRequestConfig> {
209 self.components
210 .model_policy
211 .request_variant_config(model, variant)
212 }
213
214 pub fn default_agent_model(&self, tier: &str) -> Option<AgentModelSelection> {
215 self.components.model_policy.default_agent_model(tier)
216 }
217
218 pub fn resolve_model(&self, model: &str) -> String {
219 self.components.model_policy.resolve_model(model)
220 }
221
222 pub fn context_lookup_model(&self, model: &str) -> String {
223 self.components.model_policy.context_lookup_model(model)
224 }
225
226 pub fn input_usage_excludes_cached_tokens(&self) -> bool {
227 self.components
228 .model_policy
229 .input_usage_excludes_cached_tokens()
230 }
231
232 pub fn options(&self) -> ProviderOptions {
233 self.components.state.options()
234 }
235
236 pub fn set_options(&mut self, options: ProviderOptions) {
237 self.components
238 .rate_limiter
239 .configure(options.reliability.rate_limits.clone());
240 self.components.state.set_options(options)
241 }
242
243 pub fn requires_streaming(&self) -> bool {
244 self.components.transport.requires_streaming()
245 }
246
247 pub async fn complete(
248 &mut self,
249 request: LlmRequest,
250 ) -> Result<LlmResponse, LlmTransportError> {
251 let reliability = self.options().reliability;
252 let attempts = reliability.retry.attempts();
253 let mut attempt = 0;
254 loop {
255 let _permit = self.components.rate_limiter.admit(&request).await;
256 let result = self.components.transport.complete(request.clone()).await;
257 match result {
258 Ok(response) => return Ok(response),
259 Err(failure) => {
260 let failure = self.components.failure_classifier.classify(failure);
261 if attempt + 1 >= attempts || !failure.retryable {
262 return Err(failure);
263 }
264 let delay = reliability
265 .retry
266 .delay_for_attempt(attempt, failure.retry_after);
267 tracing::debug!(
268 target: "lash_core::provider::reliability",
269 provider = self.kind(),
270 attempt = attempt + 1,
271 max_attempts = attempts,
272 delay_ms = delay.as_millis() as u64,
273 err = %failure.message,
274 "provider call failed with retryable failure; sleeping before retry"
275 );
276 if let Some(events) = request.stream_events.as_ref() {
277 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
278 wait_seconds: delay.as_secs(),
279 attempt: (attempt + 1) as usize,
280 max_attempts: attempts as usize,
281 reason: failure.message.clone(),
282 });
283 }
284 tokio::time::sleep(delay).await;
285 attempt += 1;
286 }
287 }
288 }
289 }
290
291 pub fn to_spec(&self) -> ProviderSpec {
292 ProviderSpec {
293 kind: self.kind().to_string(),
294 config: self.components.state.serialize_config(),
295 }
296 }
297
298 pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
300 let m = model.trim();
301 if m.is_empty() {
302 return Err("model cannot be empty".to_string());
303 }
304 if m.contains(char::is_whitespace) {
305 return Err("model cannot contain whitespace".to_string());
306 }
307 Ok(())
308 }
309
310 pub fn resolve_model_spec(
312 &self,
313 model: &str,
314 catalog: &ModelCatalog,
315 ) -> Result<ResolvedModelSpec, String> {
316 self.validate_model_name(model)?;
317 let configured_model = model.trim();
318 let catalog_model_id = self.context_lookup_model(configured_model);
319 let Some(info) = catalog.get(&catalog_model_id).cloned() else {
320 return Err(format!(
321 "model `{}` has no context-window entry in the supplied model catalog for {}. Provide an explicit model spec or choose a cataloged model.",
322 configured_model,
323 self.kind(),
324 ));
325 };
326 Ok(ResolvedModelSpec {
327 configured_model: configured_model.to_string(),
328 resolved_model: self.resolve_model(configured_model),
329 catalog_model_id,
330 info,
331 })
332 }
333}
334
335impl std::fmt::Debug for ProviderHandle {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 self.components.fmt(f)
338 }
339}
340
341impl Clone for ProviderHandle {
342 fn clone(&self) -> Self {
343 Self {
344 components: self.components.clone(),
345 }
346 }
347}
348
349impl PartialEq for ProviderHandle {
350 fn eq(&self, other: &Self) -> bool {
351 self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
352 }
353}
354
355impl Eq for ProviderHandle {}
356
357impl Serialize for ProviderHandle {
358 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
359 self.to_spec().serialize(serializer)
360 }
361}
362
363impl<'de> Deserialize<'de> for ProviderHandle {
364 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
365 let spec = ProviderSpec::deserialize(deserializer)?;
366 build_provider(&spec)
367 .map(ProviderHandle::new)
368 .map_err(serde::de::Error::custom)
369 }
370}
371
372impl Default for ProviderHandle {
373 fn default() -> Self {
374 Self::new(UnconfiguredProvider::default().into_components())
375 }
376}
377
378#[derive(Clone, Debug, Default)]
384pub struct UnconfiguredProvider {
385 options: ProviderOptions,
386}
387
388impl UnconfiguredProvider {
389 fn into_components(self) -> ProviderComponents {
390 ProviderComponents::shared(self, Arc::new(StaticModelPolicy::new("")))
391 }
392}
393
394impl ProviderState for UnconfiguredProvider {
395 fn kind(&self) -> &'static str {
396 "unconfigured"
397 }
398
399 fn options(&self) -> ProviderOptions {
400 self.options.clone()
401 }
402
403 fn set_options(&mut self, options: ProviderOptions) {
404 self.options = options;
405 }
406
407 fn serialize_config(&self) -> serde_json::Value {
408 serde_json::Value::Object(Default::default())
409 }
410
411 fn clone_boxed(&self) -> Box<dyn ProviderState> {
412 Box::new(self.clone())
413 }
414}
415
416#[async_trait]
417impl ProviderTransport for UnconfiguredProvider {
418 async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
419 Err(LlmTransportError::new(
420 "no provider configured: host must install a provider factory and set SessionPolicy.provider before running a turn",
421 ))
422 }
423
424 fn clone_boxed(&self) -> Box<dyn ProviderTransport> {
425 Box::new(self.clone())
426 }
427}