use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time::sleep;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct RpsController {
target_rps: Arc<RwLock<f64>>,
current_rps: Arc<RwLock<f64>>,
request_count: Arc<RwLock<u64>>,
last_reset: Arc<RwLock<Instant>>,
min_interval: Arc<RwLock<Duration>>,
}
impl RpsController {
pub fn new(target_rps: f64) -> Self {
let min_interval = if target_rps > 0.0 {
Duration::from_secs_f64(1.0 / target_rps)
} else {
Duration::from_secs(0)
};
Self {
target_rps: Arc::new(RwLock::new(target_rps)),
current_rps: Arc::new(RwLock::new(0.0)),
request_count: Arc::new(RwLock::new(0)),
last_reset: Arc::new(RwLock::new(Instant::now())),
min_interval: Arc::new(RwLock::new(min_interval)),
}
}
pub async fn set_target_rps(&self, rps: f64) {
let mut target = self.target_rps.write().await;
*target = rps;
let min_interval = if rps > 0.0 {
Duration::from_secs_f64(1.0 / rps)
} else {
Duration::from_secs(0)
};
let mut interval = self.min_interval.write().await;
*interval = min_interval;
debug!("RPS controller: target RPS set to {}", rps);
}
pub async fn get_target_rps(&self) -> f64 {
*self.target_rps.read().await
}
pub async fn get_current_rps(&self) -> f64 {
*self.current_rps.read().await
}
pub async fn wait_for_slot(&self) {
let min_interval = *self.min_interval.read().await;
if min_interval.is_zero() {
return; }
sleep(min_interval).await;
}
pub async fn record_request(&self) {
let mut count = self.request_count.write().await;
*count += 1;
let now = Instant::now();
let mut last_reset = self.last_reset.write().await;
let elapsed = now.duration_since(*last_reset);
if elapsed >= Duration::from_secs(1) {
let rps = *count as f64 / elapsed.as_secs_f64();
let mut current = self.current_rps.write().await;
*current = rps;
*count = 0;
*last_reset = now;
debug!("RPS controller: current RPS = {:.2}", rps);
}
}
pub async fn get_request_count(&self) -> u64 {
*self.request_count.read().await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RpsProfile {
pub name: String,
pub stages: Vec<RpsStage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RpsStage {
pub duration_secs: u64,
pub target_rps: f64,
pub name: Option<String>,
}
impl RpsProfile {
pub fn constant(rps: f64) -> Self {
Self {
name: format!("Constant {} RPS", rps),
stages: vec![RpsStage {
duration_secs: 0, target_rps: rps,
name: Some("Constant".to_string()),
}],
}
}
pub fn ramp_up(start_rps: f64, end_rps: f64, duration_secs: u64) -> Self {
let steps = (duration_secs / 10).max(1); let rps_step = (end_rps - start_rps) / steps as f64;
let mut stages = Vec::new();
for i in 0..steps {
let current_rps = start_rps + (i as f64 * rps_step);
stages.push(RpsStage {
duration_secs: 10,
target_rps: current_rps,
name: Some(format!("Ramp {} -> {}", current_rps, current_rps + rps_step)),
});
}
Self {
name: format!("Ramp up {} -> {} RPS", start_rps, end_rps),
stages,
}
}
pub fn spike(base_rps: f64, spike_rps: f64, spike_duration_secs: u64) -> Self {
Self {
name: format!("Spike {} -> {} RPS", base_rps, spike_rps),
stages: vec![
RpsStage {
duration_secs: 30,
target_rps: base_rps,
name: Some("Base".to_string()),
},
RpsStage {
duration_secs: spike_duration_secs,
target_rps: spike_rps,
name: Some("Spike".to_string()),
},
RpsStage {
duration_secs: 30,
target_rps: base_rps,
name: Some("Recovery".to_string()),
},
],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rps_controller() {
let controller = RpsController::new(10.0);
assert_eq!(controller.get_target_rps().await, 10.0);
for _ in 0..5 {
controller.record_request().await;
}
assert!(controller.get_request_count().await > 0);
}
#[test]
fn test_rps_profile_constant() {
let profile = RpsProfile::constant(100.0);
assert_eq!(profile.stages.len(), 1);
assert_eq!(profile.stages[0].target_rps, 100.0);
}
#[test]
fn test_rps_profile_ramp_up() {
let profile = RpsProfile::ramp_up(10.0, 100.0, 60);
assert!(!profile.stages.is_empty());
assert_eq!(profile.stages[0].target_rps, 10.0);
}
#[test]
fn test_rps_profile_spike() {
let profile = RpsProfile::spike(50.0, 200.0, 10);
assert_eq!(profile.stages.len(), 3);
assert_eq!(profile.stages[0].target_rps, 50.0);
assert_eq!(profile.stages[1].target_rps, 200.0);
assert_eq!(profile.stages[2].target_rps, 50.0);
}
}