use async_trait::async_trait;
use bytes::Bytes;
use futures::{StreamExt, stream::BoxStream};
use object_store::{path::Path, *};
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU64, Ordering},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FaultOp {
Put,
Get,
Delete,
List,
Copy,
Rename,
}
impl FaultOp {
fn is_mutation(self) -> bool {
matches!(
self,
FaultOp::Put | FaultOp::Delete | FaultOp::Copy | FaultOp::Rename
)
}
}
#[derive(Debug, Clone)]
pub enum FaultKind {
Error,
Crash,
TornWrite {
keep_bytes: usize,
},
}
#[derive(Debug, Clone)]
pub struct FaultRule {
pub op: FaultOp,
pub path_contains: Option<String>,
pub skip: u64,
pub times: u64,
pub kind: FaultKind,
}
impl FaultRule {
pub fn fail_once(op: FaultOp, path: impl Into<String>) -> Self {
Self {
op,
path_contains: Some(path.into()),
skip: 0,
times: 1,
kind: FaultKind::Error,
}
}
}
#[derive(Debug)]
struct RuleState {
rule: FaultRule,
matched: u64,
fired: u64,
}
#[derive(Debug, Default)]
struct FaultState {
powered_off: AtomicBool,
mutations: AtomicU64,
crash_at: AtomicU64,
rules: Mutex<Vec<RuleState>>,
log: Mutex<Vec<(FaultOp, String)>>,
}
impl FaultState {
fn injected(&self, op: FaultOp, path: &Path, reason: &str) -> Error {
Error::Generic {
store: "FaultStore",
source: format!("injected fault: {reason} ({op:?} {path})").into(),
}
}
fn intercept(&self, op: FaultOp, path: &Path) -> Result<Option<FaultKind>> {
if self.powered_off.load(Ordering::Acquire) {
return Err(self.injected(op, path, "power failure"));
}
if op.is_mutation() {
let n = self.mutations.fetch_add(1, Ordering::AcqRel);
if n >= self.crash_at.load(Ordering::Acquire) {
self.powered_off.store(true, Ordering::Release);
return Err(self.injected(op, path, "power failure"));
}
}
let mut rules = self.rules.lock().expect("FaultStore rules lock poisoned");
for rs in rules.iter_mut() {
if rs.rule.op != op {
continue;
}
if let Some(substr) = &rs.rule.path_contains
&& !path.as_ref().contains(substr.as_str())
{
continue;
}
rs.matched += 1;
if rs.matched > rs.rule.skip && rs.fired < rs.rule.times {
rs.fired += 1;
match rs.rule.kind.clone() {
FaultKind::Error => return Err(self.injected(op, path, "error")),
FaultKind::Crash => {
self.powered_off.store(true, Ordering::Release);
return Err(self.injected(op, path, "power failure"));
}
kind @ FaultKind::TornWrite { .. } => {
if op == FaultOp::Put {
return Ok(Some(kind));
}
return Err(self.injected(op, path, "error"));
}
}
}
}
drop(rules);
if op.is_mutation() {
self.log
.lock()
.expect("FaultStore log lock poisoned")
.push((op, path.to_string()));
}
Ok(None)
}
}
#[derive(Clone, Debug)]
pub struct FaultHandle {
state: Arc<FaultState>,
}
impl FaultHandle {
pub fn push_rule(&self, rule: FaultRule) {
self.state
.rules
.lock()
.expect("FaultStore rules lock poisoned")
.push(RuleState {
rule,
matched: 0,
fired: 0,
});
}
pub fn crash_after_mutations(&self, n: u64) {
let base = self.state.mutations.load(Ordering::Acquire);
self.state
.crash_at
.store(base.saturating_add(n), Ordering::Release);
}
pub fn mutation_count(&self) -> u64 {
self.state.mutations.load(Ordering::Acquire)
}
pub fn mutation_log(&self) -> Vec<(FaultOp, String)> {
self.state
.log
.lock()
.expect("FaultStore log lock poisoned")
.clone()
}
pub fn reset(&self) {
self.state.powered_off.store(false, Ordering::Release);
self.state.crash_at.store(u64::MAX, Ordering::Release);
self.state.mutations.store(0, Ordering::Release);
self.state
.rules
.lock()
.expect("FaultStore rules lock poisoned")
.clear();
self.state
.log
.lock()
.expect("FaultStore log lock poisoned")
.clear();
}
}
#[derive(Debug)]
pub struct FaultStore<T: ObjectStore> {
inner: Arc<T>,
state: Arc<FaultState>,
}
impl<T: ObjectStore> FaultStore<T> {
pub fn wrap(inner: T) -> (Self, FaultHandle) {
let state = Arc::new(FaultState {
crash_at: AtomicU64::new(u64::MAX),
..Default::default()
});
let handle = FaultHandle {
state: state.clone(),
};
(
Self {
inner: Arc::new(inner),
state,
},
handle,
)
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T: ObjectStore> std::fmt::Display for FaultStore<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FaultStore({})", self.inner)
}
}
#[async_trait]
impl<T: ObjectStore> ObjectStore for FaultStore<T> {
async fn put_opts(
&self,
location: &Path,
payload: PutPayload,
opts: PutOptions,
) -> Result<PutResult> {
match self.state.intercept(FaultOp::Put, location)? {
None => self.inner.put_opts(location, payload, opts).await,
Some(FaultKind::TornWrite { keep_bytes }) => {
let mut buf = Vec::with_capacity(keep_bytes.min(payload.content_length()));
'fill: for segment in payload.iter() {
for byte in segment {
if buf.len() >= keep_bytes {
break 'fill;
}
buf.push(*byte);
}
}
let _ = self
.inner
.put_opts(location, Bytes::from(buf).into(), opts)
.await;
Err(self.state.injected(FaultOp::Put, location, "torn write"))
}
Some(_) => unreachable!("intercept only returns TornWrite"),
}
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> Result<Box<dyn MultipartUpload>> {
self.state.intercept(FaultOp::Put, location)?;
let inner = self.inner.put_multipart_opts(location, opts).await?;
Ok(Box::new(FaultUploader {
location: location.clone(),
state: self.state.clone(),
inner,
}))
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
self.state.intercept(FaultOp::Get, location)?;
self.inner.get_opts(location, options).await
}
async fn get_ranges(
&self,
location: &Path,
ranges: &[std::ops::Range<u64>],
) -> Result<Vec<Bytes>> {
self.state.intercept(FaultOp::Get, location)?;
self.inner.get_ranges(location, ranges).await
}
fn delete_stream(
&self,
locations: BoxStream<'static, Result<Path>>,
) -> BoxStream<'static, Result<Path>> {
let state = self.state.clone();
let checked = locations
.map(move |location| {
let location = location?;
state.intercept(FaultOp::Delete, &location)?;
Ok(location)
})
.boxed();
self.inner.delete_stream(checked)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result<ObjectMeta>> {
if let Err(err) = self
.state
.intercept(FaultOp::List, &prefix.cloned().unwrap_or_default())
{
return futures::stream::once(async move { Err(err) }).boxed();
}
self.inner.list(prefix)
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'static, Result<ObjectMeta>> {
if let Err(err) = self
.state
.intercept(FaultOp::List, &prefix.cloned().unwrap_or_default())
{
return futures::stream::once(async move { Err(err) }).boxed();
}
self.inner.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
self.state
.intercept(FaultOp::List, &prefix.cloned().unwrap_or_default())?;
self.inner.list_with_delimiter(prefix).await
}
async fn copy_opts(&self, from: &Path, to: &Path, options: CopyOptions) -> Result<()> {
self.state.intercept(FaultOp::Copy, from)?;
self.inner.copy_opts(from, to, options).await
}
async fn rename_opts(&self, from: &Path, to: &Path, options: RenameOptions) -> Result<()> {
self.state.intercept(FaultOp::Rename, from)?;
self.inner.rename_opts(from, to, options).await
}
}
#[derive(Debug)]
struct FaultUploader {
location: Path,
state: Arc<FaultState>,
inner: Box<dyn MultipartUpload>,
}
#[async_trait]
impl MultipartUpload for FaultUploader {
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
if self.state.powered_off.load(Ordering::Acquire) {
let err = self
.state
.injected(FaultOp::Put, &self.location, "power failure");
return Box::pin(async move { Err(err) });
}
self.inner.put_part(payload)
}
async fn complete(&mut self) -> Result<PutResult> {
if self.state.powered_off.load(Ordering::Acquire) {
return Err(self
.state
.injected(FaultOp::Put, &self.location, "power failure"));
}
self.inner.complete().await
}
async fn abort(&mut self) -> Result<()> {
self.inner.abort().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::memory::InMemory;
fn payload(data: &'static [u8]) -> PutPayload {
Bytes::from_static(data).into()
}
#[tokio::test]
async fn forwards_when_no_faults() {
let (store, handle) = FaultStore::wrap(InMemory::new());
let path = Path::from("a/b");
store.put(&path, payload(b"hello")).await.unwrap();
let got = store.get(&path).await.unwrap().bytes().await.unwrap();
assert_eq!(got, Bytes::from_static(b"hello"));
assert_eq!(handle.mutation_count(), 1);
assert_eq!(
handle.mutation_log(),
vec![(FaultOp::Put, "a/b".to_string())]
);
}
#[tokio::test]
async fn crash_after_mutations_powers_off_everything() {
let (store, handle) = FaultStore::wrap(InMemory::new());
handle.crash_after_mutations(2);
store.put(&Path::from("1"), payload(b"x")).await.unwrap();
store.put(&Path::from("2"), payload(b"x")).await.unwrap();
assert!(store.put(&Path::from("3"), payload(b"x")).await.is_err());
assert!(store.get(&Path::from("1")).await.is_err());
assert!(store.delete(&Path::from("1")).await.is_err());
handle.reset();
let got = store
.get(&Path::from("1"))
.await
.unwrap()
.bytes()
.await
.unwrap();
assert_eq!(got, Bytes::from_static(b"x"));
assert!(matches!(
store.get(&Path::from("3")).await,
Err(Error::NotFound { .. })
));
}
#[tokio::test]
async fn crash_after_is_relative_to_current_count() {
let (store, handle) = FaultStore::wrap(InMemory::new());
store.put(&Path::from("1"), payload(b"x")).await.unwrap();
handle.crash_after_mutations(1);
store.put(&Path::from("2"), payload(b"x")).await.unwrap();
assert!(store.put(&Path::from("3"), payload(b"x")).await.is_err());
}
#[tokio::test]
async fn targeted_rule_fails_nth_matching_put() {
let (store, handle) = FaultStore::wrap(InMemory::new());
handle.push_rule(FaultRule {
op: FaultOp::Put,
path_contains: Some("meta".to_string()),
skip: 1,
times: 1,
kind: FaultKind::Error,
});
store
.put(&Path::from("x/meta"), payload(b"a"))
.await
.unwrap();
store
.put(&Path::from("x/data"), payload(b"b"))
.await
.unwrap();
assert!(
store
.put(&Path::from("y/meta"), payload(b"c"))
.await
.is_err()
);
store
.put(&Path::from("y/meta"), payload(b"d"))
.await
.unwrap();
}
#[tokio::test]
async fn torn_write_persists_prefix_and_reports_failure() {
let (store, handle) = FaultStore::wrap(InMemory::new());
handle.push_rule(FaultRule {
op: FaultOp::Put,
path_contains: Some("torn".to_string()),
skip: 0,
times: 1,
kind: FaultKind::TornWrite { keep_bytes: 3 },
});
let path = Path::from("torn");
assert!(store.put(&path, payload(b"hello world")).await.is_err());
let got = store.get(&path).await.unwrap().bytes().await.unwrap();
assert_eq!(got, Bytes::from_static(b"hel"));
}
#[tokio::test]
async fn delete_stream_and_list_respect_power_failure() {
let (store, handle) = FaultStore::wrap(InMemory::new());
store.put(&Path::from("a"), payload(b"1")).await.unwrap();
handle.crash_after_mutations(0);
assert!(store.delete(&Path::from("a")).await.is_err());
let listed: Vec<_> = store.list(None).collect().await;
assert!(listed.iter().any(|r| r.is_err()));
assert!(store.list_with_delimiter(None).await.is_err());
handle.reset();
assert!(store.get(&Path::from("a")).await.is_ok());
}
#[tokio::test]
async fn crash_rule_kind_powers_off() {
let (store, handle) = FaultStore::wrap(InMemory::new());
handle.push_rule(FaultRule {
op: FaultOp::Put,
path_contains: Some("ids".to_string()),
skip: 0,
times: 1,
kind: FaultKind::Crash,
});
store.put(&Path::from("meta"), payload(b"m")).await.unwrap();
assert!(
store
.put(&Path::from("col/ids"), payload(b"i"))
.await
.is_err()
);
assert!(
store
.put(&Path::from("other"), payload(b"o"))
.await
.is_err()
);
assert!(store.get(&Path::from("meta")).await.is_err());
}
}