use crate::Error;
use crate::HashAlgorithm;
use crate::ObjectType;
use crate::Result;
use core::cmp::Ordering;
use core::fmt::Debug;
#[cfg(feature = "hex")]
use core::fmt::Display;
use core::fmt::Formatter;
use core::fmt::Result as FmtResult;
use core::hash::Hash;
use core::hash::Hasher;
use core::marker::PhantomData;
use core::ops::Not as _;
use core::result::Result as StdResult;
use core::str::FromStr;
#[cfg(feature = "url")]
use core::str::Split;
#[cfg(feature = "std")]
use digest::block_buffer::generic_array::sequence::GenericSequence;
use digest::block_buffer::generic_array::GenericArray;
use digest::Digest;
use digest::OutputSizeUser;
#[cfg(feature = "std")]
use format_bytes::format_bytes;
#[cfg(feature = "serde")]
use serde::{
de::{Deserializer, Error as DeserializeError, Visitor},
Deserialize, Serialize, Serializer,
};
#[cfg(feature = "std")]
use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
#[cfg(feature = "async")]
use tokio::io::{
AsyncBufReadExt as _, AsyncRead, AsyncSeek, AsyncSeekExt as _, BufReader as AsyncBufReader,
};
#[cfg(feature = "url")]
use url::Url;
#[repr(C)]
pub struct GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
#[doc(hidden)]
_phantom: PhantomData<O>,
#[doc(hidden)]
value: H::Array,
}
#[cfg(feature = "url")]
const GITOID_URL_SCHEME: &str = "gitoid";
impl<H, O> GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
pub fn id_bytes<B: AsRef<[u8]>>(content: B) -> GitOid<H, O> {
fn inner<H, O>(content: &[u8]) -> GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
gitoid_from_buffer(H::new(), content, content.len()).unwrap()
}
inner(content.as_ref())
}
pub fn id_str<S: AsRef<str>>(s: S) -> GitOid<H, O> {
fn inner<H, O>(s: &str) -> GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
GitOid::id_bytes(s.as_bytes())
}
inner(s.as_ref())
}
#[cfg(feature = "std")]
pub fn id_reader<R: Read + Seek>(mut reader: R) -> Result<GitOid<H, O>> {
let expected_length = stream_len(&mut reader)? as usize;
GitOid::id_reader_with_length(reader, expected_length)
}
#[cfg(feature = "std")]
pub fn id_reader_with_length<R: Read>(
reader: R,
expected_length: usize,
) -> Result<GitOid<H, O>> {
gitoid_from_buffer(H::new(), reader, expected_length)
}
#[cfg(feature = "async")]
pub async fn id_async_reader<R: AsyncRead + AsyncSeek + Unpin>(
mut reader: R,
) -> Result<GitOid<H, O>> {
let expected_length = async_stream_len(&mut reader).await? as usize;
GitOid::id_async_reader_with_length(reader, expected_length).await
}
#[cfg(feature = "async")]
pub async fn id_async_reader_with_length<R: AsyncRead + Unpin>(
reader: R,
expected_length: usize,
) -> Result<GitOid<H, O>> {
gitoid_from_async_buffer(H::new(), reader, expected_length).await
}
#[cfg(feature = "url")]
pub fn try_from_url(url: Url) -> Result<GitOid<H, O>> {
GitOid::try_from(url)
}
#[cfg(feature = "url")]
pub fn url(&self) -> Url {
Url::parse(&self.to_string()).unwrap()
}
pub fn as_bytes(&self) -> &[u8] {
&self.value[..]
}
#[cfg(feature = "hex")]
pub fn as_hex(&self) -> String {
hex::encode(self.as_bytes())
}
pub const fn hash_algorithm(&self) -> &'static str {
H::NAME
}
pub const fn object_type(&self) -> &'static str {
O::NAME
}
pub fn hash_len(&self) -> usize {
<H::Alg as OutputSizeUser>::output_size()
}
}
impl<H, O> FromStr for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
type Err = Error;
fn from_str(s: &str) -> Result<GitOid<H, O>> {
let url = Url::parse(s)?;
GitOid::try_from_url(url)
}
}
impl<H, O> Clone for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn clone(&self) -> Self {
*self
}
}
impl<H, O> Copy for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
}
impl<H, O> PartialEq<GitOid<H, O>> for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn eq(&self, other: &GitOid<H, O>) -> bool {
self.value == other.value
}
}
impl<H, O> Eq for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
}
impl<H, O> PartialOrd<GitOid<H, O>> for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<H, O> Ord for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn cmp(&self, other: &Self) -> Ordering {
self.value.cmp(&other.value)
}
}
impl<H, O> Hash for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn hash<H2>(&self, state: &mut H2)
where
H2: Hasher,
{
self.value.hash(state);
}
}
impl<H, O> Debug for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("GitOid")
.field("value", &self.value)
.finish()
}
}
#[cfg(feature = "hex")]
impl<H, O> Display for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(
f,
"{}:{}:{}:{}",
GITOID_URL_SCHEME,
O::NAME,
H::NAME,
self.as_hex()
)
}
}
#[cfg(feature = "serde")]
impl<H, O> Serialize for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: Serializer,
{
let self_as_url_str = self.url().to_string();
serializer.serialize_str(&self_as_url_str)
}
}
#[cfg(feature = "serde")]
impl<'de, H, O> Deserialize<'de> for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn deserialize<D>(deserializer: D) -> StdResult<Self, D::Error>
where
D: Deserializer<'de>,
{
struct GitOidVisitor<H: HashAlgorithm, O: ObjectType>(PhantomData<H>, PhantomData<O>);
impl<'de, H: HashAlgorithm, O: ObjectType> Visitor<'de> for GitOidVisitor<H, O> {
type Value = GitOid<H, O>;
fn expecting(&self, formatter: &mut Formatter<'_>) -> FmtResult {
formatter.write_str("a gitoid-scheme URL")
}
fn visit_str<E>(self, value: &str) -> StdResult<Self::Value, E>
where
E: DeserializeError,
{
let url = Url::parse(value).map_err(E::custom)?;
let id = GitOid::try_from(url).map_err(E::custom)?;
Ok(id)
}
}
deserializer.deserialize_str(GitOidVisitor(PhantomData, PhantomData))
}
}
#[cfg(feature = "url")]
struct GitOidUrlParser<'u, H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
url: &'u Url,
segments: Split<'u, char>,
#[doc(hidden)]
_hash_algorithm: PhantomData<H>,
#[doc(hidden)]
_object_type: PhantomData<O>,
}
#[allow(dead_code)]
fn some_if_not_empty(s: &str) -> Option<&str> {
s.is_empty().not().then_some(s)
}
#[cfg(feature = "url")]
impl<'u, H, O> GitOidUrlParser<'u, H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
fn new(url: &'u Url) -> GitOidUrlParser<'u, H, O> {
GitOidUrlParser {
url,
segments: url.path().split(':'),
_hash_algorithm: PhantomData,
_object_type: PhantomData,
}
}
fn parse(&mut self) -> Result<GitOid<H, O>> {
self.validate_url_scheme()
.and_then(|_| self.validate_object_type())
.and_then(|_| self.validate_hash_algorithm())
.and_then(|_| self.parse_hash())
.map(|hash| GitOid {
_phantom: PhantomData,
value: H::array_from_generic(hash),
})
}
fn validate_url_scheme(&self) -> Result<()> {
if self.url.scheme() != GITOID_URL_SCHEME {
return Err(Error::InvalidScheme(self.url.clone()));
}
Ok(())
}
fn validate_object_type(&mut self) -> Result<()> {
let object_type = self
.segments
.next()
.and_then(some_if_not_empty)
.ok_or_else(|| Error::MissingObjectType(self.url.clone()))?;
if object_type != O::NAME {
return Err(Error::MismatchedObjectType { expected: O::NAME });
}
Ok(())
}
fn validate_hash_algorithm(&mut self) -> Result<()> {
let hash_algorithm = self
.segments
.next()
.and_then(some_if_not_empty)
.ok_or_else(|| Error::MissingHashAlgorithm(self.url.clone()))?;
if hash_algorithm != H::NAME {
return Err(Error::MismatchedHashAlgorithm { expected: H::NAME });
}
Ok(())
}
fn parse_hash(&mut self) -> Result<GenericArray<u8, <H::Alg as OutputSizeUser>::OutputSize>> {
let hex_str = self
.segments
.next()
.and_then(some_if_not_empty)
.ok_or_else(|| Error::MissingHash(self.url.clone()))?;
let mut value = GenericArray::generate(|_| 0);
hex::decode_to_slice(hex_str, &mut value)?;
let expected_size = <H::Alg as OutputSizeUser>::output_size();
if value.len() != expected_size {
return Err(Error::UnexpectedHashLength {
expected: expected_size,
observed: value.len(),
});
}
Ok(value)
}
}
#[cfg(feature = "url")]
impl<H, O> TryFrom<Url> for GitOid<H, O>
where
H: HashAlgorithm,
O: ObjectType,
{
type Error = Error;
fn try_from(url: Url) -> Result<GitOid<H, O>> {
GitOidUrlParser::new(&url).parse()
}
}
#[cfg(feature = "std")]
fn gitoid_from_buffer<H, O, R>(
digester: H::Alg,
reader: R,
expected_read_length: usize,
) -> Result<GitOid<H, O>>
where
H: HashAlgorithm,
O: ObjectType,
R: Read,
{
let expected_hash_length = <H::Alg as OutputSizeUser>::output_size();
let (hash, amount_read) =
hash_from_buffer::<H::Alg, O, R>(digester, reader, expected_read_length)?;
if amount_read != expected_read_length {
return Err(Error::UnexpectedReadLength {
expected: expected_read_length,
observed: amount_read,
});
}
if hash.len() != expected_hash_length {
return Err(Error::UnexpectedHashLength {
expected: expected_hash_length,
observed: hash.len(),
});
}
Ok(GitOid {
_phantom: PhantomData,
value: H::array_from_generic(hash),
})
}
#[cfg(not(feature = "std"))]
fn gitoid_from_buffer<H, O>(
digester: H::Alg,
reader: &[u8],
expected_read_length: usize,
) -> Result<GitOid<H, O>>
where
H: HashAlgorithm,
O: ObjectType,
{
let expected_hash_length = <H::Alg as OutputSizeUser>::output_size();
let (hash, amount_read) =
hash_from_buffer::<H::Alg, O>(digester, reader, expected_read_length)?;
if amount_read != expected_read_length {
return Err(Error::UnexpectedReadLength {
expected: expected_read_length,
observed: amount_read,
});
}
if hash.len() != expected_hash_length {
return Err(Error::UnexpectedHashLength {
expected: expected_hash_length,
observed: hash.len(),
});
}
Ok(GitOid {
_phantom: PhantomData,
value: H::array_from_generic(hash),
})
}
#[cfg(feature = "std")]
trait ForEachChunk: BufRead {
fn for_each_chunk(&mut self, f: impl FnMut(&[u8])) -> Result<usize>;
}
#[cfg(feature = "std")]
impl<R: BufRead> ForEachChunk for R {
fn for_each_chunk(&mut self, mut f: impl FnMut(&[u8])) -> Result<usize> {
let mut total_read = 0;
loop {
let buffer = self.fill_buf()?;
let amount_read = buffer.len();
if amount_read == 0 {
break;
}
f(buffer);
self.consume(amount_read);
total_read += amount_read;
}
Ok(total_read)
}
}
#[cfg(feature = "std")]
fn hash_from_buffer<D, O, R>(
mut digester: D,
reader: R,
expected_read_length: usize,
) -> Result<(GenericArray<u8, D::OutputSize>, usize)>
where
D: Digest,
O: ObjectType,
R: Read,
{
digester.update(format_bytes!(
b"{} {}\0",
O::NAME.as_bytes(),
expected_read_length
));
let amount_read = BufReader::new(reader).for_each_chunk(|b| digester.update(b))?;
let hash = digester.finalize();
Ok((hash, amount_read))
}
#[cfg(not(feature = "std"))]
fn hash_from_buffer<D, O>(
mut digester: D,
reader: &[u8],
expected_read_length: usize,
) -> Result<(GenericArray<u8, D::OutputSize>, usize)>
where
D: Digest,
O: ObjectType,
{
digester.update(O::NAME.as_bytes());
digester.update(b" ");
digester.update(expected_read_length.to_ne_bytes());
digester.update(b"\0");
let amount_read = reader.len();
digester.update(reader);
let hash = digester.finalize();
Ok((hash, amount_read))
}
#[cfg(feature = "async")]
async fn gitoid_from_async_buffer<H, O, R>(
digester: H::Alg,
reader: R,
expected_read_length: usize,
) -> Result<GitOid<H, O>>
where
H: HashAlgorithm,
O: ObjectType,
R: AsyncRead + Unpin,
{
let expected_hash_length = <H::Alg as OutputSizeUser>::output_size();
let (hash, amount_read) =
hash_from_async_buffer::<H::Alg, O, R>(digester, reader, expected_read_length).await?;
if amount_read != expected_read_length {
return Err(Error::UnexpectedHashLength {
expected: expected_read_length,
observed: amount_read,
});
}
if hash.len() != expected_hash_length {
return Err(Error::UnexpectedHashLength {
expected: expected_hash_length,
observed: hash.len(),
});
}
Ok(GitOid {
_phantom: PhantomData,
value: H::array_from_generic(hash),
})
}
#[cfg(feature = "async")]
async fn hash_from_async_buffer<D, O, R>(
mut digester: D,
reader: R,
expected_read_length: usize,
) -> Result<(GenericArray<u8, D::OutputSize>, usize)>
where
D: Digest,
O: ObjectType,
R: AsyncRead + Unpin,
{
digester.update(format_bytes!(
b"{} {}\0",
O::NAME.as_bytes(),
expected_read_length
));
let mut reader = AsyncBufReader::new(reader);
let mut total_read = 0;
loop {
let buffer = reader.fill_buf().await?;
let amount_read = buffer.len();
if amount_read == 0 {
break;
}
digester.update(buffer);
reader.consume(amount_read);
total_read += amount_read;
}
let hash = digester.finalize();
Ok((hash, total_read))
}
#[cfg(feature = "std")]
fn stream_len<R>(mut stream: R) -> Result<u64>
where
R: Seek,
{
let old_pos = stream.stream_position()?;
let len = stream.seek(SeekFrom::End(0))?;
if old_pos != len {
stream.seek(SeekFrom::Start(old_pos))?;
}
Ok(len)
}
#[cfg(feature = "async")]
async fn async_stream_len<R>(mut stream: R) -> Result<u64>
where
R: AsyncSeek + Unpin,
{
let old_pos = stream.stream_position().await?;
let len = stream.seek(SeekFrom::End(0)).await?;
if old_pos != len {
stream.seek(SeekFrom::Start(old_pos)).await?;
}
Ok(len)
}