use std::collections::HashMap;
use serde::Deserialize;
use crate::hd::{ChainCode, HdWallet};
use crate::server_client::Client;
use crate::core::asset::WalletAsset;
use crate::core::recovery::{RecoveredOutput, RecoveryError, RecoveryReport};
#[derive(Debug, Deserialize)]
struct HealthEnvelope {
#[serde(default)]
#[allow(dead_code)]
status: String,
results: HashMap<String, HealthEntry>,
}
#[derive(Debug, Deserialize)]
struct HealthEntry {
spent: Option<bool>,
#[serde(default)]
amount: Option<String>,
}
pub fn recover<A: WalletAsset>(
client: &Client,
hd: &HdWallet,
namespace: &A::Namespace,
gap_limit: u64,
reported_depths: &HashMap<ChainCode, u64>,
) -> Result<RecoveryReport<A>, RecoveryError> {
if gap_limit == 0 {
return Err(RecoveryError::InvalidGapLimit);
}
let mut report = RecoveryReport::<A>::empty();
for &chain in &ChainCode::ALL {
scan_chain::<A>(
client,
hd,
namespace,
chain,
gap_limit,
reported_depths.get(&chain).copied().unwrap_or(0),
&mut report,
)?;
}
Ok(report)
}
fn scan_chain<A: WalletAsset>(
client: &Client,
hd: &HdWallet,
namespace: &A::Namespace,
chain: ChainCode,
gap_limit: u64,
reported_depth: u64,
report: &mut RecoveryReport<A>,
) -> Result<(), RecoveryError> {
let mut current_depth = 0u64;
let mut consecutive_empty: u64 = 0;
let mut chain_max_used: Option<u64> = None;
loop {
let mut by_hash: HashMap<String, (String, u64)> = HashMap::new();
let mut publics: Vec<String> = Vec::with_capacity(gap_limit as usize);
for offset in 0..gap_limit {
let depth = current_depth + offset;
let secret_hex = hd.derive_secret(chain, depth);
let public = A::public_token_for_lookup(&secret_hex, namespace);
let hash = sha256_hex_of_ascii(&secret_hex);
by_hash.insert(hash, (secret_hex, depth));
publics.push(public);
}
let raw = client
.health_check(&publics)
.map_err(|source| RecoveryError::Server {
chain: chain.as_str(),
depth: current_depth,
source,
})?;
let env: HealthEnvelope = serde_json::from_str(&raw)
.map_err(|e| RecoveryError::Decode(format!("{chain:?}@{current_depth}: {e}")))?;
let mut batch_had_hit = false;
for (resp_key, entry) in &env.results {
let Some(hash) = A::extract_hash_from_response_key(resp_key) else {
continue;
};
let Some((secret_hex, depth)) = by_hash.get(hash) else {
continue;
};
if entry.spent.is_some() {
batch_had_hit = true;
chain_max_used = Some(chain_max_used.map_or(*depth, |d| d.max(*depth)));
}
if entry.spent == Some(false) {
let amount_wats = if A::SERVER_REPORTS_AMOUNT {
let parsed = entry
.amount
.as_deref()
.map(parse_decimal_to_wats)
.transpose()
.map_err(|e| RecoveryError::Decode(format!("amount: {e}")))?
.ok_or_else(|| {
RecoveryError::Decode(format!(
"{chain:?}@{depth}: server reported spent: false without amount"
))
})?;
Some(parsed)
} else {
None
};
report.recovered.push(RecoveredOutput {
secret_hex: secret_hex.clone(),
hash: hash.to_string(),
amount_wats,
chain,
depth: *depth,
namespace: namespace.clone(),
});
}
}
if batch_had_hit {
consecutive_empty = 0;
} else {
consecutive_empty = consecutive_empty.saturating_add(gap_limit);
}
let next_depth = current_depth.saturating_add(gap_limit);
let past_reported = next_depth > reported_depth;
if !batch_had_hit && consecutive_empty >= gap_limit && past_reported {
break;
}
current_depth = next_depth;
}
if let Some(d) = chain_max_used {
report.last_used_depth.insert(chain, d);
}
Ok(())
}
fn sha256_hex_of_ascii(s: &str) -> String {
use sha2::{Digest, Sha256};
hex::encode(Sha256::digest(s.as_bytes()))
}
fn parse_decimal_to_wats(s: &str) -> Result<i64, String> {
const SCALE: i64 = 100_000_000;
let s = s.trim();
let (neg, body) = match s.strip_prefix('-') {
Some(rest) => (true, rest),
None => (false, s),
};
let (whole_str, frac_str) = match body.split_once('.') {
Some((w, f)) => (w, f),
None => (body, ""),
};
if whole_str.is_empty() || !whole_str.chars().all(|c| c.is_ascii_digit()) {
return Err(format!("malformed whole part: {s:?}"));
}
if !frac_str.chars().all(|c| c.is_ascii_digit()) {
return Err(format!("malformed fractional part: {s:?}"));
}
if frac_str.len() > 8 {
return Err(format!("more than 8 fractional digits: {s:?}"));
}
let whole: i64 = whole_str
.parse()
.map_err(|_| format!("whole part overflows i64: {s:?}"))?;
let mut frac_padded = String::with_capacity(8);
frac_padded.push_str(frac_str);
while frac_padded.len() < 8 {
frac_padded.push('0');
}
let frac: i64 = frac_padded
.parse()
.map_err(|_| format!("frac overflows i64: {s:?}"))?;
let total = whole
.checked_mul(SCALE)
.and_then(|w| w.checked_add(frac))
.ok_or_else(|| format!("amount overflows i64 wats: {s:?}"))?;
Ok(if neg { -total } else { total })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_decimal_handles_production_shapes() {
assert_eq!(parse_decimal_to_wats("1").unwrap(), 100_000_000);
assert_eq!(parse_decimal_to_wats("0.4").unwrap(), 40_000_000);
assert_eq!(parse_decimal_to_wats("195.3125").unwrap(), 19_531_250_000);
assert_eq!(parse_decimal_to_wats("0").unwrap(), 0);
assert_eq!(parse_decimal_to_wats("0.00000001").unwrap(), 1);
}
#[test]
fn parse_decimal_rejects_garbage() {
assert!(parse_decimal_to_wats("").is_err());
assert!(parse_decimal_to_wats("abc").is_err());
assert!(parse_decimal_to_wats("1.234567890").is_err()); assert!(parse_decimal_to_wats(".5").is_err()); }
#[test]
fn sha256_matches_server() {
let s = "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234";
let h = sha256_hex_of_ascii(s);
assert_eq!(h.len(), 64);
use sha2::{Digest, Sha256};
let manual = hex::encode(Sha256::digest(s.as_bytes()));
assert_eq!(h, manual);
}
}