lash_core/provider/
traits.rs1use super::support::*;
2use crate::LlmTerminalReason;
3
4pub trait ProviderState: Send + Sync + std::fmt::Debug {
5 fn kind(&self) -> &'static str;
6
7 fn options(&self) -> ProviderOptions;
8 fn set_options(&mut self, options: ProviderOptions);
9
10 fn serialize_config(&self) -> serde_json::Value;
14
15 fn clone_boxed(&self) -> Box<dyn ProviderState>;
16}
17
18#[async_trait]
19pub trait ProviderTransport: Send + Sync + std::fmt::Debug {
20 async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
21
22 fn requires_streaming(&self) -> bool {
23 false
24 }
25
26 fn clone_boxed(&self) -> Box<dyn ProviderTransport>;
27}
28
29pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
30 fn default_model(&self) -> &str;
31 fn supported_variants(&self, model: &str) -> &'static [&'static str];
32 fn default_model_variant(&self, model: &str) -> Option<&'static str>;
33 fn request_variant_config(&self, model: &str, variant: &str) -> Option<VariantRequestConfig>;
34 fn default_agent_model(&self, tier: &str) -> Option<AgentModelSelection>;
35
36 fn resolve_model(&self, model: &str) -> String {
37 model.to_string()
38 }
39
40 fn context_lookup_model(&self, model: &str) -> String {
41 model.to_string()
42 }
43
44 fn input_usage_excludes_cached_tokens(&self) -> bool {
45 false
46 }
47}
48
49pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
50 fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
51}
52
53#[derive(Clone, Debug, Default)]
54pub struct DefaultProviderFailureClassifier;
55
56impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
57 fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
58 if let Some(status) = failure.status.or_else(|| {
59 failure
60 .code
61 .as_deref()
62 .and_then(|code| code.parse::<u16>().ok())
63 }) {
64 failure.status = Some(status);
65 if failure.kind == ProviderFailureKind::Unknown {
66 failure.kind = ProviderFailureKind::Http;
67 }
68 failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
69 if matches!(status, 401 | 403) {
70 failure.kind = ProviderFailureKind::Auth;
71 } else if matches!(status, 400 | 413 | 422) {
72 failure.kind = ProviderFailureKind::Validation;
73 }
74 } else if matches!(
75 failure.kind,
76 ProviderFailureKind::Transport | ProviderFailureKind::Timeout
77 ) {
78 failure.retryable = true;
79 }
80
81 let haystack = format!(
82 "{}\n{}\n{}",
83 failure.code.as_deref().unwrap_or_default(),
84 failure.message,
85 failure.raw.as_deref().unwrap_or_default()
86 )
87 .to_ascii_lowercase();
88 if is_context_overflow_text(&haystack) {
89 failure.kind = ProviderFailureKind::Validation;
90 failure.retryable = false;
91 failure.terminal_reason = LlmTerminalReason::ContextOverflow;
92 }
93 if haystack.contains("insufficient_quota")
94 || haystack.contains("usage_limit_reached")
95 || haystack.contains("usage_not_included")
96 || haystack.contains("quota")
97 {
98 failure.kind = ProviderFailureKind::Quota;
99 failure.retryable = false;
100 }
101 if haystack.contains("content_filter")
102 || haystack.contains("prohibited_content")
103 || haystack.contains("safety")
104 || haystack.contains("sensitive")
105 {
106 failure.terminal_reason = LlmTerminalReason::ContentFilter;
107 }
108 if haystack.contains("model_not_found")
109 || haystack.contains("unsupported model")
110 || haystack.contains("does not exist")
111 {
112 failure.kind = ProviderFailureKind::Unsupported;
113 failure.retryable = false;
114 }
115 failure
116 }
117}
118
119pub fn is_context_overflow_text(haystack: &str) -> bool {
120 let lower = haystack.to_ascii_lowercase();
121 if lower.contains("rate limit")
122 || lower.contains("rate_limit")
123 || lower.contains("ratelimit")
124 || lower.contains("throttle")
125 || lower.contains("throttling")
126 || lower.contains("too many requests")
127 || lower.contains("tokens per minute")
128 || lower.contains("tpm")
129 || lower.contains("quota")
130 {
131 return false;
132 }
133
134 lower.contains("context_length_exceeded")
135 || lower.contains("context_length")
136 || lower.contains("context length")
137 || lower.contains("maximum context")
138 || lower.contains("max context")
139 || lower.contains("context window")
140 || lower.contains("context window exceeds limit")
141 || lower.contains("exceeds the context window")
142 || lower.contains("prompt is too long")
143 || lower.contains("prompt too long")
144 || lower.contains("request_too_large")
145 || lower.contains("input token count") && lower.contains("exceeds the maximum")
146 || lower.contains("maximum prompt length is")
147 || lower.contains("reduce the length of the messages")
148 || lower.contains("maximum context length is")
149 || lower.contains("model's context length")
150 || lower.contains("models context length")
151 || lower.contains("exceeds the available context size")
152 || lower.contains("greater than the context length")
153 || lower.contains("exceeded model token limit")
154 || lower.contains("too large for model with")
155 || lower.contains("model_context_window_exceeded")
156 || lower.contains("too many tokens")
157 || lower.contains("exceeds the maximum number of tokens")
158 || lower.contains("exceeds maximum number of tokens")
159 || lower.contains("request too large")
160 || lower.contains("input is too long")
161 || lower.contains("token limit exceeded")
162 || lower.contains("reduce the length of the messages")
163 || lower.contains("reduce the length of your prompt")
164}