phi_core/provider/traits.rs
1use crate::types::*;
2use async_trait::async_trait;
3use tokio::sync::mpsc;
4
5use super::model::ModelConfig;
6
7/*
8ARCHITECTURE: The Provider Layer
9
10This module defines the core abstraction for ALL LLM providers:
11
12 StreamProvider trait — the interface every provider must implement
13 StreamEvent enum — the event protocol sent through the channel
14 StreamConfig struct — the input to every provider call
15 ProviderError enum — the error taxonomy
16
17Why streaming via a channel instead of returning a Vec of events?
18Because streaming gives real-time UI updates. The user sees tokens as they arrive,
19not after the entire response. An mpsc channel is the natural async Rust primitive
20for this producer-consumer split.
21
22The dual-output pattern:
23 provider.stream(config, tx, cancel) → Future<Result<Message, Error>>
24 ↑ ↑
25 sends StreamEvents returns final Message
26 in real-time after stream completes
27
28The channel carries partial deltas; the return value carries the complete message.
29*/
30
31/// Events emitted during LLM streaming.
32/*
33ARCHITECTURE: `content_index` in delta events
34
35LLM responses can contain MULTIPLE content blocks in one message:
36 [Thinking("..."), Text("Hello"), ToolCall({id: "x", name: "bash", args: {...}})]
37
38`content_index` identifies WHICH block a delta belongs to.
39Without it, interleaved deltas from parallel content blocks would be ambiguous.
40
41Example for an extended-thinking response:
42 ThinkingDelta { content_index: 0, delta: "Let me " }
43 ThinkingDelta { content_index: 0, delta: "think..." }
44 TextDelta { content_index: 1, delta: "Here's " }
45 TextDelta { content_index: 1, delta: "my answer." }
46 ToolCallStart { content_index: 2, id: "call_1", name: "bash" }
47 ToolCallDelta { content_index: 2, delta: "{\"cmd\":" }
48 ToolCallEnd { content_index: 2 }
49 Done { message: (complete Message) }
50*/
51#[derive(Debug, Clone)]
52pub enum StreamEvent {
53 /// Stream started — the LLM has begun generating. Consumers should create a placeholder.
54 Start,
55 /// A text token from the response text.
56 TextDelta { content_index: usize, delta: String },
57 /// A chunk from the model's chain-of-thought (extended thinking mode only).
58 ThinkingDelta { content_index: usize, delta: String },
59 /// The LLM began a tool call — id and name are now known.
60 ToolCallStart {
61 content_index: usize,
62 id: String,
63 name: String,
64 },
65 /// A JSON fragment for a tool call's arguments (accumulate until ToolCallEnd).
66 ToolCallDelta { content_index: usize, delta: String },
67 /// The tool call's argument JSON is complete.
68 ToolCallEnd { content_index: usize },
69 /// Stream completed successfully. `message` is the final complete Message.
70 Done { message: Message },
71 /// Stream failed. `message` is a synthetic error Message with stop_reason=Error.
72 Error { message: Message },
73}
74
75/// Configuration for a streaming LLM call
76/*
77ARCHITECTURE: StreamConfig — the "envelope" passed into every provider call
78
79Every `StreamProvider::stream()` call receives exactly one `StreamConfig`.
80It bundles everything the provider needs to make one API request:
81 - model_config — the complete provider identity: id, api_key, base_url, compat flags
82 - messages / system_prompt / tools — the conversation payload
83 - thinking_level / max_tokens / temperature — per-call generation overrides
84 - cache_config — whether to send prompt-caching headers
85
86`model_config` is required (non-optional). Every provider reads at minimum
87`model_config.id` (model name) and `model_config.api_key` (auth credential).
88Providers with custom endpoints also read `model_config.base_url`, `model_config.headers`,
89and (for OpenAI-compat) `model_config.compat`.
90
91Why not pass individual arguments?
92 If `stream()` took 10 positional parameters it would be unergonomic and break
93 callers every time we added a field. A config struct is extensible: adding a
94 field is backward-compatible if the caller can use `Default::default()` for it.
95 Python analogy: kwargs dict passed to a function, or a dataclass payload.
96
97RUST QUIRK: `Option<u32>` and `Option<f32>` — "nullable" fields
98 Rust has no null. `Option<T>` is an explicit "maybe absent" wrapper:
99 `None` → caller didn't set a value; provider uses its own default
100 `Some(v)` → caller explicitly overrides the value
101 Python analogy: `max_tokens: int | None = None`
102*/
103#[derive(Debug, Clone)]
104pub struct StreamConfig {
105 /// Complete provider identity: model id, api_key, base_url, compat flags, cost rates.
106 /// All providers read `model_config.id` and `model_config.api_key`; most also read
107 /// `model_config.base_url` and `model_config.headers`.
108 pub model_config: ModelConfig,
109 pub system_prompt: String,
110 pub messages: Vec<Message>,
111 pub tools: Vec<ToolDefinition>,
112 pub thinking_level: ThinkingLevel,
113 pub max_tokens: Option<u32>, // overrides model_config.max_tokens when Some
114 pub temperature: Option<f32>,
115 /// Prompt caching configuration. Default: enabled with auto strategy.
116 pub cache_config: CacheConfig,
117 /// Desired output shape. `Text` (the default) preserves the historical behaviour;
118 /// `JsonObject` / `JsonSchema` request constrained JSON output from providers that
119 /// support it natively (OpenAI, Google) or via tool-call emulation (Anthropic).
120 /// Bedrock surfaces `ProviderError::SchemaMismatch` when set on a non-Anthropic
121 /// foundation model that lacks structured-output support. See the capability matrix
122 /// in `docs/specs/developer/provider.md` for per-provider coverage.
123 pub response_format: ResponseFormat,
124}
125
126/// Desired output shape for an LLM call.
127///
128/// Default `Text` matches the historical free-form text behaviour. `JsonObject`
129/// constrains output to syntactically valid JSON with no schema enforcement;
130/// `JsonSchema` adds strict-shape enforcement when the provider supports it.
131///
132/// `Message::extract_json::<T>()` is the recommended way to parse the resulting
133/// assistant message back into a typed value — it handles both native JSON-mode
134/// output (text content is JSON) and tool-call emulation (arguments JSON of a
135/// well-known synthetic tool) uniformly.
136#[derive(Debug, Clone, Default)]
137pub enum ResponseFormat {
138 /// Free-form text. Default; providers ignore the field entirely.
139 #[default]
140 Text,
141 /// Constrain output to valid JSON; no schema enforcement.
142 ///
143 /// Maps to:
144 /// - OpenAI Completions / Responses / Azure: `response_format: { type: "json_object" }`
145 /// - Google GenAI / Vertex: `responseMimeType: "application/json"`
146 /// - Anthropic / Bedrock-Anthropic: a synthetic `respond_json` tool with an
147 /// empty-shape schema; the LLM is forced to call it with its answer
148 /// - Bedrock non-Anthropic: not supported — provider returns `SchemaMismatch`
149 JsonObject,
150 /// Strict JSON Schema enforcement. The schema is forwarded to the provider when
151 /// supported natively; otherwise emulated via tool-call shape (Anthropic).
152 JsonSchema {
153 /// JSON Schema (Draft 2020-12 compatible) describing the expected output.
154 schema: serde_json::Value,
155 /// Human-readable schema name (some providers use this in error messages).
156 name: String,
157 /// Whether the provider should enforce strict shape (no extra fields).
158 /// Some providers' strict mode disables defaults — fall back to non-strict
159 /// if unsupported.
160 strict: bool,
161 },
162}
163
164/// Tool definition sent to the LLM (schema only, no execute fn)
165/*
166ARCHITECTURE: ToolDefinition — the schema half of a tool
167
168Every tool has two sides:
169 1. `AgentTool` (types.rs) — the Rust struct that EXECUTES the tool (has code)
170 2. `ToolDefinition` (here) — the JSON schema that gets SENT TO THE LLM
171
172When we call `provider.stream(config, ...)`, only `ToolDefinition` goes to the API.
173The LLM never sees executable code — it only sees name/description/parameters so it
174can decide whether to call the tool and how to format the arguments.
175
176The separation exists because:
177 - The provider layer is pure I/O; it doesn't execute tools
178 - ToolDefinition is serializable (goes over the wire); AgentTool is not
179 - `agent_loop.rs` bridges them: it converts AgentTool → ToolDefinition before
180 calling stream(), then receives ToolCall content and finds the matching AgentTool
181
182RUST QUIRK: `serde_json::Value` — a dynamically typed JSON tree
183 JSON doesn't map to a fixed Rust type. `serde_json::Value` is an enum that
184 can hold any valid JSON structure:
185 Value::Object(Map<String, Value>)
186 Value::Array(Vec<Value>)
187 Value::String(String)
188 Value::Number(Number) — wraps i64/u64/f64
189 Value::Bool(bool)
190 Value::Null
191
192 Tool parameters are represented as a JSON Schema object — a dynamic shape
193 that varies per tool — so `serde_json::Value` is the right type here.
194
195RUST QUIRK: `#[derive(Serialize, Deserialize)]`
196 Requires the `serde` + `serde_json` crates.
197 `Serialize` → can convert this struct TO JSON (for sending to APIs)
198 `Deserialize` → can reconstruct this struct FROM JSON (for round-tripping)
199 Python analogy: combining json.dumps() and json.loads() support automatically.
200*/
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ToolDefinition {
203 pub name: String,
204 pub description: String,
205 /// JSON Schema object describing the tool's parameters.
206 /// LLMs use this schema to know what arguments to pass when calling the tool.
207 pub parameters: serde_json::Value,
208}
209
210use serde::{Deserialize, Serialize};
211
212/// The core provider trait. Implement this for each LLM backend.
213/*
214ARCHITECTURE: StreamProvider — the single extension point for ALL LLM backends
215
216Every LLM backend (Anthropic, OpenAI, Google, Bedrock, Azure, ...) implements
217this one trait. The rest of the codebase interacts only with `&dyn StreamProvider`
218— it never knows which concrete backend is being used at runtime.
219
220This is the "Strategy" pattern: swap the provider, keep everything else constant.
221
222The dual-output contract:
223 1. `tx` (mpsc channel) — sends StreamEvents in real time as they arrive
224 Consumers subscribe to this channel to update the UI with partial tokens.
225 2. Return value `Result<Message, ProviderError>` — the fully assembled Message
226 Only available after the stream completes. Contains the complete response.
227
228Why both? Because `Message` is only complete when the stream ends, but the UI
229needs to show tokens as they arrive (low latency). The channel handles the
230"streaming display" concern; the return value handles the "final record" concern.
231
232RUST QUIRK: `Send + Sync` trait bounds — thread safety requirements
233 `Send` → values of this type can be transferred across thread boundaries
234 `Sync` → references (&T) can be shared across thread boundaries simultaneously
235
236 Why required on StreamProvider?
237 The provider is stored as `Arc<dyn StreamProvider>` and accessed from
238 async tasks that may run on different OS threads in the tokio thread pool.
239 Without `Send + Sync`, the compiler would reject this as unsafe.
240
241 What do they PREVENT?
242 `Rc<T>` is not `Send` (non-atomic reference count, unsafe to move between threads)
243 `RefCell<T>` is not `Sync` (non-atomic borrow flag, unsafe to share between threads)
244 The bounds ensure implementations can't accidentally use these.
245
246RUST QUIRK: `#[async_trait]` — async methods in traits
247 Rust's native trait system doesn't support `async fn` in traits (as of stable Rust)
248 because `async fn` returns an anonymous `impl Future<Output=T>` — each
249 implementation would return a DIFFERENT type, violating the uniform vtable layout
250 required by `dyn Trait`.
251
252 `#[async_trait]` is a procedural macro from the `async-trait` crate that desugars:
253 async fn stream(&self, ...) -> Result<...>
254 into:
255 fn stream(&self, ...) -> Pin<Box<dyn Future<Output=Result<...>> + Send + '_>>
256
257 The `Pin<Box<dyn Future...>>` is a heap-allocated, type-erased future — same type
258 for every implementation, so the vtable works. The `Send` bound ensures the future
259 itself is thread-safe (can be awaited on any tokio thread).
260 Python analogy: an abstract async method that subclasses override.
261*/
262#[async_trait]
263pub trait StreamProvider: Send + Sync {
264 /// Short, stable identifier for this provider type.
265 ///
266 /// Used as the `provider_id` component of auto-derived `loop_id` signatures:
267 /// `loop_id = "{session_id}.{provider_id}.{model_slug}.{N}"`
268 ///
269 /// Return a lowercase ASCII string with no spaces (e.g. `"anthropic"`, `"openai"`, `"google"`).
270 /// Custom providers should return a unique, stable string.
271 fn provider_id(&self) -> &str;
272
273 /// Stream a completion. Send events through `tx` in real time.
274 /// Returns the final, fully-assembled assistant `Message` after the stream ends.
275 ///
276 /// Implementors must:
277 /// - Send `StreamEvent::Start` when the stream begins
278 /// - Send `StreamEvent::TextDelta` / `ThinkingDelta` / `ToolCall*` as tokens arrive
279 /// - Send `StreamEvent::Done { message }` or `StreamEvent::Error { message }` at the end
280 /// - Honor `cancel` — stop early and return `Err(ProviderError::Cancelled)`
281 async fn stream(
282 &self,
283 config: StreamConfig, // ALL REQUEST PARAMS — model, messages, tools, auth (bundled to avoid 10-arg signature)
284 tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — push StreamEvents here in real-time as tokens arrive
285 cancel: tokio_util::sync::CancellationToken, // ABORT — check this; return Err(Cancelled) if triggered
286 ) -> Result<Message, ProviderError>; // final fully-assembled Message (only available after stream ends)
287}
288
289/*
290RUST QUIRK: `thiserror::Error` derive — auto-implementing `std::error::Error`
291
292`std::error::Error` is the standard Rust error trait. Manually implementing it
293requires also implementing `Display` and optionally `source()`. Boilerplate.
294
295`thiserror` is a macro crate that generates all three from annotations:
296 `#[error("API error: {0}")]` on a tuple variant:
297 → Display impl: format!("API error: {}", self.0)
298 → The {0} refers to the first (unnamed) field of the tuple variant.
299
300 `#[error("Rate limited, retry after {retry_after_ms:?}ms")]` on a struct variant:
301 → Display impl using the named field `retry_after_ms`
302 → {:?} uses Debug formatting on the Option<u64> → "Some(60000)" or "None"
303
304 `#[derive(thiserror::Error)]` also requires `#[derive(Debug)]` (already present).
305
306Python analogy:
307 class ProviderError(Exception):
308 pass
309 class ApiError(ProviderError):
310 def __str__(self): return f"API error: {self.message}"
311
312ARCHITECTURE: ProviderError variants — the error taxonomy
313
314Variants map to HTTP status codes + semantic categories:
315 `Api` — 4xx/5xx errors that are NOT special (bad request, server error)
316 `Network` — Transport failures: connection refused, timeout, TLS error
317 `Auth` — 401/403 — bad or missing API key
318 `RateLimited` — 429 — too many requests; includes optional server-specified delay
319 `ContextOverflow`— input too long for the model's context window
320 `Cancelled` — CancellationToken was triggered by the caller
321 `Other` — catch-all for anything that doesn't fit
322
323Why a flat enum rather than a hierarchy?
324 The agent loop has a simple decision tree:
325 is_retryable() → retry (RateLimited, Network)
326 is_context_overflow() → try compaction, then give up
327 is Cancelled → clean shutdown
328 everything else → surface to caller as failure
329 A flat enum with methods makes this dispatch cheap and exhaustive.
330*/
331#[derive(Debug, thiserror::Error)]
332pub enum ProviderError {
333 /// A non-transient API error (bad request, server error, etc.).
334 #[error("API error: {0}")]
335 Api(String),
336 /// Network/transport failure — connection refused, timeout, TLS error, etc.
337 #[error("Network error: {0}")]
338 Network(String),
339 /// Authentication failure — bad or missing API key (HTTP 401/403).
340 #[error("Auth error: {0}")]
341 Auth(String),
342 /// Rate limit hit (HTTP 429). `retry_after_ms` is the server-specified delay if present.
343 #[error("Rate limited, retry after {retry_after_ms:?}ms")]
344 RateLimited { retry_after_ms: Option<u64> },
345 /// Input exceeds the model's context window. Caller should compact and retry.
346 #[error("Context overflow: {message}")]
347 ContextOverflow { message: String },
348 /// The caller cancelled the request via `CancellationToken`.
349 #[error("Cancelled")]
350 Cancelled,
351 /// Catch-all for errors that don't fit another category.
352 #[error("{0}")]
353 Other(String),
354 /// Returned by structured-output paths when the requested `ResponseFormat` is
355 /// unsupported by the provider, or when extracting JSON from a response fails
356 /// (`Message::extract_json::<T>()` returns this on parse / deserialise errors).
357 #[error("Schema mismatch: {reason}")]
358 SchemaMismatch { reason: String },
359}
360
361impl ProviderError {
362 /// Classify an HTTP error response into the appropriate error variant.
363 ///
364 /// Detects context overflow, rate limits, auth errors, and general API errors
365 /// from the HTTP status code and response body.
366 pub fn classify(
367 status: u16, // HTTP status code — 429, 401, 403, 400, 413, 5xx
368 message: &str, // response body text — checked for overflow phrases; may be empty (Cerebras quirk)
369 ) -> Self {
370 if is_context_overflow(status, message) {
371 Self::ContextOverflow {
372 message: message.to_string(),
373 }
374 } else if status == 429 {
375 Self::RateLimited {
376 retry_after_ms: None,
377 }
378 } else if status == 401 || status == 403 {
379 Self::Auth(message.to_string())
380 } else {
381 Self::Api(message.to_string())
382 }
383 }
384
385 /// Returns true if this error indicates a context overflow.
386 pub fn is_context_overflow(&self) -> bool {
387 matches!(self, Self::ContextOverflow { .. })
388 }
389}
390
391/// Known phrases that indicate context overflow across LLM providers.
392///
393/// Covers: Anthropic, OpenAI, Google Gemini, AWS Bedrock, xAI, Groq,
394/// OpenRouter, llama.cpp, LM Studio, MiniMax, Kimi, GitHub Copilot,
395/// and generic patterns.
396/*
397ARCHITECTURE: Centralised overflow detection — one place, all providers
398
399Context overflow is a universal problem: every LLM has a finite token window.
400But every provider expresses overflow differently:
401 Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
402 OpenAI: "Your input exceeds the context window of this model"
403 Gemini: "The input token count (1196265) exceeds the maximum number of tokens allowed"
404 Groq: "Please reduce the length of the messages or completion"
405 ...
406
407Centralising these phrases in ONE constant means:
408 1. Every provider uses `ProviderError::classify()` — no duplication
409 2. Adding a new provider = adding one phrase to this array
410 3. The agent loop only checks `is_context_overflow()` — doesn't know which provider
411
412RUST QUIRK: `const OVERFLOW_PHRASES: &[&str]` — a compile-time constant
413
414`const` — value is inlined at compile time (not a runtime allocation).
415 The array lives in the binary's read-only data segment (`.rodata`).
416 Python analogy: a module-level tuple of strings, but truly immutable.
417
418`&[&str]` — a slice of string slices (two levels of reference):
419 `&str` — a reference to a string (UTF-8 bytes, stored somewhere)
420 `&[T]` — a "fat pointer" to a contiguous sequence of T (pointer + length)
421 `&[&str]` — a reference to a sequence of `&str` items
422
423 The string literals ("prompt is too long") are `&'static str` — they live
424 forever in the binary, so no allocation, no lifetime issues.
425
426Why not `Vec<String>`?
427 `Vec<String>` is heap-allocated and built at runtime. A `const &[&str]` is
428 zero runtime cost — the data is baked into the binary at compile time.
429
430RUST QUIRK: `&[&str]` as the type for array literals
431 You might expect `const X: [&str; 14] = [...]` (fixed-size array), but
432 `&[&str]` (slice reference) is more ergonomic — the length is encoded in the
433 fat pointer, not the type. Functions that iterate over it don't need to be
434 generic over the array length.
435*/
436const OVERFLOW_PHRASES: &[&str] = &[
437 "prompt is too long", // Anthropic
438 "input is too long", // AWS Bedrock
439 "exceeds the context window", // OpenAI (Completions & Responses)
440 "exceeds the maximum", // Google Gemini ("input token count exceeds the maximum")
441 "maximum prompt length", // xAI
442 "reduce the length of the messages", // Groq
443 "maximum context length", // OpenRouter
444 "exceeds the limit of", // GitHub Copilot
445 "exceeds the available context size", // llama.cpp
446 "greater than the context length", // LM Studio
447 "context window exceeds limit", // MiniMax
448 "exceeded model token limit", // Kimi
449 "context length exceeded", // Generic
450 "context_length_exceeded", // Generic (underscore variant)
451 "too many tokens", // Generic
452 "token limit exceeded", // Generic
453];
454
455/// Check if an error message indicates context overflow (for use by types.rs).
456/*
457RUST QUIRK: `pub(crate)` — "public within this crate only"
458
459`pub(crate)` sits between fully public (`pub`) and private (default).
460 - `pub` → anyone importing this crate can call it
461 - `pub(crate)` → only modules within THIS crate can call it
462 - (no modifier) → only this module can call it
463
464`is_context_overflow_message` is needed by `types.rs` (to classify SSE errors
465embedded in the stream — not just HTTP status errors) but shouldn't be part of
466the public library API. `pub(crate)` is the right scope.
467
468RUST QUIRK: `.iter().any(|phrase| lower.contains(phrase))`
469 `.iter()` — returns an iterator over `&&&str` (references to &str elements)
470 `.any(predicate)` — short-circuits: returns `true` as soon as predicate is true
471 `lower.contains(phrase)` — substring search (case-sensitive, but `lower` is already
472 lowercased so we get case-insensitive matching for free)
473 Python analogy: `any(phrase in lower for phrase in OVERFLOW_PHRASES)`
474*/
475pub(crate) fn is_context_overflow_message(message: &str) -> bool {
476 let lower = message.to_lowercase(); // normalize to lowercase for case-insensitive matching
477 OVERFLOW_PHRASES.iter().any(|phrase| lower.contains(phrase))
478}
479
480/// Check if an HTTP error response indicates context overflow.
481/*
482ARCHITECTURE: Two-path overflow detection
483
484Path 1 — Empty body (Cerebras, Mistral quirk):
485 Some providers return HTTP 400/413 with an EMPTY body when the input is too long.
486 We can't match a phrase, so we infer overflow from (status=400|413) + empty body.
487
488Path 2 — Phrase matching:
489 All other providers include a descriptive message. Delegate to is_context_overflow_message().
490
491The two paths are checked in order: empty-body first (cheaper), phrase-match second.
492
493RUST QUIRK: `message.trim().is_empty()`
494 `.trim()` removes leading/trailing whitespace, returning a `&str` slice of the original.
495 `.is_empty()` returns true if the slice has length 0.
496 Together: "is this message blank (or just whitespace)?"
497 Python analogy: `not message.strip()`
498*/
499fn is_context_overflow(
500 status: u16, // HTTP status — 400/413 with empty body → overflow even without a phrase
501 message: &str, // response body — matched against OVERFLOW_PHRASES; may be empty
502) -> bool {
503 // Some providers (Cerebras, Mistral) return 400/413 with empty body on overflow
504 if (status == 400 || status == 413) && message.trim().is_empty() {
505 return true;
506 }
507 is_context_overflow_message(message)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn classify_anthropic_overflow() {
516 let err =
517 ProviderError::classify(400, "prompt is too long: 213462 tokens > 200000 maximum");
518 assert!(err.is_context_overflow());
519 }
520
521 #[test]
522 fn classify_openai_overflow() {
523 let err =
524 ProviderError::classify(400, "Your input exceeds the context window of this model");
525 assert!(err.is_context_overflow());
526 }
527
528 #[test]
529 fn classify_google_overflow() {
530 let err = ProviderError::classify(
531 400,
532 "The input token count (1196265) exceeds the maximum number of tokens allowed",
533 );
534 assert!(err.is_context_overflow());
535 }
536
537 #[test]
538 fn classify_bedrock_overflow() {
539 let err = ProviderError::classify(400, "input is too long for requested model");
540 assert!(err.is_context_overflow());
541 }
542
543 #[test]
544 fn classify_xai_overflow() {
545 let err = ProviderError::classify(
546 400,
547 "This model's maximum prompt length is 131072 but request contains 537812 tokens",
548 );
549 assert!(err.is_context_overflow());
550 }
551
552 #[test]
553 fn classify_groq_overflow() {
554 let err = ProviderError::classify(
555 400,
556 "Please reduce the length of the messages or completion",
557 );
558 assert!(err.is_context_overflow());
559 }
560
561 #[test]
562 fn classify_empty_body_overflow() {
563 // Cerebras/Mistral return 400/413 with empty body
564 let err = ProviderError::classify(413, "");
565 assert!(err.is_context_overflow());
566 let err = ProviderError::classify(400, " ");
567 assert!(err.is_context_overflow());
568 }
569
570 #[test]
571 fn classify_rate_limit() {
572 let err = ProviderError::classify(429, "rate limit exceeded");
573 assert!(matches!(err, ProviderError::RateLimited { .. }));
574 }
575
576 #[test]
577 fn classify_auth_error() {
578 let err = ProviderError::classify(401, "invalid api key");
579 assert!(matches!(err, ProviderError::Auth(_)));
580 let err = ProviderError::classify(403, "forbidden");
581 assert!(matches!(err, ProviderError::Auth(_)));
582 }
583
584 #[test]
585 fn classify_regular_api_error() {
586 let err = ProviderError::classify(400, "invalid request format");
587 assert!(matches!(err, ProviderError::Api(_)));
588 assert!(!err.is_context_overflow());
589 }
590
591 #[test]
592 fn overflow_message_case_insensitive() {
593 assert!(is_context_overflow_message("PROMPT IS TOO LONG"));
594 assert!(is_context_overflow_message("Too Many Tokens in request"));
595 }
596
597 #[test]
598 fn non_overflow_messages() {
599 assert!(!is_context_overflow_message("invalid api key"));
600 assert!(!is_context_overflow_message("internal server error"));
601 assert!(!is_context_overflow_message(""));
602 }
603}