oxen-cli 0.50.6

Oxen is a fast, unstructured data version control, to help version large machine learning datasets written in Rust.
use async_trait::async_trait;
use clap::{Arg, Command};
use colored::Colorize;

use liboxen::api;
use liboxen::error::OxenError;
use liboxen::model::LocalRepository;
use liboxen::repositories;

use crate::cli_error::UnknownSubcommand;
use crate::cmd::RunCmd;
use crate::helpers::{
    check_remote_version, check_remote_version_blocking, get_scheme_and_host_from_repo,
};

pub const NAME: &str = "branch";

pub struct BranchCmd;

#[async_trait]
impl RunCmd for BranchCmd {
    fn name(&self) -> &str {
        NAME
    }

    fn args(&self) -> Command {
        // Setups the CLI args for the init command
        Command::new(NAME)
            .about("Manage branches in repository")
            .arg(Arg::new("name").help("Name of the branch"))
            .arg(
                Arg::new("commit_id")
                    .help("Commit ID to point the branch to (used with --force)")
                    .requires("force"),
            )
            .arg(
                Arg::new("force")
                    .long("force")
                    .short('f')
                    .help("Force update an existing branch to point to a specific commit")
                    .requires("name")
                    .action(clap::ArgAction::SetTrue),
            )
            .arg(
                Arg::new("all")
                    .long("all")
                    .short('a')
                    .help("List both local and remote branches")
                    .exclusive(true)
                    .action(clap::ArgAction::SetTrue),
            )
            .arg(
                Arg::new("remote")
                    .long("remote")
                    .short('r')
                    .help("List all the remote branches")
                    .action(clap::ArgAction::Set),
            )
            .arg(
                Arg::new("force-delete")
                    .long("force-delete")
                    .short('D')
                    .help("Force remove the local branch")
                    .action(clap::ArgAction::Set),
            )
            .arg(
                Arg::new("delete")
                    .long("delete")
                    .short('d')
                    .help("Remove the local branch if it is safe to")
                    .action(clap::ArgAction::Set),
            )
            .arg(
                Arg::new("move")
                    .long("move")
                    .short('m')
                    .help("Rename the current local branch.")
                    .action(clap::ArgAction::Set),
            )
            .arg(
                Arg::new("show-current")
                    .long("show-current")
                    .help("Print the current branch")
                    .exclusive(true)
                    .action(clap::ArgAction::SetTrue),
            )
    }

    async fn run(&self, args: &clap::ArgMatches) -> Result<(), anyhow::Error> {
        // Find the repository
        let repo = LocalRepository::from_current_dir()?;

        // Parse Args
        if let Some((cmd, _)) = args.subcommand() {
            return Err(UnknownSubcommand {
                parent: "branch",
                name: cmd.to_string(),
            })?;
        } else if args.get_flag("all") {
            self.list_all_branches(&repo).await?;
        } else if let Some(remote_name) = args.get_one::<String>("remote") {
            if let Some(branch_name) = args.get_one::<String>("delete") {
                self.delete_remote_branch(&repo, remote_name, branch_name)
                    .await?;
            } else {
                self.list_remote_branches(&repo, remote_name).await?;
            }
        } else if let Some(name) = args.get_one::<String>("name") {
            if args.get_flag("force") {
                let commit_id = args
                    .get_one::<String>("commit_id")
                    .ok_or_else(|| anyhow::anyhow!("Must supply a commit ID with --force"))?;
                self.force_update_branch(&repo, name, commit_id)?;
            } else {
                self.create_branch(&repo, name)?;
            }
        } else if let Some(name) = args.get_one::<String>("delete") {
            self.delete_branch(&repo, name)?;
        } else if let Some(name) = args.get_one::<String>("force-delete") {
            self.force_delete_branch(&repo, name)?;
        } else if let Some(name) = args.get_one::<String>("move") {
            self.rename_current_branch(&repo, name)?;
        } else if args.get_flag("show-current") {
            self.show_current_branch(&repo)?;
        } else {
            // If in remote mode, include the head commit id for each branch
            if repo.is_remote_mode() {
                self.list_branches_with_commits(&repo)?;
            } else {
                self.list_branches(&repo)?;
            }
        }
        Ok(())
    }
}

