Skip to main content

inference_lab/config/
workload.rs

1use serde::Deserialize;
2
3fn default_arrival_rate() -> f64 {
4    1.0
5}
6
7#[derive(Debug, Clone, Deserialize)]
8pub struct WorkloadConfig {
9    /// Path to dataset file (JSONL in OpenAI batch API format)
10    /// If provided, dataset mode is used instead of synthetic workload
11    #[serde(default)]
12    pub dataset_path: Option<String>,
13
14    /// Arrival pattern: "poisson", "uniform", "burst", "fixed_rate", "closed_loop", "batched"
15    pub arrival_pattern: String,
16
17    /// Mean arrival rate (requests per second)
18    /// Not used for "closed_loop" or "batched" patterns
19    #[serde(default = "default_arrival_rate")]
20    pub arrival_rate: f64,
21
22    /// Input sequence length distribution (ignored in dataset mode)
23    pub input_len_dist: LengthDistribution,
24
25    /// Output sequence length distribution (ignored in dataset mode)
26    pub output_len_dist: LengthDistribution,
27
28    /// Total number of requests to simulate (None = run until duration)
29    pub num_requests: Option<usize>,
30
31    /// Number of concurrent users for closed-loop pattern
32    /// Each user immediately sends a new request when their previous one completes
33    #[serde(default)]
34    pub num_concurrent_users: Option<usize>,
35
36    /// Simulation duration in seconds (None = run until num_requests)
37    pub duration_secs: Option<f64>,
38
39    /// Random seed for reproducibility
40    pub seed: u64,
41}
42
43#[derive(Debug, Clone, Deserialize)]
44#[serde(tag = "type")]
45pub enum LengthDistribution {
46    #[serde(rename = "fixed")]
47    Fixed { value: u32 },
48
49    #[serde(rename = "uniform")]
50    Uniform { min: u32, max: u32 },
51
52    #[serde(rename = "normal")]
53    Normal { mean: f64, std_dev: f64 },
54
55    #[serde(rename = "lognormal")]
56    LogNormal { mean: f64, std_dev: f64 },
57}
58
59impl LengthDistribution {
60    /// Sample a value from this distribution
61    pub fn sample<R: rand::Rng>(&self, rng: &mut R) -> u32 {
62        use rand_distr::Distribution;
63
64        match self {
65            LengthDistribution::Fixed { value } => *value,
66            LengthDistribution::Uniform { min, max } => rng.gen_range(*min..=*max),
67            LengthDistribution::Normal { mean, std_dev } => {
68                let normal = rand_distr::Normal::new(*mean, *std_dev).unwrap();
69                normal.sample(rng).max(1.0) as u32
70            }
71            LengthDistribution::LogNormal { mean, std_dev } => {
72                let lognormal = rand_distr::LogNormal::new(*mean, *std_dev).unwrap();
73                lognormal.sample(rng).max(1.0) as u32
74            }
75        }
76    }
77}