Skip to main content

lash_core/provider/
handle.rs

1use super::support::*;
2
3#[derive(Debug)]
4struct SharedProviderComponent<T> {
5    inner: Arc<Mutex<T>>,
6}
7
8impl<T> Clone for SharedProviderComponent<T> {
9    fn clone(&self) -> Self {
10        Self {
11            inner: Arc::clone(&self.inner),
12        }
13    }
14}
15
16impl<T> SharedProviderComponent<T> {
17    fn new(inner: Arc<Mutex<T>>) -> Self {
18        Self { inner }
19    }
20}
21
22impl<T> ProviderState for SharedProviderComponent<T>
23where
24    T: ProviderState + Clone + Send + Sync + std::fmt::Debug + 'static,
25{
26    fn kind(&self) -> &'static str {
27        self.inner.lock().expect("provider state lock").kind()
28    }
29
30    fn options(&self) -> ProviderOptions {
31        self.inner.lock().expect("provider state lock").options()
32    }
33
34    fn set_options(&mut self, options: ProviderOptions) {
35        self.inner
36            .lock()
37            .expect("provider state lock")
38            .set_options(options);
39    }
40
41    fn serialize_config(&self) -> serde_json::Value {
42        self.inner
43            .lock()
44            .expect("provider state lock")
45            .serialize_config()
46    }
47
48    fn clone_boxed(&self) -> Box<dyn ProviderState> {
49        Box::new(self.clone())
50    }
51}
52
53#[async_trait]
54impl<T> ProviderTransport for SharedProviderComponent<T>
55where
56    T: ProviderTransport + Clone + Send + Sync + std::fmt::Debug + 'static,
57{
58    async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
59        let mut provider = self.inner.lock().expect("provider transport lock").clone();
60        let result = provider.complete(request).await;
61        *self.inner.lock().expect("provider transport lock") = provider;
62        result
63    }
64
65    fn requires_streaming(&self) -> bool {
66        self.inner
67            .lock()
68            .expect("provider transport lock")
69            .requires_streaming()
70    }
71
72    fn clone_boxed(&self) -> Box<dyn ProviderTransport> {
73        Box::new(self.clone())
74    }
75}
76
77/// Component bundle returned by provider factories.
78#[derive(Debug)]
79pub struct ProviderComponents {
80    pub state: Box<dyn ProviderState>,
81    pub transport: Box<dyn ProviderTransport>,
82    pub model_policy: Arc<dyn ProviderModelPolicy>,
83    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
84    pub rate_limiter: Arc<ProviderRateLimiter>,
85}
86
87impl ProviderComponents {
88    pub fn new(
89        state: Box<dyn ProviderState>,
90        transport: Box<dyn ProviderTransport>,
91        model_policy: Arc<dyn ProviderModelPolicy>,
92    ) -> Self {
93        let options = state.options();
94        Self {
95            state,
96            transport,
97            model_policy,
98            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
99            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
100        }
101    }
102
103    pub fn shared<T>(provider: T, model_policy: Arc<dyn ProviderModelPolicy>) -> Self
104    where
105        T: ProviderState + ProviderTransport + Clone + Send + Sync + std::fmt::Debug + 'static,
106    {
107        let inner = Arc::new(Mutex::new(provider));
108        let options = inner.lock().expect("provider state lock").options();
109        Self {
110            state: Box::new(SharedProviderComponent::new(Arc::clone(&inner))),
111            transport: Box::new(SharedProviderComponent::new(inner)),
112            model_policy,
113            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
114            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
115        }
116    }
117
118    pub fn map_transport(
119        mut self,
120        map: impl FnOnce(Box<dyn ProviderTransport>) -> Box<dyn ProviderTransport>,
121    ) -> Self {
122        self.transport = map(self.transport);
123        self
124    }
125
126    pub fn with_failure_classifier(
127        mut self,
128        classifier: Arc<dyn ProviderFailureClassifier>,
129    ) -> Self {
130        self.failure_classifier = classifier;
131        self
132    }
133}
134
135impl Clone for ProviderComponents {
136    fn clone(&self) -> Self {
137        Self {
138            state: self.state.clone_boxed(),
139            transport: self.transport.clone_boxed(),
140            model_policy: Arc::clone(&self.model_policy),
141            failure_classifier: Arc::clone(&self.failure_classifier),
142            rate_limiter: Arc::clone(&self.rate_limiter),
143        }
144    }
145}
146
147/// Owning handle to provider components. Session state + config store this
148/// so we can add Clone / Serialize / Deserialize impls without running
149/// into orphan-rule conflicts.
150pub struct ProviderHandle {
151    components: ProviderComponents,
152}
153
154impl ProviderHandle {
155    pub fn new(components: ProviderComponents) -> Self {
156        Self { components }
157    }
158
159    pub fn components(&self) -> &ProviderComponents {
160        &self.components
161    }
162
163    pub fn components_mut(&mut self) -> &mut ProviderComponents {
164        &mut self.components
165    }
166
167    pub fn kind(&self) -> &'static str {
168        self.components.state.kind()
169    }
170
171    pub fn default_model(&self) -> &str {
172        self.components.model_policy.default_model()
173    }
174
175    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
176        self.components.model_policy.supported_variants(model)
177    }
178
179    pub fn default_model_variant(&self, model: &str) -> Option<&'static str> {
180        self.components.model_policy.default_model_variant(model)
181    }
182
183    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
184        let variants = self.supported_variants(model);
185        if variants.is_empty() {
186            return Err(format!(
187                "Model `{}` on {} does not expose configurable variants.",
188                model,
189                self.kind()
190            ));
191        }
192        if variants.contains(&variant) {
193            return Ok(());
194        }
195        Err(format!(
196            "Unsupported variant `{}` for `{}` on {}. Available: {}",
197            variant,
198            model,
199            self.kind(),
200            variants.join(", ")
201        ))
202    }
203
204    pub fn request_variant_config(
205        &self,
206        model: &str,
207        variant: &str,
208    ) -> Option<VariantRequestConfig> {
209        self.components
210            .model_policy
211            .request_variant_config(model, variant)
212    }
213
214    pub fn default_agent_model(&self, tier: &str) -> Option<AgentModelSelection> {
215        self.components.model_policy.default_agent_model(tier)
216    }
217
218    pub fn resolve_model(&self, model: &str) -> String {
219        self.components.model_policy.resolve_model(model)
220    }
221
222    pub fn context_lookup_model(&self, model: &str) -> String {
223        self.components.model_policy.context_lookup_model(model)
224    }
225
226    pub fn input_usage_excludes_cached_tokens(&self) -> bool {
227        self.components
228            .model_policy
229            .input_usage_excludes_cached_tokens()
230    }
231
232    pub fn options(&self) -> ProviderOptions {
233        self.components.state.options()
234    }
235
236    pub fn set_options(&mut self, options: ProviderOptions) {
237        self.components
238            .rate_limiter
239            .configure(options.reliability.rate_limits.clone());
240        self.components.state.set_options(options)
241    }
242
243    pub fn requires_streaming(&self) -> bool {
244        self.components.transport.requires_streaming()
245    }
246
247    pub async fn complete(
248        &mut self,
249        request: LlmRequest,
250    ) -> Result<LlmResponse, LlmTransportError> {
251        let reliability = self.options().reliability;
252        let attempts = reliability.retry.attempts();
253        let mut attempt = 0;
254        loop {
255            let _permit = self.components.rate_limiter.admit(&request).await;
256            let result = self.components.transport.complete(request.clone()).await;
257            match result {
258                Ok(response) => return Ok(response),
259                Err(failure) => {
260                    let failure = self.components.failure_classifier.classify(failure);
261                    if attempt + 1 >= attempts || !failure.retryable {
262                        return Err(failure);
263                    }
264                    let delay = reliability
265                        .retry
266                        .delay_for_attempt(attempt, failure.retry_after);
267                    tracing::debug!(
268                        target: "lash_core::provider::reliability",
269                        provider = self.kind(),
270                        attempt = attempt + 1,
271                        max_attempts = attempts,
272                        delay_ms = delay.as_millis() as u64,
273                        err = %failure.message,
274                        "provider call failed with retryable failure; sleeping before retry"
275                    );
276                    if let Some(events) = request.stream_events.as_ref() {
277                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
278                            wait_seconds: delay.as_secs(),
279                            attempt: (attempt + 1) as usize,
280                            max_attempts: attempts as usize,
281                            reason: failure.message.clone(),
282                        });
283                    }
284                    tokio::time::sleep(delay).await;
285                    attempt += 1;
286                }
287            }
288        }
289    }
290
291    pub fn to_spec(&self) -> ProviderSpec {
292        ProviderSpec {
293            kind: self.kind().to_string(),
294            config: self.components.state.serialize_config(),
295        }
296    }
297
298    /// Validate model syntax only.
299    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
300        let m = model.trim();
301        if m.is_empty() {
302            return Err("model cannot be empty".to_string());
303        }
304        if m.contains(char::is_whitespace) {
305            return Err("model cannot contain whitespace".to_string());
306        }
307        Ok(())
308    }
309
310    /// Resolve a model against an explicit catalog supplied by the host.
311    pub fn resolve_model_spec(
312        &self,
313        model: &str,
314        catalog: &ModelCatalog,
315    ) -> Result<ResolvedModelSpec, String> {
316        self.validate_model_name(model)?;
317        let configured_model = model.trim();
318        let catalog_model_id = self.context_lookup_model(configured_model);
319        let Some(info) = catalog.get(&catalog_model_id).cloned() else {
320            return Err(format!(
321                "model `{}` has no context-window entry in the supplied model catalog for {}. Provide an explicit model spec or choose a cataloged model.",
322                configured_model,
323                self.kind(),
324            ));
325        };
326        Ok(ResolvedModelSpec {
327            configured_model: configured_model.to_string(),
328            resolved_model: self.resolve_model(configured_model),
329            catalog_model_id,
330            info,
331        })
332    }
333}
334
335impl std::fmt::Debug for ProviderHandle {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        self.components.fmt(f)
338    }
339}
340
341impl Clone for ProviderHandle {
342    fn clone(&self) -> Self {
343        Self {
344            components: self.components.clone(),
345        }
346    }
347}
348
349impl PartialEq for ProviderHandle {
350    fn eq(&self, other: &Self) -> bool {
351        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
352    }
353}
354
355impl Eq for ProviderHandle {}
356
357impl Serialize for ProviderHandle {
358    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
359        self.to_spec().serialize(serializer)
360    }
361}
362
363impl<'de> Deserialize<'de> for ProviderHandle {
364    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
365        let spec = ProviderSpec::deserialize(deserializer)?;
366        build_provider(&spec)
367            .map(ProviderHandle::new)
368            .map_err(serde::de::Error::custom)
369    }
370}
371
372impl Default for ProviderHandle {
373    fn default() -> Self {
374        Self::new(UnconfiguredProvider::default().into_components())
375    }
376}
377
378/// Placeholder provider used when `SessionPolicy::default()` is
379/// constructed without an explicit provider. Every transport-level
380/// method errors; calling code MUST replace this before executing a
381/// turn. It exists solely so `..Default::default()` shorthand keeps
382/// working in host code that always overrides the provider field.
383#[derive(Clone, Debug, Default)]
384pub struct UnconfiguredProvider {
385    options: ProviderOptions,
386}
387
388impl UnconfiguredProvider {
389    fn into_components(self) -> ProviderComponents {
390        ProviderComponents::shared(self, Arc::new(StaticModelPolicy::new("")))
391    }
392}
393
394impl ProviderState for UnconfiguredProvider {
395    fn kind(&self) -> &'static str {
396        "unconfigured"
397    }
398
399    fn options(&self) -> ProviderOptions {
400        self.options.clone()
401    }
402
403    fn set_options(&mut self, options: ProviderOptions) {
404        self.options = options;
405    }
406
407    fn serialize_config(&self) -> serde_json::Value {
408        serde_json::Value::Object(Default::default())
409    }
410
411    fn clone_boxed(&self) -> Box<dyn ProviderState> {
412        Box::new(self.clone())
413    }
414}
415
416#[async_trait]
417impl ProviderTransport for UnconfiguredProvider {
418    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
419        Err(LlmTransportError::new(
420            "no provider configured: host must install a provider factory and set SessionPolicy.provider before running a turn",
421        ))
422    }
423
424    fn clone_boxed(&self) -> Box<dyn ProviderTransport> {
425        Box::new(self.clone())
426    }
427}