use core::fmt::{self, Display, Formatter};
use core::str::FromStr;
use std::collections::BTreeSet;
use std::ops::Index;
use amplify::confinement;
use amplify::confinement::Confined;
use crate::{DerivationIndex, Idx, IndexParseError, NormalIndex};
#[derive(Clone, Eq, PartialEq, Debug, Display, Error)]
#[display(doc_comments)]
pub enum DerivationParseError {
InvalidIndex(String, IndexParseError),
InvalidFormat(String),
}
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct DerivationSeg<I: Idx = NormalIndex>(Confined<BTreeSet<I>, 1, 8>);
impl<I: Idx> DerivationSeg<I> {
pub fn new(index: I) -> Self { DerivationSeg(confined_bset![index]) }
pub fn with(iter: impl IntoIterator<Item = I>) -> Result<Self, confinement::Error> {
Confined::try_from_iter(iter).map(DerivationSeg)
}
#[inline]
pub fn count(&self) -> u8 { self.0.len() as u8 }
#[inline]
pub fn is_distinct(&self, other: &Self) -> bool { self.0.is_disjoint(&other.0) }
#[inline]
pub fn at(&self, index: u8) -> Option<I> { self.0.iter().nth(index as usize).copied() }
}
impl DerivationSeg<NormalIndex> {
pub fn standard() -> Self { DerivationSeg(confined_bset![NormalIndex::ZERO, NormalIndex::ONE]) }
}
impl<I: Idx> Index<u8> for DerivationSeg<I> {
type Output = I;
fn index(&self, index: u8) -> &Self::Output {
self.0
.iter()
.nth(index as usize)
.expect("requested position in derivation segment exceeds its length")
}
}
impl<I: Idx + Display> Display for DerivationSeg<I> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if self.count() == 1 {
write!(f, "{}", self[0])
} else {
f.write_str("<")?;
let mut first = true;
for index in &self.0 {
if !first {
f.write_str(";")?;
}
write!(f, "{index}")?;
first = false;
}
f.write_str(">")
}
}
}
#[derive(Clone, Eq, PartialEq, Debug, Display, Error, From)]
#[display(doc_comments)]
pub enum SegParseError {
#[from]
InvalidFormat(IndexParseError),
#[from]
Confinement(confinement::Error),
}
impl<I: Idx> FromStr for DerivationSeg<I>
where
I: FromStr,
SegParseError: From<I::Err>,
{
type Err = SegParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let t = s.trim_start_matches('<').trim_end_matches('>');
if t.len() == s.len() - 2 {
let set = t.split(';').map(I::from_str).collect::<Result<BTreeSet<_>, _>>()?;
Ok(Self(Confined::try_from_iter(set)?))
} else {
Ok(Self(I::from_str(s).map(Confined::with)?))
}
}
}
#[derive(Wrapper, WrapperMut, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug, From)]
#[wrapper(Deref)]
#[wrapper_mut(DerefMut)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate", rename_all = "camelCase")
)]
pub struct DerivationPath<I = DerivationIndex>(Vec<I>);
impl<I: Clone> From<&[I]> for DerivationPath<I> {
fn from(path: &[I]) -> Self { Self(path.to_vec()) }
}
impl<I: Display> Display for DerivationPath<I> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
for segment in &self.0 {
f.write_str("/")?;
Display::fmt(segment, f)?;
}
Ok(())
}
}
impl<I: FromStr> FromStr for DerivationPath<I>
where IndexParseError: From<<I as FromStr>::Err>
{
type Err = DerivationParseError;
fn from_str(mut s: &str) -> Result<Self, Self::Err> {
if s.starts_with('/') {
s = &s[1..];
}
let inner = s
.split('/')
.map(I::from_str)
.collect::<Result<Vec<_>, I::Err>>()
.map_err(|err| DerivationParseError::InvalidIndex(s.to_owned(), err.into()))?;
if inner.is_empty() {
return Err(DerivationParseError::InvalidFormat(s.to_owned()));
}
Ok(Self(inner))
}
}
impl<I> IntoIterator for DerivationPath<I> {
type Item = I;
type IntoIter = std::vec::IntoIter<I>;
fn into_iter(self) -> Self::IntoIter { self.0.into_iter() }
}
impl<'path, I: Copy> IntoIterator for &'path DerivationPath<I> {
type Item = I;
type IntoIter = std::iter::Copied<std::slice::Iter<'path, I>>;
fn into_iter(self) -> Self::IntoIter { self.0.iter().copied() }
}
impl<I> FromIterator<I> for DerivationPath<I> {
fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self { Self(iter.into_iter().collect()) }
}
impl<I> DerivationPath<I> {
pub fn new() -> Self { Self(vec![]) }
}