Skip to main content

lash_core/provider/
handle.rs

1use super::support::*;
2
3/// Component bundle returned by provider factories.
4#[derive(Debug)]
5pub struct ProviderComponents {
6    pub provider: Box<dyn Provider>,
7    pub model_policy: Arc<dyn ProviderModelPolicy>,
8    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9    pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13    pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14        let options = provider.options();
15        Self {
16            provider,
17            model_policy,
18            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20        }
21    }
22
23    /// Install a transport-level decorator that wraps the provider.
24    pub fn map_provider(
25        mut self,
26        map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27    ) -> Self {
28        self.provider = map(self.provider);
29        self
30    }
31
32    pub fn with_failure_classifier(
33        mut self,
34        classifier: Arc<dyn ProviderFailureClassifier>,
35    ) -> Self {
36        self.failure_classifier = classifier;
37        self
38    }
39
40    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41        let options = self.provider.options();
42        self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43            options.reliability.rate_limits,
44            clock,
45        ));
46        self
47    }
48}
49
50impl Clone for ProviderComponents {
51    fn clone(&self) -> Self {
52        Self {
53            provider: self.provider.clone_boxed(),
54            model_policy: Arc::clone(&self.model_policy),
55            failure_classifier: Arc::clone(&self.failure_classifier),
56            rate_limiter: Arc::clone(&self.rate_limiter),
57        }
58    }
59}
60
61/// Owning handle to provider components. This is an executable transport
62/// handle supplied by the host, not a persistence format.
63pub struct ProviderHandle {
64    components: ProviderComponents,
65}
66
67impl ProviderHandle {
68    pub fn new(components: ProviderComponents) -> Self {
69        Self { components }
70    }
71
72    pub fn unconfigured() -> Self {
73        Self::new(UnconfiguredProvider::default().into_components())
74    }
75
76    pub fn components(&self) -> &ProviderComponents {
77        &self.components
78    }
79
80    pub fn components_mut(&mut self) -> &mut ProviderComponents {
81        &mut self.components
82    }
83
84    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
85        self.components = self.components.with_clock(clock);
86        self
87    }
88
89    pub fn kind(&self) -> &'static str {
90        self.components.provider.kind()
91    }
92
93    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
94        self.components.model_policy.supported_variants(model)
95    }
96
97    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
98        let variants = self.supported_variants(model);
99        if variants.is_empty() {
100            return Err(format!(
101                "Model `{}` on {} does not expose configurable variants.",
102                model,
103                self.kind()
104            ));
105        }
106        if variants.contains(&variant) {
107            return Ok(());
108        }
109        Err(format!(
110            "Unsupported variant `{}` for `{}` on {}. Available: {}",
111            variant,
112            model,
113            self.kind(),
114            variants.join(", ")
115        ))
116    }
117
118    pub fn options(&self) -> ProviderOptions {
119        self.components.provider.options()
120    }
121
122    pub fn set_options(&mut self, options: ProviderOptions) {
123        self.components
124            .rate_limiter
125            .configure(options.reliability.rate_limits.clone());
126        self.components.provider.set_options(options)
127    }
128
129    pub fn requires_streaming(&self) -> bool {
130        self.components.provider.requires_streaming()
131    }
132
133    pub async fn complete(
134        &mut self,
135        request: LlmRequest,
136    ) -> Result<LlmResponse, LlmTransportError> {
137        let reliability = self.options().reliability;
138        let attempts = reliability.retry.attempts();
139        let mut attempt = 0;
140        // Cumulative time already spent deferring to provider throttles
141        // without consuming attempts, bounded by the policy's budget.
142        let throttle_budget = Duration::from_millis(reliability.retry.throttle_wait_budget_ms);
143        let mut throttle_waited = Duration::ZERO;
144        loop {
145            let _permit = self.components.rate_limiter.admit(&request).await;
146            let result = self.components.provider.complete(request.clone()).await;
147            match result {
148                Ok(response) => return Ok(response),
149                Err(failure) => {
150                    let failure = self.components.failure_classifier.classify(failure);
151                    // Throttle deference: when the provider signals a throttle
152                    // (retryable `Quota`) AND states how long to back off
153                    // (`Retry-After`), honor the wait without consuming a
154                    // retry attempt — the provider is asking us to come back,
155                    // not failing. The courtesy is bounded: each deferred wait
156                    // charges at least `MIN_THROTTLE_BUDGET_CHARGE` against
157                    // the cumulative `throttle_wait_budget_ms`, and once the
158                    // budget is spent a throttle counts as an ordinary
159                    // retryable failure. A throttle WITHOUT `Retry-After`
160                    // never defers: there is no server-stated wait to honor,
161                    // so the normal backoff-and-count ladder applies.
162                    if failure.retryable
163                        && failure.kind == ProviderFailureKind::Quota
164                        && let Some(retry_after) = failure.retry_after
165                    {
166                        let wait = reliability.retry.cap_retry_after(retry_after);
167                        let charge = wait.max(MIN_THROTTLE_BUDGET_CHARGE);
168                        // Saturating: an absurd uncapped `Retry-After` must
169                        // overflow the budget check, not panic the ladder.
170                        if throttle_waited.saturating_add(charge) <= throttle_budget {
171                            throttle_waited += charge;
172                            tracing::debug!(
173                                target: "lash_core::provider::reliability",
174                                provider = self.kind(),
175                                attempt = attempt + 1,
176                                max_attempts = attempts,
177                                wait_ms = wait.as_millis() as u64,
178                                throttle_waited_ms = throttle_waited.as_millis() as u64,
179                                err = %failure.message,
180                                "provider throttled with retry-after; waiting without consuming a retry attempt"
181                            );
182                            if let Some(events) = request.stream_events.as_ref() {
183                                events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
184                                    wait_seconds: wait.as_secs(),
185                                    attempt: (attempt + 1) as usize,
186                                    max_attempts: attempts as usize,
187                                    reason: failure.message.clone(),
188                                });
189                            }
190                            self.components.rate_limiter.clock().sleep(wait).await;
191                            continue;
192                        }
193                    }
194                    if attempt + 1 >= attempts || !failure.retryable {
195                        return Err(failure);
196                    }
197                    let delay = reliability
198                        .retry
199                        .delay_for_attempt(attempt, failure.retry_after);
200                    tracing::debug!(
201                        target: "lash_core::provider::reliability",
202                        provider = self.kind(),
203                        attempt = attempt + 1,
204                        max_attempts = attempts,
205                        delay_ms = delay.as_millis() as u64,
206                        err = %failure.message,
207                        "provider call failed with retryable failure; sleeping before retry"
208                    );
209                    if let Some(events) = request.stream_events.as_ref() {
210                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
211                            wait_seconds: delay.as_secs(),
212                            attempt: (attempt + 1) as usize,
213                            max_attempts: attempts as usize,
214                            reason: failure.message.clone(),
215                        });
216                    }
217                    self.components.rate_limiter.clock().sleep(delay).await;
218                    attempt += 1;
219                }
220            }
221        }
222    }
223
224    /// Release the underlying provider's host-visible transport resources.
225    ///
226    /// This forwards to [`Provider::close`]. Hosts that want a graceful
227    /// transport shutdown (for example, sending WebSocket Close frames on
228    /// cached Codex sessions) retain a clone of the handle they hand to the
229    /// core and call this before process exit. Providers with no reusable
230    /// transport state close as a no-op.
231    pub async fn close(&self) -> Result<(), LlmTransportError> {
232        self.components.provider.close().await
233    }
234
235    pub fn to_spec(&self) -> ProviderSpec {
236        ProviderSpec {
237            kind: self.kind().to_string(),
238            config: self.components.provider.serialize_config(),
239        }
240    }
241
242    /// Validate model syntax only.
243    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
244        let m = model.trim();
245        if m.is_empty() {
246            return Err("model cannot be empty".to_string());
247        }
248        if m.contains(char::is_whitespace) {
249            return Err("model cannot contain whitespace".to_string());
250        }
251        Ok(())
252    }
253}
254
255impl std::fmt::Debug for ProviderHandle {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        self.components.fmt(f)
258    }
259}
260
261impl Clone for ProviderHandle {
262    fn clone(&self) -> Self {
263        Self {
264            components: self.components.clone(),
265        }
266    }
267}
268
269impl PartialEq for ProviderHandle {
270    fn eq(&self, other: &Self) -> bool {
271        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
272    }
273}
274
275impl Eq for ProviderHandle {}
276
277/// Placeholder provider used by runtime policy defaults before a host resolver
278/// installs the executable provider. Every transport-level method errors;
279/// calling code MUST replace this before executing a turn.
280#[derive(Clone, Debug, Default)]
281pub struct UnconfiguredProvider {
282    options: ProviderOptions,
283}
284
285impl UnconfiguredProvider {
286    fn into_components(self) -> ProviderComponents {
287        ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
288    }
289}
290
291#[async_trait]
292impl Provider for UnconfiguredProvider {
293    fn kind(&self) -> &'static str {
294        "unconfigured"
295    }
296
297    fn options(&self) -> ProviderOptions {
298        self.options.clone()
299    }
300
301    fn set_options(&mut self, options: ProviderOptions) {
302        self.options = options;
303    }
304
305    fn serialize_config(&self) -> serde_json::Value {
306        serde_json::Value::Object(Default::default())
307    }
308
309    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
310        Err(LlmTransportError::new(
311            "no provider configured: host must set SessionPolicy.provider before running a turn",
312        ))
313    }
314
315    fn clone_boxed(&self) -> Box<dyn Provider> {
316        Box::new(self.clone())
317    }
318}