Skip to main content

lash_core/provider/
traits.rs

1use super::support::*;
2use crate::LlmTerminalReason;
3
4pub trait ProviderState: Send + Sync + std::fmt::Debug {
5    fn kind(&self) -> &'static str;
6
7    fn options(&self) -> ProviderOptions;
8    fn set_options(&mut self, options: ProviderOptions);
9
10    /// Emit the provider-specific JSON body used by [`ProviderSpec`]. The
11    /// object must NOT contain a `type` field — [`ProviderSpec::Serialize`]
12    /// layers that on top.
13    fn serialize_config(&self) -> serde_json::Value;
14
15    fn clone_boxed(&self) -> Box<dyn ProviderState>;
16}
17
18#[async_trait]
19pub trait ProviderTransport: Send + Sync + std::fmt::Debug {
20    async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
21
22    fn requires_streaming(&self) -> bool {
23        false
24    }
25
26    fn clone_boxed(&self) -> Box<dyn ProviderTransport>;
27}
28
29pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
30    fn default_model(&self) -> &str;
31    fn supported_variants(&self, model: &str) -> &'static [&'static str];
32    fn default_model_variant(&self, model: &str) -> Option<&'static str>;
33    fn request_variant_config(&self, model: &str, variant: &str) -> Option<VariantRequestConfig>;
34    fn default_agent_model(&self, tier: &str) -> Option<AgentModelSelection>;
35
36    fn resolve_model(&self, model: &str) -> String {
37        model.to_string()
38    }
39
40    fn context_lookup_model(&self, model: &str) -> String {
41        model.to_string()
42    }
43
44    fn input_usage_excludes_cached_tokens(&self) -> bool {
45        false
46    }
47}
48
49pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
50    fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
51}
52
53#[derive(Clone, Debug, Default)]
54pub struct DefaultProviderFailureClassifier;
55
56impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
57    fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
58        if let Some(status) = failure.status.or_else(|| {
59            failure
60                .code
61                .as_deref()
62                .and_then(|code| code.parse::<u16>().ok())
63        }) {
64            failure.status = Some(status);
65            if failure.kind == ProviderFailureKind::Unknown {
66                failure.kind = ProviderFailureKind::Http;
67            }
68            failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
69            if matches!(status, 401 | 403) {
70                failure.kind = ProviderFailureKind::Auth;
71            } else if matches!(status, 400 | 413 | 422) {
72                failure.kind = ProviderFailureKind::Validation;
73            }
74        } else if matches!(
75            failure.kind,
76            ProviderFailureKind::Transport | ProviderFailureKind::Timeout
77        ) {
78            failure.retryable = true;
79        }
80
81        let haystack = format!(
82            "{}\n{}\n{}",
83            failure.code.as_deref().unwrap_or_default(),
84            failure.message,
85            failure.raw.as_deref().unwrap_or_default()
86        )
87        .to_ascii_lowercase();
88        if is_context_overflow_text(&haystack) {
89            failure.kind = ProviderFailureKind::Validation;
90            failure.retryable = false;
91            failure.terminal_reason = LlmTerminalReason::ContextOverflow;
92        }
93        if haystack.contains("insufficient_quota")
94            || haystack.contains("usage_limit_reached")
95            || haystack.contains("usage_not_included")
96            || haystack.contains("quota")
97        {
98            failure.kind = ProviderFailureKind::Quota;
99            failure.retryable = false;
100        }
101        if haystack.contains("content_filter")
102            || haystack.contains("prohibited_content")
103            || haystack.contains("safety")
104            || haystack.contains("sensitive")
105        {
106            failure.terminal_reason = LlmTerminalReason::ContentFilter;
107        }
108        if haystack.contains("model_not_found")
109            || haystack.contains("unsupported model")
110            || haystack.contains("does not exist")
111        {
112            failure.kind = ProviderFailureKind::Unsupported;
113            failure.retryable = false;
114        }
115        failure
116    }
117}
118
119pub fn is_context_overflow_text(haystack: &str) -> bool {
120    let lower = haystack.to_ascii_lowercase();
121    if lower.contains("rate limit")
122        || lower.contains("rate_limit")
123        || lower.contains("ratelimit")
124        || lower.contains("throttle")
125        || lower.contains("throttling")
126        || lower.contains("too many requests")
127        || lower.contains("tokens per minute")
128        || lower.contains("tpm")
129        || lower.contains("quota")
130    {
131        return false;
132    }
133
134    lower.contains("context_length_exceeded")
135        || lower.contains("context_length")
136        || lower.contains("context length")
137        || lower.contains("maximum context")
138        || lower.contains("max context")
139        || lower.contains("context window")
140        || lower.contains("context window exceeds limit")
141        || lower.contains("exceeds the context window")
142        || lower.contains("prompt is too long")
143        || lower.contains("prompt too long")
144        || lower.contains("request_too_large")
145        || lower.contains("input token count") && lower.contains("exceeds the maximum")
146        || lower.contains("maximum prompt length is")
147        || lower.contains("reduce the length of the messages")
148        || lower.contains("maximum context length is")
149        || lower.contains("model's context length")
150        || lower.contains("models context length")
151        || lower.contains("exceeds the available context size")
152        || lower.contains("greater than the context length")
153        || lower.contains("exceeded model token limit")
154        || lower.contains("too large for model with")
155        || lower.contains("model_context_window_exceeded")
156        || lower.contains("too many tokens")
157        || lower.contains("exceeds the maximum number of tokens")
158        || lower.contains("exceeds maximum number of tokens")
159        || lower.contains("request too large")
160        || lower.contains("input is too long")
161        || lower.contains("token limit exceeded")
162        || lower.contains("reduce the length of the messages")
163        || lower.contains("reduce the length of your prompt")
164}