use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};
use hen::{
error::{HenError, HenErrorKind, HenResult},
request,
};
#[derive(Debug, Clone)]
pub(crate) struct ExecutionState {
pub(crate) records: Vec<request::ExecutionRecord>,
pub(crate) failures: Vec<request::RequestFailure>,
pub(crate) trace: Vec<request::ExecutionTraceEntry>,
pub(crate) execution_failed: bool,
pub(crate) interrupted: Option<request::InterruptSignal>,
}
#[derive(Debug, Default, Clone)]
struct CollectedExecution {
records: Arc<Mutex<Vec<request::ExecutionRecord>>>,
failures: Arc<Mutex<Vec<request::RequestFailure>>>,
trace: Arc<Mutex<request::ExecutionTraceCollector>>,
}
impl CollectedExecution {
fn snapshot(
&self,
) -> (
Vec<request::ExecutionRecord>,
Vec<request::RequestFailure>,
Vec<request::ExecutionTraceEntry>,
) {
let mut records = self.records.lock().unwrap().clone();
let mut failures = self.failures.lock().unwrap().clone();
let trace = self.trace.lock().unwrap().snapshot();
records.sort_by_key(|record| record.index);
failures.sort_by_key(|failure| failure.index().unwrap_or(usize::MAX));
(records, failures, trace)
}
}
type InterruptFuture = Pin<Box<dyn Future<Output = HenResult<request::InterruptSignal>> + Send>>;
fn tracking_observer(
downstream: Option<request::ExecutionObserver>,
) -> (request::ExecutionObserver, CollectedExecution) {
let collected = CollectedExecution::default();
let records = Arc::clone(&collected.records);
let failures = Arc::clone(&collected.failures);
let trace = Arc::clone(&collected.trace);
let observer: request::ExecutionObserver = Arc::new(move |event| {
trace.lock().unwrap().record_event(&event);
match &event {
request::ExecutionEvent::RequestCompleted { record } => {
records.lock().unwrap().push(record.clone());
}
request::ExecutionEvent::RequestFailed { failure } => {
failures.lock().unwrap().push(failure.clone());
}
request::ExecutionEvent::RequestWaiting { .. }
| request::ExecutionEvent::RequestStarted { .. }
| request::ExecutionEvent::AssertionPassed { .. } => {}
}
if let Some(callback) = downstream.as_ref() {
callback(event);
}
});
(observer, collected)
}
fn interrupt_listener() -> HenResult<InterruptFuture> {
#[cfg(unix)]
{
let mut terminate = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.map_err(|err| {
HenError::new(HenErrorKind::Execution, "Failed to install SIGTERM handler")
.with_detail(err.to_string())
})?;
Ok(Box::pin(async move {
tokio::select! {
result = tokio::signal::ctrl_c() => result
.map(|_| request::InterruptSignal::Sigint)
.map_err(|err| {
HenError::new(HenErrorKind::Execution, "Failed while waiting for SIGINT")
.with_detail(err.to_string())
}),
signal = terminate.recv() => match signal {
Some(()) => Ok(request::InterruptSignal::Sigterm),
None => Err(HenError::new(
HenErrorKind::Execution,
"SIGTERM listener closed unexpectedly",
)),
},
}
}))
}
#[cfg(not(unix))]
{
Ok(Box::pin(async move {
tokio::signal::ctrl_c()
.await
.map(|_| request::InterruptSignal::Sigint)
.map_err(|err| {
HenError::new(HenErrorKind::Execution, "Failed while waiting for Ctrl-C")
.with_detail(err.to_string())
})
}))
}
}
pub(crate) async fn execute_plan_with_interrupt(
requests: &[request::Request],
plan: &[usize],
options: request::ExecutionOptions,
observer: Option<request::ExecutionObserver>,
) -> HenResult<ExecutionState> {
let (observer, collected) = tracking_observer(observer);
let execution = request::execute_request_plan_with_observer(
requests,
plan,
options,
Some(observer),
);
let interrupt = interrupt_listener()?;
tokio::pin!(execution);
tokio::pin!(interrupt);
tokio::select! {
result = &mut execution => {
let (_, _, trace) = collected.snapshot();
let (records, failures, execution_failed) = match result {
Ok(records) => (records, Vec::new(), false),
Err(err) => {
let (failures, completed) = err.into_parts();
(completed, failures, true)
}
};
Ok(ExecutionState {
records,
failures,
trace,
execution_failed,
interrupted: None,
})
}
result = &mut interrupt => {
let signal = result?;
collected.trace.lock().unwrap().record_interrupt(signal);
let (records, failures, trace) = collected.snapshot();
Ok(ExecutionState {
records,
failures,
trace,
execution_failed: true,
interrupted: Some(signal),
})
}
}
}