Skip to main content

lash_core/provider/
traits.rs

1use super::support::*;
2use crate::LlmTerminalReason;
3
4/// A configured LLM backend: its identity, host-config serialization, its
5/// generation options, and the request transport.
6#[async_trait]
7pub trait Provider: Send + Sync + std::fmt::Debug {
8    fn kind(&self) -> &'static str;
9
10    fn options(&self) -> ProviderOptions;
11    fn set_options(&mut self, options: ProviderOptions);
12
13    /// Emit the provider-specific JSON body used by [`ProviderSpec`]. The
14    /// object must NOT contain a `type` field — [`ProviderSpec::Serialize`]
15    /// layers that on top.
16    fn serialize_config(&self) -> serde_json::Value;
17
18    async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
19
20    fn requires_streaming(&self) -> bool {
21        false
22    }
23
24    /// Release any host-visible transport resources this provider holds —
25    /// cached connections, pooled sockets, background tasks — sending whatever
26    /// graceful close a clean shutdown requires.
27    ///
28    /// Hosts call this before process exit so protocol niceties (e.g. WebSocket
29    /// Close frames) are sent rather than skipped by an abrupt drop. It takes
30    /// `&self` because a provider's reusable transport state lives behind its
31    /// own synchronization; a shared clone can therefore be closed from the
32    /// shutdown path. The default is a no-op: providers that hold no reusable
33    /// transport state have nothing to release.
34    async fn close(&self) -> Result<(), LlmTransportError> {
35        Ok(())
36    }
37
38    fn clone_boxed(&self) -> Box<dyn Provider>;
39}
40
41pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
42    fn supported_variants(&self, model: &str) -> &'static [&'static str];
43}
44
45pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
46    fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
47}
48
49#[derive(Clone, Debug, Default)]
50pub struct DefaultProviderFailureClassifier;
51
52impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
53    fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
54        if let Some(status) = failure.status.or_else(|| {
55            failure
56                .code
57                .as_deref()
58                .and_then(|code| code.parse::<u16>().ok())
59        }) {
60            failure.status = Some(status);
61            if failure.kind == ProviderFailureKind::Unknown {
62                failure.kind = ProviderFailureKind::Http;
63            }
64            failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
65            if status == 429 {
66                // Provider-side throttling. `Quota` + `retryable: true` is the
67                // combination `ProviderHandle`'s retry ladder defers to as a
68                // throttle; hard quota exhaustion (the text markers below)
69                // downgrades to `retryable: false`.
70                failure.kind = ProviderFailureKind::Quota;
71            } else if matches!(status, 401 | 403) {
72                failure.kind = ProviderFailureKind::Auth;
73            } else if matches!(status, 400 | 413 | 422) {
74                failure.kind = ProviderFailureKind::Validation;
75            }
76        } else if matches!(
77            failure.kind,
78            ProviderFailureKind::Transport | ProviderFailureKind::Timeout
79        ) {
80            failure.retryable = true;
81        }
82
83        let haystack = format!(
84            "{}\n{}\n{}",
85            failure.code.as_deref().unwrap_or_default(),
86            failure.message,
87            failure.raw.as_deref().unwrap_or_default()
88        )
89        .to_ascii_lowercase();
90        if is_context_overflow_text(&haystack) {
91            failure.kind = ProviderFailureKind::Validation;
92            failure.retryable = false;
93            failure.terminal_reason = LlmTerminalReason::ContextOverflow;
94        }
95        if haystack.contains("insufficient_quota")
96            || haystack.contains("usage_limit_reached")
97            || haystack.contains("usage_not_included")
98            || haystack.contains("quota")
99        {
100            failure.kind = ProviderFailureKind::Quota;
101            failure.retryable = false;
102        }
103        if haystack.contains("content_filter")
104            || haystack.contains("prohibited_content")
105            || haystack.contains("safety")
106            || haystack.contains("sensitive")
107        {
108            failure.terminal_reason = LlmTerminalReason::ContentFilter;
109        }
110        if haystack.contains("model_not_found")
111            || haystack.contains("unsupported model")
112            || haystack.contains("does not exist")
113        {
114            failure.kind = ProviderFailureKind::Unsupported;
115            failure.retryable = false;
116        }
117        failure
118    }
119}
120
121pub fn is_context_overflow_text(haystack: &str) -> bool {
122    let lower = haystack.to_ascii_lowercase();
123    if lower.contains("rate limit")
124        || lower.contains("rate_limit")
125        || lower.contains("ratelimit")
126        || lower.contains("throttle")
127        || lower.contains("throttling")
128        || lower.contains("too many requests")
129        || lower.contains("tokens per minute")
130        || lower.contains("tpm")
131        || lower.contains("quota")
132    {
133        return false;
134    }
135
136    lower.contains("context_length_exceeded")
137        || lower.contains("context_length")
138        || lower.contains("context length")
139        || lower.contains("maximum context")
140        || lower.contains("max context")
141        || lower.contains("context window")
142        || lower.contains("context window exceeds limit")
143        || lower.contains("exceeds the context window")
144        || lower.contains("prompt is too long")
145        || lower.contains("prompt too long")
146        || lower.contains("request_too_large")
147        || lower.contains("input token count") && lower.contains("exceeds the maximum")
148        || lower.contains("maximum prompt length is")
149        || lower.contains("reduce the length of the messages")
150        || lower.contains("maximum context length is")
151        || lower.contains("model's context length")
152        || lower.contains("models context length")
153        || lower.contains("exceeds the available context size")
154        || lower.contains("greater than the context length")
155        || lower.contains("exceeded model token limit")
156        || lower.contains("too large for model with")
157        || lower.contains("model_context_window_exceeded")
158        || lower.contains("too many tokens")
159        || lower.contains("exceeds the maximum number of tokens")
160        || lower.contains("exceeds maximum number of tokens")
161        || lower.contains("request too large")
162        || lower.contains("input is too long")
163        || lower.contains("token limit exceeded")
164        || lower.contains("reduce the length of the messages")
165        || lower.contains("reduce the length of your prompt")
166}