use crate::error::CognisError;
pub trait RetryClassifier: Send + Sync {
fn is_retryable(&self, error: &CognisError) -> bool;
}
pub struct DefaultRetryClassifier;
impl RetryClassifier for DefaultRetryClassifier {
fn is_retryable(&self, error: &CognisError) -> bool {
match error {
CognisError::HttpError { status, .. } => *status == 429 || *status >= 500,
CognisError::Other(msg) => {
let lower = msg.to_lowercase();
lower.contains("rate limit")
|| lower.contains("timeout")
|| lower.contains("connection")
}
CognisError::IoError(_) => true,
_ => false,
}
}
}
pub struct AlwaysRetryClassifier;
impl RetryClassifier for AlwaysRetryClassifier {
fn is_retryable(&self, _error: &CognisError) -> bool {
true
}
}
pub struct NeverRetryClassifier;
impl RetryClassifier for NeverRetryClassifier {
fn is_retryable(&self, _error: &CognisError) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_classifier_rate_limit() {
let classifier = DefaultRetryClassifier;
let err = CognisError::HttpError {
status: 429,
body: "Too Many Requests".into(),
};
assert!(classifier.is_retryable(&err));
}
#[test]
fn test_retry_classifier_server_errors() {
let classifier = DefaultRetryClassifier;
for status in [500, 502, 503, 504] {
let err = CognisError::HttpError {
status,
body: "Server Error".into(),
};
assert!(
classifier.is_retryable(&err),
"Expected status {} to be retryable",
status
);
}
}
#[test]
fn test_retry_classifier_client_errors_not_retryable() {
let classifier = DefaultRetryClassifier;
for status in [400, 401, 403, 404, 422] {
let err = CognisError::HttpError {
status,
body: "Client Error".into(),
};
assert!(
!classifier.is_retryable(&err),
"Expected status {} to NOT be retryable",
status
);
}
}
#[test]
fn test_retry_classifier_other_with_keywords() {
let classifier = DefaultRetryClassifier;
let rate = CognisError::Other("rate limit exceeded".into());
assert!(classifier.is_retryable(&rate));
let timeout = CognisError::Other("request timeout after 30s".into());
assert!(classifier.is_retryable(&timeout));
let conn = CognisError::Other("connection reset by peer".into());
assert!(classifier.is_retryable(&conn));
}
#[test]
fn test_retry_classifier_non_retryable_errors() {
let classifier = DefaultRetryClassifier;
let parse_err = CognisError::OutputParserError {
message: "bad format".into(),
observation: None,
llm_output: None,
};
assert!(!classifier.is_retryable(&parse_err));
let tool_err = CognisError::ToolException("tool failed".into());
assert!(!classifier.is_retryable(&tool_err));
let generic = CognisError::Other("some unknown error".into());
assert!(!classifier.is_retryable(&generic));
}
#[test]
fn test_always_retry_classifier() {
let classifier = AlwaysRetryClassifier;
let err = CognisError::Other("anything".into());
assert!(classifier.is_retryable(&err));
}
#[test]
fn test_never_retry_classifier() {
let classifier = NeverRetryClassifier;
let err = CognisError::HttpError {
status: 429,
body: "rate limited".into(),
};
assert!(!classifier.is_retryable(&err));
}
}