use futures::prelude::*;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;
use tokio::time::Duration;
use crate::constants::AVG_CHUNK_SIZE;
use crate::constants::DEFAULT_REMOTE_NAME;
use crate::core::progress::push_progress::PushProgress;
use crate::error::OxenError;
use crate::model::merkle_tree::node::MerkleTreeNode;
use crate::model::{
Branch, Commit, CommitEntry, LocalRepository, MerkleHash, MerkleTreeNodeType, RemoteRepository,
};
use crate::opts::PushOpts;
use crate::util::concurrency;
use crate::{api, repositories};
pub async fn push(repo: &LocalRepository) -> Result<Branch, OxenError> {
let Some(current_branch) = repositories::branches::current_branch(repo)? else {
log::debug!("Push, no current branch found");
return Err(OxenError::must_be_on_valid_branch());
};
let opts = PushOpts {
remote: DEFAULT_REMOTE_NAME.to_string(),
branch: current_branch.name,
..Default::default()
};
push_remote_branch(repo, &opts).await
}
pub async fn push_remote_branch(
repo: &LocalRepository,
opts: &PushOpts,
) -> Result<Branch, OxenError> {
let start = std::time::Instant::now();
let local_branch = repositories::branches::get_by_name(repo, &opts.branch)?;
println!(
"🐂 oxen push {} {} -> {}",
opts.remote, local_branch.name, local_branch.commit_id
);
let remote = repo
.get_remote(&opts.remote)
.ok_or_else(|| OxenError::RemoteNotSet(opts.remote.to_string()))?;
let remote_repo = api::client::repositories::get_by_remote(&remote).await?;
push_local_branch_to_remote_repo(repo, &remote_repo, &local_branch, opts).await?;
let duration = std::time::Duration::from_millis(start.elapsed().as_millis() as u64);
println!(
"🐂 push complete 🎉 took {}",
humantime::format_duration(duration)
);
Ok(local_branch)
}
async fn push_local_branch_to_remote_repo(
repo: &LocalRepository,
remote_repo: &RemoteRepository,
local_branch: &Branch,
opts: &PushOpts,
) -> Result<(), OxenError> {
let Some(commit) = repositories::commits::get_by_id(repo, &local_branch.commit_id)? else {
return Err(OxenError::RevisionNotFound(
local_branch.commit_id.clone().into(),
));
};
api::client::repositories::pre_push(remote_repo, local_branch, &commit.id).await?;
match api::client::branches::get_by_name(remote_repo, &local_branch.name).await? {
Some(remote_branch) => {
push_to_existing_branch(repo, &commit, remote_repo, &remote_branch, opts).await?
}
None => push_to_new_branch(repo, remote_repo, local_branch, &commit, opts).await?,
}
api::client::repositories::post_push(remote_repo, local_branch, &commit.id).await?;
Ok(())
}
async fn push_to_new_branch(
repo: &LocalRepository,
remote_repo: &RemoteRepository,
branch: &Branch,
commit: &Commit,
opts: &PushOpts,
) -> Result<(), OxenError> {
let history = repositories::commits::list_from(repo, &commit.id)?;
let latest_remote_commit = find_latest_remote_commit(repo, remote_repo).await?;
push_commits(repo, remote_repo, latest_remote_commit, history, opts).await?;
api::client::branches::create_from_commit(remote_repo, &branch.name, commit).await?;
Ok(())
}
async fn push_to_existing_branch(
repo: &LocalRepository,
commit: &Commit,
remote_repo: &RemoteRepository,
remote_branch: &Branch,
opts: &PushOpts,
) -> Result<(), OxenError> {
if remote_branch.commit_id == commit.id && !opts.missing_files && !opts.revalidate {
println!("Everything is up to date");
return Ok(());
}
match repositories::commits::list_from(repo, &commit.id) {
Ok(commits) => {
let is_ahead = commits.iter().any(|c| c.id == remote_branch.commit_id);
if is_ahead {
let latest_remote_commit =
repositories::commits::get_by_id(repo, &remote_branch.commit_id)?.ok_or_else(
|| OxenError::RevisionNotFound(remote_branch.commit_id.clone().into()),
)?;
let mut commits =
repositories::commits::list_between(repo, &latest_remote_commit, commit)?;
commits.reverse();
push_commits(repo, remote_repo, Some(latest_remote_commit), commits, opts).await?;
api::client::branches::update(remote_repo, &remote_branch.name, commit).await?;
} else if opts.force {
log::info!(
"Force pushing branch '{}' to {}",
&remote_branch.name,
commit.id
);
let latest_remote_commit = find_latest_remote_commit(repo, remote_repo).await?;
let history = repositories::commits::list_from(repo, &commit.id)?;
push_commits(repo, remote_repo, latest_remote_commit, history, opts).await?;
api::client::branches::update(remote_repo, &remote_branch.name, commit).await?;
} else {
let err_str = format!(
"Branch {} is behind remote commit {}.\nRun `oxen pull` to update your local branch, or use `oxen push --force` to force push.",
remote_branch.name, remote_branch.commit_id
);
return Err(OxenError::basic_str(err_str));
}
}
Err(err) => {
return Err(err);
}
};
Ok(())
}
async fn revalidate_and_push_missing_files(
repo: &LocalRepository,
opts: &PushOpts,
remote_repo: &RemoteRepository,
latest_remote_commit: &Option<Commit>,
commits: &[Commit],
) -> Result<(), OxenError> {
println!("🐂 revalidating the remote repo...");
let response = api::client::versions::clean(remote_repo).await?;
let clean_result = response.result;
if clean_result.corrupted > 0 {
println!(
"🔍 scanned {} files, found {} corrupted files, cleaned {} of them. Scanning took {}",
clean_result.scanned,
clean_result.corrupted,
clean_result.cleaned,
humantime::format_duration(clean_result.elapsed)
);
println!("🐂 pushing missing files...");
if clean_result.corrupted > clean_result.cleaned
|| clean_result.errors > clean_result.cleaned
{
println!(
"🚧 This fix is not complete. Some files may still be corrupted. Please try running this command again."
);
}
return push_missing_files(repo, opts, remote_repo, latest_remote_commit, commits).await;
} else {
println!(
"🔍 scanned {} files, no corrupted files found. Scanning took {}",
clean_result.scanned,
humantime::format_duration(clean_result.elapsed)
);
Ok(())
}
}
async fn push_missing_files(
repo: &LocalRepository,
opts: &PushOpts,
remote_repo: &RemoteRepository,
latest_remote_commit: &Option<Commit>,
commits: &[Commit],
) -> Result<(), OxenError> {
let Some(head_commit) = commits.last() else {
return Err(OxenError::basic_str(
"Cannot push missing files without a head commit",
));
};
if let Some(commit_id) = &opts.missing_files_commit_id {
let commit = repositories::commits::get_by_id(repo, commit_id)?
.ok_or_else(|| OxenError::commit_id_does_not_exist(commit_id))?;
list_and_push_missing_files(repo, remote_repo, None, &commit).await?;
} else if latest_remote_commit
.as_ref()
.is_some_and(|rc| head_commit.id == rc.id)
{
let history = repositories::commits::list_from(repo, &head_commit.id)?;
for commit in history {
list_and_push_missing_files(repo, remote_repo, None, &commit).await?;
}
} else {
list_and_push_missing_files(repo, remote_repo, latest_remote_commit.clone(), head_commit)
.await?;
}
Ok(())
}
async fn list_and_push_missing_files(
repo: &LocalRepository,
remote_repo: &RemoteRepository,
base_commit: Option<Commit>,
head_commit: &Commit,
) -> Result<(), OxenError> {
let missing_files =
api::client::commits::list_missing_files(remote_repo, base_commit, &head_commit.id).await?;
if let Some(commit_entry) = missing_files.first() {
let version_store = repo.version_store()?;
if !version_store.version_exists(&commit_entry.hash).await? {
return Err(OxenError::CannotPushShallowClone {
commit_id: head_commit.id.clone(),
commit_message: head_commit.message.clone(),
help: "To repair the remote, re-run this command from a clone that has the full history.".to_string(),
});
}
}
let total_bytes = missing_files.iter().map(|e| e.num_bytes).sum();
let progress = Arc::new(PushProgress::new_with_totals(
missing_files.len() as u64,
total_bytes,
));
push_entries(repo, remote_repo, &missing_files, head_commit, &progress).await?;
Ok(())
}
async fn get_commit_missing_hashes(
repo: &LocalRepository,
latest_remote_commit: Option<Commit>,
commits: &[Commit],
) -> Result<HashMap<MerkleHash, PushCommitInfo>, OxenError> {
let mut base_hashes = HashSet::new();
let paths = &repo.subtree_paths().unwrap_or(vec![PathBuf::new()]);
if let Some(ref commit) = latest_remote_commit {
for path in paths {
let mut starting_node_hashes = HashSet::new();
let Some(root) = repositories::tree::get_subtree_by_depth_with_unique_children(
repo,
commit,
path,
None,
Some(&mut starting_node_hashes),
None,
repo.depth().unwrap_or(-1),
)?
else {
log::warn!("Could not get remote tree for path {path:?}");
continue;
};
base_hashes.extend(starting_node_hashes);
root.walk_tree(|node: &MerkleTreeNode| {
let t = node.node.node_type();
if t == MerkleTreeNodeType::File || t == MerkleTreeNodeType::FileChunk {
base_hashes.insert(*node.node.hash());
}
});
}
}
log::debug!("starting hashes: {:?}", base_hashes.len());
let mut result = HashMap::new();
for commit in commits.iter().rev() {
let mut unique_hashes = HashSet::new();
let mut file_hashes_seen = HashSet::new();
let mut files: Vec<CommitEntry> = Vec::new();
let mut dir_nodes: HashSet<MerkleHash> = HashSet::new();
for path in paths {
log::debug!("push_commits adding candidate nodes for commit: {commit}");
let Some(root) = repositories::tree::get_subtree_by_depth_with_unique_children(
repo,
commit,
path,
Some(&base_hashes),
Some(&mut unique_hashes),
None,
repo.depth().unwrap_or(-1),
)?
else {
log::error!("push_commits commit node not found for commit: {commit}");
continue;
};
root.walk_tree(|node: &MerkleTreeNode| {
let t = node.node.node_type();
if t == MerkleTreeNodeType::File || t == MerkleTreeNodeType::FileChunk {
let file_hash = *node.node.hash();
if !base_hashes.contains(&file_hash) && file_hashes_seen.insert(file_hash) {
files.push(CommitEntry::from_node(&node.node));
}
} else if !node.node.is_leaf() {
let hash = node.node.hash();
dir_nodes.insert(*hash);
}
});
}
base_hashes.extend(unique_hashes);
base_hashes.extend(file_hashes_seen);
repositories::tree::get_ancestor_nodes(repo, commit, paths, &mut dir_nodes)?;
dir_nodes.insert(commit.hash()?);
log::debug!("push_commits dir nodes: {dir_nodes:?}");
let total_bytes = files.iter().map(|e| e.num_bytes).sum();
let push_commit_info = PushCommitInfo {
unique_dir_nodes: dir_nodes,
unique_file_hashes: files,
total_bytes,
};
result.insert(commit.hash()?, push_commit_info);
}
Ok(result)
}
#[derive(Debug, Clone)]
struct PushCommitInfo {
unique_dir_nodes: HashSet<MerkleHash>,
unique_file_hashes: Vec<CommitEntry>,
total_bytes: u64,
}
async fn push_commits(
repo: &LocalRepository,
remote_repo: &RemoteRepository,
latest_remote_commit: Option<Commit>,
commits: Vec<Commit>,
opts: &PushOpts,
) -> Result<(), OxenError> {
if opts.revalidate {
return revalidate_and_push_missing_files(
repo,
opts,
remote_repo,
&latest_remote_commit,
&commits,
)
.await;
}
if opts.missing_files {
return push_missing_files(repo, opts, remote_repo, &latest_remote_commit, &commits).await;
}
let commit_info = get_commit_missing_hashes(repo, latest_remote_commit, &commits).await?;
log::debug!("got commit info {}", commit_info.len());
let missing_commits = api::client::commits::list_missing_hashes(remote_repo, commits).await?;
log::debug!("got missing commits {}", missing_commits.len());
let commits_with_info = missing_commits
.into_iter()
.map(|commit| {
let commit_hash = commit.hash()?;
let info = commit_info.get(&commit_hash).cloned().ok_or_else(|| {
OxenError::basic_str(format!("Commit info not found for commit {commit_hash}"))
})?;
Ok((commit, info))
})
.collect::<Result<Vec<(Commit, PushCommitInfo)>, OxenError>>()?;
let version_store = repo.version_store()?;
for (commit, info) in &commits_with_info {
if let Some(commit_entry) = info.unique_file_hashes.first()
&& !version_store.version_exists(&commit_entry.hash).await?
{
return Err(OxenError::CannotPushShallowClone {
commit_id: commit.id.clone(),
commit_message: commit.message.clone(),
help: "Run `oxen pull --all` to fetch all data, then try again.".to_string(),
});
}
}
let total_bytes = commits_with_info
.iter()
.map(|(_, info)| info.total_bytes)
.sum();
let num_files: usize = commits_with_info
.iter()
.map(|(_, info)| info.unique_file_hashes.len())
.sum();
log::debug!("got commits with info {commits_with_info:?}");
let num_commits = commits_with_info.len();
log::debug!("got commit info {num_commits}");
let errors = Arc::new(Mutex::new(Vec::new()));
let progress = Arc::new(PushProgress::new_with_totals(num_files as u64, total_bytes));
stream::iter(commits_with_info)
.for_each_concurrent(
concurrency::num_threads_for_items(num_commits),
|(commit, commit_info)| {
let id = commit.id.clone();
log::debug!("Pushing commit {commit:?}");
let progress = progress.clone();
let errors = errors.clone();
async move {
let result = async {
let commit_hash = commit.hash()?;
log::debug!("Pushing commit {commit_hash}");
log::debug!("missing files {}", commit_info.unique_file_hashes.len());
push_entries(
repo,
remote_repo,
&commit_info.unique_file_hashes,
&commit,
&progress,
)
.await?;
log::debug!("pushed entries missing files");
let mut nodes = commit_info.unique_dir_nodes;
nodes.insert(commit_hash);
api::client::tree::create_nodes(repo, remote_repo, nodes, &progress)
.await?;
log::debug!("created nodes");
api::client::commits::post_commits_dir_hashes_to_server(
repo,
remote_repo,
&vec![commit],
)
.await?;
api::client::commits::mark_commits_as_synced(
remote_repo,
HashSet::from([commit_hash]),
)
.await?;
log::debug!("marked commits as synced {commit_hash}");
Ok::<(), OxenError>(())
}
.await;
if let Err(err) = result {
let err_str = format!("Error pushing commit {id:?}: {err}");
errors.lock().await.push(OxenError::basic_str(err_str));
}
}
},
)
.await;
let errors = errors.lock().await;
if !errors.is_empty() {
let error_messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
return Err(OxenError::basic_str(format!(
"Failed to push {} commit(s):\n{}",
errors.len(),
error_messages.join("\n")
)));
}
Ok(())
}
pub async fn push_entries(
local_repo: &LocalRepository,
remote_repo: &RemoteRepository,
entries: &[CommitEntry],
commit: &Commit,
progress: &Arc<PushProgress>,
) -> Result<(), OxenError> {
log::debug!(
"PUSH ENTRIES {} -> {} -> '{}'",
entries.len(),
commit.id,
commit.message
);
let smaller_entries: Vec<CommitEntry> = entries
.iter()
.filter(|e| e.num_bytes <= AVG_CHUNK_SIZE)
.map(|e| e.to_owned())
.collect();
let larger_entries: Vec<CommitEntry> = entries
.iter()
.filter(|e| e.num_bytes > AVG_CHUNK_SIZE)
.map(|e| e.to_owned())
.collect();
let large_entries_sync =
chunk_and_send_large_entries(local_repo, remote_repo, larger_entries, progress);
let small_entries_sync = bundle_and_send_small_entries(
local_repo,
remote_repo,
smaller_entries,
commit,
AVG_CHUNK_SIZE,
progress,
);
match tokio::join!(large_entries_sync, small_entries_sync) {
(Ok(_), Ok(_)) => {
log::debug!("Moving on to post-push validation");
Ok(())
}
(Err(err), Ok(_)) => {
let err = format!("Error syncing large entries: {err}");
Err(OxenError::basic_str(err))
}
(Ok(_), Err(err)) => {
let err = format!("Error syncing small entries: {err}");
Err(OxenError::basic_str(err))
}
_ => Err(OxenError::basic_str("Unknown error syncing entries")),
}
}
async fn chunk_and_send_large_entries(
local_repo: &LocalRepository,
remote_repo: &RemoteRepository,
entries: Vec<CommitEntry>,
progress: &Arc<PushProgress>,
) -> Result<(), OxenError> {
if entries.is_empty() {
return Ok(());
}
use tokio::time::sleep;
type PieceOfWork = (CommitEntry, RemoteRepository);
type TaskQueue = deadqueue::limited::Queue<PieceOfWork>;
log::debug!("Chunking and sending {} larger files", entries.len());
let entries: Vec<PieceOfWork> = entries
.iter()
.map(|e| (e.to_owned(), remote_repo.to_owned()))
.collect();
let queue = Arc::new(TaskQueue::new(entries.len()));
for entry in entries.iter() {
queue.try_push(entry.to_owned()).unwrap();
}
let version_store = local_repo.version_store()?;
let worker_count = concurrency::num_threads_for_items(entries.len());
log::debug!(
"worker_count {} entries len {}",
worker_count,
entries.len()
);
let should_stop = Arc::new(AtomicBool::new(false));
let first_error = Arc::new(Mutex::new(None::<String>));
let mut handles = vec![];
for worker in 0..worker_count {
let queue = queue.clone();
let bar = Arc::clone(progress);
let should_stop = should_stop.clone();
let first_error = first_error.clone();
let version_store = Arc::clone(&version_store);
let handle = tokio::spawn(async move {
loop {
if should_stop.load(Ordering::Relaxed) {
break;
}
let Some((commit_entry, remote_repo)) = queue.try_pop() else {
break;
};
let version_path = match version_store.get_version_path(&commit_entry.hash).await {
Ok(path) => path,
Err(e) => {
log::error!("Failed to get version path: {e}");
should_stop.store(true, Ordering::Relaxed);
*first_error.lock().await = Some(e.to_string());
break;
}
};
match api::client::versions::parallel_large_file_upload(
&remote_repo,
&*version_path,
None::<PathBuf>,
None,
false,
Some(commit_entry.clone()),
Some(&bar),
)
.await
{
Ok(_) => {
log::debug!(
"worker[{}] successfully uploaded {:?}",
worker,
commit_entry.path
);
}
Err(err) => {
log::error!(
"worker[{}] failed to upload {:?}: {}",
worker,
commit_entry.path,
err
);
should_stop.store(true, Ordering::Relaxed);
*first_error.lock().await = Some(err.to_string());
break;
}
}
}
});
handles.push(handle);
}
let join_results = futures::future::join_all(handles).await;
for res in join_results {
if let Err(e) = res {
return Err(OxenError::basic_str(format!("worker task panicked: {e}")));
}
}
if let Some(err) = first_error.lock().await.clone() {
return Err(OxenError::basic_str(err));
}
log::debug!("All large file tasks done. :-)");
sleep(Duration::from_millis(100)).await;
Ok(())
}
async fn bundle_and_send_small_entries(
local_repo: &LocalRepository,
remote_repo: &RemoteRepository,
entries: Vec<CommitEntry>,
commit: &Commit,
avg_chunk_size: u64,
progress: &Arc<PushProgress>,
) -> Result<(), OxenError> {
if entries.is_empty() {
return Ok(());
}
let total_size = repositories::entries::compute_entries_size(&entries)?;
let num_chunks = ((total_size / avg_chunk_size) + 1) as usize;
let mut chunk_size = entries.len() / num_chunks;
if num_chunks > entries.len() {
chunk_size = entries.len();
}
let client = Arc::new(api::client::builder_for_remote_repo(remote_repo)?.build()?);
use tokio::time::sleep;
type PieceOfWork = (
Vec<CommitEntry>,
LocalRepository,
Commit,
RemoteRepository,
Arc<reqwest::Client>,
);
type TaskQueue = deadqueue::limited::Queue<PieceOfWork>;
type FinishedTaskQueue = deadqueue::limited::Queue<bool>;
log::debug!("Creating {num_chunks} chunks from {total_size} bytes with size {chunk_size}");
let chunks: Vec<PieceOfWork> = entries
.chunks(chunk_size)
.map(|c| {
(
c.to_owned(),
local_repo.to_owned(),
commit.to_owned(),
remote_repo.to_owned(),
client.clone(),
)
})
.collect();
let worker_count = concurrency::num_threads_for_items(chunks.len());
let queue = Arc::new(TaskQueue::new(chunks.len()));
let finished_queue = Arc::new(FinishedTaskQueue::new(chunks.len()));
for chunk in chunks {
queue.try_push(chunk).unwrap();
finished_queue.try_push(false).unwrap();
}
use std::sync::atomic::{AtomicBool, Ordering};
let should_stop = Arc::new(AtomicBool::new(false));
let first_error = Arc::new(Mutex::new(None::<String>));
let mut handles = vec![];
for worker in 0..worker_count {
let queue = queue.clone();
let finished_queue = finished_queue.clone();
let bar = Arc::clone(progress);
let should_stop = should_stop.clone();
let first_error = first_error.clone();
let handle = tokio::spawn(async move {
loop {
log::debug!("worker[{worker}] processing task");
if should_stop.load(Ordering::Relaxed) {
break;
}
let Some((chunk, repo, _commit, remote_repo, client)) = queue.try_pop() else {
break;
};
let chunk_size = match repositories::entries::compute_entries_size(&chunk) {
Ok(size) => size,
Err(e) => {
log::error!("Failed to compute entries size: {e}");
should_stop.store(true, Ordering::Relaxed);
*first_error.lock().await = Some(e.to_string());
finished_queue.pop().await;
break;
}
};
match api::client::versions::multipart_batch_upload_with_retry(
&repo,
&remote_repo,
&chunk,
&client,
)
.await
{
Ok(_) => {
bar.add_bytes(chunk_size);
bar.add_files(chunk.len() as u64);
finished_queue.pop().await;
}
Err(e) => {
should_stop.store(true, Ordering::Relaxed);
*first_error.lock().await = Some(e.to_string());
finished_queue.pop().await;
break;
}
}
}
});
handles.push(handle);
}
let join_results = futures::future::join_all(handles).await;
for res in join_results {
if let Err(e) = res {
return Err(OxenError::basic_str(format!("worker task panicked: {e}")));
}
}
if let Some(err) = first_error.lock().await.clone() {
return Err(OxenError::basic_str(err));
}
sleep(Duration::from_millis(100)).await;
Ok(())
}
async fn find_latest_remote_commit(
repo: &LocalRepository,
remote_repo: &RemoteRepository,
) -> Result<Option<Commit>, OxenError> {
let remote_branches = api::client::branches::list(remote_repo).await?;
if remote_branches.is_empty() {
return Ok(None);
}
let default_branch = remote_branches
.iter()
.find(|b| b.name == crate::constants::DEFAULT_BRANCH_NAME)
.or_else(|| remote_branches.first());
if let Some(remote_branch) = default_branch {
if let Some(remote_commit) =
repositories::commits::get_by_id(repo, &remote_branch.commit_id)?
{
Ok(Some(remote_commit))
} else {
Ok(None)
}
} else {
Ok(None)
}
}