use std::time::{Duration, Instant};
use tracing::{info, warn};
use crate::rpc::config;
pub const MAX_ERROR_MSG_LEN: usize = 512;
pub const MAX_ERRORS_PER_WINDOW: usize = config::MAX_RETRIES as usize;
pub const ERROR_WINDOW_SECONDS: u64 = 600;
pub fn truncate_error(err: impl std::fmt::Display) -> String {
let msg = err.to_string();
if msg.len() <= MAX_ERROR_MSG_LEN {
msg
} else {
format!(
"{}... [truncated {} chars]",
&msg[..MAX_ERROR_MSG_LEN],
msg.len() - MAX_ERROR_MSG_LEN
)
}
}
pub fn truncate_error_short(err: impl std::fmt::Display, max_len: usize) -> String {
let msg = err.to_string();
if msg.len() <= max_len {
msg
} else {
msg[..max_len].to_string()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCategory {
ResponseTooLarge,
RateLimit,
Connection,
Other,
}
impl ErrorCategory {
pub fn from_error_msg(error_msg: &str) -> Self {
if error_msg.contains("32081")
|| error_msg.contains("too large")
|| error_msg.contains("max results")
|| error_msg.contains("deserialization")
|| error_msg.contains("EOF")
|| error_msg.contains("truncated")
|| error_msg.contains("decoding response")
{
return Self::ResponseTooLarge;
}
if error_msg.contains("429")
|| error_msg.contains("Too Many Requests")
|| error_msg.contains("rate limit")
|| error_msg.contains("Rate limit")
{
return Self::RateLimit;
}
if error_msg.contains("Connection reset")
|| error_msg.contains("connection")
|| error_msg.contains("timeout")
|| error_msg.contains("Timeout")
{
return Self::Connection;
}
Self::Other
}
}
pub struct ProviderErrorTracker {
error_timestamps: Vec<Instant>,
suspended: bool,
identifier: String,
}
impl ProviderErrorTracker {
pub fn new(identifier: impl Into<String>) -> Self {
Self {
error_timestamps: Vec::new(),
suspended: false,
identifier: identifier.into(),
}
}
pub fn record_error(&mut self) -> bool {
let now = Instant::now();
self.cleanup_old_errors(now);
self.error_timestamps.push(now);
if self.error_timestamps.len() >= MAX_ERRORS_PER_WINDOW {
self.suspended = true;
warn!(
"Provider {} suspended: {} errors in {} seconds",
self.identifier,
self.error_timestamps.len(),
ERROR_WINDOW_SECONDS
);
}
self.suspended
}
pub fn record_success(&mut self) {
self.error_timestamps.clear();
}
fn cleanup_old_errors(&mut self, now: Instant) {
let window = Duration::from_secs(ERROR_WINDOW_SECONDS);
self.error_timestamps
.retain(|&ts| now.duration_since(ts) < window);
if self.suspended && self.error_timestamps.len() < MAX_ERRORS_PER_WINDOW {
info!("Provider {} un-suspended: errors aged out", self.identifier);
self.suspended = false;
}
}
pub fn is_suspended(&mut self) -> bool {
self.cleanup_old_errors(Instant::now());
self.suspended
}
pub fn error_count(&mut self) -> usize {
self.cleanup_old_errors(Instant::now());
self.error_timestamps.len()
}
pub fn identifier(&self) -> &str {
&self.identifier
}
pub fn backoff_duration(&mut self) -> Duration {
let error_count = self.error_count();
Duration::from_secs(config::BASE_BACKOFF_SECS.pow(error_count as u32))
}
}
pub fn find_active_provider(
trackers: &mut [ProviderErrorTracker],
start_index: usize,
) -> Option<usize> {
let num_providers = trackers.len();
for i in 0..num_providers {
let idx = (start_index + i) % num_providers;
if !trackers[idx].is_suspended() {
return Some(idx);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_category_size_errors() {
assert_eq!(
ErrorCategory::from_error_msg("Response 32081 too large"),
ErrorCategory::ResponseTooLarge
);
assert_eq!(
ErrorCategory::from_error_msg("max results exceeded"),
ErrorCategory::ResponseTooLarge
);
assert_eq!(
ErrorCategory::from_error_msg("deserialization failed"),
ErrorCategory::ResponseTooLarge
);
assert_eq!(
ErrorCategory::from_error_msg("error decoding response body"),
ErrorCategory::ResponseTooLarge
);
}
#[test]
fn test_error_category_rate_limit() {
assert_eq!(
ErrorCategory::from_error_msg("HTTP error: 429 Too Many Requests"),
ErrorCategory::RateLimit
);
assert_eq!(
ErrorCategory::from_error_msg("rate limit exceeded"),
ErrorCategory::RateLimit
);
}
#[test]
fn test_error_category_connection() {
assert_eq!(
ErrorCategory::from_error_msg("Connection reset without closing handshake"),
ErrorCategory::Connection
);
assert_eq!(
ErrorCategory::from_error_msg("request timeout"),
ErrorCategory::Connection
);
}
#[test]
fn test_provider_error_tracker() {
let mut tracker = ProviderErrorTracker::new("test-provider");
assert!(!tracker.is_suspended());
assert_eq!(tracker.error_count(), 0);
for _ in 0..(MAX_ERRORS_PER_WINDOW - 1) {
assert!(!tracker.record_error());
}
assert!(!tracker.is_suspended());
assert_eq!(tracker.error_count(), MAX_ERRORS_PER_WINDOW - 1);
assert!(tracker.record_error());
assert!(tracker.is_suspended());
tracker.record_success();
assert_eq!(tracker.error_count(), 0);
}
#[test]
fn test_truncate_error() {
let short_msg = "short error";
assert_eq!(truncate_error(short_msg), short_msg);
let long_msg = "x".repeat(1000);
let truncated = truncate_error(&long_msg);
assert!(truncated.len() < long_msg.len());
assert!(truncated.contains("[truncated"));
}
#[test]
fn test_find_active_provider() {
let mut trackers = vec![
ProviderErrorTracker::new("provider-0"),
ProviderErrorTracker::new("provider-1"),
ProviderErrorTracker::new("provider-2"),
];
assert_eq!(find_active_provider(&mut trackers, 0), Some(0));
assert_eq!(find_active_provider(&mut trackers, 1), Some(1));
for _ in 0..MAX_ERRORS_PER_WINDOW {
trackers[0].record_error();
}
assert_eq!(find_active_provider(&mut trackers, 0), Some(1));
for tracker in &mut trackers {
for _ in 0..MAX_ERRORS_PER_WINDOW {
tracker.record_error();
}
}
assert_eq!(find_active_provider(&mut trackers, 0), None);
}
}