#![cfg(feature = "shuttle")]
use bytes::Bytes;
use chrono::{DateTime, Utc};
use futures::TryStreamExt;
use futures::future::try_join_all;
use icechunk::format::manifest::ChunkPayload;
use icechunk::format::repo_info::UpdateType;
use icechunk::format::snapshot::ArrayShape;
use icechunk::format::{ChunkIndices, Path, SnapshotId};
use icechunk::repository::VersionInfo;
use icechunk::{Repository, new_in_memory_storage};
use proptest::collection::vec;
use proptest::prelude::*;
use shuttle::future::{block_on, spawn};
use shuttle::{Config, Runner, scheduler};
use std::collections::HashSet;
use std::error::Error;
use std::sync::Arc;
fn assert_ops_log_invariants(log: &[(DateTime<Utc>, UpdateType, Option<String>)]) {
log.windows(2).for_each(|window| {
let (time_a, _, _) = &window[0];
let (time_b, _, _) = &window[1];
assert!(
time_a > time_b,
"ops log timestamps must be strictly decreasing: {time_a} should be > {time_b}"
);
});
let backup_paths: HashSet<_> =
log[..log.len() - 1].iter().map(|(_, _, path)| path.as_ref()).collect();
assert_eq!(backup_paths.len(), log.len() - 1);
if let Some((_, update_type, _backup_path)) = log.last() {
assert!(
matches!(update_type, UpdateType::RepoInitializedUpdate),
"last ops log entry (earliest) should be RepoInitializedUpdate, got: {update_type:?}"
);
} else {
unreachable!();
}
}
fn check_pct(f: impl Fn() + Send + Sync + 'static, iterations: usize, depth: usize) {
let mut config = Config::default();
config.stack_size = 0x80_0000; let scheduler = scheduler::PctScheduler::new(depth, iterations);
Runner::new(scheduler, config).run(f);
}
async fn mk_commit(
repo: Arc<Repository>,
path: Path,
branch: &str,
c: u32,
) -> Result<SnapshotId, Box<dyn Error + Send + Sync>> {
let mut session = repo.writable_session(branch).await?;
session
.set_chunk_ref(
path,
ChunkIndices(vec![c]),
Some(ChunkPayload::Inline("foo".into())),
)
.await?;
Ok(session
.commit("write chunk")
.rebase(&icechunk::conflicts::detector::ConflictDetector, 3)
.execute()
.await?)
}
async fn mk_concurrent_commits_same_branch() -> Result<(), Box<dyn Error + Send + Sync>> {
let storage = new_in_memory_storage().await?;
let repo = Arc::new(
Repository::create(None, storage, Default::default(), None, false).await?,
);
let mut session = repo.writable_session("main").await?;
let shape = ArrayShape::new(vec![(10, 10)]).unwrap();
let path: Path = "/array".try_into()?;
session.add_array(path.clone(), shape, None, Bytes::new()).await?;
let mut snaps = vec![];
snaps.push(session.commit("init array").execute().await?);
repo.create_branch("feature", &snaps[0]).await?;
let repo1 = repo.clone();
let path1 = path.clone();
let handle1 = spawn(async move { mk_commit(repo1, path1, "main", 2).await });
let repo2 = repo.clone();
let path2 = path.clone();
let handle2 = spawn(async move { mk_commit(repo2, path2, "feature", 1).await });
for r in try_join_all([handle1, handle2]).await? {
r?;
}
let (stream, _, _) = repo.ops_log().await?;
let log: Vec<_> = stream.try_collect().await?;
assert_ops_log_invariants(&log);
Ok(())
}
#[test]
fn concurrent_commits_same_branch() {
check_pct(
|| {
block_on(mk_concurrent_commits_same_branch()).unwrap();
},
100,
3,
);
}
#[derive(Clone, Debug, PartialEq)]
enum Action {
Commit,
AddBranch,
DeleteBranch,
AddTag,
DeleteTag,
Amend,
ResetBranch,
}
#[derive(Debug, Clone)]
enum ActionResult {
Commit(String, SnapshotId),
AddBranch(String, SnapshotId),
DeleteBranch { branch: String, previous_snap: SnapshotId },
AddTag(String, SnapshotId),
DeleteTag { tag: String, previous_snap: SnapshotId },
Amend { branch: String, new_snap: SnapshotId, previous_snap: SnapshotId },
ResetBranch { branch: String, to_snap: SnapshotId, previous_snap: SnapshotId },
}
fn actions(
range: impl Into<proptest::collection::SizeRange>,
) -> impl Strategy<Value = Vec<Action>> {
use Action::*;
vec(
proptest::sample::select(vec![
Commit,
AddBranch,
DeleteBranch,
AddTag,
DeleteTag,
Amend,
ResetBranch,
]),
range,
)
}
async fn setup_branches(
repo: Arc<Repository>,
actions: &[Action],
branches: &[String],
) -> Result<(), Box<dyn Error + Send + Sync>> {
let path: Path = "/array".try_into()?;
for (action, branch) in actions.iter().zip(branches.iter()) {
if *action != Action::AddBranch {
repo.create_branch(branch, &repo.lookup_branch("main").await?).await?;
}
match action {
Action::AddBranch => {}
Action::Amend => {
mk_commit(repo.clone(), path.clone(), branch, 3).await?;
}
Action::ResetBranch => {
mk_commit(repo.clone(), path.clone(), branch, 2).await?;
mk_commit(repo.clone(), path.clone(), branch, 3).await?;
}
_ => {
repo.create_tag(
&format!("tag-to-delete-{branch}"),
&repo.lookup_branch("main").await?,
)
.await?;
}
}
}
Ok(())
}
async fn execute_action(
repo: Arc<Repository>,
action: Action,
branch: String,
) -> Result<ActionResult, Box<dyn Error + Send + Sync>> {
use Action::*;
let res = match action {
Commit => {
let snap = mk_commit(repo, "/array".try_into()?, &branch, 0).await?;
ActionResult::Commit(branch, snap)
}
AddBranch => {
let snap = repo.lookup_branch("main").await?;
repo.create_branch(&branch, &snap).await?;
ActionResult::AddBranch(branch, snap)
}
DeleteBranch => {
let previous_snap = repo.lookup_branch(&branch).await?;
repo.delete_branch(&branch).await?;
ActionResult::DeleteBranch { branch, previous_snap }
}
AddTag => {
let snap = repo.lookup_branch(&branch).await?;
let tag = format!("tag-to-create-{branch}");
repo.create_tag(&tag, &snap).await?;
ActionResult::AddTag(tag, snap)
}
DeleteTag => {
let tag = format!("tag-to-delete-{branch}");
let previous_snap = repo.lookup_tag(&tag).await?;
repo.delete_tag(&tag).await?;
ActionResult::DeleteTag { tag, previous_snap }
}
Amend => {
let previous_snap = repo.lookup_branch(&branch).await?;
let mut session = repo.writable_session(&branch).await?;
session
.set_chunk_ref(
"/array".try_into()?,
ChunkIndices(vec![4]),
Some(ChunkPayload::Inline("amend".into())),
)
.await?;
let new_snap = session.commit("amend commit").amend().execute().await?;
ActionResult::Amend { branch, new_snap, previous_snap }
}
ResetBranch => {
let previous_snap = repo.lookup_branch(&branch).await?;
let to_snap = repo.lookup_branch("main").await?;
repo.reset_branch(&branch, &to_snap, None).await?;
ActionResult::ResetBranch { branch, to_snap, previous_snap }
}
};
Ok(res)
}
async fn assert_action_postcondition(
repo: Arc<Repository>,
action: ActionResult,
) -> Result<(), Box<dyn Error + Send + Sync>> {
use ActionResult::*;
match action {
Commit(branch, snap) => {
let anc: HashSet<SnapshotId> = repo
.ancestry(&VersionInfo::BranchTipRef(branch.clone()))
.await?
.map_ok(|info| info.id)
.try_collect()
.await?;
assert!(anc.contains(&snap));
}
AddBranch(branch, snap) => {
assert!(repo.list_branches().await?.contains(&branch));
assert_eq!(repo.lookup_branch(&branch).await?, snap);
}
DeleteBranch { branch, .. } => {
assert!(!repo.list_branches().await?.contains(&branch));
}
AddTag(tag, snap) => {
assert!(repo.list_tags().await?.contains(&tag));
assert_eq!(repo.lookup_tag(&tag).await?, snap);
}
DeleteTag { tag, .. } => {
assert!(!repo.list_tags().await?.contains(&tag));
}
Amend { branch, new_snap, previous_snap } => {
let tip = repo.lookup_branch(&branch).await?;
assert_eq!(tip, new_snap, "amend snapshot should be branch tip for {branch}");
let anc: HashSet<SnapshotId> = repo
.ancestry(&VersionInfo::BranchTipRef(branch.clone()))
.await?
.map_ok(|info| info.id)
.try_collect()
.await?;
assert!(anc.contains(&new_snap));
assert!(!anc.contains(&previous_snap));
}
ResetBranch { branch, to_snap, .. } => {
let tip = repo.lookup_branch(&branch).await?;
assert_eq!(tip, to_snap, "branch {branch} should be reset to {to_snap:?}");
}
};
Ok(())
}
fn assert_ops_log_contains(
entries: &[(DateTime<Utc>, UpdateType, Option<String>)],
result: &ActionResult,
) {
use ActionResult::*;
let found = entries.iter().any(|(_, update_type, _)| match (result, update_type) {
(
Commit(branch, snap),
UpdateType::NewCommitUpdate { branch: b, new_snap_id },
) => b == branch && new_snap_id == snap,
(AddBranch(branch, _), UpdateType::BranchCreatedUpdate { name }) => {
name == branch
}
(
DeleteBranch { branch, previous_snap },
UpdateType::BranchDeletedUpdate { name, previous_snap_id },
) => name == branch && previous_snap_id == previous_snap,
(AddTag(tag, _), UpdateType::TagCreatedUpdate { name }) => name == tag,
(
DeleteTag { tag, previous_snap },
UpdateType::TagDeletedUpdate { name, previous_snap_id },
) => name == tag && previous_snap_id == previous_snap,
(
Amend { branch, new_snap, previous_snap },
UpdateType::CommitAmendedUpdate { branch: b, new_snap_id, previous_snap_id },
) => b == branch && new_snap_id == new_snap && previous_snap_id == previous_snap,
(
ResetBranch { branch, previous_snap, .. },
UpdateType::BranchResetUpdate { name, previous_snap_id },
) => name == branch && previous_snap_id == previous_snap,
_ => false,
});
assert!(found, "no ops log entry found for action result {result:?}");
}
async fn execute_concurrent_actions(
actions: Vec<Action>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let branches: Vec<String> =
(0..actions.len()).map(|i| format!("branch-{i}")).collect();
let storage = new_in_memory_storage().await?;
let repo = Arc::new(
Repository::create(None, storage, Default::default(), None, false).await?,
);
let mut session = repo.writable_session("main").await?;
let shape = ArrayShape::new(vec![(10, 10)]).unwrap();
let path: Path = "/array".try_into()?;
session.add_array(path.clone(), shape, None, Bytes::new()).await?;
session.commit("foo").execute().await?;
setup_branches(repo.clone(), &actions, &branches).await?;
let (stream, _, _) = repo.ops_log().await?;
let ops_count_before = stream.try_collect::<Vec<_>>().await?.len();
let handles = actions
.iter()
.zip(branches.iter())
.map(|(action, branch)| {
spawn({
let repo = repo.clone();
let branch = branch.clone();
let action = action.clone();
async move { execute_action(repo, action, branch).await }
})
})
.collect::<Vec<_>>();
let results: Vec<_> =
try_join_all(handles).await?.into_iter().collect::<Result<Vec<_>, _>>()?;
let (stream, _, _) = repo.ops_log().await?;
let log: Vec<_> = stream.try_collect().await?;
let new_entries: Vec<_> = log[..log.len() - ops_count_before].to_vec();
assert_eq!(new_entries.len(), actions.len());
assert_ops_log_invariants(&log);
for r in &results {
assert_ops_log_contains(&new_entries, r);
}
for r in results {
assert_action_postcondition(repo.clone(), r).await?;
}
Ok(())
}
proptest! {
#[test]
fn concurrent_actions(acts in actions(3..=5)) {
let acts = acts.clone();
check_pct(move || {
let acts = acts.clone();
block_on(execute_concurrent_actions(acts)).unwrap();
}, 100, 3);
}
}