inference_lab/config/
workload.rs1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
4pub struct WorkloadConfig {
5 pub arrival_pattern: String,
7
8 pub arrival_rate: f64,
10
11 pub input_len_dist: LengthDistribution,
13
14 pub output_len_dist: LengthDistribution,
16
17 pub num_requests: Option<usize>,
19
20 #[serde(default)]
23 pub num_concurrent_users: Option<usize>,
24
25 pub duration_secs: Option<f64>,
27
28 pub seed: u64,
30}
31
32#[derive(Debug, Clone, Deserialize)]
33#[serde(tag = "type")]
34pub enum LengthDistribution {
35 #[serde(rename = "fixed")]
36 Fixed { value: u32 },
37
38 #[serde(rename = "uniform")]
39 Uniform { min: u32, max: u32 },
40
41 #[serde(rename = "normal")]
42 Normal { mean: f64, std_dev: f64 },
43
44 #[serde(rename = "lognormal")]
45 LogNormal { mean: f64, std_dev: f64 },
46}
47
48impl LengthDistribution {
49 pub fn sample<R: rand::Rng>(&self, rng: &mut R) -> u32 {
51 use rand_distr::Distribution;
52
53 match self {
54 LengthDistribution::Fixed { value } => *value,
55 LengthDistribution::Uniform { min, max } => rng.gen_range(*min..=*max),
56 LengthDistribution::Normal { mean, std_dev } => {
57 let normal = rand_distr::Normal::new(*mean, *std_dev).unwrap();
58 normal.sample(rng).max(1.0) as u32
59 }
60 LengthDistribution::LogNormal { mean, std_dev } => {
61 let lognormal = rand_distr::LogNormal::new(*mean, *std_dev).unwrap();
62 lognormal.sample(rng).max(1.0) as u32
63 }
64 }
65 }
66}