use indicatif::{ProgressBar, ProgressStyle};
use crate::core::v_latest::fetch;
use crate::core::v_latest::index::restore::{self, FileToRestore};
use crate::error::OxenError;
use crate::model::merkle_tree::node::{EMerkleTreeNode, MerkleTreeNode};
use crate::model::{Commit, CommitEntry, LocalRepository, MerkleHash, PartialNode};
use crate::repositories;
use crate::util;
use filetime::FileTime;
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::time::Duration;
struct CheckoutProgressBar {
revision: String,
progress: ProgressBar,
num_restored: usize,
num_modified: usize,
num_removed: usize,
}
impl CheckoutProgressBar {
pub fn new(revision: String) -> Self {
let progress = ProgressBar::new_spinner();
progress.set_style(ProgressStyle::default_spinner());
progress.enable_steady_tick(Duration::from_millis(100));
Self {
revision,
progress,
num_restored: 0,
num_modified: 0,
num_removed: 0,
}
}
pub fn increment_restored(&mut self) {
self.num_restored += 1;
self.update_message();
}
pub fn increment_modified(&mut self) {
self.num_modified += 1;
self.update_message();
}
pub fn increment_removed(&mut self) {
self.num_removed += 1;
self.update_message();
}
fn update_message(&mut self) {
self.progress.set_message(format!(
"🐂 checkout '{}' restored {}, modified {}, removed {}",
self.revision, self.num_restored, self.num_modified, self.num_removed
));
}
}
struct CheckoutResult {
pub files_to_restore: Vec<FileToRestore>,
pub cannot_overwrite_entries: Vec<PathBuf>,
}
impl CheckoutResult {
pub fn new() -> Self {
CheckoutResult {
files_to_restore: vec![],
cannot_overwrite_entries: vec![],
}
}
}
struct CheckoutHashes {
pub seen_paths: HashSet<PathBuf>,
pub common_nodes: HashSet<MerkleHash>,
}
impl CheckoutHashes {
pub fn from_hashes(common_nodes: HashSet<MerkleHash>) -> Self {
CheckoutHashes {
seen_paths: HashSet::new(),
common_nodes,
}
}
}
pub fn list_entry_versions_for_commit(
repo: &LocalRepository,
commit_id: &str,
path: &Path,
) -> Result<Vec<(Commit, CommitEntry)>, OxenError> {
log::debug!("list_entry_versions_for_commit {commit_id} for file: {path:?}");
let mut branch_commits = repositories::commits::list_from(repo, commit_id)?;
branch_commits.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
let mut result: Vec<(Commit, CommitEntry)> = Vec::new();
let mut seen_hashes: HashSet<String> = HashSet::new();
for commit in branch_commits {
log::debug!("list_entry_versions_for_commit {commit}");
let node = repositories::tree::get_node_by_path(repo, &commit, path)?;
if let Some(node) = node {
if !seen_hashes.contains(&node.node.hash().to_string()) {
log::debug!("list_entry_versions_for_commit adding {commit} -> {node}");
seen_hashes.insert(node.node.hash().to_string());
match node.node {
EMerkleTreeNode::File(file_node) => {
let entry = CommitEntry::from_file_node(&file_node);
result.push((commit, entry));
}
EMerkleTreeNode::Directory(dir_node) => {
let entry = CommitEntry::from_dir_node(&dir_node);
result.push((commit, entry));
}
_ => {}
}
} else {
log::debug!("list_entry_versions_for_commit already seen {node}");
}
}
}
result.reverse();
Ok(result)
}
pub async fn checkout(
repo: &LocalRepository,
branch_name: &str,
from_commit: &Option<Commit>,
) -> Result<(), OxenError> {
log::debug!("checkout {branch_name}");
let branch = repositories::branches::get_by_name(repo, branch_name)?
.ok_or(OxenError::local_branch_not_found(branch_name))?;
let commit = repositories::commits::get_by_id(repo, &branch.commit_id)?
.ok_or(OxenError::commit_id_does_not_exist(&branch.commit_id))?;
checkout_commit(repo, &commit, from_commit).await?;
Ok(())
}
pub async fn checkout_subtrees(
repo: &LocalRepository,
to_commit: &Commit,
subtree_paths: &[PathBuf],
depth: i32,
) -> Result<(), OxenError> {
for subtree_path in subtree_paths {
let mut progress = CheckoutProgressBar::new(to_commit.id.clone());
let mut target_hashes = HashSet::new();
let target_root = if let Some(target_root) =
repositories::tree::get_subtree_by_depth_with_unique_children(
repo,
to_commit,
subtree_path.clone(),
None,
Some(&mut target_hashes),
None,
depth,
)? {
target_root
} else {
log::error!("Cannot get subtree for commit: {to_commit}");
continue;
};
let mut shared_hashes = HashSet::new();
let mut partial_nodes = HashMap::new();
let maybe_from_commit = repositories::commits::head_commit_maybe(repo)?;
let from_root = if let Some(from_commit) = &maybe_from_commit {
log::debug!("from id: {:?}", from_commit.id);
log::debug!("to id: {:?}", to_commit.id);
repositories::tree::get_root_with_children_and_partial_nodes(
repo,
from_commit,
Some(&target_hashes),
None,
Some(&mut shared_hashes),
&mut partial_nodes,
)
.map_err(|e| {
OxenError::basic_str(format!("Cannot get root node for base commit: {e:?}"))
})?
} else {
log::warn!("head commit missing, might be a clone");
None
};
let parent_path = subtree_path.parent().unwrap_or(Path::new(""));
let mut results = CheckoutResult::new();
let mut hashes = CheckoutHashes::from_hashes(shared_hashes);
let version_store = repo.version_store()?;
r_restore_missing_or_modified_files(
repo,
&target_root,
parent_path,
&mut results,
&mut progress,
&mut partial_nodes,
&mut hashes,
depth,
)?;
if !results.cannot_overwrite_entries.is_empty() {
return Err(OxenError::cannot_overwrite_files(
&results.cannot_overwrite_entries,
));
}
if let Some(root) = from_root {
log::debug!("Cleanup_removed_files");
cleanup_removed_files(repo, &root, &mut progress, &mut hashes).await?;
} else {
log::debug!("head commit missing, no cleanup");
}
if repo.is_remote_mode() {
for file_to_restore in results.files_to_restore {
log::debug!("file_to_restore: {:?}", file_to_restore.file_node);
let file_hash = format!("{}", &file_to_restore.file_node.hash());
if version_store.version_exists(&file_hash).await? {
restore::restore_file(
repo,
&file_to_restore.file_node,
&file_to_restore.path,
&version_store,
)
.await?;
}
}
} else {
for file_to_restore in results.files_to_restore {
restore::restore_file(
repo,
&file_to_restore.file_node,
&file_to_restore.path,
&version_store,
)
.await?;
}
}
}
Ok(())
}
pub async fn checkout_commit(
repo: &LocalRepository,
to_commit: &Commit,
from_commit: &Option<Commit>,
) -> Result<(), OxenError> {
log::debug!("checkout_commit to {to_commit} from {from_commit:?}");
if let Some(from_commit) = from_commit
&& from_commit.id == to_commit.id
{
return Ok(());
}
fetch::maybe_fetch_missing_entries(repo, to_commit).await?;
set_working_repo_to_commit(repo, to_commit, from_commit).await?;
Ok(())
}
pub async fn set_working_repo_to_commit(
repo: &LocalRepository,
to_commit: &Commit,
maybe_from_commit: &Option<Commit>,
) -> Result<(), OxenError> {
let mut progress = CheckoutProgressBar::new(to_commit.id.clone());
let mut target_hashes = HashSet::new();
let Some(target_tree) = repositories::tree::get_root_with_children_and_node_hashes(
repo,
to_commit,
None,
Some(&mut target_hashes),
None,
)?
else {
return Err(OxenError::basic_str(
"Cannot get root node for target commit",
));
};
let mut shared_hashes = HashSet::new();
let mut partial_nodes = HashMap::new();
let from_tree = if let Some(from_commit) = maybe_from_commit {
if from_commit.id == to_commit.id {
return Ok(());
}
log::debug!("from id: {:?}", from_commit.id);
log::debug!("to id: {:?}", to_commit.id);
repositories::tree::get_root_with_children_and_partial_nodes(
repo,
from_commit,
Some(&target_hashes),
None,
Some(&mut shared_hashes),
&mut partial_nodes,
)
.map_err(|_| OxenError::basic_str("Cannot get root node for base commit"))?
} else {
None
};
let mut results = CheckoutResult::new();
let mut hashes = CheckoutHashes::from_hashes(shared_hashes);
let version_store = repo.version_store()?;
log::debug!("restore_missing_or_modified_files");
r_restore_missing_or_modified_files(
repo,
&target_tree,
Path::new(""),
&mut results,
&mut progress,
&mut partial_nodes,
&mut hashes,
i32::MAX,
)?;
if !results.cannot_overwrite_entries.is_empty() {
return Err(OxenError::cannot_overwrite_files(
&results.cannot_overwrite_entries,
));
}
if let Some(from_tree) = from_tree {
log::debug!("Cleanup_removed_files");
cleanup_removed_files(repo, &from_tree, &mut progress, &mut hashes).await?;
}
for file_to_restore in results.files_to_restore {
restore::restore_file(
repo,
&file_to_restore.file_node,
&file_to_restore.path,
&version_store,
)
.await?;
}
Ok(())
}
async fn cleanup_removed_files(
repo: &LocalRepository,
from_node: &MerkleTreeNode,
progress: &mut CheckoutProgressBar,
hashes: &mut CheckoutHashes,
) -> Result<(), OxenError> {
let mut paths_to_remove: Vec<PathBuf> = vec![];
let mut files_to_store: Vec<(MerkleHash, PathBuf)> = vec![];
let mut cannot_overwrite_entries: Vec<PathBuf> = vec![];
r_remove_if_not_in_target(
repo,
from_node,
Path::new(""),
&mut paths_to_remove,
&mut files_to_store,
&mut cannot_overwrite_entries,
hashes,
)?;
if !cannot_overwrite_entries.is_empty() {
return Err(OxenError::cannot_overwrite_files(&cannot_overwrite_entries));
}
if repo.is_remote_mode() {
let version_store = repo.version_store()?;
for (hash, full_path) in files_to_store {
log::debug!("Storing hash {hash:?} and path {full_path:?}");
version_store
.store_version_from_path(&hash.to_string(), &full_path)
.await?;
}
}
for full_path in paths_to_remove {
if full_path.is_dir() && full_path.read_dir()?.next().is_none() {
log::debug!("Removing dir: {full_path:?}");
util::fs::remove_dir_all(&full_path)?;
} else if full_path.is_file() {
log::debug!("Removing file: {full_path:?}");
util::fs::remove_file(&full_path)?;
}
progress.increment_removed();
}
Ok(())
}
fn r_remove_if_not_in_target(
repo: &LocalRepository,
from_node: &MerkleTreeNode,
current_path: &Path,
paths_to_remove: &mut Vec<PathBuf>,
files_to_store: &mut Vec<(MerkleHash, PathBuf)>,
cannot_overwrite_entries: &mut Vec<PathBuf>,
hashes: &mut CheckoutHashes,
) -> Result<(), OxenError> {
match &from_node.node {
EMerkleTreeNode::File(file_node) => {
let file_path = current_path.join(file_node.name());
let full_path = repo.path.join(&file_path);
if !hashes.seen_paths.contains(&file_path) {
if full_path.exists() {
if util::fs::is_modified_from_node(&full_path, file_node)? {
cannot_overwrite_entries.push(file_path.clone());
} else {
if repo.is_remote_mode() {
files_to_store.push((from_node.hash, full_path.clone()))
}
paths_to_remove.push(full_path.clone());
}
}
} else if full_path.exists() && repo.is_remote_mode() {
files_to_store.push((from_node.hash, full_path.clone()))
}
}
EMerkleTreeNode::Directory(dir_node) => {
let dir_path = current_path.join(dir_node.name());
if hashes.common_nodes.contains(&from_node.hash) {
return Ok(());
};
let children = {
let dir_vnodes = &from_node.children;
let mut unique_nodes = Vec::new();
for vnode in dir_vnodes {
if !hashes.common_nodes.contains(&vnode.hash) {
unique_nodes.extend(vnode.children.iter().cloned());
}
}
unique_nodes
};
for child in &children {
r_remove_if_not_in_target(
repo,
child,
&dir_path,
paths_to_remove,
files_to_store,
cannot_overwrite_entries,
hashes,
)?;
}
log::debug!(
"r_remove_if_not_in_target checked {:?} paths",
children.len()
);
let full_dir_path = repo.path.join(&dir_path);
if full_dir_path.exists() {
paths_to_remove.push(full_dir_path.clone());
}
}
EMerkleTreeNode::Commit(_) => {
let root_dir = repositories::tree::get_root_dir(from_node)?;
r_remove_if_not_in_target(
repo,
root_dir,
current_path,
paths_to_remove,
files_to_store,
cannot_overwrite_entries,
hashes,
)?;
}
_ => {}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn r_restore_missing_or_modified_files(
repo: &LocalRepository,
target_node: &MerkleTreeNode,
path: &Path, results: &mut CheckoutResult,
progress: &mut CheckoutProgressBar,
partial_nodes: &mut HashMap<PathBuf, PartialNode>,
hashes: &mut CheckoutHashes,
depth: i32,
) -> Result<(), OxenError> {
if depth < 0 {
return Ok(());
}
match &target_node.node {
EMerkleTreeNode::File(file_node) => {
let file_path = path.join(file_node.name());
let full_path = repo.path.join(&file_path);
hashes.seen_paths.insert(file_path.clone());
if !full_path.exists() {
if let Some(from_node) = partial_nodes.get(&file_path) {
if from_node.hash == target_node.hash {
log::debug!("Preserving uncommitted deletion of file: {file_path:?}");
return Ok(());
} else {
log::debug!(
"Conflict: uncommitted deletion of modified file: {file_path:?}"
);
results.cannot_overwrite_entries.push(file_path.clone());
return Ok(());
}
}
log::debug!("Restoring missing file: {file_path:?}");
results.files_to_restore.push(FileToRestore {
file_node: file_node.clone(),
path: file_path.clone(),
});
progress.increment_restored();
} else {
let meta = util::fs::metadata(&full_path)?;
let last_modified = Some(FileTime::from_last_modification_time(&meta));
let size = Some(meta.len());
let target_last_modified = util::fs::last_modified_time(
file_node.last_modified_seconds(),
file_node.last_modified_nanoseconds(),
);
let target_size = file_node.num_bytes();
if last_modified == Some(target_last_modified) && size == Some(target_size) {
return Ok(());
}
let (from_node, from_last_modified, from_size) =
if let Some(from_node) = partial_nodes.get(&file_path) {
(
Some(from_node),
Some(from_node.last_modified),
Some(from_node.size),
)
} else {
(None, None, None)
};
if last_modified == from_last_modified && size == from_size {
results.files_to_restore.push(FileToRestore {
file_node: file_node.clone(),
path: file_path.clone(),
});
progress.increment_modified();
return Ok(());
}
let working_hash = Some(util::hasher::get_hash_given_metadata(&full_path, &meta)?);
let target_hash = target_node.hash.to_u128();
if working_hash == Some(target_hash) {
return Ok(());
}
let from_hash = from_node.map(|from_node| from_node.hash.to_u128());
if working_hash == from_hash {
results.files_to_restore.push(FileToRestore {
file_node: file_node.clone(),
path: file_path.clone(),
});
progress.increment_modified();
return Ok(());
}
results.cannot_overwrite_entries.push(file_path.clone());
progress.increment_modified();
}
}
EMerkleTreeNode::Directory(dir_node) => {
let dir_path = path.join(dir_node.name());
let full_dir_path = repo.path.join(&dir_path);
if full_dir_path.exists() && !full_dir_path.is_dir() {
std::fs::remove_file(&full_dir_path)?;
}
if hashes.common_nodes.contains(&target_node.hash) && full_dir_path.is_dir() {
return Ok(());
};
let walk_all = !full_dir_path.is_dir();
let children = {
let dir_vnodes = &target_node.children;
let mut unique_nodes = Vec::new();
for vnode in dir_vnodes {
if walk_all || !hashes.common_nodes.contains(&vnode.hash) {
unique_nodes.extend(vnode.children.iter().cloned());
}
}
unique_nodes
};
for child_node in children {
r_restore_missing_or_modified_files(
repo,
&child_node,
&dir_path,
results,
progress,
partial_nodes,
hashes,
depth - 1,
)?;
}
}
EMerkleTreeNode::Commit(_) => {
let root_dir = repositories::tree::get_root_dir(target_node)?;
r_restore_missing_or_modified_files(
repo,
root_dir,
path,
results,
progress,
partial_nodes,
hashes,
depth - 1,
)?;
}
_ => {
return Err(OxenError::basic_str(
"Got an unexpected node type during checkout",
));
}
}
Ok(())
}