lash_core/provider/
traits.rs1use super::support::*;
2use crate::LlmTerminalReason;
3
4#[async_trait]
7pub trait Provider: Send + Sync + std::fmt::Debug {
8 fn kind(&self) -> &'static str;
9
10 fn options(&self) -> ProviderOptions;
11 fn set_options(&mut self, options: ProviderOptions);
12
13 fn serialize_config(&self) -> serde_json::Value;
17
18 async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
19
20 fn requires_streaming(&self) -> bool {
21 false
22 }
23
24 async fn close(&self) -> Result<(), LlmTransportError> {
35 Ok(())
36 }
37
38 fn clone_boxed(&self) -> Box<dyn Provider>;
39}
40
41pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
42 fn supported_variants(&self, model: &str) -> &'static [&'static str];
43}
44
45pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
46 fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
47}
48
49#[derive(Clone, Debug, Default)]
50pub struct DefaultProviderFailureClassifier;
51
52impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
53 fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
54 if let Some(status) = failure.status.or_else(|| {
55 failure
56 .code
57 .as_deref()
58 .and_then(|code| code.parse::<u16>().ok())
59 }) {
60 failure.status = Some(status);
61 if failure.kind == ProviderFailureKind::Unknown {
62 failure.kind = ProviderFailureKind::Http;
63 }
64 failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
65 if status == 429 {
66 failure.kind = ProviderFailureKind::Quota;
71 } else if matches!(status, 401 | 403) {
72 failure.kind = ProviderFailureKind::Auth;
73 } else if matches!(status, 400 | 413 | 422) {
74 failure.kind = ProviderFailureKind::Validation;
75 }
76 } else if matches!(
77 failure.kind,
78 ProviderFailureKind::Transport | ProviderFailureKind::Timeout
79 ) {
80 failure.retryable = true;
81 }
82
83 let haystack = format!(
84 "{}\n{}\n{}",
85 failure.code.as_deref().unwrap_or_default(),
86 failure.message,
87 failure.raw.as_deref().unwrap_or_default()
88 )
89 .to_ascii_lowercase();
90 if is_context_overflow_text(&haystack) {
91 failure.kind = ProviderFailureKind::Validation;
92 failure.retryable = false;
93 failure.terminal_reason = LlmTerminalReason::ContextOverflow;
94 }
95 if haystack.contains("insufficient_quota")
96 || haystack.contains("usage_limit_reached")
97 || haystack.contains("usage_not_included")
98 || haystack.contains("quota")
99 {
100 failure.kind = ProviderFailureKind::Quota;
101 failure.retryable = false;
102 }
103 if haystack.contains("content_filter")
104 || haystack.contains("prohibited_content")
105 || haystack.contains("safety")
106 || haystack.contains("sensitive")
107 {
108 failure.terminal_reason = LlmTerminalReason::ContentFilter;
109 }
110 if haystack.contains("model_not_found")
111 || haystack.contains("unsupported model")
112 || haystack.contains("does not exist")
113 {
114 failure.kind = ProviderFailureKind::Unsupported;
115 failure.retryable = false;
116 }
117 failure
118 }
119}
120
121pub fn is_context_overflow_text(haystack: &str) -> bool {
122 let lower = haystack.to_ascii_lowercase();
123 if lower.contains("rate limit")
124 || lower.contains("rate_limit")
125 || lower.contains("ratelimit")
126 || lower.contains("throttle")
127 || lower.contains("throttling")
128 || lower.contains("too many requests")
129 || lower.contains("tokens per minute")
130 || lower.contains("tpm")
131 || lower.contains("quota")
132 {
133 return false;
134 }
135
136 lower.contains("context_length_exceeded")
137 || lower.contains("context_length")
138 || lower.contains("context length")
139 || lower.contains("maximum context")
140 || lower.contains("max context")
141 || lower.contains("context window")
142 || lower.contains("context window exceeds limit")
143 || lower.contains("exceeds the context window")
144 || lower.contains("prompt is too long")
145 || lower.contains("prompt too long")
146 || lower.contains("request_too_large")
147 || lower.contains("input token count") && lower.contains("exceeds the maximum")
148 || lower.contains("maximum prompt length is")
149 || lower.contains("reduce the length of the messages")
150 || lower.contains("maximum context length is")
151 || lower.contains("model's context length")
152 || lower.contains("models context length")
153 || lower.contains("exceeds the available context size")
154 || lower.contains("greater than the context length")
155 || lower.contains("exceeded model token limit")
156 || lower.contains("too large for model with")
157 || lower.contains("model_context_window_exceeded")
158 || lower.contains("too many tokens")
159 || lower.contains("exceeds the maximum number of tokens")
160 || lower.contains("exceeds maximum number of tokens")
161 || lower.contains("request too large")
162 || lower.contains("input is too long")
163 || lower.contains("token limit exceeded")
164 || lower.contains("reduce the length of the messages")
165 || lower.contains("reduce the length of your prompt")
166}