Skip to main content

lash_core/provider/
handle.rs

1use super::support::*;
2
3/// Component bundle returned by provider factories.
4#[derive(Debug)]
5pub struct ProviderComponents {
6    pub provider: Box<dyn Provider>,
7    pub model_policy: Arc<dyn ProviderModelPolicy>,
8    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9    pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13    pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14        let options = provider.options();
15        Self {
16            provider,
17            model_policy,
18            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20        }
21    }
22
23    /// Install a transport-level decorator that wraps the provider.
24    pub fn map_provider(
25        mut self,
26        map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27    ) -> Self {
28        self.provider = map(self.provider);
29        self
30    }
31
32    pub fn with_failure_classifier(
33        mut self,
34        classifier: Arc<dyn ProviderFailureClassifier>,
35    ) -> Self {
36        self.failure_classifier = classifier;
37        self
38    }
39}
40
41impl Clone for ProviderComponents {
42    fn clone(&self) -> Self {
43        Self {
44            provider: self.provider.clone_boxed(),
45            model_policy: Arc::clone(&self.model_policy),
46            failure_classifier: Arc::clone(&self.failure_classifier),
47            rate_limiter: Arc::clone(&self.rate_limiter),
48        }
49    }
50}
51
52/// Owning handle to provider components. This is an executable transport
53/// handle supplied by the host, not a persistence format.
54pub struct ProviderHandle {
55    components: ProviderComponents,
56}
57
58impl ProviderHandle {
59    pub fn new(components: ProviderComponents) -> Self {
60        Self { components }
61    }
62
63    pub fn components(&self) -> &ProviderComponents {
64        &self.components
65    }
66
67    pub fn components_mut(&mut self) -> &mut ProviderComponents {
68        &mut self.components
69    }
70
71    pub fn kind(&self) -> &'static str {
72        self.components.provider.kind()
73    }
74
75    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
76        self.components.model_policy.supported_variants(model)
77    }
78
79    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
80        let variants = self.supported_variants(model);
81        if variants.is_empty() {
82            return Err(format!(
83                "Model `{}` on {} does not expose configurable variants.",
84                model,
85                self.kind()
86            ));
87        }
88        if variants.contains(&variant) {
89            return Ok(());
90        }
91        Err(format!(
92            "Unsupported variant `{}` for `{}` on {}. Available: {}",
93            variant,
94            model,
95            self.kind(),
96            variants.join(", ")
97        ))
98    }
99
100    pub fn input_usage_excludes_cached_tokens(&self) -> bool {
101        self.components
102            .model_policy
103            .input_usage_excludes_cached_tokens()
104    }
105
106    pub fn options(&self) -> ProviderOptions {
107        self.components.provider.options()
108    }
109
110    pub fn set_options(&mut self, options: ProviderOptions) {
111        self.components
112            .rate_limiter
113            .configure(options.reliability.rate_limits.clone());
114        self.components.provider.set_options(options)
115    }
116
117    pub fn requires_streaming(&self) -> bool {
118        self.components.provider.requires_streaming()
119    }
120
121    pub async fn complete(
122        &mut self,
123        request: LlmRequest,
124    ) -> Result<LlmResponse, LlmTransportError> {
125        let reliability = self.options().reliability;
126        let attempts = reliability.retry.attempts();
127        let mut attempt = 0;
128        loop {
129            let _permit = self.components.rate_limiter.admit(&request).await;
130            let result = self.components.provider.complete(request.clone()).await;
131            match result {
132                Ok(response) => return Ok(response),
133                Err(failure) => {
134                    let failure = self.components.failure_classifier.classify(failure);
135                    if attempt + 1 >= attempts || !failure.retryable {
136                        return Err(failure);
137                    }
138                    let delay = reliability
139                        .retry
140                        .delay_for_attempt(attempt, failure.retry_after);
141                    tracing::debug!(
142                        target: "lash_core::provider::reliability",
143                        provider = self.kind(),
144                        attempt = attempt + 1,
145                        max_attempts = attempts,
146                        delay_ms = delay.as_millis() as u64,
147                        err = %failure.message,
148                        "provider call failed with retryable failure; sleeping before retry"
149                    );
150                    if let Some(events) = request.stream_events.as_ref() {
151                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
152                            wait_seconds: delay.as_secs(),
153                            attempt: (attempt + 1) as usize,
154                            max_attempts: attempts as usize,
155                            reason: failure.message.clone(),
156                        });
157                    }
158                    tokio::time::sleep(delay).await;
159                    attempt += 1;
160                }
161            }
162        }
163    }
164
165    pub fn to_spec(&self) -> ProviderSpec {
166        ProviderSpec {
167            kind: self.kind().to_string(),
168            config: self.components.provider.serialize_config(),
169        }
170    }
171
172    /// Validate model syntax only.
173    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
174        let m = model.trim();
175        if m.is_empty() {
176            return Err("model cannot be empty".to_string());
177        }
178        if m.contains(char::is_whitespace) {
179            return Err("model cannot contain whitespace".to_string());
180        }
181        Ok(())
182    }
183}
184
185impl std::fmt::Debug for ProviderHandle {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        self.components.fmt(f)
188    }
189}
190
191impl Clone for ProviderHandle {
192    fn clone(&self) -> Self {
193        Self {
194            components: self.components.clone(),
195        }
196    }
197}
198
199impl PartialEq for ProviderHandle {
200    fn eq(&self, other: &Self) -> bool {
201        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
202    }
203}
204
205impl Eq for ProviderHandle {}
206
207impl Default for ProviderHandle {
208    fn default() -> Self {
209        Self::new(UnconfiguredProvider::default().into_components())
210    }
211}
212
213/// Placeholder provider used when `SessionPolicy::default()` is
214/// constructed without an explicit provider. Every transport-level
215/// method errors; calling code MUST replace this before executing a
216/// turn. It exists solely so `..Default::default()` shorthand keeps
217/// working in host code that always overrides the provider field.
218#[derive(Clone, Debug, Default)]
219pub struct UnconfiguredProvider {
220    options: ProviderOptions,
221}
222
223impl UnconfiguredProvider {
224    fn into_components(self) -> ProviderComponents {
225        ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
226    }
227}
228
229#[async_trait]
230impl Provider for UnconfiguredProvider {
231    fn kind(&self) -> &'static str {
232        "unconfigured"
233    }
234
235    fn options(&self) -> ProviderOptions {
236        self.options.clone()
237    }
238
239    fn set_options(&mut self, options: ProviderOptions) {
240        self.options = options;
241    }
242
243    fn serialize_config(&self) -> serde_json::Value {
244        serde_json::Value::Object(Default::default())
245    }
246
247    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
248        Err(LlmTransportError::new(
249            "no provider configured: host must set SessionPolicy.provider before running a turn",
250        ))
251    }
252
253    fn clone_boxed(&self) -> Box<dyn Provider> {
254        Box::new(self.clone())
255    }
256}