multi_skill/
experiment_config.rs1use crate::data_processing::{get_dataset_by_name, Contest, Dataset};
2use crate::systems::{
3 simulate_contest, CodeforcesSys, EloMMR, EloMMRVariant, Glicko, RatingSystem, TopcoderSys,
4 TrueSkillSPb, BAR,
5};
6
7use crate::metrics::compute_metrics_custom;
8use serde::Deserialize;
9use std::collections::HashMap;
10use std::path::Path;
11
12#[derive(Deserialize, Debug)]
13pub struct SystemParams {
14 pub method: String,
15 pub params: Vec<f64>,
16}
17
18fn usize_max() -> usize {
19 usize::MAX
20}
21
22#[allow(dead_code)]
23fn is_usize_max(&num: &usize) -> bool {
24 num == usize_max()
25}
26
27#[derive(Deserialize, Debug)]
28pub struct ExperimentConfig {
29 #[serde(default = "usize_max", skip_serializing_if = "is_usize_max")]
30 pub max_contests: usize,
31 pub mu_noob: f64,
32 pub sig_noob: f64,
33 pub system: SystemParams,
34 pub contest_source: String,
35}
36
37pub struct Experiment {
38 pub max_contests: usize,
39 pub mu_noob: f64,
40 pub sig_noob: f64,
41
42 pub system: Box<dyn RatingSystem + Send>,
44 pub dataset: Box<dyn Dataset<Item = Contest> + Send>,
45}
46
47impl Experiment {
48 pub fn from_file(source: impl AsRef<Path>) -> Self {
49 let params_json = std::fs::read_to_string(source).expect("Failed to read parameters file");
50 let params = serde_json::from_str(¶ms_json).expect("Failed to parse params as JSON");
51 Self::from_config(params)
52 }
53
54 pub fn from_config(params: ExperimentConfig) -> Self {
55 println!("Loading rating system:\n{:#?}", params);
56 let dataset = get_dataset_by_name(¶ms.contest_source).unwrap();
57
58 let system: Box<dyn RatingSystem + Send> = match params.system.method.as_str() {
59 "glicko" => Box::new(Glicko {
60 beta: params.system.params[0],
61 sig_drift: params.system.params[1],
62 }),
63 "bar" => Box::new(BAR {
64 beta: params.system.params[0],
65 sig_drift: params.system.params[1],
66 kappa: 1e-4,
67 }),
68 "codeforces" => Box::new(CodeforcesSys {
69 beta: params.system.params[0],
70 weight_multiplier: params.system.params[1],
71 }),
72 "topcoder" => Box::new(TopcoderSys {
73 weight_multiplier: params.system.params[0],
74 }),
75 "trueskill" => Box::new(TrueSkillSPb {
76 eps: params.system.params[0],
77 beta: params.system.params[1],
78 convergence_eps: params.system.params[2],
79 sig_drift: params.system.params[3],
80 }),
81 "mmx" => Box::new(EloMMR {
82 beta: params.system.params[0],
83 sig_limit: params.system.params[1],
84 drift_per_sec: 0.,
85 split_ties: params.system.params[2] > 0.,
86 subsample_size: params.system.params[3] as usize,
87 subsample_bucket: params.system.params[4],
88 variant: EloMMRVariant::Gaussian,
89 }),
90 "mmr" => Box::new(EloMMR {
91 beta: params.system.params[0],
92 sig_limit: params.system.params[1],
93 drift_per_sec: 0.,
94 split_ties: params.system.params[2] > 0.,
95 subsample_size: params.system.params[3] as usize,
96 subsample_bucket: params.system.params[4],
97 variant: EloMMRVariant::Logistic(params.system.params[5]),
98 }),
99 x => panic!("'{}' is not a valid system name!", x),
100 };
101
102 Self {
103 max_contests: params.max_contests,
104 mu_noob: params.mu_noob,
105 sig_noob: params.sig_noob,
106 system,
107 dataset,
108 }
109 }
110
111 pub fn eval(self, mut num_rounds_postpone_eval: usize, tag: &str) {
112 let mut players = HashMap::new();
113 let mut avg_perf = compute_metrics_custom(&mut players, &[]);
114
115 let now = std::time::Instant::now();
117 for (index, contest) in self.dataset.iter().enumerate().take(self.max_contests) {
118 if num_rounds_postpone_eval > 0 {
121 num_rounds_postpone_eval -= 1;
122 } else {
123 avg_perf += compute_metrics_custom(&mut players, &contest.standings);
124 }
125
126 simulate_contest(
128 &mut players,
129 &contest,
130 &*self.system,
131 self.mu_noob,
132 self.sig_noob,
133 index,
134 );
135 }
136 let secs_elapsed = now.elapsed().as_nanos() as f64 * 1e-9;
137
138 let horizontal = "============================================================";
139 let output = format!(
140 "{} {:?}: {}, {}s\n{}",
141 tag, self.system, avg_perf, secs_elapsed, horizontal
142 );
143 println!("{}", output);
144 }
145}