use std::fmt::Debug;
use std::sync::Arc;
use crate::callback::JobReturn;
use crate::config::value::ConfigDict;
use crate::launcher::{JobOverrideBatch, Launcher, LauncherError};
#[derive(Debug, Clone)]
pub struct SweeperError {
pub message: String,
}
impl SweeperError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for SweeperError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SweeperError: {}", self.message)
}
}
impl std::error::Error for SweeperError {}
impl From<LauncherError> for SweeperError {
fn from(err: LauncherError) -> Self {
Self::new(err.message)
}
}
pub trait Sweeper: Send + Sync + Debug {
fn setup(
&mut self,
config: &ConfigDict,
launcher: Arc<dyn Launcher>,
) -> Result<(), SweeperError>;
fn sweep(&self, arguments: &[String]) -> Result<Vec<JobReturn>, SweeperError>;
fn name(&self) -> &str;
}
#[derive(Debug)]
pub struct BasicSweeper {
config: Option<ConfigDict>,
launcher: Option<Arc<dyn Launcher>>,
max_batch_size: Option<usize>,
}
impl Default for BasicSweeper {
fn default() -> Self {
Self::new(None)
}
}
impl BasicSweeper {
pub fn new(max_batch_size: Option<usize>) -> Self {
Self {
config: None,
launcher: None,
max_batch_size,
}
}
fn parse_sweep_arguments(&self, arguments: &[String]) -> Vec<Vec<(String, Vec<String>)>> {
let mut param_lists: Vec<(String, Vec<String>)> = Vec::new();
for arg in arguments {
if let Some((key, value)) = arg.split_once('=') {
let values: Vec<String> = if value.contains(',') && !value.contains('[') {
value.split(',').map(|s| s.trim().to_string()).collect()
} else {
vec![value.to_string()]
};
param_lists.push((key.to_string(), values));
}
}
self.cartesian_product(¶m_lists)
}
fn cartesian_product(
&self,
params: &[(String, Vec<String>)],
) -> Vec<Vec<(String, Vec<String>)>> {
if params.is_empty() {
return vec![vec![]];
}
let mut combinations: Vec<Vec<String>> = vec![vec![]];
for (key, values) in params {
let mut new_combinations = Vec::new();
for combo in &combinations {
for value in values {
let mut new_combo = combo.clone();
new_combo.push(format!("{}={}", key, value));
new_combinations.push(new_combo);
}
}
combinations = new_combinations;
}
combinations
.into_iter()
.map(|c| c.into_iter().map(|s| (s.clone(), vec![s])).collect())
.collect()
}
fn split_into_batches(&self, combinations: Vec<Vec<String>>) -> Vec<Vec<Vec<String>>> {
match self.max_batch_size {
None => vec![combinations],
Some(size) if size == 0 => vec![combinations],
Some(size) => combinations.chunks(size).map(|c| c.to_vec()).collect(),
}
}
}
impl Sweeper for BasicSweeper {
fn setup(
&mut self,
config: &ConfigDict,
launcher: Arc<dyn Launcher>,
) -> Result<(), SweeperError> {
self.config = Some(config.clone());
self.launcher = Some(launcher);
Ok(())
}
fn sweep(&self, arguments: &[String]) -> Result<Vec<JobReturn>, SweeperError> {
let launcher = self
.launcher
.as_ref()
.ok_or_else(|| SweeperError::new("Sweeper not set up - no launcher"))?;
let mut all_combinations: Vec<Vec<String>> = Vec::new();
let mut param_values: Vec<(String, Vec<String>)> = Vec::new();
for arg in arguments {
if let Some((key, value)) = arg.split_once('=') {
let values: Vec<String> = if value.contains(',') && !value.starts_with('[') {
value.split(',').map(|s| s.trim().to_string()).collect()
} else {
vec![value.to_string()]
};
param_values.push((key.to_string(), values));
}
}
if param_values.is_empty() {
all_combinations.push(vec![]);
} else {
let mut combos: Vec<Vec<String>> = vec![vec![]];
for (key, values) in ¶m_values {
let mut new_combos = Vec::new();
for combo in &combos {
for value in values {
let mut new_combo = combo.clone();
new_combo.push(format!("{}={}", key, value));
new_combos.push(new_combo);
}
}
combos = new_combos;
}
all_combinations = combos;
}
let batches = self.split_into_batches(all_combinations);
let mut all_results = Vec::new();
let mut job_idx = 0;
for batch in batches {
let results = launcher.launch(&batch, job_idx)?;
job_idx += results.len();
all_results.extend(results);
}
Ok(all_results)
}
fn name(&self) -> &str {
"basic"
}
}
#[derive(Default)]
pub struct SweeperManager {
sweeper: Option<Arc<dyn Sweeper>>,
}
impl SweeperManager {
pub fn new() -> Self {
Self::default()
}
pub fn set_sweeper(&mut self, sweeper: Arc<dyn Sweeper>) {
self.sweeper = Some(sweeper);
}
pub fn set_basic_sweeper(&mut self, max_batch_size: Option<usize>) {
self.sweeper = Some(Arc::new(BasicSweeper::new(max_batch_size)));
}
pub fn sweeper(&self) -> Option<&Arc<dyn Sweeper>> {
self.sweeper.as_ref()
}
pub fn sweep(&self, arguments: &[String]) -> Result<Vec<JobReturn>, SweeperError> {
match &self.sweeper {
Some(s) => s.sweep(arguments),
None => Err(SweeperError::new("No sweeper configured")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::launcher::BasicLauncher;
#[test]
fn test_basic_sweeper_setup() {
let mut sweeper = BasicSweeper::new(None);
let config = ConfigDict::new();
let launcher = Arc::new(BasicLauncher::new());
assert!(sweeper.setup(&config, launcher.clone()).is_ok());
assert_eq!(sweeper.name(), "basic");
}
#[test]
fn test_basic_sweeper_single_value() {
let mut sweeper = BasicSweeper::new(None);
let config = ConfigDict::new();
let mut launcher = BasicLauncher::new();
launcher.setup(&config, "test").unwrap();
let launcher = Arc::new(launcher);
sweeper.setup(&config, launcher.clone()).unwrap();
let args = vec!["key=value".to_string()];
let results = sweeper.sweep(&args).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_basic_sweeper_cartesian_product() {
let mut sweeper = BasicSweeper::new(None);
let config = ConfigDict::new();
let mut launcher = BasicLauncher::new();
launcher.setup(&config, "test").unwrap();
let launcher = Arc::new(launcher);
sweeper.setup(&config, launcher.clone()).unwrap();
let args = vec!["a=1,2".to_string(), "b=x,y".to_string()];
let results = sweeper.sweep(&args).unwrap();
assert_eq!(results.len(), 4);
}
#[test]
fn test_basic_sweeper_with_batch_size() {
let mut sweeper = BasicSweeper::new(Some(2));
let config = ConfigDict::new();
let mut launcher = BasicLauncher::new();
launcher.setup(&config, "test").unwrap();
let launcher = Arc::new(launcher);
sweeper.setup(&config, launcher.clone()).unwrap();
let args = vec!["a=1,2".to_string(), "b=x,y".to_string()];
let results = sweeper.sweep(&args).unwrap();
assert_eq!(results.len(), 4);
}
#[test]
fn test_sweeper_manager() {
let mut manager = SweeperManager::new();
manager.set_basic_sweeper(None);
assert!(manager.sweeper().is_some());
}
#[test]
fn test_sweeper_manager_no_sweeper() {
let manager = SweeperManager::new();
assert!(manager.sweep(&[]).is_err());
}
}