use std::time::Duration;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, Serialize)]
pub struct RateLimitState {
pub schema: u8,
#[serde(with = "time::serde::rfc3339")]
pub next_allowed_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub blocked_until: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub slowdown_until: Option<OffsetDateTime>,
pub consecutive_blocks: u32,
pub last_block_reason: Option<String>,
}
impl Default for RateLimitState {
fn default() -> Self {
Self {
schema: 2,
next_allowed_at: OffsetDateTime::now_utc(),
blocked_until: None,
slowdown_until: None,
consecutive_blocks: 0,
last_block_reason: None,
}
}
}
impl RateLimitState {
pub fn sanitize(&mut self, now: OffsetDateTime) {
let cap = now + Duration::from_secs(24 * 60 * 60);
if self.next_allowed_at > cap {
*self = Self::default();
self.next_allowed_at = now;
return;
}
if self.blocked_until.is_some_and(|t| t > cap) {
self.blocked_until = None;
}
if self.slowdown_until.is_some_and(|t| t > cap) {
self.slowdown_until = None;
}
}
}
#[derive(Deserialize)]
struct RawState {
#[serde(default, rename = "schema")]
_schema: u8,
#[serde(default, with = "time::serde::rfc3339::option")]
next_allowed_at: Option<OffsetDateTime>,
#[serde(default, with = "time::serde::rfc3339::option")]
blocked_until: Option<OffsetDateTime>,
#[serde(default, with = "time::serde::rfc3339::option")]
slowdown_until: Option<OffsetDateTime>,
#[serde(default)]
consecutive_blocks: u32,
#[serde(default)]
last_block_reason: Option<String>,
}
impl<'de> Deserialize<'de> for RateLimitState {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawState::deserialize(deserializer)?;
Ok(Self {
schema: 2,
next_allowed_at: raw.next_allowed_at.unwrap_or_else(OffsetDateTime::now_utc),
blocked_until: raw.blocked_until,
slowdown_until: raw.slowdown_until,
consecutive_blocks: raw.consecutive_blocks,
last_block_reason: raw.last_block_reason,
})
}
}
#[cfg(test)]
#[path = "state_tests.rs"]
mod tests;