inference_lab/config/
workload.rs1use serde::Deserialize;
2
3fn default_arrival_rate() -> f64 {
4 1.0
5}
6
7#[derive(Debug, Clone, Deserialize)]
8pub struct WorkloadConfig {
9 #[serde(default)]
12 pub dataset_path: Option<String>,
13
14 pub arrival_pattern: String,
16
17 #[serde(default = "default_arrival_rate")]
20 pub arrival_rate: f64,
21
22 pub input_len_dist: LengthDistribution,
24
25 pub output_len_dist: LengthDistribution,
27
28 pub num_requests: Option<usize>,
30
31 #[serde(default)]
34 pub num_concurrent_users: Option<usize>,
35
36 pub duration_secs: Option<f64>,
38
39 pub seed: u64,
41}
42
43#[derive(Debug, Clone, Deserialize)]
44#[serde(tag = "type")]
45pub enum LengthDistribution {
46 #[serde(rename = "fixed")]
47 Fixed { value: u32 },
48
49 #[serde(rename = "uniform")]
50 Uniform { min: u32, max: u32 },
51
52 #[serde(rename = "normal")]
53 Normal { mean: f64, std_dev: f64 },
54
55 #[serde(rename = "lognormal")]
56 LogNormal { mean: f64, std_dev: f64 },
57}
58
59impl LengthDistribution {
60 pub fn sample<R: rand::Rng>(&self, rng: &mut R) -> u32 {
62 use rand_distr::Distribution;
63
64 match self {
65 LengthDistribution::Fixed { value } => *value,
66 LengthDistribution::Uniform { min, max } => rng.gen_range(*min..=*max),
67 LengthDistribution::Normal { mean, std_dev } => {
68 let normal = rand_distr::Normal::new(*mean, *std_dev).unwrap();
69 normal.sample(rng).max(1.0) as u32
70 }
71 LengthDistribution::LogNormal { mean, std_dev } => {
72 let lognormal = rand_distr::LogNormal::new(*mean, *std_dev).unwrap();
73 lognormal.sample(rng).max(1.0) as u32
74 }
75 }
76 }
77}