use super::context::{get_current_context, increment_processed, set_current_file, AnalysisContext};
use rayon::prelude::*;
use std::path::Path;
use tracing::{debug_span, Span};
thread_local! {
pub(crate) static PARALLEL_CONTEXT: std::cell::RefCell<Option<AnalysisContext>> =
const { std::cell::RefCell::new(None) };
}
#[derive(Clone)]
pub struct ParallelContext {
span: Span,
analysis_context: AnalysisContext,
}
impl ParallelContext {
#[must_use]
pub fn capture() -> Self {
Self {
span: Span::current(),
analysis_context: get_current_context(),
}
}
#[must_use]
pub fn enter(&self) -> ParallelContextGuard {
let span_guard = self.span.clone().entered();
super::context::CURRENT_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = self.analysis_context.clone();
});
ParallelContextGuard { _span: span_guard }
}
#[must_use]
pub fn analysis_context(&self) -> &AnalysisContext {
&self.analysis_context
}
#[must_use]
pub fn span(&self) -> &Span {
&self.span
}
}
pub struct ParallelContextGuard {
_span: tracing::span::EnteredSpan,
}
#[inline]
pub fn with_parallel_context<T, F>(ctx: &ParallelContext, f: F) -> T
where
F: FnOnce() -> T,
{
let _guard = ctx.enter();
f()
}
pub fn process_file_with_context<T, F>(path: &Path, parent_ctx: &ParallelContext, f: F) -> T
where
F: FnOnce() -> T,
{
let _parent = parent_ctx.enter();
let _file = set_current_file(path);
let _span = debug_span!("process_file", path = %path.display()).entered();
increment_processed();
f()
}
pub trait ParallelContextExt<T>: ParallelIterator<Item = T> + Sized {
fn map_with_context<R, F>(self, f: F) -> impl ParallelIterator<Item = R>
where
F: Fn(T) -> R + Sync + Send,
R: Send,
{
let ctx = ParallelContext::capture();
self.map(move |item| with_parallel_context(&ctx, || f(item)))
}
fn filter_map_with_context<R, F>(self, f: F) -> impl ParallelIterator<Item = R>
where
F: Fn(T) -> Option<R> + Sync + Send,
R: Send,
{
let ctx = ParallelContext::capture();
self.filter_map(move |item| with_parallel_context(&ctx, || f(item)))
}
fn for_each_with_context<F>(self, f: F)
where
F: Fn(T) + Sync + Send,
{
let ctx = ParallelContext::capture();
self.for_each(move |item| with_parallel_context(&ctx, || f(item)));
}
}
impl<T, I: ParallelIterator<Item = T> + Sized> ParallelContextExt<T> for I {}
#[cfg(test)]
mod tests {
use super::*;
use crate::observability::context::{reset_context, reset_progress, set_phase, AnalysisPhase};
use std::path::PathBuf;
#[test]
fn test_context_capture() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::Parsing);
let ctx = ParallelContext::capture();
assert_eq!(
ctx.analysis_context().phase,
Some(AnalysisPhase::Parsing),
"Captured context should have Parsing phase"
);
}
#[test]
fn test_context_propagates_to_workers() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::DebtScoring);
let ctx = ParallelContext::capture();
let results: Vec<_> = (0..10)
.into_par_iter()
.map(|i| {
let _guard = ctx.enter();
let context = get_current_context();
(i, context.phase)
})
.collect();
for (_, phase) in results {
assert_eq!(
phase,
Some(AnalysisPhase::DebtScoring),
"Phase should propagate to workers"
);
}
}
#[test]
fn test_file_context_per_item() {
reset_context();
reset_progress();
let ctx = ParallelContext::capture();
let files = vec![
PathBuf::from("a.rs"),
PathBuf::from("b.rs"),
PathBuf::from("c.rs"),
];
files.par_iter().for_each(|path| {
process_file_with_context(path, &ctx, || {
let context = get_current_context();
assert_eq!(
context.current_file.as_ref(),
Some(path),
"Current file should be set in context"
);
});
});
}
#[test]
fn test_with_parallel_context_helper() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::Parsing);
let ctx = ParallelContext::capture();
let result: i32 = (0..100)
.into_par_iter()
.map(|x| with_parallel_context(&ctx, || x * 2))
.sum();
assert_eq!(
result, 9900,
"Computation should work correctly with context"
);
}
#[test]
fn test_map_with_context_extension() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::PurityAnalysis);
let results: Vec<_> = (0..10)
.into_par_iter()
.map_with_context(|i| {
let ctx = get_current_context();
(i, ctx.phase)
})
.collect();
for (_, phase) in results {
assert_eq!(
phase,
Some(AnalysisPhase::PurityAnalysis),
"map_with_context should propagate phase"
);
}
}
#[test]
fn test_filter_map_with_context_extension() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::CoverageLoading);
let results: Vec<_> = (0..20)
.into_par_iter()
.filter_map_with_context(|i| {
let ctx = get_current_context();
if i % 2 == 0 {
Some((i, ctx.phase))
} else {
None
}
})
.collect();
assert_eq!(results.len(), 10, "Should filter half the items");
for (_, phase) in results {
assert_eq!(
phase,
Some(AnalysisPhase::CoverageLoading),
"filter_map_with_context should propagate phase"
);
}
}
#[test]
fn test_for_each_with_context_extension() {
use std::sync::atomic::{AtomicUsize, Ordering};
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::OutputGeneration);
let count = AtomicUsize::new(0);
(0..10).into_par_iter().for_each_with_context(|_| {
let ctx = get_current_context();
if ctx.phase == Some(AnalysisPhase::OutputGeneration) {
count.fetch_add(1, Ordering::Relaxed);
}
});
assert_eq!(
count.load(Ordering::Relaxed),
10,
"All items should have correct phase"
);
}
#[test]
fn test_nested_context_in_parallel() {
reset_context();
reset_progress();
let _phase = set_phase(AnalysisPhase::Parsing);
let ctx = ParallelContext::capture();
let results: Vec<_> = (0..5)
.into_par_iter()
.map(|i| {
let _guard = ctx.enter();
let _file = set_current_file(format!("file_{}.rs", i));
let inner_ctx = get_current_context();
(
i,
inner_ctx.phase,
inner_ctx
.current_file
.map(|p| p.to_string_lossy().to_string()),
)
})
.collect();
for (i, phase, file) in results {
assert_eq!(phase, Some(AnalysisPhase::Parsing));
assert_eq!(file, Some(format!("file_{}.rs", i)));
}
}
}