#![deny(missing_docs, missing_debug_implementations)]
use std::collections::HashMap;
use std::env::VarError;
use std::fmt::Debug;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, TryLockError};
use std::time::{Duration, Instant};
use std::{env, thread};
#[derive(Clone)]
struct SyncCallback(Arc<dyn Fn() + Send + Sync>);
impl Debug for SyncCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("SyncCallback()")
}
}
impl PartialEq for SyncCallback {
#[allow(clippy::vtable_address_comparisons)]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl SyncCallback {
fn new(f: impl Fn() + Send + Sync + 'static) -> SyncCallback {
SyncCallback(Arc::new(f))
}
fn run(&self) {
let callback = &self.0;
callback();
}
}
#[derive(Clone, Debug, PartialEq)]
enum Task {
Off,
Return(Option<String>),
Sleep(u64),
Panic(Option<String>),
Print(Option<String>),
Pause,
Yield,
Delay(u64),
Callback(SyncCallback),
}
#[derive(Debug)]
struct Action {
task: Task,
freq: f32,
count: Option<AtomicUsize>,
}
impl PartialEq for Action {
fn eq(&self, hs: &Action) -> bool {
if self.task != hs.task || self.freq != hs.freq {
return false;
}
if let Some(ref lhs) = self.count {
if let Some(ref rhs) = hs.count {
return lhs.load(Ordering::Relaxed) == rhs.load(Ordering::Relaxed);
}
} else if hs.count.is_none() {
return true;
}
false
}
}
impl Action {
fn new(task: Task, freq: f32, max_cnt: Option<usize>) -> Action {
Action {
task,
freq,
count: max_cnt.map(AtomicUsize::new),
}
}
fn from_callback(f: impl Fn() + Send + Sync + 'static) -> Action {
let task = Task::Callback(SyncCallback::new(f));
Action {
task,
freq: 1.0,
count: None,
}
}
fn get_task(&self) -> Option<Task> {
use rand::Rng;
if let Some(ref cnt) = self.count {
let c = cnt.load(Ordering::Acquire);
if c == 0 {
return None;
}
}
if self.freq < 1f32 && !rand::thread_rng().gen_bool(f64::from(self.freq)) {
return None;
}
if let Some(ref ref_cnt) = self.count {
let mut cnt = ref_cnt.load(Ordering::Acquire);
loop {
if cnt == 0 {
return None;
}
let new_cnt = cnt - 1;
match ref_cnt.compare_exchange_weak(
cnt,
new_cnt,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(c) => cnt = c,
}
}
}
Some(self.task.clone())
}
}
fn partition(s: &str, pattern: char) -> (&str, Option<&str>) {
let mut splits = s.splitn(2, pattern);
(splits.next().unwrap(), splits.next())
}
impl FromStr for Action {
type Err = String;
fn from_str(s: &str) -> Result<Action, String> {
let mut remain = s.trim();
let mut args = None;
let (first, second) = partition(remain, '(');
if let Some(second) = second {
remain = first;
if !second.ends_with(')') {
return Err("parentheses do not match".to_owned());
}
args = Some(&second[..second.len() - 1]);
}
let mut frequency = 1f32;
let (first, second) = partition(remain, '%');
if let Some(second) = second {
remain = second;
match first.parse::<f32>() {
Err(e) => return Err(format!("failed to parse frequency: {}", e)),
Ok(freq) => frequency = freq / 100.0,
}
}
let mut max_cnt = None;
let (first, second) = partition(remain, '*');
if let Some(second) = second {
remain = second;
match first.parse() {
Err(e) => return Err(format!("failed to parse count: {}", e)),
Ok(cnt) => max_cnt = Some(cnt),
}
}
let parse_timeout = || match args {
None => Err("sleep require timeout".to_owned()),
Some(timeout_str) => match timeout_str.parse() {
Err(e) => Err(format!("failed to parse timeout: {}", e)),
Ok(timeout) => Ok(timeout),
},
};
let task = match remain {
"off" => Task::Off,
"return" => Task::Return(args.map(str::to_owned)),
"sleep" => Task::Sleep(parse_timeout()?),
"panic" => Task::Panic(args.map(str::to_owned)),
"print" => Task::Print(args.map(str::to_owned)),
"pause" => Task::Pause,
"yield" => Task::Yield,
"delay" => Task::Delay(parse_timeout()?),
_ => return Err(format!("unrecognized command {:?}", remain)),
};
Ok(Action::new(task, frequency, max_cnt))
}
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::mutex_atomic))]
#[derive(Debug)]
struct FailPoint {
pause: Mutex<bool>,
pause_notifier: Condvar,
actions: RwLock<Vec<Action>>,
actions_str: RwLock<String>,
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::mutex_atomic))]
impl FailPoint {
fn new() -> FailPoint {
FailPoint {
pause: Mutex::new(false),
pause_notifier: Condvar::new(),
actions: RwLock::default(),
actions_str: RwLock::default(),
}
}
fn set_actions(&self, actions_str: &str, actions: Vec<Action>) {
loop {
match self.actions.try_write() {
Err(TryLockError::WouldBlock) => {}
Ok(mut guard) => {
*guard = actions;
*self.actions_str.write().unwrap() = actions_str.to_string();
return;
}
Err(e) => panic!("unexpected poison: {:?}", e),
}
let mut guard = self.pause.lock().unwrap();
*guard = false;
self.pause_notifier.notify_all();
}
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::option_option))]
fn eval(&self, name: &str) -> Option<Option<String>> {
let task = {
let actions = self.actions.read().unwrap();
match actions.iter().filter_map(Action::get_task).next() {
Some(Task::Pause) => {
let mut guard = self.pause.lock().unwrap();
*guard = true;
loop {
guard = self.pause_notifier.wait(guard).unwrap();
if !*guard {
break;
}
}
return None;
}
Some(t) => t,
None => return None,
}
};
match task {
Task::Off => {}
Task::Return(s) => return Some(s),
Task::Sleep(t) => thread::sleep(Duration::from_millis(t)),
Task::Panic(msg) => match msg {
Some(ref msg) => panic!("{}", msg),
None => panic!("failpoint {} panic", name),
},
Task::Print(msg) => match msg {
Some(ref msg) => log::info!("{}", msg),
None => log::info!("failpoint {} executed.", name),
},
Task::Pause => unreachable!(),
Task::Yield => thread::yield_now(),
Task::Delay(t) => {
let timer = Instant::now();
let timeout = Duration::from_millis(t);
while timer.elapsed() < timeout {}
}
Task::Callback(f) => {
f.run();
}
}
None
}
}
type Registry = HashMap<String, Arc<FailPoint>>;
#[derive(Debug, Default)]
struct FailPointRegistry {
registry: RwLock<Registry>,
}
use once_cell::sync::Lazy;
static REGISTRY: Lazy<FailPointRegistry> = Lazy::new(FailPointRegistry::default);
static SCENARIO: Lazy<Mutex<&'static FailPointRegistry>> = Lazy::new(|| Mutex::new(®ISTRY));
#[derive(Debug)]
pub struct FailScenario<'a> {
scenario_guard: MutexGuard<'a, &'static FailPointRegistry>,
}
impl<'a> FailScenario<'a> {
pub fn setup() -> Self {
let scenario_guard = SCENARIO.lock().unwrap_or_else(|e| e.into_inner());
let mut registry = scenario_guard.registry.write().unwrap();
Self::cleanup(&mut registry);
let failpoints = match env::var("FAILPOINTS") {
Ok(s) => s,
Err(VarError::NotPresent) => return Self { scenario_guard },
Err(e) => panic!("invalid failpoints: {:?}", e),
};
for mut cfg in failpoints.trim().split(';') {
cfg = cfg.trim();
if cfg.is_empty() {
continue;
}
let (name, order) = partition(cfg, '=');
match order {
None => panic!("invalid failpoint: {:?}", cfg),
Some(order) => {
if let Err(e) = set(&mut registry, name.to_owned(), order) {
panic!("unable to configure failpoint \"{}\": {}", name, e);
}
}
}
}
Self { scenario_guard }
}
pub fn teardown(self) {
drop(self)
}
fn cleanup(registry: &mut std::sync::RwLockWriteGuard<'a, Registry>) {
for p in registry.values() {
p.set_actions("", vec![]);
}
registry.clear();
}
}
impl<'a> Drop for FailScenario<'a> {
fn drop(&mut self) {
let mut registry = self.scenario_guard.registry.write().unwrap();
Self::cleanup(&mut registry)
}
}
pub const fn has_failpoints() -> bool {
cfg!(feature = "failpoints")
}
pub fn list() -> Vec<(String, String)> {
let registry = REGISTRY.registry.read().unwrap();
registry
.iter()
.map(|(name, fp)| (name.to_string(), fp.actions_str.read().unwrap().clone()))
.collect()
}
#[doc(hidden)]
pub fn eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
let p = {
let registry = REGISTRY.registry.read().unwrap();
match registry.get(name) {
None => return None,
Some(p) => p.clone(),
}
};
p.eval(name).map(f)
}
pub fn cfg<S: Into<String>>(name: S, actions: &str) -> Result<(), String> {
let mut registry = REGISTRY.registry.write().unwrap();
set(&mut registry, name.into(), actions)
}
pub fn cfg_callback<S, F>(name: S, f: F) -> Result<(), String>
where
S: Into<String>,
F: Fn() + Send + Sync + 'static,
{
let mut registry = REGISTRY.registry.write().unwrap();
let p = registry
.entry(name.into())
.or_insert_with(|| Arc::new(FailPoint::new()));
let action = Action::from_callback(f);
let actions = vec![action];
p.set_actions("callback", actions);
Ok(())
}
pub fn remove<S: AsRef<str>>(name: S) {
let mut registry = REGISTRY.registry.write().unwrap();
if let Some(p) = registry.remove(name.as_ref()) {
p.set_actions("", vec![]);
}
}
#[derive(Debug)]
pub struct FailGuard(String);
impl Drop for FailGuard {
fn drop(&mut self) {
remove(&self.0);
}
}
impl FailGuard {
pub fn new<S: Into<String>>(name: S, actions: &str) -> Result<FailGuard, String> {
let name = name.into();
cfg(&name, actions)?;
Ok(FailGuard(name))
}
pub fn with_callback<S, F>(name: S, f: F) -> Result<FailGuard, String>
where
S: Into<String>,
F: Fn() + Send + Sync + 'static,
{
let name = name.into();
cfg_callback(&name, f)?;
Ok(FailGuard(name))
}
}
fn set(
registry: &mut HashMap<String, Arc<FailPoint>>,
name: String,
actions: &str,
) -> Result<(), String> {
let actions_str = actions;
let actions = actions
.split("->")
.map(Action::from_str)
.collect::<Result<_, _>>()?;
let p = registry
.entry(name)
.or_insert_with(|| Arc::new(FailPoint::new()));
p.set_actions(actions_str, actions);
Ok(())
}
#[macro_export]
#[cfg(feature = "failpoints")]
macro_rules! fail_point {
($name:expr) => {{
$crate::eval($name, |_| {
panic!("Return is not supported for the fail point \"{}\"", $name);
});
}};
($name:expr, $e:expr) => {{
if let Some(res) = $crate::eval($name, $e) {
return res;
}
}};
($name:expr, $cond:expr, $e:expr) => {{
if $cond {
$crate::fail_point!($name, $e);
}
}};
}
#[macro_export]
#[cfg(not(feature = "failpoints"))]
macro_rules! fail_point {
($name:expr, $e:expr) => {{}};
($name:expr) => {{}};
($name:expr, $cond:expr, $e:expr) => {{}};
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::*;
#[test]
fn test_has_failpoints() {
assert_eq!(cfg!(feature = "failpoints"), has_failpoints());
}
#[test]
fn test_off() {
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Off, 1.0, None)]);
assert!(point.eval("test_fail_point_off").is_none());
}
#[test]
fn test_return() {
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Return(None), 1.0, None)]);
let res = point.eval("test_fail_point_return");
assert_eq!(res, Some(None));
let ret = Some("test".to_owned());
point.set_actions("", vec![Action::new(Task::Return(ret.clone()), 1.0, None)]);
let res = point.eval("test_fail_point_return");
assert_eq!(res, Some(ret));
}
#[test]
fn test_sleep() {
let point = FailPoint::new();
let timer = Instant::now();
point.set_actions("", vec![Action::new(Task::Sleep(1000), 1.0, None)]);
assert!(point.eval("test_fail_point_sleep").is_none());
assert!(timer.elapsed() > Duration::from_millis(1000));
}
#[should_panic]
#[test]
fn test_panic() {
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Panic(None), 1.0, None)]);
point.eval("test_fail_point_panic");
}
#[test]
fn test_print() {
struct LogCollector(Arc<Mutex<Vec<String>>>);
impl log::Log for LogCollector {
fn enabled(&self, _: &log::Metadata) -> bool {
true
}
fn log(&self, record: &log::Record) {
let mut buf = self.0.lock().unwrap();
buf.push(format!("{}", record.args()));
}
fn flush(&self) {}
}
let buffer = Arc::new(Mutex::new(vec![]));
let collector = LogCollector(buffer.clone());
log::set_max_level(log::LevelFilter::Info);
log::set_boxed_logger(Box::new(collector)).unwrap();
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Print(None), 1.0, None)]);
assert!(point.eval("test_fail_point_print").is_none());
let msg = buffer.lock().unwrap().pop().unwrap();
assert_eq!(msg, "failpoint test_fail_point_print executed.");
}
#[test]
fn test_pause() {
let point = Arc::new(FailPoint::new());
point.set_actions("", vec![Action::new(Task::Pause, 1.0, None)]);
let p = point.clone();
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
assert_eq!(p.eval("test_fail_point_pause"), None);
tx.send(()).unwrap();
});
assert!(rx.recv_timeout(Duration::from_secs(1)).is_err());
point.set_actions("", vec![Action::new(Task::Off, 1.0, None)]);
rx.recv_timeout(Duration::from_secs(1)).unwrap();
}
#[test]
fn test_yield() {
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Yield, 1.0, None)]);
assert!(point.eval("test_fail_point_yield").is_none());
}
#[test]
fn test_delay() {
let point = FailPoint::new();
let timer = Instant::now();
point.set_actions("", vec![Action::new(Task::Delay(1000), 1.0, None)]);
assert!(point.eval("test_fail_point_delay").is_none());
assert!(timer.elapsed() > Duration::from_millis(1000));
}
#[test]
fn test_frequency_and_count() {
let point = FailPoint::new();
point.set_actions("", vec![Action::new(Task::Return(None), 0.8, Some(100))]);
let mut count = 0;
let mut times = 0f64;
while count < 100 {
if point.eval("test_fail_point_frequency").is_some() {
count += 1;
}
times += 1f64;
}
assert!(100.0 / 0.9 < times && times < 100.0 / 0.7, "{}", times);
for _ in 0..times as u64 {
assert!(point.eval("test_fail_point_frequency").is_none());
}
}
#[test]
fn test_parse() {
let cases = vec![
("return", Action::new(Task::Return(None), 1.0, None)),
(
"return(64)",
Action::new(Task::Return(Some("64".to_owned())), 1.0, None),
),
("5*return", Action::new(Task::Return(None), 1.0, Some(5))),
("25%return", Action::new(Task::Return(None), 0.25, None)),
(
"125%2*return",
Action::new(Task::Return(None), 1.25, Some(2)),
),
(
"return(2%5)",
Action::new(Task::Return(Some("2%5".to_owned())), 1.0, None),
),
("125%2*off", Action::new(Task::Off, 1.25, Some(2))),
(
"125%2*sleep(100)",
Action::new(Task::Sleep(100), 1.25, Some(2)),
),
(" 125%2*off ", Action::new(Task::Off, 1.25, Some(2))),
("125%2*panic", Action::new(Task::Panic(None), 1.25, Some(2))),
(
"125%2*panic(msg)",
Action::new(Task::Panic(Some("msg".to_owned())), 1.25, Some(2)),
),
("125%2*print", Action::new(Task::Print(None), 1.25, Some(2))),
(
"125%2*print(msg)",
Action::new(Task::Print(Some("msg".to_owned())), 1.25, Some(2)),
),
("125%2*pause", Action::new(Task::Pause, 1.25, Some(2))),
("125%2*yield", Action::new(Task::Yield, 1.25, Some(2))),
("125%2*delay(2)", Action::new(Task::Delay(2), 1.25, Some(2))),
];
for (expr, exp) in cases {
let res: Action = expr.parse().unwrap();
assert_eq!(res, exp);
}
let fail_cases = vec![
"delay",
"sleep",
"Return",
"ab%return",
"ab*return",
"return(msg",
"unknown",
];
for case in fail_cases {
assert!(case.parse::<Action>().is_err());
}
}
#[test]
#[cfg_attr(not(feature = "failpoints"), ignore)]
fn test_setup_and_teardown() {
let f1 = || {
fail_point!("setup_and_teardown1", |_| 1);
0
};
let f2 = || {
fail_point!("setup_and_teardown2", |_| 2);
0
};
env::set_var(
"FAILPOINTS",
"setup_and_teardown1=return;setup_and_teardown2=pause;",
);
let scenario = FailScenario::setup();
assert_eq!(f1(), 1);
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
tx.send(f2()).unwrap();
});
assert!(rx.recv_timeout(Duration::from_millis(500)).is_err());
scenario.teardown();
assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0);
assert_eq!(f1(), 0);
}
}