mcai_benchmark/
configuration.rs1use crate::error::Result;
2use itertools::iproduct;
3use serde::{Deserialize, Serialize};
4use serde_valid::Validate;
5use serde_yaml;
6use std::{collections::HashMap, fs};
7
8#[derive(Clone, Debug, Deserialize, Serialize)]
9#[serde(tag = "version")]
10pub enum Configuration {
11 #[serde(rename = "1")]
12 Version1(Version1),
13}
14
15impl Configuration {
16 pub fn get_worker_docker_image(&self) -> String {
17 let Configuration::Version1(config) = self;
18 config.worker_docker_image.clone()
19 }
20
21 pub fn get_output_folder(&self) -> Option<String> {
22 let Configuration::Version1(config) = self;
23 config.output_folder.clone()
24 }
25
26 pub fn get_envs_for_job(&self, job_name: &str) -> HashMap<String, String> {
27 let Configuration::Version1(config) = self;
28 config
29 .benchmarks
30 .get(job_name)
31 .map(|config| config.envs.clone())
32 .unwrap_or_default()
33 }
34
35 pub fn get_volumes_for_job(&self, job_name: &str) -> Vec<VolumeConfig> {
36 let Configuration::Version1(config) = self;
37 config
38 .benchmarks
39 .get(job_name)
40 .map(|config| config.volumes.clone())
41 .unwrap_or_default()
42 }
43}
44
45#[derive(Clone, Debug, Deserialize, Serialize)]
46pub struct Version1 {
47 pub version: usize,
48 pub worker_docker_image: String,
49 pub output_folder: Option<String>,
50 #[serde(flatten)]
51 pub benchmarks: HashMap<String, BenchmarkConfig>,
52}
53
54impl Version1 {
55 pub fn read_from_file(filepath: &str) -> Result<Version1> {
56 let content = fs::read_to_string(filepath)?;
57 let config = serde_yaml::from_str(&content)?;
58 Ok(config)
59 }
60}
61
62fn default_iterations() -> i64 {
63 10
64}
65
66#[derive(Clone, Debug, Deserialize, Serialize, Validate)]
67pub struct HardwareConfig {
68 #[validate(min_items = 1)]
69 pub cpu: Vec<f32>,
70 #[validate(min_items = 1)]
71 pub memory: Vec<i64>,
72}
73
74#[derive(Clone, Debug, Deserialize, Serialize)]
75pub struct VolumeConfig {
76 pub host: String,
77 pub container: String,
78 #[serde(default)]
79 pub readonly: bool,
80}
81
82#[derive(Clone, Debug, Deserialize, Serialize)]
83pub struct BenchmarkConfig {
84 pub source_order: String,
85 #[serde(default = "default_iterations")]
86 pub iterations: i64,
87 pub hardware: HardwareConfig,
88 #[serde(default)]
89 pub envs: HashMap<String, String>,
90 #[serde(default)]
91 pub volumes: Vec<VolumeConfig>,
92}
93
94impl BenchmarkConfig {
95 pub fn get_hardware_configurations(&self) -> Result<Vec<(i64, f32)>> {
96 let configurations: Vec<(i64, f32)> =
97 iproduct!(self.hardware.memory.clone(), self.hardware.cpu.clone())
98 .map(|(memory, cpu)| (memory, cpu))
99 .collect();
100 Ok(configurations)
101 }
102}