Skip to main content

lash_core/provider/
handle.rs

1use super::support::*;
2
3/// Component bundle returned by provider factories.
4#[derive(Debug)]
5pub struct ProviderComponents {
6    pub provider: Box<dyn Provider>,
7    pub model_policy: Arc<dyn ProviderModelPolicy>,
8    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9    pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13    pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14        let options = provider.options();
15        Self {
16            provider,
17            model_policy,
18            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20        }
21    }
22
23    /// Install a transport-level decorator that wraps the provider.
24    pub fn map_provider(
25        mut self,
26        map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27    ) -> Self {
28        self.provider = map(self.provider);
29        self
30    }
31
32    pub fn with_failure_classifier(
33        mut self,
34        classifier: Arc<dyn ProviderFailureClassifier>,
35    ) -> Self {
36        self.failure_classifier = classifier;
37        self
38    }
39
40    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41        let options = self.provider.options();
42        self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43            options.reliability.rate_limits,
44            clock,
45        ));
46        self
47    }
48}
49
50impl Clone for ProviderComponents {
51    fn clone(&self) -> Self {
52        Self {
53            provider: self.provider.clone_boxed(),
54            model_policy: Arc::clone(&self.model_policy),
55            failure_classifier: Arc::clone(&self.failure_classifier),
56            rate_limiter: Arc::clone(&self.rate_limiter),
57        }
58    }
59}
60
61/// Owning handle to provider components. This is an executable transport
62/// handle supplied by the host, not a persistence format.
63pub struct ProviderHandle {
64    components: ProviderComponents,
65}
66
67impl ProviderHandle {
68    pub fn new(components: ProviderComponents) -> Self {
69        Self { components }
70    }
71
72    pub fn unconfigured() -> Self {
73        Self::new(UnconfiguredProvider::default().into_components())
74    }
75
76    pub fn components(&self) -> &ProviderComponents {
77        &self.components
78    }
79
80    pub fn components_mut(&mut self) -> &mut ProviderComponents {
81        &mut self.components
82    }
83
84    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
85        self.components = self.components.with_clock(clock);
86        self
87    }
88
89    pub fn kind(&self) -> &'static str {
90        self.components.provider.kind()
91    }
92
93    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
94        self.components.model_policy.supported_variants(model)
95    }
96
97    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
98        let variants = self.supported_variants(model);
99        if variants.is_empty() {
100            return Err(format!(
101                "Model `{}` on {} does not expose configurable variants.",
102                model,
103                self.kind()
104            ));
105        }
106        if variants.contains(&variant) {
107            return Ok(());
108        }
109        Err(format!(
110            "Unsupported variant `{}` for `{}` on {}. Available: {}",
111            variant,
112            model,
113            self.kind(),
114            variants.join(", ")
115        ))
116    }
117
118    pub fn input_usage_excludes_cached_tokens(&self) -> bool {
119        self.components
120            .model_policy
121            .input_usage_excludes_cached_tokens()
122    }
123
124    pub fn options(&self) -> ProviderOptions {
125        self.components.provider.options()
126    }
127
128    pub fn set_options(&mut self, options: ProviderOptions) {
129        self.components
130            .rate_limiter
131            .configure(options.reliability.rate_limits.clone());
132        self.components.provider.set_options(options)
133    }
134
135    pub fn requires_streaming(&self) -> bool {
136        self.components.provider.requires_streaming()
137    }
138
139    pub async fn complete(
140        &mut self,
141        request: LlmRequest,
142    ) -> Result<LlmResponse, LlmTransportError> {
143        let reliability = self.options().reliability;
144        let attempts = reliability.retry.attempts();
145        let mut attempt = 0;
146        loop {
147            let _permit = self.components.rate_limiter.admit(&request).await;
148            let result = self.components.provider.complete(request.clone()).await;
149            match result {
150                Ok(response) => return Ok(response),
151                Err(failure) => {
152                    let failure = self.components.failure_classifier.classify(failure);
153                    if attempt + 1 >= attempts || !failure.retryable {
154                        return Err(failure);
155                    }
156                    let delay = reliability
157                        .retry
158                        .delay_for_attempt(attempt, failure.retry_after);
159                    tracing::debug!(
160                        target: "lash_core::provider::reliability",
161                        provider = self.kind(),
162                        attempt = attempt + 1,
163                        max_attempts = attempts,
164                        delay_ms = delay.as_millis() as u64,
165                        err = %failure.message,
166                        "provider call failed with retryable failure; sleeping before retry"
167                    );
168                    if let Some(events) = request.stream_events.as_ref() {
169                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
170                            wait_seconds: delay.as_secs(),
171                            attempt: (attempt + 1) as usize,
172                            max_attempts: attempts as usize,
173                            reason: failure.message.clone(),
174                        });
175                    }
176                    self.components.rate_limiter.clock().sleep(delay).await;
177                    attempt += 1;
178                }
179            }
180        }
181    }
182
183    pub fn to_spec(&self) -> ProviderSpec {
184        ProviderSpec {
185            kind: self.kind().to_string(),
186            config: self.components.provider.serialize_config(),
187        }
188    }
189
190    /// Validate model syntax only.
191    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
192        let m = model.trim();
193        if m.is_empty() {
194            return Err("model cannot be empty".to_string());
195        }
196        if m.contains(char::is_whitespace) {
197            return Err("model cannot contain whitespace".to_string());
198        }
199        Ok(())
200    }
201}
202
203impl std::fmt::Debug for ProviderHandle {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        self.components.fmt(f)
206    }
207}
208
209impl Clone for ProviderHandle {
210    fn clone(&self) -> Self {
211        Self {
212            components: self.components.clone(),
213        }
214    }
215}
216
217impl PartialEq for ProviderHandle {
218    fn eq(&self, other: &Self) -> bool {
219        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
220    }
221}
222
223impl Eq for ProviderHandle {}
224
225/// Placeholder provider used by runtime policy defaults before a host resolver
226/// installs the executable provider. Every transport-level method errors;
227/// calling code MUST replace this before executing a turn.
228#[derive(Clone, Debug, Default)]
229pub struct UnconfiguredProvider {
230    options: ProviderOptions,
231}
232
233impl UnconfiguredProvider {
234    fn into_components(self) -> ProviderComponents {
235        ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
236    }
237}
238
239#[async_trait]
240impl Provider for UnconfiguredProvider {
241    fn kind(&self) -> &'static str {
242        "unconfigured"
243    }
244
245    fn options(&self) -> ProviderOptions {
246        self.options.clone()
247    }
248
249    fn set_options(&mut self, options: ProviderOptions) {
250        self.options = options;
251    }
252
253    fn serialize_config(&self) -> serde_json::Value {
254        serde_json::Value::Object(Default::default())
255    }
256
257    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
258        Err(LlmTransportError::new(
259            "no provider configured: host must set SessionPolicy.provider before running a turn",
260        ))
261    }
262
263    fn clone_boxed(&self) -> Box<dyn Provider> {
264        Box::new(self.clone())
265    }
266}