use backon::Retryable;
use eyre::WrapErr;
use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
use itertools::Itertools;
use once_cell::sync::Lazy;
use std::{
collections::HashMap,
ops::Deref,
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex,
},
time::{Duration, SystemTime},
};
use tokio::sync::broadcast;
use tracing::*;
use crate::{
config::{self, get_tanu_config, CaptureHttpMode, ProjectConfig},
http,
reporter::Reporter,
Config, ModuleName, ProjectName,
};
tokio::task_local! {
pub(crate) static TEST_INFO: Arc<TestInfo>;
}
pub(crate) fn get_test_info() -> Arc<TestInfo> {
TEST_INFO.with(Arc::clone)
}
pub fn scope_current<F>(fut: F) -> impl std::future::Future<Output = F::Output> + Send
where
F: std::future::Future + Send,
F::Output: Send,
{
let project = crate::config::PROJECT.try_with(Arc::clone).ok();
let test_info = TEST_INFO.try_with(Arc::clone).ok();
async move {
match (project, test_info) {
(Some(project), Some(test_info)) => {
crate::config::PROJECT
.scope(project, TEST_INFO.scope(test_info, fut))
.await
}
(Some(project), None) => crate::config::PROJECT.scope(project, fut).await,
(None, Some(test_info)) => TEST_INFO.scope(test_info, fut).await,
(None, None) => fut.await,
}
}
}
#[allow(clippy::type_complexity)]
pub(crate) static CHANNEL: Lazy<
Mutex<Option<(broadcast::Sender<Event>, broadcast::Receiver<Event>)>>,
> = Lazy::new(|| Mutex::new(Some(broadcast::channel(1000))));
pub(crate) static REPORTER_BARRIER: Lazy<Mutex<Option<Arc<tokio::sync::Barrier>>>> =
Lazy::new(|| Mutex::new(None));
pub fn publish(e: impl Into<Event>) -> eyre::Result<()> {
let Ok(guard) = CHANNEL.lock() else {
eyre::bail!("failed to acquire runner channel lock");
};
let Some((tx, _)) = guard.deref() else {
eyre::bail!("runner channel has been already closed");
};
tx.send(e.into())
.wrap_err("failed to publish message to the runner channel")?;
Ok(())
}
pub fn subscribe() -> eyre::Result<broadcast::Receiver<Event>> {
let Ok(guard) = CHANNEL.lock() else {
eyre::bail!("failed to acquire runner channel lock");
};
let Some((tx, _)) = guard.deref() else {
eyre::bail!("runner channel has been already closed");
};
Ok(tx.subscribe())
}
pub(crate) fn setup_reporter_barrier(count: usize) -> eyre::Result<()> {
let Ok(mut barrier) = REPORTER_BARRIER.lock() else {
eyre::bail!("failed to acquire reporter barrier lock");
};
*barrier = Some(Arc::new(tokio::sync::Barrier::new(count + 1)));
Ok(())
}
pub(crate) async fn wait_reporter_barrier() {
let barrier = match REPORTER_BARRIER.lock() {
Ok(guard) => guard.clone(),
Err(e) => {
error!("failed to acquire reporter barrier lock (poisoned): {e}");
return;
}
};
if let Some(b) = barrier {
b.wait().await;
}
}
async fn execute_test(
project: Arc<ProjectConfig>,
info: Arc<TestInfo>,
factory: TestCaseFactory,
serial_mutex: Option<Arc<tokio::sync::Mutex<()>>>,
worker_id: isize,
) -> eyre::Result<Test> {
let project_for_scope = Arc::clone(&project);
let info_for_scope = Arc::clone(&info);
config::PROJECT
.scope(project_for_scope, async {
TEST_INFO
.scope(info_for_scope, async {
let test_name = info.name.clone();
publish(EventBody::Start)?;
let retry_count = AtomicUsize::new(project.retry.count.unwrap_or(0));
let serial_mutex_clone = serial_mutex.clone();
let f = || async {
let _serial_guard = if let Some(ref mutex) = serial_mutex_clone {
Some(mutex.lock().await)
} else {
None
};
let started_at = SystemTime::now();
let request_started = std::time::Instant::now();
let res = factory().await;
let ended_at = SystemTime::now();
if res.is_err() && retry_count.load(Ordering::SeqCst) > 0 {
let test_result = match &res {
Ok(_) => Ok(()),
Err(e) => Err(Error::ErrorReturned(format!("{e:?}"))),
};
let test = Test {
result: test_result,
info: Arc::clone(&info),
worker_id,
started_at,
ended_at,
request_time: request_started.elapsed(),
};
publish(EventBody::Retry(test))?;
retry_count.fetch_sub(1, Ordering::SeqCst);
};
res
};
let started_at = SystemTime::now();
let started = std::time::Instant::now();
let fut = f.retry(project.retry.backoff());
let fut = std::panic::AssertUnwindSafe(fut).catch_unwind();
let res = fut.await;
let request_time = started.elapsed();
let ended_at = SystemTime::now();
let result = match res {
Ok(Ok(_)) => {
debug!("{test_name} ok");
Ok(())
}
Ok(Err(e)) => {
debug!("{test_name} failed: {e:#}");
Err(Error::ErrorReturned(format!("{e:?}")))
}
Err(e) => {
let panic_message =
if let Some(panic_message) = e.downcast_ref::<&str>() {
format!("{test_name} failed with message: {panic_message}")
} else if let Some(panic_message) = e.downcast_ref::<String>() {
format!("{test_name} failed with message: {panic_message}")
} else {
format!("{test_name} failed with unknown message")
};
let e = eyre::eyre!(panic_message);
Err(Error::Panicked(format!("{e:?}")))
}
};
let test = Test {
result,
info: Arc::clone(&info),
worker_id,
started_at,
ended_at,
request_time,
};
publish(EventBody::End(test.clone()))?;
eyre::Ok(test)
})
.await
})
.await
}
pub(crate) fn clear_reporter_barrier() {
match REPORTER_BARRIER.lock() {
Ok(mut barrier) => {
*barrier = None;
}
Err(e) => {
error!("failed to clear reporter barrier (poisoned lock): {e}");
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum Error {
#[error("panic: {0}")]
Panicked(String),
#[error("error: {0}")]
ErrorReturned(String),
}
#[derive(Debug, Clone)]
pub struct Check {
pub result: bool,
pub expr: String,
}
impl Check {
pub fn success(expr: impl Into<String>) -> Check {
Check {
result: true,
expr: expr.into(),
}
}
pub fn error(expr: impl Into<String>) -> Check {
Check {
result: false,
expr: expr.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct Event {
pub project: ProjectName,
pub module: ModuleName,
pub test: ModuleName,
pub body: EventBody,
}
#[derive(Debug, Clone)]
pub enum CallLog {
Http(Box<http::Log>),
#[cfg(feature = "grpc")]
Grpc(Box<crate::grpc::Log>),
}
#[derive(Debug, Clone)]
pub enum EventBody {
Start,
Check(Box<Check>),
Call(CallLog),
Retry(Test),
End(Test),
Summary(TestSummary),
}
impl From<EventBody> for Event {
fn from(body: EventBody) -> Self {
let project = crate::config::get_config();
let test_info = crate::runner::get_test_info();
Event {
project: project.name.clone(),
module: test_info.module.clone(),
test: test_info.name.clone(),
body,
}
}
}
#[derive(Debug, Clone)]
pub struct Test {
pub info: Arc<TestInfo>,
pub worker_id: isize,
pub started_at: SystemTime,
pub ended_at: SystemTime,
pub request_time: Duration,
pub result: Result<(), Error>,
}
#[derive(Debug, Clone)]
pub struct TestSummary {
pub total_tests: usize,
pub passed_tests: usize,
pub failed_tests: usize,
pub skipped_tests: usize,
pub total_time: Duration,
pub test_prep_time: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct TestInfo {
pub module: String,
pub name: String,
pub serial_group: Option<String>,
pub line: u32,
pub ordered: bool,
}
impl TestInfo {
pub fn full_name(&self) -> String {
format!("{}::{}", self.module, self.name)
}
pub fn unique_name(&self, project: &str) -> String {
format!("{project}::{}::{}", self.module, self.name)
}
}
#[derive(Debug)]
pub struct WorkerIds {
enabled: bool,
ids: Mutex<Vec<isize>>,
}
impl WorkerIds {
pub fn new(concurrency: Option<usize>) -> Self {
match concurrency {
Some(c) => Self {
enabled: true,
ids: Mutex::new((0..c as isize).collect()),
},
None => Self {
enabled: false,
ids: Mutex::new(Vec::new()),
},
}
}
pub fn acquire(&self) -> isize {
if !self.enabled {
return -1;
}
self.ids
.lock()
.ok()
.and_then(|mut guard| guard.pop())
.unwrap_or(-1)
}
pub fn release(&self, id: isize) {
if !self.enabled || id < 0 {
return;
}
if let Ok(mut guard) = self.ids.lock() {
guard.push(id);
}
}
}
type TestCaseFactory = Arc<
dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
+ Sync
+ Send
+ 'static,
>;
#[derive(Debug, Clone)]
pub struct Options {
pub debug: bool,
pub capture_http: CaptureHttpMode,
pub capture_rust: bool,
pub terminate_channel: bool,
pub concurrency: Option<usize>,
pub mask_sensitive: bool,
pub fail_fast: bool,
}
impl Default for Options {
fn default() -> Self {
Self {
debug: false,
capture_http: CaptureHttpMode::Off,
capture_rust: false,
terminate_channel: false,
concurrency: None,
mask_sensitive: true, fail_fast: false,
}
}
}
pub trait Filter {
fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool;
}
pub struct ProjectFilter<'a> {
project_names: &'a [String],
}
impl Filter for ProjectFilter<'_> {
fn filter(&self, project: &ProjectConfig, _info: &TestInfo) -> bool {
if self.project_names.is_empty() {
return true;
}
self.project_names
.iter()
.any(|project_name| &project.name == project_name)
}
}
pub struct ModuleFilter<'a> {
module_names: &'a [String],
}
impl Filter for ModuleFilter<'_> {
fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
if self.module_names.is_empty() {
return true;
}
self.module_names
.iter()
.any(|module_name| &info.module == module_name)
}
}
pub struct TestNameFilter<'a> {
test_names: &'a [String],
}
impl Filter for TestNameFilter<'_> {
fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
if self.test_names.is_empty() {
return true;
}
self.test_names
.iter()
.any(|test_name| &info.full_name() == test_name)
}
}
pub struct TestIgnoreFilter {
test_ignores: HashMap<String, Vec<String>>,
}
impl Default for TestIgnoreFilter {
fn default() -> TestIgnoreFilter {
TestIgnoreFilter {
test_ignores: get_tanu_config()
.projects
.iter()
.map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
.collect(),
}
}
}
impl Filter for TestIgnoreFilter {
fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool {
let Some(test_ignore) = self.test_ignores.get(&project.name) else {
return true;
};
test_ignore
.iter()
.all(|test_name| &info.full_name() != test_name)
}
}
#[derive(Default)]
pub struct Runner {
cfg: Config,
options: Options,
test_cases: Vec<(Arc<TestInfo>, TestCaseFactory)>,
reporters: Vec<Box<dyn Reporter + Send>>,
}
impl Runner {
pub fn new() -> Runner {
Runner::with_config(get_tanu_config().clone())
}
pub fn with_config(cfg: Config) -> Runner {
Runner {
cfg,
options: Options::default(),
test_cases: Vec::new(),
reporters: Vec::new(),
}
}
pub fn capture_http(&mut self) {
self.options.capture_http = CaptureHttpMode::All;
}
pub fn set_capture_http_mode(&mut self, mode: CaptureHttpMode) {
self.options.capture_http = mode;
}
pub fn capture_rust(&mut self) {
self.options.capture_rust = true;
}
pub fn terminate_channel(&mut self) {
self.options.terminate_channel = true;
}
pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
self.reporters.push(Box::new(reporter));
}
pub fn add_boxed_reporter(&mut self, reporter: Box<dyn Reporter + 'static + Send>) {
self.reporters.push(reporter);
}
pub fn add_test(
&mut self,
name: &str,
module: &str,
serial_group: Option<&str>,
line: u32,
ordered: bool,
factory: TestCaseFactory,
) {
self.test_cases.push((
Arc::new(TestInfo {
name: name.into(),
module: module.into(),
serial_group: serial_group.map(|s| s.to_string()),
line,
ordered,
}),
factory,
));
}
pub fn set_concurrency(&mut self, concurrency: usize) {
self.options.concurrency = Some(concurrency);
}
pub fn show_sensitive(&mut self) {
self.options.mask_sensitive = false;
}
pub fn set_fail_fast(&mut self, fail_fast: bool) {
self.options.fail_fast = fail_fast;
}
#[allow(clippy::too_many_lines)]
pub async fn run(
&mut self,
project_names: &[String],
module_names: &[String],
test_names: &[String],
) -> eyre::Result<()> {
crate::masking::set_mask_sensitive(self.options.mask_sensitive);
if self.options.capture_rust {
tracing_subscriber::fmt::init();
}
let reporters = std::mem::take(&mut self.reporters);
setup_reporter_barrier(reporters.len())?;
let reporter_handles: Vec<_> = reporters
.into_iter()
.map(|mut reporter| tokio::spawn(async move { reporter.run().await }))
.collect();
wait_reporter_barrier().await;
let project_filter = ProjectFilter { project_names };
let module_filter = ModuleFilter { module_names };
let test_name_filter = TestNameFilter { test_names };
let test_ignore_filter = TestIgnoreFilter::default();
let start = std::time::Instant::now();
let fail_fast = self.options.fail_fast;
let cancelled = Arc::new(AtomicBool::new(false));
let handles: FuturesUnordered<_> = {
let concurrency = self.options.concurrency;
let semaphore = Arc::new(tokio::sync::Semaphore::new(
concurrency.unwrap_or(tokio::sync::Semaphore::MAX_PERMITS),
));
let worker_ids = Arc::new(WorkerIds::new(concurrency));
let serial_groups: Arc<
tokio::sync::RwLock<std::collections::HashMap<String, Arc<tokio::sync::Mutex<()>>>>,
> = Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new()));
let projects = self.cfg.projects.clone();
let projects = if projects.is_empty() {
vec![Arc::new(ProjectConfig {
name: "default".into(),
..Default::default()
})]
} else {
projects
};
let mut all_tests: Vec<_> = self
.test_cases
.iter()
.cartesian_product(projects)
.map(|((info, factory), project)| (project, Arc::clone(info), factory.clone()))
.filter(move |(project, info, _)| test_name_filter.filter(project, info))
.filter(move |(project, info, _)| module_filter.filter(project, info))
.filter(move |(project, info, _)| project_filter.filter(project, info))
.filter(move |(project, info, _)| test_ignore_filter.filter(project, info))
.collect();
let (mut ordered_tests, non_ordered_tests): (Vec<_>, Vec<_>) =
all_tests.drain(..).partition(|(_, info, _)| info.ordered);
ordered_tests
.sort_by_key(|(_project, info, _factory)| (info.serial_group.clone(), info.line));
let mut ordered_groups: std::collections::HashMap<String, Vec<_>> =
std::collections::HashMap::new();
for (project, info, factory) in ordered_tests {
let key = format!(
"{}::{}",
project.name,
info.serial_group.as_deref().unwrap_or("")
);
ordered_groups
.entry(key)
.or_default()
.push((project, info, factory));
}
let ordered_handles = ordered_groups.into_iter().map(|(group_key, tests)| {
let semaphore = semaphore.clone();
let worker_ids = worker_ids.clone();
let serial_groups = serial_groups.clone();
let cancelled = cancelled.clone();
tokio::spawn(async move {
let serial_mutex = {
let mut write_lock = serial_groups.write().await;
write_lock
.entry(group_key.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
};
let mut group_failed = false;
let mut group_error: Option<eyre::Report> = None;
for (project, info, factory) in tests {
if cancelled.load(Ordering::Relaxed) {
break;
}
let _permit = semaphore
.acquire()
.await
.map_err(|e| eyre::eyre!("failed to acquire semaphore: {e}"));
if _permit.is_err() {
continue;
}
let worker_id = worker_ids.acquire();
let result = execute_test(
project,
info,
factory,
Some(serial_mutex.clone()),
worker_id,
)
.await;
worker_ids.release(worker_id);
match result {
Ok(test) => {
if test.result.is_err() {
group_failed = true;
}
}
Err(e) => {
group_failed = true;
if group_error.is_none() {
group_error = Some(e);
}
}
}
}
if group_failed {
if let Some(e) = group_error {
return Err(e);
}
eyre::bail!("one or more tests failed");
}
eyre::Ok(())
})
});
let non_ordered_handles =
non_ordered_tests
.into_iter()
.map(|(project, info, factory)| {
let semaphore = semaphore.clone();
let worker_ids = worker_ids.clone();
let serial_groups = serial_groups.clone();
let cancelled = cancelled.clone();
tokio::spawn(async move {
if cancelled.load(Ordering::Relaxed) {
return Ok(());
}
let serial_mutex = match &info.serial_group {
Some(group_name) => {
let key = format!("{}::{}", project.name, group_name);
let read_lock = serial_groups.read().await;
if let Some(mutex) = read_lock.get(&key) {
Some(Arc::clone(mutex))
} else {
drop(read_lock);
let mut write_lock = serial_groups.write().await;
Some(
write_lock
.entry(key)
.or_insert_with(|| {
Arc::new(tokio::sync::Mutex::new(()))
})
.clone(),
)
}
}
None => None,
};
let _permit = semaphore
.acquire()
.await
.map_err(|e| eyre::eyre!("failed to acquire semaphore: {e}"))?;
let worker_id = worker_ids.acquire();
let result = execute_test(
project,
info,
factory,
serial_mutex.clone(),
worker_id,
)
.await
.and_then(|test| {
let is_err = test.result.is_err();
eyre::ensure!(!is_err);
eyre::Ok(())
});
worker_ids.release(worker_id);
result
})
});
let all_handles = FuturesUnordered::new();
for handle in ordered_handles {
all_handles.push(handle);
}
for handle in non_ordered_handles {
all_handles.push(handle);
}
all_handles
};
let test_prep_time = start.elapsed();
debug!(
"created handles for {} test cases",
test_prep_time.as_secs_f32()
);
let mut has_any_error = false;
let total_tests = handles.len();
let options = self.options.clone();
let runner = async move {
let mut handles = handles;
let mut failed_tests = 0;
let mut processed_tests = 0;
while let Some(result) = handles.next().await {
processed_tests += 1;
match result {
Ok(res) => {
if let Err(e) = res {
debug!("test case failed: {e:#}");
has_any_error = true;
failed_tests += 1;
if fail_fast {
cancelled.store(true, Ordering::Relaxed);
break;
}
}
}
Err(e) => {
if e.is_panic() {
error!("{e}");
has_any_error = true;
failed_tests += 1;
if fail_fast {
cancelled.store(true, Ordering::Relaxed);
break;
}
}
}
}
}
if total_tests == 0 {
console::Term::stdout().write_line("no test cases found")?;
}
let skipped_tests = total_tests - processed_tests;
let passed_tests = total_tests - failed_tests - skipped_tests;
let total_time = start.elapsed();
let summary = TestSummary {
total_tests,
passed_tests,
failed_tests,
skipped_tests,
total_time,
test_prep_time,
};
let summary_event = Event {
project: "".to_string(),
module: "".to_string(),
test: "".to_string(),
body: EventBody::Summary(summary),
};
if let Ok(guard) = CHANNEL.lock() {
if let Some((tx, _)) = guard.as_ref() {
let _ = tx.send(summary_event);
}
}
debug!("all test finished. sending stop signal to the background tasks.");
if options.terminate_channel {
let Ok(mut guard) = CHANNEL.lock() else {
eyre::bail!("failed to acquire runner channel lock");
};
guard.take(); }
if has_any_error {
eyre::bail!("one or more tests failed");
}
eyre::Ok(())
};
let runner_result = runner.await;
for handle in reporter_handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => error!("reporter failed: {e:#}"),
Err(e) => error!("reporter task panicked: {e:#}"),
}
}
clear_reporter_barrier();
debug!("runner stopped");
runner_result
}
pub fn list(&self) -> Vec<&TestInfo> {
self.test_cases
.iter()
.map(|(meta, _test)| meta.as_ref())
.collect::<Vec<_>>()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::config::RetryConfig;
use crate::ProjectConfig;
use std::sync::Arc;
fn create_config() -> Config {
Config {
projects: vec![Arc::new(ProjectConfig {
name: "default".into(),
..Default::default()
})],
..Default::default()
}
}
fn create_config_with_retry() -> Config {
Config {
projects: vec![Arc::new(ProjectConfig {
name: "default".into(),
retry: RetryConfig {
count: Some(1),
..Default::default()
},
..Default::default()
})],
..Default::default()
}
}
#[tokio::test]
async fn runner_fail_because_no_retry_configured() -> eyre::Result<()> {
let mut server = mockito::Server::new_async().await;
let m1 = server
.mock("GET", "/")
.with_status(500)
.expect(1)
.create_async()
.await;
let m2 = server
.mock("GET", "/")
.with_status(200)
.expect(0)
.create_async()
.await;
let factory: TestCaseFactory = Arc::new(move || {
let url = server.url();
Box::pin(async move {
let client = crate::http::Client::new();
let res = client.get(&url).send().await?;
if res.status().is_success() {
Ok(())
} else {
eyre::bail!("request failed")
}
})
});
let _runner_rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.add_test("retry_test", "module", None, 0, false, factory);
let result = runner.run(&[], &[], &[]).await;
m1.assert_async().await;
m2.assert_async().await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn runner_retry_successful_after_failure() -> eyre::Result<()> {
let mut server = mockito::Server::new_async().await;
let m1 = server
.mock("GET", "/")
.with_status(500)
.expect(1)
.create_async()
.await;
let m2 = server
.mock("GET", "/")
.with_status(200)
.expect(1)
.create_async()
.await;
let factory: TestCaseFactory = Arc::new(move || {
let url = server.url();
Box::pin(async move {
let client = crate::http::Client::new();
let res = client.get(&url).send().await?;
if res.status().is_success() {
Ok(())
} else {
eyre::bail!("request failed")
}
})
});
let _runner_rx = subscribe()?;
let mut runner = Runner::with_config(create_config_with_retry());
runner.add_test("retry_test", "module", None, 0, false, factory);
let result = runner.run(&[], &[], &[]).await;
m1.assert_async().await;
m2.assert_async().await;
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn spawned_task_panics_without_task_local_context() {
let project = Arc::new(ProjectConfig {
name: "default".to_string(),
..Default::default()
});
let test_info = Arc::new(TestInfo {
module: "mod".to_string(),
name: "test".to_string(),
serial_group: None,
line: 0,
ordered: false,
});
crate::config::PROJECT
.scope(
project,
TEST_INFO.scope(test_info, async move {
let handle = tokio::spawn(async move {
let _ = crate::config::get_config();
});
let join_err = handle.await.expect_err("spawned task should panic");
assert!(join_err.is_panic());
}),
)
.await;
}
#[tokio::test]
async fn scope_current_propagates_task_local_context_into_spawned_task() {
let project = Arc::new(ProjectConfig {
name: "default".to_string(),
..Default::default()
});
let test_info = Arc::new(TestInfo {
module: "mod".to_string(),
name: "test".to_string(),
serial_group: None,
line: 0,
ordered: false,
});
crate::config::PROJECT
.scope(
project,
TEST_INFO.scope(test_info, async move {
let handle = tokio::spawn(super::scope_current(async move {
let _ = crate::config::get_config();
let _ = super::get_test_info();
}));
handle.await.expect("spawned task should not panic");
}),
)
.await;
}
#[tokio::test]
#[serial_test::serial]
async fn masking_masks_sensitive_query_params_in_http_logs() -> eyre::Result<()> {
use crate::masking;
masking::set_mask_sensitive(true);
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", mockito::Matcher::Any)
.with_status(200)
.create_async()
.await;
let factory: TestCaseFactory = Arc::new(move || {
let url = server.url();
Box::pin(async move {
let client = crate::http::Client::new();
let _res = client
.get(format!("{url}?access_token=secret_token_123&user=john"))
.send()
.await?;
Ok(())
})
});
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.add_test(
"masking_query_test",
"masking_module",
None,
0,
false,
factory,
);
runner.run(&[], &[], &[]).await?;
let mut found_http_event = false;
while let Ok(event) = rx.try_recv() {
if event.test != "masking_query_test" {
continue;
}
if let EventBody::Call(CallLog::Http(log)) = event.body {
found_http_event = true;
let url_str = log.request.url.to_string();
assert!(
url_str.contains("access_token=*****"),
"access_token should be masked, got: {url_str}"
);
assert!(
url_str.contains("user=john"),
"user should not be masked, got: {url_str}"
);
}
}
assert!(found_http_event, "Should have received HTTP event");
Ok(())
}
#[tokio::test]
#[serial_test::serial]
async fn masking_masks_sensitive_headers_in_http_logs() -> eyre::Result<()> {
use crate::masking;
masking::set_mask_sensitive(true);
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/")
.with_status(200)
.create_async()
.await;
let factory: TestCaseFactory = Arc::new(move || {
let url = server.url();
Box::pin(async move {
let client = crate::http::Client::new();
let _res = client
.get(&url)
.header("authorization", "Bearer secret_bearer_token")
.header("x-api-key", "my_secret_api_key")
.header("content-type", "application/json")
.send()
.await?;
Ok(())
})
});
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.add_test(
"masking_headers_test",
"masking_module",
None,
0,
false,
factory,
);
runner.run(&[], &[], &[]).await?;
let mut found_http_event = false;
while let Ok(event) = rx.try_recv() {
if event.test != "masking_headers_test" {
continue;
}
if let EventBody::Call(CallLog::Http(log)) = event.body {
found_http_event = true;
if let Some(auth) = log.request.headers.get("authorization") {
assert_eq!(
auth.to_str().unwrap(),
"*****",
"authorization header should be masked"
);
}
if let Some(api_key) = log.request.headers.get("x-api-key") {
assert_eq!(
api_key.to_str().unwrap(),
"*****",
"x-api-key header should be masked"
);
}
if let Some(content_type) = log.request.headers.get("content-type") {
assert_eq!(
content_type.to_str().unwrap(),
"application/json",
"content-type header should not be masked"
);
}
}
}
assert!(found_http_event, "Should have received HTTP event");
Ok(())
}
#[tokio::test]
#[serial_test::serial]
async fn masking_show_sensitive_disables_masking_in_http_logs() -> eyre::Result<()> {
use crate::masking;
masking::set_mask_sensitive(true);
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/")
.with_status(200)
.create_async()
.await;
let factory: TestCaseFactory = Arc::new(move || {
let url = server.url();
Box::pin(async move {
let client = crate::http::Client::new();
let _res = client
.get(format!("{url}?access_token=secret_token_123"))
.header("authorization", "Bearer secret_bearer_token")
.send()
.await?;
Ok(())
})
});
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.capture_http();
runner.show_sensitive();
runner.add_test(
"show_sensitive_test",
"masking_module",
None,
0,
false,
factory,
);
runner.run(&[], &[], &[]).await?;
let mut found_http_event = false;
while let Ok(event) = rx.try_recv() {
if event.test != "show_sensitive_test" {
continue;
}
if let EventBody::Call(CallLog::Http(log)) = event.body {
found_http_event = true;
let url_str = log.request.url.to_string();
assert!(
url_str.contains("access_token=secret_token_123"),
"access_token should not be masked when show_sensitive is enabled"
);
if let Some(auth) = log.request.headers.get("authorization") {
assert_eq!(
auth.to_str().unwrap(),
"Bearer secret_bearer_token",
"authorization header should not be masked when show_sensitive is enabled"
);
}
}
}
assert!(found_http_event, "Should have received HTTP event");
Ok(())
}
fn passing_factory() -> TestCaseFactory {
Arc::new(|| Box::pin(async { Ok(()) }))
}
fn failing_factory() -> TestCaseFactory {
Arc::new(|| Box::pin(async { eyre::bail!("intentional failure") }))
}
#[tokio::test]
#[serial_test::serial]
async fn runner_fail_fast_skips_remaining_tests() -> eyre::Result<()> {
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.set_concurrency(1);
runner.set_fail_fast(true);
runner.add_test("ff_fail", "module", None, 0, false, failing_factory());
runner.add_test("ff_pass1", "module", None, 1, false, passing_factory());
runner.add_test("ff_pass2", "module", None, 2, false, passing_factory());
let result = runner.run(&[], &[], &[]).await;
assert!(result.is_err());
let mut summary = None;
while let Ok(event) = rx.try_recv() {
if let EventBody::Summary(s) = event.body {
summary = Some(s);
}
}
let summary = summary.expect("should have received Summary event");
assert!(
summary.failed_tests >= 1,
"should have at least one failure"
);
assert!(
summary.skipped_tests >= 1,
"fail-fast should have skipped remaining tests"
);
Ok(())
}
#[tokio::test]
#[serial_test::serial]
async fn runner_without_fail_fast_runs_all_tests() -> eyre::Result<()> {
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.set_concurrency(1);
runner.add_test("noff_fail", "module", None, 0, false, failing_factory());
runner.add_test("noff_pass1", "module", None, 1, false, passing_factory());
runner.add_test("noff_pass2", "module", None, 2, false, passing_factory());
let result = runner.run(&[], &[], &[]).await;
assert!(result.is_err());
let mut summary = None;
while let Ok(event) = rx.try_recv() {
if let EventBody::Summary(s) = event.body {
summary = Some(s);
}
}
let summary = summary.expect("should have received Summary event");
assert_eq!(summary.failed_tests, 1, "should have exactly one failure");
assert_eq!(summary.passed_tests, 2, "should have two passed tests");
assert_eq!(summary.skipped_tests, 0, "should have no skipped tests");
Ok(())
}
#[tokio::test]
#[serial_test::serial]
async fn capture_http_events_published_for_all_tests_regardless_of_mode() -> eyre::Result<()> {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", mockito::Matcher::Any)
.with_status(200)
.create_async()
.await;
let make_http_factory = |url: String| -> TestCaseFactory {
Arc::new(move || {
let url = url.clone();
Box::pin(async move {
let client = crate::http::Client::new();
client.get(&url).send().await?;
Ok(())
})
})
};
let failing_http_factory = |url: String| -> TestCaseFactory {
Arc::new(move || {
let url = url.clone();
Box::pin(async move {
let client = crate::http::Client::new();
client.get(&url).send().await?;
eyre::bail!("intentional failure after http call");
})
})
};
let url = server.url();
let mut rx = subscribe()?;
let mut runner = Runner::with_config(create_config());
runner.set_capture_http_mode(CaptureHttpMode::OnFailure);
runner.add_test(
"ch_pass",
"ch_module",
None,
0,
false,
make_http_factory(url.clone()),
);
runner.add_test(
"ch_fail",
"ch_module",
None,
1,
false,
failing_http_factory(url.clone()),
);
let _ = runner.run(&[], &[], &[]).await;
let mut pass_has_call = false;
let mut fail_has_call = false;
while let Ok(event) = rx.try_recv() {
if let EventBody::Call(CallLog::Http(_)) = &event.body {
match event.test.as_str() {
"ch_pass" => pass_has_call = true,
"ch_fail" => fail_has_call = true,
_ => {}
}
}
}
assert!(
pass_has_call,
"passing test should still publish HTTP Call event"
);
assert!(
fail_has_call,
"failing test should still publish HTTP Call event"
);
Ok(())
}
#[test]
fn set_capture_http_mode_stores_mode() {
let mut runner = Runner::new();
assert_eq!(runner.options.capture_http, CaptureHttpMode::Off);
runner.capture_http();
assert_eq!(runner.options.capture_http, CaptureHttpMode::All);
runner.set_capture_http_mode(CaptureHttpMode::OnFailure);
assert_eq!(runner.options.capture_http, CaptureHttpMode::OnFailure);
runner.set_capture_http_mode(CaptureHttpMode::Off);
assert_eq!(runner.options.capture_http, CaptureHttpMode::Off);
}
}