Skip to main content

aster_bench/runners/
model_runner.rs

1use crate::bench_config::{BenchEval, BenchModel, BenchRunConfig};
2use crate::eval_suites::EvaluationSuite;
3use crate::reporting::{BenchmarkResults, SuiteResult};
4use crate::runners::eval_runner::EvalRunner;
5use crate::utilities::{await_process_exits, parallel_bench_cmd};
6use anyhow::{Context, Result};
7use dotenvy::from_path_iter;
8use std::collections::HashMap;
9use std::fs::read_to_string;
10use std::path::PathBuf;
11use std::process::Child;
12use std::thread;
13use tracing;
14
15#[derive(Clone)]
16pub struct ModelRunner {
17    config: BenchRunConfig,
18}
19
20impl ModelRunner {
21    pub fn from(config: String) -> Result<ModelRunner> {
22        let config =
23            BenchRunConfig::from_string(config).context("Failed to parse configuration")?;
24        Ok(ModelRunner { config })
25    }
26
27    pub fn run(&self) -> Result<()> {
28        let model = self
29            .config
30            .models
31            .first()
32            .context("No model specified in config")?;
33        let suites = self.collect_evals_for_run();
34
35        let mut handles = vec![];
36
37        for i in 0..self.config.repeat.unwrap_or(1) {
38            let self_copy = self.clone();
39            let model_clone = model.clone();
40            let suites_clone = suites.clone();
41            let handle = thread::spawn(move || -> Result<()> {
42                self_copy.run_benchmark(&model_clone, suites_clone, i.to_string())
43            });
44            handles.push(handle);
45        }
46        await_process_exits(&mut Vec::new(), handles);
47
48        let mut all_runs_results: Vec<BenchmarkResults> = Vec::new();
49        for i in 0..self.config.repeat.unwrap_or(1) {
50            match self.collect_run_results(model.clone(), suites.clone(), i.to_string()) {
51                Ok(run_results) => all_runs_results.push(run_results),
52                Err(e) => {
53                    tracing::error!("Failed to collect results for run {}: {}", i, e)
54                }
55            }
56        }
57
58        Ok(())
59    }
60
61    fn run_benchmark(
62        &self,
63        model: &BenchModel,
64        suites: HashMap<String, Vec<BenchEval>>,
65        run_id: String,
66    ) -> Result<()> {
67        let mut results_handles = HashMap::<String, Vec<Child>>::new();
68
69        // Load environment variables from file if specified
70        let mut envs = self.toolshim_envs();
71        if let Some(env_file) = &self.config.env_file {
72            let env_vars = ModelRunner::load_env_file(env_file).context(format!(
73                "Failed to load environment file: {}",
74                env_file.display()
75            ))?;
76            envs.extend(env_vars);
77        }
78        envs.push(("ASTER_MODEL".to_string(), model.clone().name));
79        envs.push(("ASTER_PROVIDER".to_string(), model.clone().provider));
80
81        // Only run in parallel if the model is parallel_safe
82        let run_parallel = model.parallel_safe;
83
84        for (suite, evals) in suites.iter() {
85            results_handles.insert((*suite).clone(), Vec::new());
86
87            // Group evaluations by parallel_safe
88            let mut parallel_evals = Vec::new();
89            let mut sequential_evals = Vec::new();
90
91            for eval in evals {
92                if eval.parallel_safe && run_parallel {
93                    parallel_evals.push(eval);
94                } else {
95                    sequential_evals.push(eval);
96                }
97            }
98
99            // Run parallel-safe evaluations in parallel
100            if !parallel_evals.is_empty() {
101                for eval_selector in &parallel_evals {
102                    let mut config_copy = self.config.clone();
103                    config_copy.run_id = Some(run_id.clone());
104                    config_copy.evals = vec![(*eval_selector).clone()];
105                    let cfg = config_copy
106                        .to_string()
107                        .context("Failed to serialize configuration")?;
108
109                    let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
110                    results_handles.get_mut(suite).unwrap().push(handle);
111                }
112            }
113
114            // Run non-parallel-safe evaluations sequentially
115            for eval_selector in &sequential_evals {
116                let mut config_copy = self.config.clone();
117                config_copy.run_id = Some(run_id.clone());
118                config_copy.evals = vec![(*eval_selector).clone()];
119                let cfg = config_copy
120                    .to_string()
121                    .context("Failed to serialize configuration")?;
122
123                let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
124
125                // Wait for this process to complete before starting the next one
126                let mut child_procs = vec![handle];
127                await_process_exits(&mut child_procs, Vec::new());
128            }
129        }
130
131        // Wait for any remaining parallel processes to complete
132        for (_, child_procs) in results_handles.iter_mut() {
133            await_process_exits(child_procs, Vec::new());
134        }
135
136        Ok(())
137    }
138
139    fn collect_run_results(
140        &self,
141        model: BenchModel,
142        suites: HashMap<String, Vec<BenchEval>>,
143        run_id: String,
144    ) -> Result<BenchmarkResults> {
145        let mut results = BenchmarkResults::new(model.provider.clone());
146
147        let mut summary_path: Option<PathBuf> = None;
148
149        for (suite, evals) in suites.iter() {
150            let mut suite_result = SuiteResult::new(suite.clone());
151            for eval_selector in evals {
152                let mut eval_path =
153                    EvalRunner::path_for_eval(&model, eval_selector, run_id.clone());
154                eval_path.push(self.config.eval_result_filename.clone());
155
156                let content = read_to_string(&eval_path).with_context(|| {
157                    format!(
158                        "Failed to read evaluation results from {}",
159                        eval_path.display()
160                    )
161                })?;
162
163                let eval_result = serde_json::from_str(&content)
164                    .context("Failed to parse evaluation results JSON")?;
165
166                suite_result.add_evaluation(eval_result);
167
168                // use current eval to determine where the summary should be written
169                if summary_path.is_none() {
170                    let mut result = PathBuf::new();
171                    let mut iter = eval_path.components();
172                    if let Some(first) = iter.next() {
173                        result.push(first);
174                        if let Some(second) = iter.next() {
175                            result.push(second);
176                        }
177                    }
178                    summary_path = Some(result);
179                }
180            }
181            results.add_suite(suite_result);
182        }
183
184        if let Some(path) = summary_path {
185            let mut run_summary = PathBuf::new();
186            run_summary.push(path);
187            run_summary.push(&self.config.run_summary_filename);
188
189            let output_str = serde_json::to_string_pretty(&results)
190                .context("Failed to serialize benchmark results to JSON")?;
191
192            std::fs::write(&run_summary, &output_str).with_context(|| {
193                format!(
194                    "Failed to write results summary to {}",
195                    run_summary.display()
196                )
197            })?;
198        }
199
200        Ok(results)
201    }
202
203    fn collect_evals_for_run(&self) -> HashMap<String, Vec<BenchEval>> {
204        // convert suites map {suite_name => [eval_selector_str] to map suite_name => [BenchEval]
205        let mut result: HashMap<String, Vec<BenchEval>> = HashMap::new();
206        for eval in self.config.evals.iter() {
207            let selected_suites = EvaluationSuite::select(vec![eval.selector.clone()]);
208            for (suite, evals) in selected_suites {
209                let entry: &mut Vec<BenchEval> = result.entry(suite).or_default();
210                entry.reserve(evals.len());
211                for suite_eval in evals {
212                    let mut updated_eval = eval.clone();
213                    updated_eval.selector = suite_eval.to_string();
214                    entry.push(updated_eval);
215                }
216            }
217        }
218        result
219    }
220
221    fn toolshim_envs(&self) -> Vec<(String, String)> {
222        // read tool-shim preference from config, set respective env vars accordingly
223        let mut shim_envs: Vec<(String, String)> = Vec::new();
224        if let Some(model) = self.config.models.first() {
225            if let Some(shim_opt) = &model.tool_shim {
226                if shim_opt.use_tool_shim {
227                    shim_envs.push(("ASTER_TOOLSHIM".to_string(), "true".to_string()));
228                    if let Some(shim_model) = &shim_opt.tool_shim_model {
229                        shim_envs.push((
230                            "ASTER_TOOLSHIM_OLLAMA_MODEL".to_string(),
231                            shim_model.clone(),
232                        ));
233                    }
234                }
235            }
236        }
237        shim_envs
238    }
239
240    fn load_env_file(path: &PathBuf) -> Result<Vec<(String, String)>> {
241        let iter =
242            from_path_iter(path).context("Failed to read environment variables from file")?;
243        let env_vars = iter
244            .map(|item| item.context("Failed to parse environment variable"))
245            .collect::<Result<_, _>>()?;
246        Ok(env_vars)
247    }
248}