use std::{collections::HashSet, io::Read};
use color_eyre::eyre::Result;
use git2::{Branch, DiffFormat, DiffOptions, Oid, Repository, Revwalk};
use log::trace;
use ring::digest::{Context, Digest, SHA256};
use crate::tasks::git::{branch::get_branch_name, errors::GitError as E};
pub(super) fn unmerged_commits(
repo: &Repository,
upstream: &Branch,
head: &Branch,
) -> Result<bool> {
let head_name = get_branch_name(head)?;
let upstream_name = get_branch_name(upstream)?;
let head_oid = head.get().target().ok_or(E::NoOidFound {
branch_name: head_name,
})?;
let upstream_oid = upstream.get().target().ok_or(E::NoOidFound {
branch_name: upstream_name,
})?;
let merge_base = repo.merge_base(head_oid, upstream_oid)?;
let upstream_ids = rev_list(repo, upstream_oid, merge_base)?;
let mut upstream_patch_ids = HashSet::new();
for id in upstream_ids {
let id = id?;
upstream_patch_ids.insert(patch_id(repo, id)?.as_ref().to_owned());
}
trace!("Upstream patch IDs: {upstream_patch_ids:?}");
let merge_base = repo.merge_base(head_oid, upstream_oid)?;
let head_ids: Vec<Oid> = rev_list(repo, head_oid, merge_base)?.collect::<Result<_, _>>()?;
trace!("Found head IDs: {head_ids:?}");
for id in head_ids {
let head_patch_id = patch_id(repo, id)?;
trace!("Head patch ID for '{id:?}': '{head_patch_id:?}'");
if !upstream_patch_ids.contains(head_patch_id.as_ref()) {
return Ok(true);
}
}
Ok(false)
}
fn patch_id(repo: &Repository, id: Oid) -> Result<Digest> {
let commit = repo.find_commit(id).map_err(|e| E::NoCommitFound {
oid: id.to_string(),
source: e,
})?;
let parent = commit.parent(0)?;
let mut diff_opts = DiffOptions::new();
let diff = repo.diff_tree_to_tree(
Some(&parent.tree()?),
Some(&commit.tree()?),
Some(&mut diff_opts),
)?;
let mut trimmed_diff: Vec<u8> = Vec::new();
diff.print(DiffFormat::PatchId, |delta, hunk_opt, line| -> bool {
trimmed_diff.extend(&u32_to_u8_array(delta.flags().bits()));
if let Some(hunk) = hunk_opt {
trimmed_diff.extend(hunk.header());
}
trimmed_diff.extend(line.content());
true
})?;
sha256_digest(&trimmed_diff[..])
}
#[allow(clippy::cast_possible_truncation)]
const fn u32_to_u8_array(x: u32) -> [u8; 4] {
let b1: u8 = ((x >> 24) & 0xff) as u8;
let b2: u8 = ((x >> 16) & 0xff) as u8;
let b3: u8 = ((x >> 8) & 0xff) as u8;
let b4: u8 = (x & 0xff) as u8;
[b1, b2, b3, b4]
}
fn sha256_digest<R: Read>(mut reader: R) -> Result<Digest> {
let mut context = Context::new(&SHA256);
let mut buffer = [0; 1024];
loop {
let count = reader.read(&mut buffer)?;
if count == 0 {
break;
}
context.update(&buffer[..count]);
}
Ok(context.finish())
}
fn rev_list(repo: &Repository, from: Oid, to: Oid) -> Result<Revwalk> {
let mut revwalk = repo.revwalk()?;
revwalk.push(from)?;
revwalk.hide(to)?;
Ok(revwalk)
}