use std::sync::Arc;
use tokio::sync::{mpsc, Semaphore};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::config::Config;
use crate::prompts::PromptMode;
use super::EvalResult;
const DEFAULT_PARALLEL_LIMIT: usize = 3;
#[derive(Debug, Clone)]
pub struct TrialEvent {
pub mode: PromptMode,
pub trial_num: u32,
pub event: TrialEventKind,
}
#[derive(Debug, Clone)]
pub enum TrialEventKind {
Started,
Planning,
Building { iteration: u32, max_iterations: u32 },
Testing,
Complete { result: Box<TrialResult> },
Failed { error: String },
}
#[derive(Debug, Clone)]
pub struct TrialResult {
pub mode: PromptMode,
pub trial_num: u32,
pub eval_result: EvalResult,
}
#[allow(clippy::too_many_arguments)]
pub async fn run_parallel_evals(
modes: Vec<PromptMode>,
trials_per_mode: u32,
project_name: String,
keep: bool,
no_tui: bool,
config: Config,
event_tx: mpsc::UnboundedSender<TrialEvent>,
cancel_token: CancellationToken,
) -> Vec<TrialResult> {
let semaphore = Arc::new(Semaphore::new(DEFAULT_PARALLEL_LIMIT));
let mut set = JoinSet::new();
for mode in &modes {
for trial_num in 1..=trials_per_mode {
let permit = semaphore.clone();
let tx = event_tx.clone();
let mode = *mode;
let config = config.clone();
let project_name = project_name.clone();
let cancel = cancel_token.clone();
set.spawn(async move {
let _permit = permit.acquire().await.expect("semaphore closed");
let _ = tx.send(TrialEvent {
mode,
trial_num,
event: TrialEventKind::Started,
});
run_single_trial_parallel(
mode,
trial_num,
&project_name,
keep,
no_tui,
&config,
tx.clone(),
cancel,
)
.await
});
}
}
let mut results = Vec::new();
while let Some(result) = set.join_next().await {
match result {
Ok(Ok(trial_result)) => results.push(trial_result),
Ok(Err(e)) => eprintln!("Trial failed: {}", e),
Err(e) => eprintln!("Task panicked: {}", e),
}
}
results
}
#[allow(clippy::too_many_arguments)]
async fn run_single_trial_parallel(
mode: PromptMode,
trial_num: u32,
project_name: &str,
_keep: bool,
no_tui: bool,
config: &Config,
event_tx: mpsc::UnboundedSender<TrialEvent>,
cancel_token: CancellationToken,
) -> color_eyre::Result<TrialResult> {
use super::command::run_single_trial_with_mode;
let _ = event_tx.send(TrialEvent {
mode,
trial_num,
event: TrialEventKind::Planning,
});
let tx_for_callback = event_tx.clone();
let max_iterations = config.max_iterations;
let progress_callback: super::command::ProgressCallback = Arc::new(move |iteration, _total| {
let _ = tx_for_callback.send(TrialEvent {
mode,
trial_num,
event: TrialEventKind::Building {
iteration,
max_iterations,
},
});
});
let result = run_single_trial_with_mode(
project_name,
trial_num,
mode,
no_tui,
config,
cancel_token,
Some(progress_callback),
)
.await;
match &result {
Ok(eval_result) => {
let trial_result = TrialResult {
mode,
trial_num,
eval_result: eval_result.clone(),
};
let _ = event_tx.send(TrialEvent {
mode,
trial_num,
event: TrialEventKind::Complete {
result: Box::new(trial_result.clone()),
},
});
Ok(trial_result)
}
Err(e) => {
let _ = event_tx.send(TrialEvent {
mode,
trial_num,
event: TrialEventKind::Failed {
error: e.to_string(),
},
});
Err(color_eyre::eyre::eyre!("Trial failed: {}", e))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trial_event_creation() {
let event = TrialEvent {
mode: PromptMode::Basic,
trial_num: 1,
event: TrialEventKind::Started,
};
assert_eq!(event.mode, PromptMode::Basic);
assert_eq!(event.trial_num, 1);
matches!(event.event, TrialEventKind::Started);
}
#[test]
fn test_trial_event_building() {
let event = TrialEvent {
mode: PromptMode::Gsd,
trial_num: 2,
event: TrialEventKind::Building {
iteration: 3,
max_iterations: 10,
},
};
assert_eq!(event.mode, PromptMode::Gsd);
assert_eq!(event.trial_num, 2);
if let TrialEventKind::Building {
iteration,
max_iterations,
} = event.event
{
assert_eq!(iteration, 3);
assert_eq!(max_iterations, 10);
} else {
panic!("Expected Building event");
}
}
#[test]
fn test_trial_event_failed() {
let event = TrialEvent {
mode: PromptMode::GsdTdd,
trial_num: 5,
event: TrialEventKind::Failed {
error: "Test error".to_string(),
},
};
if let TrialEventKind::Failed { error } = event.event {
assert_eq!(error, "Test error");
} else {
panic!("Expected Failed event");
}
}
#[test]
fn test_default_parallel_limit() {
const { assert!(DEFAULT_PARALLEL_LIMIT >= 1) };
const { assert!(DEFAULT_PARALLEL_LIMIT <= 10) };
}
}