pub mod resp;
use self::resp::{RespParser, RespSerializer, RespValue};
use super::Transport;
use crate::actor::RateLimiterHandle;
use crate::metrics::{Metrics, Transport as MetricsTransport};
use crate::types::ThrottleRequest;
use anyhow::{Context, Result};
use async_trait::async_trait;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tracing::{debug, error, info};
pub struct RedisTransport {
addr: SocketAddr,
metrics: Arc<Metrics>,
}
impl RedisTransport {
pub fn new(host: &str, port: u16, metrics: Arc<Metrics>) -> Result<Self> {
let addr = format!("{host}:{port}")
.parse()
.with_context(|| format!("Invalid address: {host}:{port}"))?;
Ok(Self { addr, metrics })
}
}
#[async_trait]
impl Transport for RedisTransport {
async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
let listener = TcpListener::bind(&self.addr)
.await
.with_context(|| format!("Failed to bind to {}", self.addr))?;
info!("Redis transport listening on {}", self.addr);
loop {
let (socket, addr) = listener.accept().await?;
let limiter = limiter.clone();
let metrics = Arc::clone(&self.metrics);
tokio::spawn(async move {
if let Err(e) = handle_connection(socket, addr, limiter, metrics).await {
error!("Error handling Redis connection from {}: {}", addr, e);
}
});
}
}
}
const MAX_BUFFER_SIZE: usize = 64 * 1024;
async fn handle_connection(
mut socket: TcpStream,
addr: SocketAddr,
limiter: RateLimiterHandle,
metrics: Arc<Metrics>,
) -> Result<()> {
debug!("New Redis connection from {}", addr);
let mut buffer = Vec::new();
let mut parser = RespParser::new();
loop {
let mut temp_buf = vec![0; 1024];
let read_timeout = Duration::from_secs(300);
let n = match timeout(read_timeout, socket.read(&mut temp_buf)).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
debug!(
"Redis connection {} timed out after 5 minutes of inactivity",
addr
);
return Ok(());
}
};
if n == 0 {
debug!("Redis connection closed by client {}", addr);
return Ok(());
}
buffer.extend_from_slice(&temp_buf[..n]);
if buffer.len() > MAX_BUFFER_SIZE {
error!("Redis connection {} exceeded buffer size limit", addr);
return Err(anyhow::anyhow!("Buffer size limit exceeded"));
}
while let Some((value, consumed)) = parser.parse(&buffer)? {
buffer.drain(..consumed);
let is_quit = matches!(&value, RespValue::Array(arr) if arr.first().map(|v| {
matches!(v, RespValue::BulkString(Some(cmd)) if cmd.to_uppercase() == "QUIT")
}).unwrap_or(false));
let response = process_command(value, &limiter, &metrics).await;
let response_bytes = RespSerializer::serialize(&response);
socket.write_all(&response_bytes).await?;
if is_quit {
debug!("Closing Redis connection for {} after QUIT", addr);
return Ok(());
}
}
}
}
pub(super) async fn process_command(
value: RespValue,
limiter: &RateLimiterHandle,
metrics: &Arc<Metrics>,
) -> RespValue {
let command_array = match value {
RespValue::Array(arr) => arr,
_ => return RespValue::Error("ERR expected array of commands".to_string()),
};
if command_array.is_empty() {
return RespValue::Error("ERR empty command".to_string());
}
let command = match &command_array[0] {
RespValue::BulkString(Some(cmd)) => cmd.to_uppercase(),
_ => return RespValue::Error("ERR invalid command format".to_string()),
};
let (result, key_opt) = match command.as_str() {
"PING" => (handle_ping(&command_array), None),
"THROTTLE" => {
let key = if command_array.len() > 1 {
match &command_array[1] {
RespValue::BulkString(Some(k)) => Some(k.clone()),
_ => None,
}
} else {
None
};
(handle_throttle(&command_array, limiter, metrics).await, key)
}
"QUIT" => (RespValue::SimpleString("OK".to_string()), None),
_ => (
RespValue::Error(format!("ERR unknown command '{command}'")),
None,
),
};
let allowed = match &result {
RespValue::Array(values) if values.len() >= 5 => {
matches!(&values[0], RespValue::Integer(1))
}
_ => true, };
if let Some(key) = key_opt {
metrics.record_request_with_key(MetricsTransport::Redis, allowed, &key);
} else {
metrics.record_request(MetricsTransport::Redis, allowed);
}
result
}
fn handle_ping(args: &[RespValue]) -> RespValue {
if args.len() == 1 {
RespValue::SimpleString("PONG".to_string())
} else if args.len() == 2 {
args[1].clone()
} else {
RespValue::Error("ERR wrong number of arguments for 'ping' command".to_string())
}
}
async fn handle_throttle(
args: &[RespValue],
limiter: &RateLimiterHandle,
_metrics: &Arc<Metrics>,
) -> RespValue {
if args.len() < 5 || args.len() > 6 {
return RespValue::Error(
"ERR wrong number of arguments for 'throttle' command".to_string(),
);
}
let key = match &args[1] {
RespValue::BulkString(Some(s)) => s.clone(),
_ => return RespValue::Error("ERR invalid key".to_string()),
};
let max_burst = match parse_integer(&args[2]) {
Some(n) => n,
None => return RespValue::Error("ERR invalid max_burst".to_string()),
};
let count_per_period = match parse_integer(&args[3]) {
Some(n) => n,
None => return RespValue::Error("ERR invalid count_per_period".to_string()),
};
let period = match parse_integer(&args[4]) {
Some(n) => n,
None => return RespValue::Error("ERR invalid period".to_string()),
};
let quantity = if args.len() == 6 {
match parse_integer(&args[5]) {
Some(n) => n,
None => return RespValue::Error("ERR invalid quantity".to_string()),
}
} else {
1
};
let request = ThrottleRequest {
key,
max_burst,
count_per_period,
period,
quantity,
timestamp: SystemTime::now(),
};
match limiter.throttle(request).await {
Ok(response) => {
RespValue::Array(vec![
RespValue::Integer(if response.allowed { 1 } else { 0 }),
RespValue::Integer(response.limit),
RespValue::Integer(response.remaining),
RespValue::Integer(response.reset_after),
RespValue::Integer(response.retry_after),
])
}
Err(e) => RespValue::Error(format!("ERR {e}")),
}
}
fn parse_integer(value: &RespValue) -> Option<i64> {
match value {
RespValue::BulkString(Some(s)) => s.parse().ok(),
RespValue::Integer(n) => Some(*n),
_ => None,
}
}