impl BranchCmd {
    pub async fn list_all_branches(&self, repo: &LocalRepository) -> Result<(), OxenError> {
        self.list_branches(repo)?;

        for remote in repo.remotes().iter() {
            self.list_remote_branches(repo, &remote.name).await?;
        }

        Ok(())
    }

    pub fn list_branches(&self, repo: &LocalRepository) -> Result<(), OxenError> {
        let branches = repositories::branches::list(repo)?;
        let current_branch = repositories::branches::current_branch(repo)?;

        for branch in branches.iter() {
            if current_branch.is_some() && current_branch.as_ref().unwrap().name == branch.name {
                let branch_str = format!("* {}", branch.name).green();
                println!("{branch_str}")
            } else {
                println!("  {}", branch.name)
            }
        }

        Ok(())
    }

    pub fn list_branches_with_commits(&self, repo: &LocalRepository) -> Result<(), OxenError> {
        let branches = repositories::branches::list_with_commits(repo)?;
        let current_branch = repositories::branches::current_branch(repo)?;

        for (branch, commit) in branches.iter() {
            if current_branch.is_some() && current_branch.as_ref().unwrap().name == branch.name {
                let combined_str = format!("* {}: {}", branch.name, commit.id).green();
                println!("{combined_str}")
            } else {
                println!("  {}: {}", branch.name, commit.id);
            }
        }

        Ok(())
    }

    pub fn show_current_branch(&self, repo: &LocalRepository) -> Result<(), OxenError> {
        if let Some(current_branch) = repositories::branches::current_branch(repo)? {
            println!("{}", current_branch.name);
        }
        Ok(())
    }

    pub fn create_branch(&self, repo: &LocalRepository, name: &str) -> Result<(), OxenError> {
        repositories::branches::create_from_head(repo, name)?;
        Ok(())
    }

    pub fn force_update_branch(
        &self,
        repo: &LocalRepository,
        name: &str,
        commit_id: &str,
    ) -> Result<(), OxenError> {
        log::info!("Force updating branch '{name}' to {commit_id}");
        let branch = repositories::branches::update(repo, name, commit_id)?;
        println!("Updated branch '{}' to {}", branch.name, branch.commit_id);
        Ok(())
    }

    pub fn delete_branch(&self, repo: &LocalRepository, name: &str) -> Result<(), OxenError> {
        repositories::branches::delete(repo, name)?;
        Ok(())
    }

    pub fn force_delete_branch(&self, repo: &LocalRepository, name: &str) -> Result<(), OxenError> {
        repositories::branches::force_delete(repo, name)?;
        Ok(())
    }

    pub fn rename_current_branch(
        &self,
        repo: &LocalRepository,
        name: &str,
    ) -> Result<(), OxenError> {
        repositories::branches::rename_current_branch(repo, name)?;
        Ok(())
    }

    pub async fn list_remote_branches(
        &self,
        repo: &LocalRepository,
        remote_name: &str,
    ) -> Result<(), OxenError> {
        let (scheme, host) = get_scheme_and_host_from_repo(repo)?;

        check_remote_version_blocking(scheme.clone(), host.clone()).await?;
        check_remote_version(scheme, host).await?;

        let remote = repo
            .get_remote(remote_name)
            .ok_or_else(|| OxenError::RemoteNotSet(remote_name.to_string()))?;
        let remote_repo = api::client::repositories::get_by_remote(&remote).await?;

        let branches = api::client::branches::list(&remote_repo).await?;
        for branch in branches.iter() {
            println!("{}\t{}", &remote.name, branch.name);
        }
        Ok(())
    }

    pub async fn delete_remote_branch(
        &self,
        repo: &LocalRepository,
        remote_name: &str,
        branch_name: &str,
    ) -> Result<(), OxenError> {
        let (scheme, host) = get_scheme_and_host_from_repo(repo)?;

        check_remote_version(scheme, host).await?;

        api::client::branches::delete_remote(repo, remote_name, branch_name).await?;
        Ok(())
    }
}