#[cfg(not(feature = "std"))]
use alloc::{
string::{String, ToString},
vec,
vec::Vec,
};
use sha2::{Digest, Sha256};
use super::{BITS_PER_BYTE, BITS_PER_WORD, BitAccumulator, Count};
use crate::{
error::Error,
language::{AnyLanguage, Language},
};
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DecodeMode {
ValidateOnly,
BuildNormalizedPhrase,
}
pub struct DecodedPhrase {
pub entropy: Vec<u8>,
pub normalized_phrase: Option<String>,
}
pub fn decode_phrase<L: Language>(phrase: &str, mode: DecodeMode) -> Result<DecodedPhrase, Error> {
decode_phrase_with(AnyLanguage::of::<L>(), phrase, mode)
}
pub fn decode_phrase_with(
language: AnyLanguage,
phrase: &str,
mode: DecodeMode,
) -> Result<DecodedPhrase, Error> {
let params = parse_params_from_phrase(phrase)?;
let mut normalized_phrase = match mode {
DecodeMode::ValidateOnly => None,
DecodeMode::BuildNormalizedPhrase => {
let words = params.count.word_count();
let rough_phrase_cap = words * 8 + (words.saturating_sub(1));
Some(String::with_capacity(rough_phrase_cap))
},
};
let mut state = DecodeState::new(params);
for word in phrase.split_whitespace() {
if let Some(out) = normalized_phrase.as_mut() {
if !out.is_empty() {
out.push(' ');
}
out.push_str(word);
}
let index = match language.index_of(word) {
Some(i) => i as u64,
None => return Err(Error::UnknownWord(word.to_string())),
};
state.push_index(index);
}
let entropy = state.finish()?;
Ok(DecodedPhrase { entropy, normalized_phrase })
}
struct DecodeParams {
entropy_byte_length: usize,
checksum_bit_length: usize,
count: Count,
}
fn parse_params_from_phrase(phrase: &str) -> Result<DecodeParams, Error> {
let count = Count::from_phrase(phrase)?;
let entropy_byte_length = count.entropy_bit_length() / BITS_PER_BYTE;
let checksum_bit_length = count.checksum_bit_length();
Ok(DecodeParams { entropy_byte_length, checksum_bit_length, count })
}
struct DecodeState {
params: DecodeParams,
entropy: Vec<u8>,
entropy_out: usize,
accumulator: BitAccumulator,
actual_checksum: u8,
actual_checksum_filled: usize,
}
impl DecodeState {
fn new(params: DecodeParams) -> Self {
let entropy_byte_length = params.entropy_byte_length;
Self {
params,
entropy: vec![0u8; entropy_byte_length],
entropy_out: 0,
accumulator: BitAccumulator::new(),
actual_checksum: 0,
actual_checksum_filled: 0,
}
}
fn push_index(&mut self, index: u64) {
let entropy_byte_length = self.params.entropy_byte_length;
let checksum_bit_length = self.params.checksum_bit_length;
self.accumulator.push_bits(index, BITS_PER_WORD);
while self.entropy_out < entropy_byte_length && self.accumulator.can_take(BITS_PER_BYTE) {
self.entropy[self.entropy_out] = self.accumulator.take_bits(BITS_PER_BYTE) as u8;
self.entropy_out += 1;
}
while self.entropy_out == entropy_byte_length
&& self.actual_checksum_filled < checksum_bit_length
&& self.accumulator.can_take(1)
{
let bit = self.accumulator.take_bits(1) as u8;
self.actual_checksum = (self.actual_checksum << 1) | bit;
self.actual_checksum_filled += 1;
}
}
fn finish(self) -> Result<Vec<u8>, Error> {
debug_assert_eq!(
self.entropy_out, self.params.entropy_byte_length,
"decoded entropy length mismatch (bytes)"
);
debug_assert_eq!(
self.actual_checksum_filled, self.params.checksum_bit_length,
"decoded checksum length mismatch (bits)"
);
debug_assert_eq!(self.accumulator.bits(), 0, "trailing bits remained after decoding");
const fn checksum(source: u8, bit_length: usize) -> u8 {
source >> (BITS_PER_BYTE - bit_length)
}
let checksum_byte = Sha256::digest(&self.entropy)[0];
let expected_checksum = checksum(checksum_byte, self.params.checksum_bit_length);
if self.actual_checksum != expected_checksum {
return Err(Error::InvalidChecksum);
}
Ok(self.entropy)
}
}