lash_core/provider/
traits.rs1use super::support::*;
2use crate::LlmTerminalReason;
3
4#[async_trait]
7pub trait Provider: Send + Sync + std::fmt::Debug {
8 fn kind(&self) -> &'static str;
9
10 fn options(&self) -> ProviderOptions;
11 fn set_options(&mut self, options: ProviderOptions);
12
13 fn serialize_config(&self) -> serde_json::Value;
17
18 async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
19
20 fn requires_streaming(&self) -> bool {
21 false
22 }
23
24 fn clone_boxed(&self) -> Box<dyn Provider>;
25}
26
27pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
28 fn supported_variants(&self, model: &str) -> &'static [&'static str];
29
30 fn input_usage_excludes_cached_tokens(&self) -> bool {
31 false
32 }
33}
34
35pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
36 fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
37}
38
39#[derive(Clone, Debug, Default)]
40pub struct DefaultProviderFailureClassifier;
41
42impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
43 fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
44 if let Some(status) = failure.status.or_else(|| {
45 failure
46 .code
47 .as_deref()
48 .and_then(|code| code.parse::<u16>().ok())
49 }) {
50 failure.status = Some(status);
51 if failure.kind == ProviderFailureKind::Unknown {
52 failure.kind = ProviderFailureKind::Http;
53 }
54 failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
55 if matches!(status, 401 | 403) {
56 failure.kind = ProviderFailureKind::Auth;
57 } else if matches!(status, 400 | 413 | 422) {
58 failure.kind = ProviderFailureKind::Validation;
59 }
60 } else if matches!(
61 failure.kind,
62 ProviderFailureKind::Transport | ProviderFailureKind::Timeout
63 ) {
64 failure.retryable = true;
65 }
66
67 let haystack = format!(
68 "{}\n{}\n{}",
69 failure.code.as_deref().unwrap_or_default(),
70 failure.message,
71 failure.raw.as_deref().unwrap_or_default()
72 )
73 .to_ascii_lowercase();
74 if is_context_overflow_text(&haystack) {
75 failure.kind = ProviderFailureKind::Validation;
76 failure.retryable = false;
77 failure.terminal_reason = LlmTerminalReason::ContextOverflow;
78 }
79 if haystack.contains("insufficient_quota")
80 || haystack.contains("usage_limit_reached")
81 || haystack.contains("usage_not_included")
82 || haystack.contains("quota")
83 {
84 failure.kind = ProviderFailureKind::Quota;
85 failure.retryable = false;
86 }
87 if haystack.contains("content_filter")
88 || haystack.contains("prohibited_content")
89 || haystack.contains("safety")
90 || haystack.contains("sensitive")
91 {
92 failure.terminal_reason = LlmTerminalReason::ContentFilter;
93 }
94 if haystack.contains("model_not_found")
95 || haystack.contains("unsupported model")
96 || haystack.contains("does not exist")
97 {
98 failure.kind = ProviderFailureKind::Unsupported;
99 failure.retryable = false;
100 }
101 failure
102 }
103}
104
105pub fn is_context_overflow_text(haystack: &str) -> bool {
106 let lower = haystack.to_ascii_lowercase();
107 if lower.contains("rate limit")
108 || lower.contains("rate_limit")
109 || lower.contains("ratelimit")
110 || lower.contains("throttle")
111 || lower.contains("throttling")
112 || lower.contains("too many requests")
113 || lower.contains("tokens per minute")
114 || lower.contains("tpm")
115 || lower.contains("quota")
116 {
117 return false;
118 }
119
120 lower.contains("context_length_exceeded")
121 || lower.contains("context_length")
122 || lower.contains("context length")
123 || lower.contains("maximum context")
124 || lower.contains("max context")
125 || lower.contains("context window")
126 || lower.contains("context window exceeds limit")
127 || lower.contains("exceeds the context window")
128 || lower.contains("prompt is too long")
129 || lower.contains("prompt too long")
130 || lower.contains("request_too_large")
131 || lower.contains("input token count") && lower.contains("exceeds the maximum")
132 || lower.contains("maximum prompt length is")
133 || lower.contains("reduce the length of the messages")
134 || lower.contains("maximum context length is")
135 || lower.contains("model's context length")
136 || lower.contains("models context length")
137 || lower.contains("exceeds the available context size")
138 || lower.contains("greater than the context length")
139 || lower.contains("exceeded model token limit")
140 || lower.contains("too large for model with")
141 || lower.contains("model_context_window_exceeded")
142 || lower.contains("too many tokens")
143 || lower.contains("exceeds the maximum number of tokens")
144 || lower.contains("exceeds maximum number of tokens")
145 || lower.contains("request too large")
146 || lower.contains("input is too long")
147 || lower.contains("token limit exceeded")
148 || lower.contains("reduce the length of the messages")
149 || lower.contains("reduce the length of your prompt")
150}