Skip to main content

lash_core/provider/
traits.rs

1use super::support::*;
2use crate::LlmTerminalReason;
3
4/// A configured LLM backend: its identity, host-config serialization, its
5/// generation options, and the request transport.
6#[async_trait]
7pub trait Provider: Send + Sync + std::fmt::Debug {
8    fn kind(&self) -> &'static str;
9
10    fn options(&self) -> ProviderOptions;
11    fn set_options(&mut self, options: ProviderOptions);
12
13    /// Emit the provider-specific JSON body used by [`ProviderSpec`]. The
14    /// object must NOT contain a `type` field — [`ProviderSpec::Serialize`]
15    /// layers that on top.
16    fn serialize_config(&self) -> serde_json::Value;
17
18    async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
19
20    fn requires_streaming(&self) -> bool {
21        false
22    }
23
24    fn clone_boxed(&self) -> Box<dyn Provider>;
25}
26
27pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
28    fn supported_variants(&self, model: &str) -> &'static [&'static str];
29
30    fn input_usage_excludes_cached_tokens(&self) -> bool {
31        false
32    }
33}
34
35pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
36    fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
37}
38
39#[derive(Clone, Debug, Default)]
40pub struct DefaultProviderFailureClassifier;
41
42impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
43    fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
44        if let Some(status) = failure.status.or_else(|| {
45            failure
46                .code
47                .as_deref()
48                .and_then(|code| code.parse::<u16>().ok())
49        }) {
50            failure.status = Some(status);
51            if failure.kind == ProviderFailureKind::Unknown {
52                failure.kind = ProviderFailureKind::Http;
53            }
54            failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
55            if matches!(status, 401 | 403) {
56                failure.kind = ProviderFailureKind::Auth;
57            } else if matches!(status, 400 | 413 | 422) {
58                failure.kind = ProviderFailureKind::Validation;
59            }
60        } else if matches!(
61            failure.kind,
62            ProviderFailureKind::Transport | ProviderFailureKind::Timeout
63        ) {
64            failure.retryable = true;
65        }
66
67        let haystack = format!(
68            "{}\n{}\n{}",
69            failure.code.as_deref().unwrap_or_default(),
70            failure.message,
71            failure.raw.as_deref().unwrap_or_default()
72        )
73        .to_ascii_lowercase();
74        if is_context_overflow_text(&haystack) {
75            failure.kind = ProviderFailureKind::Validation;
76            failure.retryable = false;
77            failure.terminal_reason = LlmTerminalReason::ContextOverflow;
78        }
79        if haystack.contains("insufficient_quota")
80            || haystack.contains("usage_limit_reached")
81            || haystack.contains("usage_not_included")
82            || haystack.contains("quota")
83        {
84            failure.kind = ProviderFailureKind::Quota;
85            failure.retryable = false;
86        }
87        if haystack.contains("content_filter")
88            || haystack.contains("prohibited_content")
89            || haystack.contains("safety")
90            || haystack.contains("sensitive")
91        {
92            failure.terminal_reason = LlmTerminalReason::ContentFilter;
93        }
94        if haystack.contains("model_not_found")
95            || haystack.contains("unsupported model")
96            || haystack.contains("does not exist")
97        {
98            failure.kind = ProviderFailureKind::Unsupported;
99            failure.retryable = false;
100        }
101        failure
102    }
103}
104
105pub fn is_context_overflow_text(haystack: &str) -> bool {
106    let lower = haystack.to_ascii_lowercase();
107    if lower.contains("rate limit")
108        || lower.contains("rate_limit")
109        || lower.contains("ratelimit")
110        || lower.contains("throttle")
111        || lower.contains("throttling")
112        || lower.contains("too many requests")
113        || lower.contains("tokens per minute")
114        || lower.contains("tpm")
115        || lower.contains("quota")
116    {
117        return false;
118    }
119
120    lower.contains("context_length_exceeded")
121        || lower.contains("context_length")
122        || lower.contains("context length")
123        || lower.contains("maximum context")
124        || lower.contains("max context")
125        || lower.contains("context window")
126        || lower.contains("context window exceeds limit")
127        || lower.contains("exceeds the context window")
128        || lower.contains("prompt is too long")
129        || lower.contains("prompt too long")
130        || lower.contains("request_too_large")
131        || lower.contains("input token count") && lower.contains("exceeds the maximum")
132        || lower.contains("maximum prompt length is")
133        || lower.contains("reduce the length of the messages")
134        || lower.contains("maximum context length is")
135        || lower.contains("model's context length")
136        || lower.contains("models context length")
137        || lower.contains("exceeds the available context size")
138        || lower.contains("greater than the context length")
139        || lower.contains("exceeded model token limit")
140        || lower.contains("too large for model with")
141        || lower.contains("model_context_window_exceeded")
142        || lower.contains("too many tokens")
143        || lower.contains("exceeds the maximum number of tokens")
144        || lower.contains("exceeds maximum number of tokens")
145        || lower.contains("request too large")
146        || lower.contains("input is too long")
147        || lower.contains("token limit exceeded")
148        || lower.contains("reduce the length of the messages")
149        || lower.contains("reduce the length of your prompt")
150}