use std::{fmt::Debug, sync::atomic::AtomicBool};
use crate::{format::Index, format::Manifest};
use futures::future::BoxFuture;
use object_store::path::Path;
use object_store::Error as ObjectStoreError;
use super::ObjectStore;
pub type ManifestWriter = for<'a> fn(
object_store: &'a ObjectStore,
manifest: &'a mut Manifest,
indices: Option<Vec<Index>>,
path: &'a Path,
) -> BoxFuture<'a, crate::Result<()>>;
#[async_trait::async_trait]
pub trait CommitHandler: Debug + Send + Sync {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> Result<(), CommitError>;
}
pub enum CommitError {
CommitConflict,
OtherError(crate::Error),
}
impl From<crate::Error> for CommitError {
fn from(e: crate::Error) -> Self {
Self::OtherError(e)
}
}
impl From<CommitError> for crate::Error {
fn from(e: CommitError) -> Self {
match e {
CommitError::CommitConflict => Self::Internal {
message: "Commit conflict".to_string(),
},
CommitError::OtherError(e) => e,
}
}
}
static WARNED_ON_UNSAFE_COMMIT: AtomicBool = AtomicBool::new(false);
pub struct UnsafeCommitHandler;
#[async_trait::async_trait]
impl CommitHandler for UnsafeCommitHandler {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> Result<(), CommitError> {
if !WARNED_ON_UNSAFE_COMMIT.load(std::sync::atomic::Ordering::Relaxed) {
WARNED_ON_UNSAFE_COMMIT.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!(
"Using unsafe commit handler. Concurrent writes may result in data loss. \
Consider providing a commit handler that prevents conflicting writes."
);
}
manifest_writer(object_store, manifest, indices, path).await?;
Ok(())
}
}
impl Debug for UnsafeCommitHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnsafeCommitHandler").finish()
}
}
pub struct RenameCommitHandler;
#[async_trait::async_trait]
impl CommitHandler for RenameCommitHandler {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> Result<(), CommitError> {
let mut parts: Vec<_> = path.parts().collect();
let uuid = uuid::Uuid::new_v4();
let new_name = format!(
".tmp_{}_{}",
parts.last().unwrap().as_ref(),
uuid.as_hyphenated()
);
let _ = std::mem::replace(parts.last_mut().unwrap(), new_name.into());
let tmp_path: Path = parts.into_iter().collect();
manifest_writer(object_store, manifest, indices, &tmp_path).await?;
match object_store
.inner
.rename_if_not_exists(&tmp_path, path)
.await
{
Ok(_) => Ok(()),
Err(ObjectStoreError::AlreadyExists { .. }) => {
let _ = object_store.inner.delete(&tmp_path).await;
return Err(CommitError::CommitConflict);
}
Err(e) => {
return Err(CommitError::OtherError(e.into()));
}
}
}
}
impl Debug for RenameCommitHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RenameCommitHandler").finish()
}
}
#[async_trait::async_trait]
pub trait CommitLock {
type Lease: CommitLease;
async fn lock(&self, version: u64) -> Result<Self::Lease, CommitError>;
}
#[async_trait::async_trait]
pub trait CommitLease: Send + Sync {
async fn release(&self, success: bool) -> Result<(), CommitError>;
}
#[async_trait::async_trait]
impl<T: CommitLock + Send + Sync + Debug> CommitHandler for T {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> Result<(), CommitError> {
let lease = self.lock(manifest.version).await?;
match object_store.inner.head(path).await {
Ok(_) => {
lease.release(false).await?;
return Err(CommitError::CommitConflict);
}
Err(ObjectStoreError::NotFound { .. }) => {}
Err(e) => {
lease.release(false).await?;
return Err(CommitError::OtherError(e.into()));
}
}
let res = manifest_writer(object_store, manifest, indices, path).await;
lease.release(res.is_ok()).await?;
res.map_err(|err| err.into())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use arrow_array::{Int64Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::future::join_all;
use super::*;
use crate::dataset::WriteParams;
use crate::io::object_store::ObjectStoreParams;
use crate::Dataset;
async fn test_commit_handler(handler: Arc<dyn CommitHandler>, should_succeed: bool) {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"x",
DataType::Int64,
false,
)]));
let data = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(data)], schema);
let options = WriteParams {
store_params: Some(ObjectStoreParams {
commit_handler: Some(handler),
..Default::default()
}),
..Default::default()
};
let dataset = Dataset::write(reader, "memory://test", Some(options))
.await
.unwrap();
let tasks = (0..10).map(|_| {
let mut dataset = dataset.clone();
tokio::task::spawn(async move {
dataset
.delete("x = 2")
.await
.map(|_| dataset.manifest.version)
})
});
let task_results: Vec<Option<u64>> = join_all(tasks)
.await
.iter()
.map(|res| match res {
Ok(Ok(version)) => Some(*version),
_ => None,
})
.collect();
let num_successes = task_results.iter().filter(|x| x.is_some()).count();
let distinct_results: HashSet<_> = task_results.iter().filter_map(|x| x.as_ref()).collect();
if should_succeed {
assert_eq!(
num_successes,
distinct_results.len(),
"Expected no two tasks to succeed for the same version. Got {:?}",
task_results
);
} else {
assert!(
num_successes > distinct_results.len(),
"Expected some conflicts. Got {:?}",
task_results
);
}
}
#[tokio::test]
async fn test_rename_commit_handler() {
let handler = Arc::new(RenameCommitHandler);
test_commit_handler(handler, true).await;
}
#[tokio::test]
async fn test_custom_commit() {
#[derive(Debug)]
struct CustomCommitHandler {
locked_version: Arc<Mutex<Option<u64>>>,
}
struct CustomCommitLease {
version: u64,
locked_version: Arc<Mutex<Option<u64>>>,
}
#[async_trait::async_trait]
impl CommitLock for CustomCommitHandler {
type Lease = CustomCommitLease;
async fn lock(&self, version: u64) -> Result<Self::Lease, CommitError> {
let mut locked_version = self.locked_version.lock().unwrap();
if locked_version.is_some() {
return Err(CommitError::CommitConflict);
}
*locked_version = Some(version);
Ok(CustomCommitLease {
version,
locked_version: self.locked_version.clone(),
})
}
}
#[async_trait::async_trait]
impl CommitLease for CustomCommitLease {
async fn release(&self, _success: bool) -> Result<(), CommitError> {
let mut locked_version = self.locked_version.lock().unwrap();
if *locked_version != Some(self.version) {
return Err(CommitError::CommitConflict);
}
*locked_version = None;
Ok(())
}
}
let locked_version = Arc::new(Mutex::new(None));
let handler = Arc::new(CustomCommitHandler { locked_version });
test_commit_handler(handler, true).await;
}
#[tokio::test]
async fn test_unsafe_commit_handler() {
let handler = Arc::new(UnsafeCommitHandler);
test_commit_handler(handler, false).await;
}
}