inference_lab/config/
workload.rs

1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
4pub struct WorkloadConfig {
5    /// Arrival pattern: "poisson", "uniform", "burst", "fixed_rate", "closed_loop", "batched"
6    pub arrival_pattern: String,
7
8    /// Mean arrival rate (requests per second)
9    pub arrival_rate: f64,
10
11    /// Input sequence length distribution
12    pub input_len_dist: LengthDistribution,
13
14    /// Output sequence length distribution
15    pub output_len_dist: LengthDistribution,
16
17    /// Total number of requests to simulate (None = run until duration)
18    pub num_requests: Option<usize>,
19
20    /// Number of concurrent users for closed-loop pattern
21    /// Each user immediately sends a new request when their previous one completes
22    #[serde(default)]
23    pub num_concurrent_users: Option<usize>,
24
25    /// Simulation duration in seconds (None = run until num_requests)
26    pub duration_secs: Option<f64>,
27
28    /// Random seed for reproducibility
29    pub seed: u64,
30}
31
32#[derive(Debug, Clone, Deserialize)]
33#[serde(tag = "type")]
34pub enum LengthDistribution {
35    #[serde(rename = "fixed")]
36    Fixed { value: u32 },
37
38    #[serde(rename = "uniform")]
39    Uniform { min: u32, max: u32 },
40
41    #[serde(rename = "normal")]
42    Normal { mean: f64, std_dev: f64 },
43
44    #[serde(rename = "lognormal")]
45    LogNormal { mean: f64, std_dev: f64 },
46}
47
48impl LengthDistribution {
49    /// Sample a value from this distribution
50    pub fn sample<R: rand::Rng>(&self, rng: &mut R) -> u32 {
51        use rand_distr::Distribution;
52
53        match self {
54            LengthDistribution::Fixed { value } => *value,
55            LengthDistribution::Uniform { min, max } => {
56                rng.gen_range(*min..=*max)
57            }
58            LengthDistribution::Normal { mean, std_dev } => {
59                let normal = rand_distr::Normal::new(*mean, *std_dev).unwrap();
60                normal.sample(rng).max(1.0) as u32
61            }
62            LengthDistribution::LogNormal { mean, std_dev } => {
63                let lognormal = rand_distr::LogNormal::new(*mean, *std_dev).unwrap();
64                lognormal.sample(rng).max(1.0) as u32
65            }
66        }
67    }
68}