Skip to main content

lash_core/provider/
handle.rs

1use super::support::*;
2
3/// Component bundle returned by provider factories.
4#[derive(Debug)]
5pub struct ProviderComponents {
6    pub provider: Box<dyn Provider>,
7    pub model_policy: Arc<dyn ProviderModelPolicy>,
8    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9    pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13    pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14        let options = provider.options();
15        Self {
16            provider,
17            model_policy,
18            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20        }
21    }
22
23    /// Install a transport-level decorator that wraps the provider.
24    pub fn map_provider(
25        mut self,
26        map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27    ) -> Self {
28        self.provider = map(self.provider);
29        self
30    }
31
32    pub fn with_failure_classifier(
33        mut self,
34        classifier: Arc<dyn ProviderFailureClassifier>,
35    ) -> Self {
36        self.failure_classifier = classifier;
37        self
38    }
39
40    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41        let options = self.provider.options();
42        self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43            options.reliability.rate_limits,
44            clock,
45        ));
46        self
47    }
48}
49
50impl Clone for ProviderComponents {
51    fn clone(&self) -> Self {
52        Self {
53            provider: self.provider.clone_boxed(),
54            model_policy: Arc::clone(&self.model_policy),
55            failure_classifier: Arc::clone(&self.failure_classifier),
56            rate_limiter: Arc::clone(&self.rate_limiter),
57        }
58    }
59}
60
61/// Owning handle to provider components. This is an executable transport
62/// handle supplied by the host, not a persistence format.
63pub struct ProviderHandle {
64    components: ProviderComponents,
65}
66
67impl ProviderHandle {
68    pub fn new(components: ProviderComponents) -> Self {
69        Self { components }
70    }
71
72    pub fn components(&self) -> &ProviderComponents {
73        &self.components
74    }
75
76    pub fn components_mut(&mut self) -> &mut ProviderComponents {
77        &mut self.components
78    }
79
80    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
81        self.components = self.components.with_clock(clock);
82        self
83    }
84
85    pub fn kind(&self) -> &'static str {
86        self.components.provider.kind()
87    }
88
89    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
90        self.components.model_policy.supported_variants(model)
91    }
92
93    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
94        let variants = self.supported_variants(model);
95        if variants.is_empty() {
96            return Err(format!(
97                "Model `{}` on {} does not expose configurable variants.",
98                model,
99                self.kind()
100            ));
101        }
102        if variants.contains(&variant) {
103            return Ok(());
104        }
105        Err(format!(
106            "Unsupported variant `{}` for `{}` on {}. Available: {}",
107            variant,
108            model,
109            self.kind(),
110            variants.join(", ")
111        ))
112    }
113
114    pub fn input_usage_excludes_cached_tokens(&self) -> bool {
115        self.components
116            .model_policy
117            .input_usage_excludes_cached_tokens()
118    }
119
120    pub fn options(&self) -> ProviderOptions {
121        self.components.provider.options()
122    }
123
124    pub fn set_options(&mut self, options: ProviderOptions) {
125        self.components
126            .rate_limiter
127            .configure(options.reliability.rate_limits.clone());
128        self.components.provider.set_options(options)
129    }
130
131    pub fn requires_streaming(&self) -> bool {
132        self.components.provider.requires_streaming()
133    }
134
135    pub async fn complete(
136        &mut self,
137        request: LlmRequest,
138    ) -> Result<LlmResponse, LlmTransportError> {
139        let reliability = self.options().reliability;
140        let attempts = reliability.retry.attempts();
141        let mut attempt = 0;
142        loop {
143            let _permit = self.components.rate_limiter.admit(&request).await;
144            let result = self.components.provider.complete(request.clone()).await;
145            match result {
146                Ok(response) => return Ok(response),
147                Err(failure) => {
148                    let failure = self.components.failure_classifier.classify(failure);
149                    if attempt + 1 >= attempts || !failure.retryable {
150                        return Err(failure);
151                    }
152                    let delay = reliability
153                        .retry
154                        .delay_for_attempt(attempt, failure.retry_after);
155                    tracing::debug!(
156                        target: "lash_core::provider::reliability",
157                        provider = self.kind(),
158                        attempt = attempt + 1,
159                        max_attempts = attempts,
160                        delay_ms = delay.as_millis() as u64,
161                        err = %failure.message,
162                        "provider call failed with retryable failure; sleeping before retry"
163                    );
164                    if let Some(events) = request.stream_events.as_ref() {
165                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
166                            wait_seconds: delay.as_secs(),
167                            attempt: (attempt + 1) as usize,
168                            max_attempts: attempts as usize,
169                            reason: failure.message.clone(),
170                        });
171                    }
172                    self.components.rate_limiter.clock().sleep(delay).await;
173                    attempt += 1;
174                }
175            }
176        }
177    }
178
179    pub fn to_spec(&self) -> ProviderSpec {
180        ProviderSpec {
181            kind: self.kind().to_string(),
182            config: self.components.provider.serialize_config(),
183        }
184    }
185
186    /// Validate model syntax only.
187    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
188        let m = model.trim();
189        if m.is_empty() {
190            return Err("model cannot be empty".to_string());
191        }
192        if m.contains(char::is_whitespace) {
193            return Err("model cannot contain whitespace".to_string());
194        }
195        Ok(())
196    }
197}
198
199impl std::fmt::Debug for ProviderHandle {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        self.components.fmt(f)
202    }
203}
204
205impl Clone for ProviderHandle {
206    fn clone(&self) -> Self {
207        Self {
208            components: self.components.clone(),
209        }
210    }
211}
212
213impl PartialEq for ProviderHandle {
214    fn eq(&self, other: &Self) -> bool {
215        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
216    }
217}
218
219impl Eq for ProviderHandle {}
220
221impl Default for ProviderHandle {
222    fn default() -> Self {
223        Self::new(UnconfiguredProvider::default().into_components())
224    }
225}
226
227/// Placeholder provider used when `SessionPolicy::default()` is
228/// constructed without an explicit provider. Every transport-level
229/// method errors; calling code MUST replace this before executing a
230/// turn. It exists solely so `..Default::default()` shorthand keeps
231/// working in host code that always overrides the provider field.
232#[derive(Clone, Debug, Default)]
233pub struct UnconfiguredProvider {
234    options: ProviderOptions,
235}
236
237impl UnconfiguredProvider {
238    fn into_components(self) -> ProviderComponents {
239        ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
240    }
241}
242
243#[async_trait]
244impl Provider for UnconfiguredProvider {
245    fn kind(&self) -> &'static str {
246        "unconfigured"
247    }
248
249    fn options(&self) -> ProviderOptions {
250        self.options.clone()
251    }
252
253    fn set_options(&mut self, options: ProviderOptions) {
254        self.options = options;
255    }
256
257    fn serialize_config(&self) -> serde_json::Value {
258        serde_json::Value::Object(Default::default())
259    }
260
261    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
262        Err(LlmTransportError::new(
263            "no provider configured: host must set SessionPolicy.provider before running a turn",
264        ))
265    }
266
267    fn clone_boxed(&self) -> Box<dyn Provider> {
268        Box::new(self.clone())
269    }
270}