lash_core/provider/
handle.rs1use super::support::*;
2
3#[derive(Debug)]
5pub struct ProviderComponents {
6 pub provider: Box<dyn Provider>,
7 pub model_policy: Arc<dyn ProviderModelPolicy>,
8 pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9 pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13 pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14 let options = provider.options();
15 Self {
16 provider,
17 model_policy,
18 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20 }
21 }
22
23 pub fn map_provider(
25 mut self,
26 map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27 ) -> Self {
28 self.provider = map(self.provider);
29 self
30 }
31
32 pub fn with_failure_classifier(
33 mut self,
34 classifier: Arc<dyn ProviderFailureClassifier>,
35 ) -> Self {
36 self.failure_classifier = classifier;
37 self
38 }
39}
40
41impl Clone for ProviderComponents {
42 fn clone(&self) -> Self {
43 Self {
44 provider: self.provider.clone_boxed(),
45 model_policy: Arc::clone(&self.model_policy),
46 failure_classifier: Arc::clone(&self.failure_classifier),
47 rate_limiter: Arc::clone(&self.rate_limiter),
48 }
49 }
50}
51
52pub struct ProviderHandle {
55 components: ProviderComponents,
56}
57
58impl ProviderHandle {
59 pub fn new(components: ProviderComponents) -> Self {
60 Self { components }
61 }
62
63 pub fn components(&self) -> &ProviderComponents {
64 &self.components
65 }
66
67 pub fn components_mut(&mut self) -> &mut ProviderComponents {
68 &mut self.components
69 }
70
71 pub fn kind(&self) -> &'static str {
72 self.components.provider.kind()
73 }
74
75 pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
76 self.components.model_policy.supported_variants(model)
77 }
78
79 pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
80 let variants = self.supported_variants(model);
81 if variants.is_empty() {
82 return Err(format!(
83 "Model `{}` on {} does not expose configurable variants.",
84 model,
85 self.kind()
86 ));
87 }
88 if variants.contains(&variant) {
89 return Ok(());
90 }
91 Err(format!(
92 "Unsupported variant `{}` for `{}` on {}. Available: {}",
93 variant,
94 model,
95 self.kind(),
96 variants.join(", ")
97 ))
98 }
99
100 pub fn input_usage_excludes_cached_tokens(&self) -> bool {
101 self.components
102 .model_policy
103 .input_usage_excludes_cached_tokens()
104 }
105
106 pub fn options(&self) -> ProviderOptions {
107 self.components.provider.options()
108 }
109
110 pub fn set_options(&mut self, options: ProviderOptions) {
111 self.components
112 .rate_limiter
113 .configure(options.reliability.rate_limits.clone());
114 self.components.provider.set_options(options)
115 }
116
117 pub fn requires_streaming(&self) -> bool {
118 self.components.provider.requires_streaming()
119 }
120
121 pub async fn complete(
122 &mut self,
123 request: LlmRequest,
124 ) -> Result<LlmResponse, LlmTransportError> {
125 let reliability = self.options().reliability;
126 let attempts = reliability.retry.attempts();
127 let mut attempt = 0;
128 loop {
129 let _permit = self.components.rate_limiter.admit(&request).await;
130 let result = self.components.provider.complete(request.clone()).await;
131 match result {
132 Ok(response) => return Ok(response),
133 Err(failure) => {
134 let failure = self.components.failure_classifier.classify(failure);
135 if attempt + 1 >= attempts || !failure.retryable {
136 return Err(failure);
137 }
138 let delay = reliability
139 .retry
140 .delay_for_attempt(attempt, failure.retry_after);
141 tracing::debug!(
142 target: "lash_core::provider::reliability",
143 provider = self.kind(),
144 attempt = attempt + 1,
145 max_attempts = attempts,
146 delay_ms = delay.as_millis() as u64,
147 err = %failure.message,
148 "provider call failed with retryable failure; sleeping before retry"
149 );
150 if let Some(events) = request.stream_events.as_ref() {
151 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
152 wait_seconds: delay.as_secs(),
153 attempt: (attempt + 1) as usize,
154 max_attempts: attempts as usize,
155 reason: failure.message.clone(),
156 });
157 }
158 tokio::time::sleep(delay).await;
159 attempt += 1;
160 }
161 }
162 }
163 }
164
165 pub fn to_spec(&self) -> ProviderSpec {
166 ProviderSpec {
167 kind: self.kind().to_string(),
168 config: self.components.provider.serialize_config(),
169 }
170 }
171
172 pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
174 let m = model.trim();
175 if m.is_empty() {
176 return Err("model cannot be empty".to_string());
177 }
178 if m.contains(char::is_whitespace) {
179 return Err("model cannot contain whitespace".to_string());
180 }
181 Ok(())
182 }
183}
184
185impl std::fmt::Debug for ProviderHandle {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 self.components.fmt(f)
188 }
189}
190
191impl Clone for ProviderHandle {
192 fn clone(&self) -> Self {
193 Self {
194 components: self.components.clone(),
195 }
196 }
197}
198
199impl PartialEq for ProviderHandle {
200 fn eq(&self, other: &Self) -> bool {
201 self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
202 }
203}
204
205impl Eq for ProviderHandle {}
206
207impl Default for ProviderHandle {
208 fn default() -> Self {
209 Self::new(UnconfiguredProvider::default().into_components())
210 }
211}
212
213#[derive(Clone, Debug, Default)]
219pub struct UnconfiguredProvider {
220 options: ProviderOptions,
221}
222
223impl UnconfiguredProvider {
224 fn into_components(self) -> ProviderComponents {
225 ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
226 }
227}
228
229#[async_trait]
230impl Provider for UnconfiguredProvider {
231 fn kind(&self) -> &'static str {
232 "unconfigured"
233 }
234
235 fn options(&self) -> ProviderOptions {
236 self.options.clone()
237 }
238
239 fn set_options(&mut self, options: ProviderOptions) {
240 self.options = options;
241 }
242
243 fn serialize_config(&self) -> serde_json::Value {
244 serde_json::Value::Object(Default::default())
245 }
246
247 async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
248 Err(LlmTransportError::new(
249 "no provider configured: host must set SessionPolicy.provider before running a turn",
250 ))
251 }
252
253 fn clone_boxed(&self) -> Box<dyn Provider> {
254 Box::new(self.clone())
255 }
256}