use std::fmt::Debug;
use std::sync::Arc;
use crate::callback::JobReturn;
use crate::config::value::ConfigDict;
#[derive(Debug, Clone)]
pub struct LauncherError {
pub message: String,
}
impl LauncherError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for LauncherError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LauncherError: {}", self.message)
}
}
impl std::error::Error for LauncherError {}
pub type JobOverrides = Vec<String>;
pub type JobOverrideBatch = Vec<JobOverrides>;
pub trait Launcher: Send + Sync + Debug {
fn setup(&mut self, config: &ConfigDict, task_name: &str) -> Result<(), LauncherError>;
fn launch(
&self,
job_overrides: &JobOverrideBatch,
initial_job_idx: usize,
) -> Result<Vec<JobReturn>, LauncherError>;
fn name(&self) -> &str;
}
#[derive(Debug, Default)]
pub struct BasicLauncher {
config: Option<ConfigDict>,
task_name: String,
}
impl BasicLauncher {
pub fn new() -> Self {
Self::default()
}
}
impl Launcher for BasicLauncher {
fn setup(&mut self, config: &ConfigDict, task_name: &str) -> Result<(), LauncherError> {
self.config = Some(config.clone());
self.task_name = task_name.to_string();
Ok(())
}
fn launch(
&self,
job_overrides: &JobOverrideBatch,
initial_job_idx: usize,
) -> Result<Vec<JobReturn>, LauncherError> {
let mut results = Vec::with_capacity(job_overrides.len());
for (idx, _overrides) in job_overrides.iter().enumerate() {
let job_idx = initial_job_idx + idx;
let job_return = JobReturn {
return_value: None,
working_dir: std::env::current_dir()
.unwrap_or_default()
.to_string_lossy()
.to_string(),
output_dir: format!("outputs/{}", job_idx),
job_name: format!("job_{}", job_idx),
task_name: self.task_name.clone(),
status_code: 0,
};
results.push(job_return);
}
Ok(results)
}
fn name(&self) -> &str {
"basic"
}
}
#[derive(Default)]
pub struct LauncherManager {
launcher: Option<Arc<dyn Launcher>>,
}
impl LauncherManager {
pub fn new() -> Self {
Self::default()
}
pub fn set_launcher(&mut self, launcher: Arc<dyn Launcher>) {
self.launcher = Some(launcher);
}
pub fn set_basic_launcher(&mut self) {
self.launcher = Some(Arc::new(BasicLauncher::new()));
}
pub fn launcher(&self) -> Option<&Arc<dyn Launcher>> {
self.launcher.as_ref()
}
pub fn launch(
&self,
job_overrides: &JobOverrideBatch,
initial_job_idx: usize,
) -> Result<Vec<JobReturn>, LauncherError> {
match &self.launcher {
Some(l) => l.launch(job_overrides, initial_job_idx),
None => Err(LauncherError::new("No launcher configured")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_launcher_setup() {
let mut launcher = BasicLauncher::new();
let config = ConfigDict::new();
assert!(launcher.setup(&config, "my_task").is_ok());
assert_eq!(launcher.name(), "basic");
}
#[test]
fn test_basic_launcher_launch() {
let mut launcher = BasicLauncher::new();
let config = ConfigDict::new();
launcher.setup(&config, "test_task").unwrap();
let overrides = vec![
vec!["db=mysql".to_string()],
vec!["db=postgres".to_string()],
];
let results = launcher.launch(&overrides, 0).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].job_name, "job_0");
assert_eq!(results[1].job_name, "job_1");
}
#[test]
fn test_launcher_manager() {
let mut manager = LauncherManager::new();
manager.set_basic_launcher();
let overrides = vec![vec!["key=value".to_string()]];
let results = manager.launch(&overrides, 0).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_launcher_manager_no_launcher() {
let manager = LauncherManager::new();
let overrides = vec![vec!["key=value".to_string()]];
assert!(manager.launch(&overrides, 0).is_err());
}
}