use crate::error::{ImportError, Result};
use rustywallet_keys::prelude::{Network, PrivateKey};
use sha2::{Sha256, Digest};
pub fn import_wif(wif: &str) -> Result<(PrivateKey, Network, bool)> {
let wif = wif.trim();
let decoded = bs58::decode(wif)
.into_vec()
.map_err(|e| ImportError::InvalidWif(format!("Base58 decode failed: {}", e)))?;
if decoded.len() < 37 || decoded.len() > 38 {
return Err(ImportError::InvalidWif(format!(
"Invalid length: expected 37-38 bytes, got {}",
decoded.len()
)));
}
let checksum_pos = decoded.len() - 4;
let payload = &decoded[..checksum_pos];
let checksum = &decoded[checksum_pos..];
let hash1 = Sha256::digest(payload);
let hash2 = Sha256::digest(hash1);
if &hash2[..4] != checksum {
return Err(ImportError::InvalidChecksum);
}
let version = decoded[0];
let (network, compressed) = match version {
0x80 => {
let compressed = decoded.len() == 38 && decoded[33] == 0x01;
(Network::Mainnet, compressed)
}
0xEF => {
let compressed = decoded.len() == 38 && decoded[33] == 0x01;
(Network::Testnet, compressed)
}
_ => {
return Err(ImportError::InvalidWif(format!(
"Unknown version byte: 0x{:02x}",
version
)));
}
};
let key_bytes: [u8; 32] = decoded[1..33]
.try_into()
.map_err(|_| ImportError::InvalidWif("Invalid key length".to_string()))?;
let private_key = PrivateKey::from_bytes(key_bytes)
.map_err(|e| ImportError::InvalidWif(format!("Invalid key: {}", e)))?;
Ok((private_key, network, compressed))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_import_uncompressed_mainnet() {
let wif = "5HueCGU8rMjxEXxiPuD5BDku4MkFqeZyd4dZ1jvhTVqvbTLvyTJ";
let (key, network, compressed) = import_wif(wif).unwrap();
assert_eq!(network, Network::Mainnet);
assert!(!compressed);
let expected_hex = "0c28fca386c7a227600b2fe50b7cae11ec86d3bf1fbe471be89827e19d72aa1d";
assert_eq!(hex::encode(key.to_bytes()), expected_hex);
}
#[test]
fn test_import_compressed_mainnet() {
let wif = "KwdMAjGmerYanjeui5SHS7JkmpZvVipYvB2LJGU1ZxJwYvP98617";
let (key, network, compressed) = import_wif(wif).unwrap();
assert_eq!(network, Network::Mainnet);
assert!(compressed);
}
#[test]
fn test_invalid_checksum() {
let wif = "5HueCGU8rMjxEXxiPuD5BDku4MkFqeZyd4dZ1jvhTVqvbTLvyTK";
let result = import_wif(wif);
assert!(matches!(result, Err(ImportError::InvalidChecksum)));
}
#[test]
fn test_invalid_base58() {
let wif = "5HueCGU8rMjxEXxiPuD5BDku4MkFqeZyd4dZ1jvhTVqvbTLvy0O"; let result = import_wif(wif);
assert!(result.is_err());
}
}