use std::{fmt::Debug, sync::atomic::AtomicBool};
use crate::dataset::transaction::{Operation, Transaction};
use crate::dataset::{write_manifest_file, ManifestWriteConfig};
use crate::Dataset;
use crate::Result;
use crate::{format::pb, format::Index, format::Manifest};
use futures::future::BoxFuture;
use object_store::path::Path;
use object_store::Error as ObjectStoreError;
use prost::Message;
use super::ObjectStore;
pub type ManifestWriter = for<'a> fn(
object_store: &'a ObjectStore,
manifest: &'a mut Manifest,
indices: Option<Vec<Index>>,
path: &'a Path,
) -> BoxFuture<'a, Result<()>>;
#[async_trait::async_trait]
pub trait CommitHandler: Debug + Send + Sync {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> std::result::Result<(), CommitError>;
}
#[derive(Debug)]
pub enum CommitError {
CommitConflict,
OtherError(crate::Error),
}
impl From<crate::Error> for CommitError {
fn from(e: crate::Error) -> Self {
Self::OtherError(e)
}
}
impl From<CommitError> for crate::Error {
fn from(e: CommitError) -> Self {
match e {
CommitError::CommitConflict => Self::Internal {
message: "Commit conflict".to_string(),
},
CommitError::OtherError(e) => e,
}
}
}
static WARNED_ON_UNSAFE_COMMIT: AtomicBool = AtomicBool::new(false);
pub struct UnsafeCommitHandler;
#[async_trait::async_trait]
impl CommitHandler for UnsafeCommitHandler {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> std::result::Result<(), CommitError> {
if !WARNED_ON_UNSAFE_COMMIT.load(std::sync::atomic::Ordering::Relaxed) {
WARNED_ON_UNSAFE_COMMIT.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!(
"Using unsafe commit handler. Concurrent writes may result in data loss. \
Consider providing a commit handler that prevents conflicting writes."
);
}
manifest_writer(object_store, manifest, indices, path).await?;
Ok(())
}
}
impl Debug for UnsafeCommitHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnsafeCommitHandler").finish()
}
}
pub struct RenameCommitHandler;
#[async_trait::async_trait]
impl CommitHandler for RenameCommitHandler {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> std::result::Result<(), CommitError> {
let mut parts: Vec<_> = path.parts().collect();
let uuid = uuid::Uuid::new_v4();
let new_name = format!(
".tmp_{}_{}",
parts.last().unwrap().as_ref(),
uuid.as_hyphenated()
);
let _ = std::mem::replace(parts.last_mut().unwrap(), new_name.into());
let tmp_path: Path = parts.into_iter().collect();
manifest_writer(object_store, manifest, indices, &tmp_path).await?;
match object_store
.inner
.rename_if_not_exists(&tmp_path, path)
.await
{
Ok(_) => Ok(()),
Err(ObjectStoreError::AlreadyExists { .. }) => {
let _ = object_store.inner.delete(&tmp_path).await;
return Err(CommitError::CommitConflict);
}
Err(e) => {
return Err(CommitError::OtherError(e.into()));
}
}
}
}
impl Debug for RenameCommitHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RenameCommitHandler").finish()
}
}
#[async_trait::async_trait]
pub trait CommitLock {
type Lease: CommitLease;
async fn lock(&self, version: u64) -> std::result::Result<Self::Lease, CommitError>;
}
#[async_trait::async_trait]
pub trait CommitLease: Send + Sync {
async fn release(&self, success: bool) -> std::result::Result<(), CommitError>;
}
#[async_trait::async_trait]
impl<T: CommitLock + Send + Sync + Debug> CommitHandler for T {
async fn commit(
&self,
manifest: &mut Manifest,
indices: Option<Vec<Index>>,
path: &Path,
object_store: &ObjectStore,
manifest_writer: ManifestWriter,
) -> std::result::Result<(), CommitError> {
let lease = self.lock(manifest.version).await?;
match object_store.inner.head(path).await {
Ok(_) => {
lease.release(false).await?;
return Err(CommitError::CommitConflict);
}
Err(ObjectStoreError::NotFound { .. }) => {}
Err(e) => {
lease.release(false).await?;
return Err(CommitError::OtherError(e.into()));
}
}
let res = manifest_writer(object_store, manifest, indices, path).await;
lease.release(res.is_ok()).await?;
res.map_err(|err| err.into())
}
}
#[derive(Debug, Clone)]
pub struct CommitConfig {
pub num_retries: u32,
}
impl Default for CommitConfig {
fn default() -> Self {
Self { num_retries: 5 }
}
}
async fn read_transaction_file(
object_store: &ObjectStore,
base_path: &Path,
transaction_file: &str,
) -> Result<Transaction> {
let path = base_path.child("_transactions").child(transaction_file);
let result = object_store.inner.get(&path).await?;
let data = result.bytes().await?;
let transaction = pb::Transaction::decode(data)?;
(&transaction).try_into()
}
async fn write_transaction_file(
object_store: &ObjectStore,
base_path: &Path,
transaction: &Transaction,
) -> Result<String> {
let file_name = format!("{}-{}.txn", transaction.read_version, transaction.uuid);
let path = base_path.child("_transactions").child(file_name.as_str());
let message = pb::Transaction::from(transaction);
let buf = message.encode_to_vec();
object_store.inner.put(&path, buf.into()).await?;
Ok(file_name)
}
fn check_transaction(
transaction: &Transaction,
other_version: u64,
other_transaction: &Option<Transaction>,
) -> Result<()> {
if other_transaction.is_none() {
return Err(crate::Error::Internal {
message: format!(
"There was a conflicting transaction at version {}, \
and it was missing transaction metadata.",
other_version
),
});
}
if transaction.conflicts_with(other_transaction.as_ref().unwrap()) {
return Err(crate::Error::CommitConflict {
version: other_version,
source: format!(
"There was a concurrent commit that conflicts with this one and it \
cannot be automatically resolved. Please rerun the operation off the latest version \
of the table.\n Transaction: {:?}\n Conflicting Transaction: {:?}",
transaction, other_transaction
)
.into(),
});
}
Ok(())
}
pub(crate) async fn commit_new_dataset(
object_store: &ObjectStore,
base_path: &Path,
transaction: &Transaction,
write_config: &ManifestWriteConfig,
) -> Result<Manifest> {
let transaction_file = write_transaction_file(object_store, base_path, transaction).await?;
let (mut manifest, indices) =
transaction.build_manifest(None, vec![], &transaction_file, write_config)?;
write_manifest_file(
object_store,
base_path,
&mut manifest,
if indices.is_empty() {
None
} else {
Some(indices.clone())
},
write_config,
)
.await?;
Ok(manifest)
}
pub(crate) async fn commit_transaction(
dataset: &Dataset,
object_store: &ObjectStore,
transaction: &Transaction,
write_config: &ManifestWriteConfig,
commit_config: &CommitConfig,
) -> Result<Manifest> {
let transaction_file = write_transaction_file(object_store, &dataset.base, transaction).await?;
let mut dataset = dataset.clone();
let mut other_transactions = Vec::new();
let mut version = transaction.read_version;
loop {
version += 1;
match dataset.checkout_version(version).await {
Ok(next_dataset) => {
let other_txn = if let Some(txn_file) = &next_dataset.manifest.transaction_file {
Some(read_transaction_file(object_store, &next_dataset.base, txn_file).await?)
} else {
None
};
other_transactions.push(other_txn);
dataset = next_dataset;
}
Err(crate::Error::NotFound { .. }) | Err(crate::Error::DatasetNotFound { .. }) => {
break;
}
Err(e) => {
return Err(e);
}
}
}
let mut target_version = version;
for (version_offset, other_transaction) in other_transactions.iter().enumerate() {
let other_version = transaction.read_version + version_offset as u64 + 1;
check_transaction(transaction, other_version, other_transaction)?;
}
for _ in 0..commit_config.num_retries {
let (mut manifest, indices) = match transaction.operation {
Operation::Restore { version } => {
Transaction::restore_old_manifest(
object_store,
&dataset.base,
version,
write_config,
&transaction_file,
)
.await?
}
_ => transaction.build_manifest(
Some(dataset.manifest.as_ref()),
dataset.load_indices().await?,
&transaction_file,
write_config,
)?,
};
manifest.version = target_version;
let result = write_manifest_file(
object_store,
&dataset.base,
&mut manifest,
if indices.is_empty() {
None
} else {
Some(indices.clone())
},
write_config,
)
.await;
match result {
Ok(()) => {
return Ok(manifest);
}
Err(CommitError::CommitConflict) => {
dataset = dataset.checkout_version(target_version).await?;
let other_transaction =
if let Some(txn_file) = dataset.manifest.transaction_file.as_ref() {
Some(read_transaction_file(object_store, &dataset.base, txn_file).await?)
} else {
None
};
check_transaction(transaction, target_version, &other_transaction)?;
target_version += 1;
}
Err(CommitError::OtherError(err)) => {
return Err(err);
}
}
}
Err(crate::Error::CommitConflict {
version: target_version,
source: format!(
"Failed to commit the transaction after {} retries.",
commit_config.num_retries
)
.into(),
})
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use arrow_array::{Int32Array, Int64Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::future::join_all;
use super::*;
use crate::arrow::FixedSizeListArrayExt;
use crate::dataset::transaction::Operation;
use crate::dataset::{WriteMode, WriteParams};
use crate::index::vector::{MetricType, VectorIndexParams};
use crate::index::{DatasetIndexExt, IndexType};
use crate::io::object_store::ObjectStoreParams;
use crate::utils::testing::generate_random_array;
use crate::Dataset;
async fn test_commit_handler(handler: Arc<dyn CommitHandler>, should_succeed: bool) {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"x",
DataType::Int64,
false,
)]));
let data = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(data)], schema);
let options = WriteParams {
store_params: Some(ObjectStoreParams {
commit_handler: Some(handler),
..Default::default()
}),
..Default::default()
};
let dataset = Dataset::write(reader, "memory://test", Some(options))
.await
.unwrap();
let tasks = (0..10).map(|_| {
let mut dataset = dataset.clone();
tokio::task::spawn(async move {
dataset
.delete("x = 2")
.await
.map(|_| dataset.manifest.version)
})
});
let task_results: Vec<Option<u64>> = join_all(tasks)
.await
.iter()
.map(|res| match res {
Ok(Ok(version)) => Some(*version),
_ => None,
})
.collect();
let num_successes = task_results.iter().filter(|x| x.is_some()).count();
let distinct_results: HashSet<_> = task_results.iter().filter_map(|x| x.as_ref()).collect();
if should_succeed {
assert_eq!(
num_successes,
distinct_results.len(),
"Expected no two tasks to succeed for the same version. Got {:?}",
task_results
);
} else {
assert!(num_successes >= distinct_results.len(),);
}
}
#[tokio::test]
async fn test_rename_commit_handler() {
let handler = Arc::new(RenameCommitHandler);
test_commit_handler(handler, true).await;
}
#[tokio::test]
async fn test_custom_commit() {
#[derive(Debug)]
struct CustomCommitHandler {
locked_version: Arc<Mutex<Option<u64>>>,
}
struct CustomCommitLease {
version: u64,
locked_version: Arc<Mutex<Option<u64>>>,
}
#[async_trait::async_trait]
impl CommitLock for CustomCommitHandler {
type Lease = CustomCommitLease;
async fn lock(&self, version: u64) -> std::result::Result<Self::Lease, CommitError> {
let mut locked_version = self.locked_version.lock().unwrap();
if locked_version.is_some() {
return Err(CommitError::CommitConflict);
}
*locked_version = Some(version);
Ok(CustomCommitLease {
version,
locked_version: self.locked_version.clone(),
})
}
}
#[async_trait::async_trait]
impl CommitLease for CustomCommitLease {
async fn release(&self, _success: bool) -> std::result::Result<(), CommitError> {
let mut locked_version = self.locked_version.lock().unwrap();
if *locked_version != Some(self.version) {
return Err(CommitError::CommitConflict);
}
*locked_version = None;
Ok(())
}
}
let locked_version = Arc::new(Mutex::new(None));
let handler = Arc::new(CustomCommitHandler { locked_version });
test_commit_handler(handler, true).await;
}
#[tokio::test]
async fn test_unsafe_commit_handler() {
let handler = Arc::new(UnsafeCommitHandler);
test_commit_handler(handler, false).await;
}
#[tokio::test]
async fn test_roundtrip_transaction_file() {
let object_store = ObjectStore::memory();
let base_path = Path::from("test");
let transaction = Transaction::new(
42,
Operation::Append { fragments: vec![] },
Some("hello world".to_string()),
);
let file_name = write_transaction_file(&object_store, &base_path, &transaction)
.await
.unwrap();
let read_transaction = read_transaction_file(&object_store, &base_path, &file_name)
.await
.unwrap();
assert_eq!(transaction.read_version, read_transaction.read_version);
assert_eq!(transaction.uuid, read_transaction.uuid);
assert!(matches!(
read_transaction.operation,
Operation::Append { .. }
));
assert_eq!(transaction.tag, read_transaction.tag);
}
#[tokio::test]
async fn test_concurrent_create_index() {
let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let dimension = 16;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(
"vector1",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
),
Field::new(
"vector2",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
),
]));
let float_arr = generate_random_array(512 * dimension as usize);
let vectors = Arc::new(
<arrow_array::FixedSizeListArray as FixedSizeListArrayExt>::try_new_from_values(
float_arr, dimension,
)
.unwrap(),
);
let batches =
vec![
RecordBatch::try_new(schema.clone(), vec![vectors.clone(), vectors.clone()])
.unwrap(),
];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let dataset = Dataset::write(reader, test_uri, None).await.unwrap();
dataset.validate().await.unwrap();
let params = VectorIndexParams::ivf_pq(10, 8, 2, false, MetricType::L2, 50);
let futures: Vec<_> = ["vector1", "vector1", "vector2"]
.iter()
.map(|col_name| {
let dataset = dataset.clone();
let params = params.clone();
tokio::spawn(async move {
dataset
.create_index(&[col_name], IndexType::Vector, None, ¶ms, true)
.await
})
})
.collect();
let results = join_all(futures).await;
for result in results {
assert!(matches!(result, Ok(Ok(_))), "{:?}", result);
}
let dataset = dataset.checkout_version(1).await.unwrap();
assert!(dataset.load_indices().await.unwrap().is_empty());
let dataset = dataset.checkout_version(2).await.unwrap();
assert_eq!(dataset.load_indices().await.unwrap().len(), 1);
let dataset = dataset.checkout_version(3).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert!(!indices.is_empty() && indices.len() <= 2);
if indices.len() == 2 {
let mut fields: Vec<i32> = indices.iter().flat_map(|i| i.fields.clone()).collect();
fields.sort();
assert_eq!(fields, vec![0, 1]);
} else {
assert_eq!(indices[0].fields, vec![0]);
}
let dataset = dataset.checkout_version(4).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 2);
let mut fields: Vec<i32> = indices.iter().flat_map(|i| i.fields.clone()).collect();
fields.sort();
assert_eq!(fields, vec![0, 1]);
}
#[tokio::test]
async fn test_concurrent_writes() {
for write_mode in [WriteMode::Append, WriteMode::Overwrite] {
let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"i",
DataType::Int32,
false,
)]));
let dataset = Dataset::write(
RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()),
test_uri,
None,
)
.await
.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let futures: Vec<_> = (0..5)
.map(|_| {
let batch = batch.clone();
let schema = schema.clone();
let uri = test_uri.to_string();
tokio::spawn(async move {
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
Dataset::write(
reader,
&uri,
Some(WriteParams {
mode: write_mode,
..Default::default()
}),
)
.await
})
})
.collect();
let results = join_all(futures).await;
for result in results {
assert!(matches!(result, Ok(Ok(_))), "{:?}", result);
}
let dataset = dataset.checkout_version(6).await.unwrap();
match write_mode {
WriteMode::Append => {
assert_eq!(dataset.get_fragments().len(), 5);
}
WriteMode::Overwrite => {
assert_eq!(dataset.get_fragments().len(), 1);
}
_ => unreachable!(),
}
dataset.validate().await.unwrap()
}
}
}