use crate::errors::AnalysisError;
use crate::progress::traits::HasProgress;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use stillwater::effect::prelude::*;
use stillwater::Effect;
pub fn with_stage<T, Err, Env, Eff>(
stage_name: &str,
effect: Eff,
) -> impl Effect<Output = T, Error = Err, Env = Env>
where
Env: HasProgress + Clone + Send + Sync + 'static,
T: Send + 'static,
Err: Send + 'static,
Eff: Effect<Output = T, Error = Err, Env = Env> + Send + 'static,
{
let start_name = stage_name.to_string();
let end_name = stage_name.to_string();
from_async(move |env: &Env| {
let env = env.clone();
let start = start_name.clone();
let end = end_name.clone();
async move {
env.progress().start_stage(&start);
let result = effect.run(&env).await;
env.progress().complete_stage(&end);
result
}
})
}
pub fn traverse_with_progress<T, U, Env, F, Eff>(
items: Vec<T>,
stage_name: &str,
f: F,
) -> impl Effect<Output = Vec<U>, Error = AnalysisError, Env = Env>
where
T: Send + 'static,
U: Send + 'static,
Env: HasProgress + Clone + Send + Sync + 'static,
F: Fn(T) -> Eff + Send + Sync + 'static,
Eff: Effect<Output = U, Error = AnalysisError, Env = Env> + Send,
{
let name = stage_name.to_string();
let total = items.len();
from_async(move |env: &Env| {
let env = env.clone();
let stage = name.clone();
async move {
env.progress().start_stage(&stage);
let mut results = Vec::with_capacity(total);
for (i, item) in items.into_iter().enumerate() {
env.progress().report(&stage, i, total);
match f(item).run(&env).await {
Ok(result) => results.push(result),
Err(e) => {
env.progress().complete_stage(&stage);
return Err(e);
}
}
}
env.progress().complete_stage(&stage);
Ok(results)
}
})
}
pub fn par_traverse_with_progress<T, U, Env, F, Eff>(
items: Vec<T>,
stage_name: &str,
f: F,
) -> impl Effect<Output = Vec<U>, Error = AnalysisError, Env = Env>
where
T: Send + 'static,
U: Send + 'static,
Env: HasProgress + Clone + Send + Sync + 'static,
F: Fn(T) -> Eff + Send + Sync + Clone + 'static,
Eff: Effect<Output = U, Error = AnalysisError, Env = Env> + Send,
{
let name = stage_name.to_string();
let total = items.len();
from_async(move |env: &Env| {
let env = env.clone();
let stage = name.clone();
let f = f.clone();
async move {
env.progress().start_stage(&stage);
let counter = Arc::new(AtomicUsize::new(0));
let mut results = Vec::with_capacity(total);
for item in items {
let current = counter.fetch_add(1, Ordering::Relaxed);
env.progress().report(&stage, current, total);
match f(item).run(&env).await {
Ok(result) => results.push(result),
Err(e) => {
env.progress().complete_stage(&stage);
return Err(e);
}
}
}
env.progress().complete_stage(&stage);
Ok(results)
}
})
}
pub fn report_progress<Env>(
stage: &str,
current: usize,
total: usize,
) -> impl Effect<Output = (), Error = AnalysisError, Env = Env>
where
Env: HasProgress + Clone + Send + Sync + 'static,
{
let stage = stage.to_string();
stillwater::asks(move |env: &Env| {
env.progress().report(&stage, current, total);
})
}
pub fn warn_progress<Env>(
message: &str,
) -> impl Effect<Output = (), Error = AnalysisError, Env = Env>
where
Env: HasProgress + Clone + Send + Sync + 'static,
{
let message = message.to_string();
stillwater::asks(move |env: &Env| {
env.progress().warn(&message);
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::env::RealEnv;
use crate::progress::implementations::{ProgressEvent, RecordingProgressSink};
use stillwater::EffectExt;
fn test_env() -> (RealEnv, Arc<RecordingProgressSink>) {
let recorder = Arc::new(RecordingProgressSink::new());
let env = RealEnv::with_progress(crate::config::DebtmapConfig::default(), recorder.clone());
(env, recorder)
}
#[tokio::test]
async fn test_with_stage_calls_start_and_complete() {
let (env, recorder) = test_env();
let effect = with_stage("Test Stage", pure::<_, AnalysisError, _>(42));
let result = effect.run(&env).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(recorder.stages(), vec!["Test Stage"]);
assert_eq!(recorder.completed_stages(), vec!["Test Stage"]);
}
#[tokio::test]
async fn test_with_stage_completes_on_error() {
let (env, recorder) = test_env();
let effect = with_stage(
"Failing Stage",
fail::<i32, AnalysisError, RealEnv>(AnalysisError::other("test error")),
);
let result = effect.run(&env).await;
assert!(result.is_err());
assert_eq!(recorder.stages(), vec!["Failing Stage"]);
assert_eq!(recorder.completed_stages(), vec!["Failing Stage"]);
}
#[tokio::test]
async fn test_traverse_with_progress_reports_each_item() {
let (env, recorder) = test_env();
let items = vec![1, 2, 3];
let effect = traverse_with_progress(items, "Processing", |n| pure(n * 2));
let result = effect.run(&env).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![2, 4, 6]);
assert_eq!(recorder.stages(), vec!["Processing"]);
assert_eq!(recorder.completed_stages(), vec!["Processing"]);
let reports: Vec<_> = recorder
.events()
.into_iter()
.filter(|e| matches!(e, ProgressEvent::Report { .. }))
.collect();
assert_eq!(reports.len(), 3);
}
#[tokio::test]
async fn test_traverse_with_progress_completes_on_error() {
let (env, recorder) = test_env();
let items = vec![1, 2, 3];
let effect = traverse_with_progress(items, "Failing", |n| {
if n == 2 {
fail::<i32, AnalysisError, RealEnv>(AnalysisError::other("failed at 2")).boxed()
} else {
pure::<_, AnalysisError, RealEnv>(n).boxed()
}
});
let result = effect.run(&env).await;
assert!(result.is_err());
assert_eq!(recorder.stages(), vec!["Failing"]);
assert_eq!(recorder.completed_stages(), vec!["Failing"]);
}
#[tokio::test]
async fn test_report_progress_emits_event() {
let (env, recorder) = test_env();
let effect = report_progress("Manual", 5, 10);
let result = effect.run(&env).await;
assert!(result.is_ok());
let events = recorder.events();
assert_eq!(events.len(), 1);
assert!(matches!(
&events[0],
ProgressEvent::Report { stage, current: 5, total: 10 } if stage == "Manual"
));
}
#[tokio::test]
async fn test_warn_progress_emits_warning() {
let (env, recorder) = test_env();
let effect = warn_progress("Test warning message");
let result = effect.run(&env).await;
assert!(result.is_ok());
let warnings = recorder.warnings();
assert_eq!(warnings, vec!["Test warning message"]);
}
#[tokio::test]
async fn test_nested_stages() {
let (env, recorder) = test_env();
let inner = with_stage("Inner", pure::<_, AnalysisError, _>(1));
let outer = with_stage("Outer", inner.map(|n| n + 1));
let result = outer.run(&env).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 2);
let stages = recorder.stages();
assert!(stages.contains(&"Outer".to_string()));
assert!(stages.contains(&"Inner".to_string()));
}
#[tokio::test]
async fn test_par_traverse_with_progress() {
let (env, recorder) = test_env();
let items = vec![1, 2, 3, 4, 5];
let effect = par_traverse_with_progress(items, "Parallel", |n| pure(n * 2));
let result = effect.run(&env).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![2, 4, 6, 8, 10]);
assert_eq!(recorder.stages(), vec!["Parallel"]);
assert_eq!(recorder.completed_stages(), vec!["Parallel"]);
let reports: Vec<_> = recorder
.events()
.into_iter()
.filter(|e| matches!(e, ProgressEvent::Report { .. }))
.collect();
assert_eq!(reports.len(), 5);
}
}