use crate::error::{ImportError, Result};
use crate::types::{ImportMetadata, ImportResult, ImportFormat};
use rustywallet_mnemonic::Mnemonic;
use rustywallet_hd::{ExtendedPrivateKey, DerivationPath, Network};
#[derive(Debug, Clone)]
pub struct MnemonicImport {
pub mnemonic: String,
pub passphrase: Option<String>,
pub path: Option<String>,
pub network: Option<Network>,
}
impl MnemonicImport {
pub fn new(mnemonic: impl Into<String>) -> Self {
Self {
mnemonic: mnemonic.into(),
passphrase: None,
path: None,
network: None,
}
}
pub fn with_passphrase(mut self, passphrase: impl Into<String>) -> Self {
self.passphrase = Some(passphrase.into());
self
}
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
pub fn with_network(mut self, network: Network) -> Self {
self.network = Some(network);
self
}
}
pub mod paths {
pub const BIP44: &str = "m/44'/0'/0'/0/0";
pub const BIP49: &str = "m/49'/0'/0'/0/0";
pub const BIP84: &str = "m/84'/0'/0'/0/0";
}
pub fn import_mnemonic(config: MnemonicImport) -> Result<ImportResult> {
let mnemonic_str = config.mnemonic.trim();
let mnemonic = Mnemonic::from_phrase(mnemonic_str)
.map_err(|e| ImportError::InvalidMnemonic(format!("{}", e)))?;
let word_count = mnemonic_str.split_whitespace().count();
let passphrase = config.passphrase.as_deref().unwrap_or("");
let seed = mnemonic.to_seed(passphrase);
let network = config.network.unwrap_or(Network::Mainnet);
let master = ExtendedPrivateKey::from_seed(seed.as_bytes(), network)
.map_err(|e| ImportError::KeyDerivationFailed(format!("{}", e)))?;
let path_str = config.path.as_deref().unwrap_or(paths::BIP44);
let path: DerivationPath = path_str.parse()
.map_err(|e| ImportError::KeyDerivationFailed(format!("Invalid path: {}", e)))?;
let derived = master.derive_path(&path)
.map_err(|e| ImportError::KeyDerivationFailed(format!("{}", e)))?;
let private_key = derived.private_key()
.map_err(|e| ImportError::KeyDerivationFailed(format!("{}", e)))?;
let metadata = ImportMetadata {
derivation_path: Some(path_str.to_string()),
word_count: Some(word_count),
has_passphrase: !passphrase.is_empty(),
};
Ok(ImportResult::new(private_key, ImportFormat::Mnemonic)
.with_network(network)
.with_compressed(true)
.with_metadata(metadata))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_import_12_word_mnemonic() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let config = MnemonicImport::new(mnemonic);
let result = import_mnemonic(config).unwrap();
assert_eq!(result.format, ImportFormat::Mnemonic);
assert_eq!(result.metadata.word_count, Some(12));
assert!(!result.metadata.has_passphrase);
}
#[test]
fn test_import_with_passphrase() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let config = MnemonicImport::new(mnemonic)
.with_passphrase("TREZOR");
let result = import_mnemonic(config).unwrap();
assert!(result.metadata.has_passphrase);
}
#[test]
fn test_import_with_custom_path() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let config = MnemonicImport::new(mnemonic)
.with_path("m/84'/0'/0'/0/0");
let result = import_mnemonic(config).unwrap();
assert_eq!(result.metadata.derivation_path, Some("m/84'/0'/0'/0/0".to_string()));
}
#[test]
fn test_invalid_mnemonic() {
let mnemonic = "invalid words that are not a valid mnemonic phrase at all";
let config = MnemonicImport::new(mnemonic);
let result = import_mnemonic(config);
assert!(matches!(result, Err(ImportError::InvalidMnemonic(_))));
}
#[test]
fn test_deterministic() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let config1 = MnemonicImport::new(mnemonic).with_path("m/44'/0'/0'/0/0");
let config2 = MnemonicImport::new(mnemonic).with_path("m/44'/0'/0'/0/0");
let result1 = import_mnemonic(config1).unwrap();
let result2 = import_mnemonic(config2).unwrap();
assert_eq!(result1.private_key.to_bytes(), result2.private_key.to_bytes());
}
}