Skip to main content

bittensor_wallet/
keyfile.rs

1use std::collections::HashMap;
2use std::env;
3use std::fs;
4use std::io::{Read, Write};
5use std::os::unix::fs::PermissionsExt;
6use std::path::PathBuf;
7use std::str::from_utf8;
8
9use ansible_vault::{decrypt_vault, encrypt_vault};
10use fernet::Fernet;
11
12use base64::{engine::general_purpose, Engine as _};
13use passwords::analyzer;
14use passwords::scorer;
15use serde_json::json;
16
17use crate::constants::CRYPTO_SR25519;
18use crate::errors::KeyFileError;
19use crate::keypair::Keypair;
20use crate::utils;
21
22use sodiumoxide::crypto::pwhash;
23use sodiumoxide::crypto::secretbox;
24
25const NACL_SALT: &[u8] = b"\x13q\x83\xdf\xf1Z\t\xbc\x9c\x90\xb5Q\x879\xe9\xb1";
26const LEGACY_SALT: &[u8] = b"Iguesscyborgslikemyselfhaveatendencytobeparanoidaboutourorigins";
27
28/// Serializes keypair object into keyfile data.
29///
30/// ```text
31///     Arguments:
32///         keypair (Keypair): The keypair object to be serialized.
33///     Returns:
34///         data (bytes): Serialized keypair data.
35/// ```
36pub fn serialized_keypair_to_keyfile_data(keypair: &Keypair) -> Result<Vec<u8>, KeyFileError> {
37    let mut data: HashMap<&str, serde_json::Value> = HashMap::new();
38
39    // publicKey and privateKey fields are optional. If they exist, hex prefix "0x" is added to them.
40    if let Ok(Some(public_key)) = keypair.public_key() {
41        let public_key_str = hex::encode(&public_key);
42        data.insert("accountId", json!(format!("0x{}", public_key_str)));
43        data.insert("publicKey", json!(format!("0x{}", public_key_str)));
44    }
45    if let Ok(Some(private_key)) = keypair.private_key() {
46        let private_key_str = hex::encode(&private_key);
47        data.insert("privateKey", json!(format!("0x{}", private_key_str)));
48    }
49
50    // mnemonic and ss58_address fields are optional.
51    if let Some(mnemonic) = keypair.mnemonic() {
52        data.insert("secretPhrase", json!(mnemonic.to_string()));
53    }
54
55    // the seed_hex field is optional. If it exists, hex prefix "0x" is added to it.
56    if let Some(seed_hex) = keypair.seed_hex() {
57        let seed_hex_str = match from_utf8(&seed_hex) {
58            Ok(s) => s.to_string(),
59            Err(_) => hex::encode(seed_hex),
60        };
61        data.insert("secretSeed", json!(format!("0x{}", seed_hex_str)));
62    }
63
64    if let Some(ss58_address) = keypair.ss58_address() {
65        data.insert("ss58Address", json!(ss58_address.to_string()));
66    }
67
68    data.insert("cryptoType", json!(keypair.crypto_type()));
69
70    // Serialize the data into JSON string and return it as bytes
71    let json_data = serde_json::to_string(&data)
72        .map_err(|e| KeyFileError::SerializationError(format!("Serialization error: {}", e)))?;
73    Ok(json_data.into_bytes())
74}
75
76/// Deserializes Keypair object from passed keyfile data.
77///
78/// ```text
79///     Arguments:
80///         keyfile_data (PyBytes): The keyfile data to be loaded.
81///     Returns:
82///         keypair (Keypair): The Keypair loaded from bytes.
83///     Raises:
84///         KeyFileError: Raised if the passed PyBytes cannot construct a keypair object.
85/// ```
86pub fn deserialize_keypair_from_keyfile_data(keyfile_data: &[u8]) -> Result<Keypair, KeyFileError> {
87    let decoded = from_utf8(keyfile_data).map_err(|_| {
88        KeyFileError::DeserializationError("Failed to decode keyfile data.".to_string())
89    })?;
90
91    let keyfile_dict: serde_json::Value = serde_json::from_str(decoded).map_err(|_| {
92        KeyFileError::DeserializationError("Failed to parse keyfile data.".to_string())
93    })?;
94
95    let crypto_type = keyfile_dict
96        .get("cryptoType")
97        .and_then(|v| v.as_u64())
98        .map(|v| v as u8)
99        .unwrap_or(CRYPTO_SR25519);
100
101    let secret_phrase = keyfile_dict
102        .get("secretPhrase")
103        .and_then(|v| v.as_str())
104        .map(String::from);
105    let secret_seed = keyfile_dict
106        .get("secretSeed")
107        .and_then(|v| v.as_str())
108        .map(String::from);
109    let private_key = keyfile_dict
110        .get("privateKey")
111        .and_then(|v| v.as_str())
112        .map(String::from);
113    let ss58_address = keyfile_dict
114        .get("ss58Address")
115        .and_then(|v| v.as_str())
116        .map(String::from);
117
118    if let Some(secret_phrase) = secret_phrase {
119        Keypair::create_from_mnemonic(&secret_phrase, crypto_type).map_err(KeyFileError::Generic)
120    } else if let Some(seed) = secret_seed {
121        let seed = seed.trim_start_matches("0x");
122        let seed_bytes = hex::decode(seed).map_err(|e| KeyFileError::Generic(e.to_string()))?;
123        Keypair::create_from_seed(seed_bytes, crypto_type).map_err(KeyFileError::Generic)
124    } else if let Some(private_key) = private_key {
125        let key = private_key.trim_start_matches("0x");
126        Keypair::create_from_private_key(key, crypto_type).map_err(KeyFileError::Generic)
127    } else if let Some(ss58) = ss58_address {
128        Keypair::new(Some(ss58), None, None, 42, None, crypto_type).map_err(KeyFileError::Generic)
129    } else {
130        Err(KeyFileError::Generic(
131            "Keypair could not be created from keyfile data.".to_string(),
132        ))
133    }
134}
135
136/// Validates the password against a password policy.
137///
138/// ```text
139///     Arguments:
140///         password (str): The password to verify.
141///     Returns:
142///         valid (bool): ``True`` if the password meets validity requirements.
143/// ```
144pub fn validate_password(password: &str) -> Result<bool, KeyFileError> {
145    // Check for an empty password
146    if password.is_empty() {
147        return Ok(false);
148    }
149
150    // Define the password policy
151    let min_length = 6;
152    let min_score = 20.0; // Adjusted based on the scoring system described in the documentation
153
154    // Analyze the password
155    let analyzed = analyzer::analyze(password);
156    let score = scorer::score(&analyzed);
157
158    // Check conditions
159    if password.len() >= min_length && score >= min_score {
160        // Prompt user to retype the password
161        let password_verification_response =
162            utils::prompt_password("Retype your password: ".to_string())
163                .expect("Failed to read the password.");
164
165        // Remove potential newline or whitespace at the end
166        let password_verification = password_verification_response.trim();
167
168        if password == password_verification {
169            Ok(true)
170        } else {
171            utils::print("Passwords do not match.\n".to_string());
172            Ok(false)
173        }
174    } else {
175        utils::print("Password not strong enough. Try increasing the length of the password or the password complexity.\n".to_string());
176        Ok(false)
177    }
178}
179
180/// Prompts the user to enter a password for key encryption.
181///
182/// ```text
183///     Arguments:
184///         validation_required (bool): If ``True``, validates the password against policy requirements.
185///     Returns:
186///         password (str): The valid password entered by the user.
187/// ```
188pub fn ask_password(validation_required: bool) -> Result<String, KeyFileError> {
189    let mut valid = false;
190    let mut password = utils::prompt_password("Enter your password: ".to_string());
191
192    if validation_required {
193        while !valid {
194            if let Some(ref pwd) = password {
195                valid = validate_password(pwd)?;
196                if !valid {
197                    password = utils::prompt_password("Enter your password again: ".to_string());
198                }
199            } else {
200                valid = true
201            }
202        }
203    }
204
205    Ok(password.unwrap_or("".to_string()).trim().to_string())
206}
207
208/// Returns `true` if the keyfile data is NaCl encrypted.
209///
210/// ```text
211///     Arguments:
212///         `keyfile_data` - Bytes to validate
213///     Returns:
214///         `is_nacl` - `true` if the data is ansible encrypted.
215/// ```
216pub fn keyfile_data_is_encrypted_nacl(keyfile_data: &[u8]) -> bool {
217    keyfile_data.starts_with(b"$NACL")
218}
219
220/// Returns true if the keyfile data is ansible encrypted.
221///
222/// ```text
223///     Arguments:
224///         `keyfile_data` - The bytes to validate.
225///     Returns:
226///         `is_ansible` - ``True`` if the data is ansible encrypted.
227/// ```
228pub fn keyfile_data_is_encrypted_ansible(keyfile_data: &[u8]) -> bool {
229    keyfile_data.starts_with(b"$ANSIBLE_VAULT")
230}
231
232/// Returns true if the keyfile data is legacy encrypted.
233///
234/// ```text
235///     Arguments:
236///         `keyfile_data` - The bytes to validate.
237///     Returns:
238///         `is_legacy` - `true` if the data is legacy encrypted.
239/// ```
240pub fn keyfile_data_is_encrypted_legacy(keyfile_data: &[u8]) -> bool {
241    keyfile_data.starts_with(b"gAAAAA")
242}
243
244/// Returns `true` if the keyfile data is encrypted.
245///
246/// ```text
247///     Arguments:
248///         keyfile_data (bytes): The bytes to validate.
249///     Returns:
250///         is_encrypted (bool): `true` if the data is encrypted.
251/// ```
252pub fn keyfile_data_is_encrypted(keyfile_data: &[u8]) -> bool {
253    let nacl = keyfile_data_is_encrypted_nacl(keyfile_data);
254    let ansible = keyfile_data_is_encrypted_ansible(keyfile_data);
255    let legacy = keyfile_data_is_encrypted_legacy(keyfile_data);
256    nacl || ansible || legacy
257}
258
259/// Returns type of encryption method as a string.
260///
261/// ```text
262///     Arguments:
263///         keyfile_data (bytes): Bytes to validate.
264///     Returns:
265///         (str): A string representing the name of encryption method.
266/// ```
267pub fn keyfile_data_encryption_method(keyfile_data: &[u8]) -> String {
268    if keyfile_data_is_encrypted_nacl(keyfile_data) {
269        "NaCl"
270    } else if keyfile_data_is_encrypted_ansible(keyfile_data) {
271        "Ansible Vault"
272    } else if keyfile_data_is_encrypted_legacy(keyfile_data) {
273        "legacy"
274    } else {
275        "unknown"
276    }
277    .to_string()
278}
279
280/// legacy_encrypt_keyfile_data.
281///
282/// ```text
283///     Arguments:
284///         keyfile_data (bytes): Bytes of data from the keyfile.
285///         password (str): Optional string that represents the password.
286///     Returns:
287///         encrypted_data (bytes): The encrypted keyfile data in bytes.
288/// ```
289pub fn legacy_encrypt_keyfile_data(
290    keyfile_data: &[u8],
291    password: Option<String>,
292) -> Result<Vec<u8>, KeyFileError> {
293    let password = password.unwrap_or_else(||
294        // function to get password from user
295        ask_password(true).unwrap());
296
297    utils::print(
298        ":exclamation_mark: Encrypting key with legacy encryption method...\n".to_string(),
299    );
300
301    // Encrypting key with legacy encryption method
302    let encrypted_data = encrypt_vault(keyfile_data, password.as_str())
303        .map_err(|err| KeyFileError::EncryptionError(format!("{}", err)))?;
304
305    Ok(encrypted_data.into_bytes())
306}
307
308/// Retrieves the cold key password from the environment variables.
309///
310/// ```text
311///     Arguments:
312///         `coldkey_name` - The name of the cold key.
313///     Returns:
314///         `Option<String>` - The password retrieved from the environment variables, or `None` if not found.
315/// ```
316pub fn get_password_from_environment(env_var_name: String) -> Result<Option<String>, KeyFileError> {
317    match env::var(&env_var_name) {
318        Ok(encrypted_password_base64) => {
319            let encrypted_password = general_purpose::STANDARD
320                .decode(&encrypted_password_base64)
321                .map_err(|_| KeyFileError::Base64DecodeError("Invalid Base64".to_string()))?;
322            let decrypted_password = decrypt_password(&encrypted_password, &env_var_name);
323            Ok(Some(decrypted_password))
324        }
325        Err(_) => Ok(None),
326    }
327}
328
329/// decrypt of keyfile_data with secretbox
330fn derive_key(password: &[u8]) -> secretbox::Key {
331    let nacl_salt = pwhash::argon2i13::Salt::from_slice(NACL_SALT).expect("Invalid NACL_SALT.");
332    let mut key = secretbox::Key([0; secretbox::KEYBYTES]);
333    pwhash::argon2i13::derive_key(
334        &mut key.0,
335        password,
336        &nacl_salt,
337        pwhash::argon2i13::OPSLIMIT_SENSITIVE,
338        pwhash::argon2i13::MEMLIMIT_SENSITIVE,
339    )
340    .expect("Failed to derive key for NaCl decryption.");
341    key
342}
343
344/// Encrypts the passed keyfile data using ansible vault.
345///
346/// ```text
347///     Arguments:
348///         keyfile_data (bytes): The bytes to encrypt.
349///         password (str): The password used to encrypt the data. If `None`, asks for user input.
350///     Returns:
351///         encrypted_data (bytes): The encrypted data.
352/// ```
353pub fn encrypt_keyfile_data(
354    keyfile_data: &[u8],
355    password: Option<String>,
356) -> Result<Vec<u8>, KeyFileError> {
357    // get password or ask user
358    let password = match password {
359        Some(pwd) => pwd,
360        None => ask_password(true)?,
361    };
362
363    utils::print("Encrypting...\n".to_string());
364
365    // crate the key with pwhash Argon2i
366    let key = derive_key(password.as_bytes());
367
368    // encrypt the data using SecretBox
369    let nonce = secretbox::gen_nonce();
370    let encrypted_data = secretbox::seal(keyfile_data, &nonce, &key);
371
372    // concatenate with b"$NACL"
373    let mut result = b"$NACL".to_vec();
374    result.extend_from_slice(&nonce.0);
375    result.extend_from_slice(&encrypted_data);
376
377    Ok(result)
378}
379
380/// Decrypts the passed keyfile data using ansible vault.
381///
382/// ```text
383///     Arguments:
384///         keyfile_data (): The bytes to decrypt.
385///         password (str): The password used to decrypt the data. If `None`, asks for user input.
386///         coldkey_name (str): The name of the cold key. If provided, retrieves the password from environment variables.
387///     Returns:
388///         decrypted_data (bytes): The decrypted data.
389/// ```
390pub fn decrypt_keyfile_data(
391    keyfile_data: &[u8],
392    password: Option<String>,
393    password_env_var: Option<String>,
394) -> Result<Vec<u8>, KeyFileError> {
395    // decrypt of keyfile_data with secretbox
396    fn nacl_decrypt(keyfile_data: &[u8], key: &secretbox::Key) -> Result<Vec<u8>, KeyFileError> {
397        let data = &keyfile_data[5..]; // Remove the $NACL prefix
398        let nonce = secretbox::Nonce::from_slice(&data[0..secretbox::NONCEBYTES]).ok_or(
399            KeyFileError::InvalidEncryption("Invalid nonce.".to_string()),
400        )?;
401        let ciphertext = &data[secretbox::NONCEBYTES..];
402        secretbox::open(ciphertext, &nonce, key).map_err(|_| {
403            KeyFileError::DecryptionError("Wrong password for nacl decryption.".to_string())
404        })
405    }
406    // decrypt of keyfile_data with legacy way
407    fn legacy_decrypt(password: &str, keyfile_data: &[u8]) -> Result<Vec<u8>, KeyFileError> {
408        let kdf = pbkdf2::pbkdf2_hmac::<sha2::Sha256>;
409        let mut key = vec![0; 32];
410        kdf(password.as_bytes(), LEGACY_SALT, 10000000, &mut key);
411
412        let fernet_key = Fernet::generate_key();
413        let fernet = Fernet::new(&fernet_key).unwrap();
414        let keyfile_data_str = from_utf8(keyfile_data)
415            .map_err(|e| KeyFileError::DeserializationError(e.to_string()))?;
416        fernet.decrypt(keyfile_data_str).map_err(|_| {
417            KeyFileError::DecryptionError("Wrong password for legacy decryption.".to_string())
418        })
419    }
420
421    let mut password = password;
422
423    // Retrieve password from environment variable if env_var_name is provided
424    if let Some(env_var_name_) = password_env_var {
425        if password.is_none() {
426            password = get_password_from_environment(env_var_name_)?;
427        }
428    }
429
430    // If password is still None, ask the user for input
431    if password.is_none() {
432        password = Some(ask_password(false)?);
433    }
434
435    let password = password.unwrap();
436
437    utils::print("Decrypting...\n".to_string());
438    // NaCl decryption
439    if keyfile_data_is_encrypted_nacl(keyfile_data) {
440        let key = derive_key(password.as_bytes());
441        let decrypted_data = nacl_decrypt(keyfile_data, &key).map_err(|_| {
442            KeyFileError::DecryptionError("Wrong password for decryption.".to_string())
443        })?;
444        return Ok(decrypted_data);
445    }
446
447    // Ansible Vault decryption
448    if keyfile_data_is_encrypted_ansible(keyfile_data) {
449        let decrypted_data = decrypt_vault(keyfile_data, password.as_str()).map_err(|_| {
450            KeyFileError::DecryptionError("Wrong password for decryption.".to_string())
451        })?;
452        return Ok(decrypted_data);
453    }
454
455    // Legacy decryption
456    if keyfile_data_is_encrypted_legacy(keyfile_data) {
457        let decrypted_data = legacy_decrypt(&password, keyfile_data).map_err(|_| {
458            KeyFileError::DecryptionError("Wrong password for decryption.".to_string())
459        })?;
460        return Ok(decrypted_data);
461    }
462
463    // If none of the methods work, raise error
464    Err(KeyFileError::InvalidEncryption(
465        "Invalid or unknown encryption method.".to_string(),
466    ))
467}
468
469fn confirm_prompt(question: &str) -> bool {
470    let choice = utils::prompt(format!("{} (y/N): ", question)).expect("Failed to read input.");
471    choice.trim().to_lowercase() == "y"
472}
473
474fn expand_tilde(path: &str) -> String {
475    if path.starts_with("~/") {
476        if let Some(home_dir) = dirs::home_dir() {
477            return path.replacen('~', home_dir.to_str().unwrap(), 1);
478        }
479    }
480    path.to_string()
481}
482
483// Encryption password
484fn encrypt_password(key: &str, value: &str) -> Vec<u8> {
485    let key_bytes = key.as_bytes();
486    value
487        .as_bytes()
488        .iter()
489        .enumerate()
490        .map(|(i, &c)| c ^ key_bytes[i % key_bytes.len()])
491        .collect()
492}
493
494// Decrypting password
495fn decrypt_password(data: &[u8], key: &str) -> String {
496    let key_bytes = key.as_bytes();
497    let decrypted_bytes: Vec<u8> = data
498        .iter()
499        .enumerate()
500        .map(|(i, &c)| c ^ key_bytes[i % key_bytes.len()])
501        .collect();
502    String::from_utf8(decrypted_bytes).unwrap_or_else(|_| String::new())
503}
504
505#[derive(Clone)]
506pub struct Keyfile {
507    pub path: String,
508    _path: PathBuf,
509    name: String,
510    should_save_to_env: bool,
511}
512impl std::fmt::Display for Keyfile {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        match self.__str__() {
515            Ok(s) => write!(f, "{}", s),
516            Err(e) => write!(f, "Error displaying keyfile: {}", e),
517        }
518    }
519}
520
521impl Keyfile {
522    /// Creates a new Keyfile instance.
523    ///
524    /// ```text
525    ///     Arguments:
526    ///         path (String): The file system path where the keyfile is stored.
527    ///         name (Option<String>): Optional name for the keyfile. Defaults to "Keyfile" if not provided.
528    ///         should_save_to_env (bool): If ``True``, saves the password to environment variables.
529    ///     Returns:
530    ///         keyfile (Keyfile): A new Keyfile instance.
531    /// ```
532    pub fn new(
533        path: String,
534        name: Option<String>,
535        should_save_to_env: bool,
536    ) -> Result<Self, KeyFileError> {
537        let expanded_path: PathBuf = PathBuf::from(expand_tilde(&path));
538        let name = name.unwrap_or_else(|| "Keyfile".to_string());
539        Ok(Keyfile {
540            path,
541            _path: expanded_path,
542            name,
543            should_save_to_env,
544        })
545    }
546
547    #[allow(clippy::bool_comparison)]
548    fn __str__(&self) -> Result<String, KeyFileError> {
549        if self.exists_on_device()? != true {
550            Ok(format!("keyfile (empty, {})>", self.path))
551        } else if self.is_encrypted()? {
552            let encryption_method = self._read_keyfile_data_from_file()?;
553            Ok(format!(
554                "Keyfile ({:?} encrypted, {})>",
555                encryption_method, self.path
556            ))
557        } else {
558            Ok(format!("keyfile (decrypted, {})>", self.path))
559        }
560    }
561
562    fn __repr__(&self) -> Result<String, KeyFileError> {
563        self.__str__()
564    }
565
566    /// Returns the keypair from path, decrypts data if the file is encrypted.
567    ///
568    /// ```text
569    ///     Arguments:
570    ///         password (Option<String>): The password used to decrypt the data. If ``None``, asks for user input.
571    ///     Returns:
572    ///         keypair (Keypair): The Keypair loaded from the file.
573    /// ```
574    pub fn get_keypair(&self, password: Option<String>) -> Result<Keypair, KeyFileError> {
575        // read file
576        let keyfile_data = self._read_keyfile_data_from_file()?;
577
578        // check if encrypted
579        let decrypted_keyfile_data = if keyfile_data_is_encrypted(&keyfile_data) {
580            decrypt_keyfile_data(&keyfile_data, password, Some(self.env_var_name()?))?
581        } else {
582            keyfile_data
583        };
584
585        // deserialization data into the Keypair
586        deserialize_keypair_from_keyfile_data(&decrypted_keyfile_data)
587    }
588
589    /// Loads the name from keyfile.name or raises an error.
590    pub fn get_name(&self) -> Result<String, KeyFileError> {
591        Ok(self.name.clone())
592    }
593
594    /// Loads the name from keyfile.path or raises an error.
595    pub fn get_path(&self) -> Result<String, KeyFileError> {
596        Ok(self.path.clone())
597    }
598
599    /// Returns the keyfile data under path.
600    pub fn data(&self) -> Result<Vec<u8>, KeyFileError> {
601        self._read_keyfile_data_from_file()
602    }
603
604    /// Returns the keyfile data under path.
605    pub fn keyfile_data(&self) -> Result<Vec<u8>, KeyFileError> {
606        self._read_keyfile_data_from_file()
607    }
608
609    /// Returns local environment variable key name based on Keyfile path.
610    pub fn env_var_name(&self) -> Result<String, KeyFileError> {
611        let path = &self.path.replace([std::path::MAIN_SEPARATOR, '.'], "_");
612        Ok(format!("BT_PW_{}", path.to_uppercase()))
613    }
614
615    /// Writes the keypair to the file and optionally encrypts data.
616    ///
617    /// ```text
618    ///     Arguments:
619    ///         keypair (Keypair): The keypair object to be stored.
620    ///         encrypt (bool): If ``True``, encrypts the keyfile data.
621    ///         overwrite (bool): If ``True``, overwrites existing file without prompting.
622    ///         password (Option<String>): The password used to encrypt the data. If ``None``, asks for user input.
623    /// ```
624    pub fn set_keypair(
625        &self,
626        keypair: Keypair,
627        encrypt: bool,
628        overwrite: bool,
629        password: Option<String>,
630    ) -> Result<(), KeyFileError> {
631        self.make_dirs()?;
632
633        let keyfile_data = serialized_keypair_to_keyfile_data(&keypair)?;
634
635        let final_keyfile_data = if encrypt {
636            let encrypted_data = encrypt_keyfile_data(&keyfile_data, password.clone())?;
637
638            // store password to local env
639            if self.should_save_to_env {
640                self.save_password_to_env(password.clone())?;
641            }
642
643            encrypted_data
644        } else {
645            keyfile_data
646        };
647
648        self._write_keyfile_data_to_file(&final_keyfile_data, overwrite)?;
649
650        Ok(())
651    }
652
653    /// Creates directories for the path if they do not exist.
654    pub fn make_dirs(&self) -> Result<(), KeyFileError> {
655        if let Some(directory) = self._path.parent() {
656            // check if the dir is exit already
657            if !directory.exists() {
658                // create the dir if not
659                fs::create_dir_all(directory)
660                    .map_err(|e| KeyFileError::DirectoryCreation(e.to_string()))?;
661            }
662        }
663        Ok(())
664    }
665
666    /// Returns ``True`` if the file exists on the device.
667    ///
668    /// ```text
669    ///     Returns:
670    ///         readable (bool): ``True`` if the file is readable.
671    /// ```
672    pub fn exists_on_device(&self) -> Result<bool, KeyFileError> {
673        Ok(self._path.exists())
674    }
675
676    /// Returns ``True`` if the file under path is readable.
677    pub fn is_readable(&self) -> Result<bool, KeyFileError> {
678        // check file exist
679        if !self.exists_on_device()? {
680            return Ok(false);
681        }
682
683        // get file metadata
684        let metadata = fs::metadata(&self._path).map_err(|e| {
685            KeyFileError::MetadataError(format!("Failed to get metadata for file: {}.", e))
686        })?;
687
688        // check permissions
689        let permissions = metadata.permissions();
690        let readable = permissions.mode() & 0o444 != 0; // check readability
691
692        Ok(readable)
693    }
694
695    /// Returns ``True`` if the file under path is writable.
696    ///
697    /// ```text
698    ///     Returns:
699    ///         writable (bool): ``True`` if the file is writable.
700    /// ```
701    pub fn is_writable(&self) -> Result<bool, KeyFileError> {
702        // check if file exist
703        if !self.exists_on_device()? {
704            return Ok(false);
705        }
706
707        // get file metadata
708        let metadata = fs::metadata(&self._path).map_err(|e| {
709            KeyFileError::MetadataError(format!("Failed to get metadata for file: {}", e))
710        })?;
711
712        // check the permissions
713        let permissions = metadata.permissions();
714        let writable = permissions.mode() & 0o222 != 0; // check if file is writable
715
716        Ok(writable)
717    }
718
719    /// Returns ``True`` if the file under path is encrypted.
720    ///
721    /// ```text
722    ///     Returns:
723    ///         encrypted (bool): ``True`` if the file is encrypted.
724    /// ```
725    pub fn is_encrypted(&self) -> Result<bool, KeyFileError> {
726        // check if file exist
727        if !self.exists_on_device()? {
728            return Ok(false);
729        }
730
731        // check readable
732        if !self.is_readable()? {
733            return Ok(false);
734        }
735
736        // get the data from file
737        let keyfile_data = self._read_keyfile_data_from_file()?;
738
739        // check if encrypted
740        let is_encrypted = keyfile_data_is_encrypted(&keyfile_data);
741
742        Ok(is_encrypted)
743    }
744
745    /// Asks the user if it is okay to overwrite the file.
746    pub fn _may_overwrite(&self) -> bool {
747        let choice = utils::prompt(format!(
748            "File {} already exists. Overwrite? (y/N) ",
749            self.path
750        ))
751        .expect("Failed to read input.");
752
753        choice.trim().to_lowercase() == "y"
754    }
755
756    /// Check the version of keyfile and update if needed.
757    ///
758    /// ```text
759    ///     Arguments:
760    ///         print_result (bool): If ``True``, prints the result of the encryption check.
761    ///         no_prompt (bool): If ``True``, skips user prompts during the update process.
762    ///     Returns:
763    ///         updated (bool): ``True`` if the keyfile was successfully updated to the latest encryption method.
764    /// ```
765    pub fn check_and_update_encryption(
766        &self,
767        print_result: bool,
768        no_prompt: bool,
769    ) -> Result<bool, KeyFileError> {
770        if !self.exists_on_device()? {
771            if print_result {
772                utils::print(format!("Keyfile '{}' does not exist.\n", self.path));
773            }
774            return Ok(false);
775        }
776
777        if !self.is_readable()? {
778            if print_result {
779                utils::print(format!("Keyfile '{}' is not readable.\n", self.path));
780            }
781            return Ok(false);
782        }
783
784        if !self.is_writable()? {
785            if print_result {
786                utils::print(format!("Keyfile '{}' is not writable.\n", self.path));
787            }
788            return Ok(false);
789        }
790
791        let mut update_keyfile = false;
792        if !no_prompt {
793            // read keyfile
794            let keyfile_data = self._read_keyfile_data_from_file()?;
795
796            // check if file is decrypted
797            if keyfile_data_is_encrypted(&keyfile_data)
798                && !keyfile_data_is_encrypted_nacl(&keyfile_data)
799            {
800                utils::print("You may update the keyfile to improve security...\n".to_string());
801
802                // ask user for the confirmation for updating
803                update_keyfile = confirm_prompt("Update keyfile?");
804                if update_keyfile {
805                    let mut stored_mnemonic = false;
806
807                    // check mnemonic if saved
808                    while !stored_mnemonic {
809                        utils::print(
810                            "Please store your mnemonic in case an error occurs...\n".to_string(),
811                        );
812                        if confirm_prompt("Have you stored the mnemonic?") {
813                            stored_mnemonic = true;
814                        } else if !confirm_prompt("Retry and continue keyfile update?") {
815                            return Ok(false);
816                        }
817                    }
818
819                    // try decrypt data
820                    let mut decrypted_keyfile_data: Option<Vec<u8>> = None;
821                    let mut password: Option<String> = None;
822                    while decrypted_keyfile_data.is_none() {
823                        let pwd = ask_password(false)?;
824                        password = Some(pwd.clone());
825
826                        match decrypt_keyfile_data(
827                            &keyfile_data,
828                            Some(pwd),
829                            Some(self.env_var_name()?),
830                        ) {
831                            Ok(decrypted_data) => {
832                                decrypted_keyfile_data = Some(decrypted_data);
833                            }
834                            Err(_) => {
835                                if !confirm_prompt("Invalid password, retry?") {
836                                    return Ok(false);
837                                }
838                            }
839                        }
840                    }
841
842                    // encryption of updated data
843                    if let Some(password) = password {
844                        if let Some(decrypted_data) = decrypted_keyfile_data {
845                            let encrypted_keyfile_data =
846                                encrypt_keyfile_data(&decrypted_data, Some(password))?;
847                            self._write_keyfile_data_to_file(&encrypted_keyfile_data, true)?;
848                        }
849                    }
850                }
851            }
852        }
853
854        if print_result || update_keyfile {
855            // check and get result
856            let keyfile_data = self._read_keyfile_data_from_file()?;
857
858            return if !keyfile_data_is_encrypted(&keyfile_data) {
859                if print_result {
860                    utils::print("Keyfile is not encrypted.\n".to_string());
861                }
862                Ok(false)
863            } else if keyfile_data_is_encrypted_nacl(&keyfile_data) {
864                if print_result {
865                    utils::print("Keyfile is updated.\n".to_string());
866                }
867                Ok(true)
868            } else {
869                if print_result {
870                    utils::print("Keyfile is outdated, please update using 'btcli'.\n".to_string());
871                }
872                Ok(false)
873            };
874        }
875        Ok(false)
876    }
877
878    /// Encrypts the file under the path.
879    ///
880    /// ```text
881    ///     Arguments:
882    ///         password (Option<String>): The password used to encrypt the data. If ``None``, asks for user input.
883    /// ```
884    pub fn encrypt(&self, mut password: Option<String>) -> Result<(), KeyFileError> {
885        // checkers
886        if !self.exists_on_device()? {
887            return Err(KeyFileError::FileNotFound(format!(
888                "Keyfile at: {} does not exist",
889                self.path
890            )));
891        }
892
893        if !self.is_readable()? {
894            return Err(KeyFileError::NotReadable(format!(
895                "Keyfile at: {} is not readable",
896                self.path
897            )));
898        }
899
900        if !self.is_writable()? {
901            return Err(KeyFileError::NotWritable(format!(
902                "Keyfile at: {} is not writable",
903                self.path
904            )));
905        }
906
907        // read the data
908        let keyfile_data = self._read_keyfile_data_from_file()?;
909
910        let final_data = if !keyfile_data_is_encrypted(&keyfile_data) {
911            let as_keypair = deserialize_keypair_from_keyfile_data(&keyfile_data)?;
912            let serialized_data = serialized_keypair_to_keyfile_data(&as_keypair)?;
913
914            // get password from local env if exist
915            if password.is_none() {
916                password = get_password_from_environment(self.env_var_name()?)?;
917            }
918
919            let encrypted_keyfile_data = encrypt_keyfile_data(&serialized_data, password.clone())?;
920
921            if self.should_save_to_env {
922                self.save_password_to_env(password.clone())?;
923            }
924
925            encrypted_keyfile_data
926        } else {
927            keyfile_data
928        };
929
930        // write back
931        self._write_keyfile_data_to_file(&final_data, true)?;
932
933        Ok(())
934    }
935
936    /// Decrypts the file under the path.
937    ///
938    /// ```text
939    ///     Arguments:
940    ///         password (Option<String>): The password used to decrypt the data. If ``None``, asks for user input.
941    /// ```
942    pub fn decrypt(&self, password: Option<String>) -> Result<(), KeyFileError> {
943        // checkers
944        if !self.exists_on_device()? {
945            return Err(KeyFileError::FileNotFound(format!(
946                "Keyfile at: {} does not exist.",
947                self.path
948            )));
949        }
950        if !self.is_readable()? {
951            return Err(KeyFileError::NotReadable(format!(
952                "Keyfile at: {} is not readable.",
953                self.path
954            )));
955        }
956        if !self.is_writable()? {
957            return Err(KeyFileError::NotWritable(format!(
958                "Keyfile at: {} is not writable.",
959                self.path
960            )));
961        }
962
963        // read data
964        let keyfile_data = self._read_keyfile_data_from_file()?;
965
966        let decrypted_data = if keyfile_data_is_encrypted(&keyfile_data) {
967            decrypt_keyfile_data(&keyfile_data, password, Some(self.env_var_name()?))?
968        } else {
969            keyfile_data
970        };
971
972        let as_keypair = deserialize_keypair_from_keyfile_data(&decrypted_data)?;
973
974        let serialized_data = serialized_keypair_to_keyfile_data(&as_keypair)?;
975        self._write_keyfile_data_to_file(&serialized_data, true)?;
976        Ok(())
977    }
978
979    /// Reads the keyfile data from the file.
980    ///
981    /// ```text
982    ///     Returns:
983    ///         keyfile_data (Vec<u8>): The keyfile data stored under the path.
984    ///     Raises:
985    ///         KeyFileError: Raised if the file does not exist or is not readable.
986    /// ```
987    pub fn _read_keyfile_data_from_file(&self) -> Result<Vec<u8>, KeyFileError> {
988        // Check if the file exists
989        if !self.exists_on_device()? {
990            return Err(KeyFileError::FileNotFound(format!(
991                "Keyfile at: {} does not exist.",
992                self.path
993            )));
994        }
995
996        // Check if the file is readable
997        if !self.is_readable()? {
998            return Err(KeyFileError::NotReadable(format!(
999                "Keyfile at: {} is not readable.",
1000                self.path
1001            )));
1002        }
1003
1004        // Open and read the file
1005        let mut file = fs::File::open(&self._path)
1006            .map_err(|e| KeyFileError::FileOpen(format!("Failed to open file: {}.", e)))?;
1007        let mut data_vec = Vec::new();
1008        file.read_to_end(&mut data_vec)
1009            .map_err(|e| KeyFileError::FileRead(format!("Failed to read file: {}.", e)))?;
1010
1011        Ok(data_vec)
1012    }
1013
1014    /// Writes the keyfile data to the file.
1015    ///
1016    /// ```text
1017    ///     Arguments:
1018    ///         keyfile_data: The byte data to store under the path.
1019    ///         overwrite: If true, overwrites the data without asking for permission from the user. Default is false.
1020    /// ```
1021    pub fn _write_keyfile_data_to_file(
1022        &self,
1023        keyfile_data: &[u8],
1024        overwrite: bool,
1025    ) -> Result<(), KeyFileError> {
1026        // ask user for rewriting
1027        if self.exists_on_device()? && !overwrite && !self._may_overwrite() {
1028            return Err(KeyFileError::NotWritable(format!(
1029                "Keyfile at: {} is not writable",
1030                self.path
1031            )));
1032        }
1033
1034        let mut keyfile = fs::OpenOptions::new()
1035            .write(true)
1036            .create(true)
1037            .truncate(true) // cleanup if rewrite
1038            .open(&self._path)
1039            .map_err(|e| KeyFileError::FileOpen(format!("Failed to open file: {}.", e)))?;
1040
1041        // write data
1042        keyfile
1043            .write_all(keyfile_data)
1044            .map_err(|e| KeyFileError::FileWrite(format!("Failed to write to file: {}.", e)))?;
1045
1046        // set permissions
1047        let mut permissions = fs::metadata(&self._path)
1048            .map_err(|e| {
1049                KeyFileError::MetadataError(format!("Failed to get metadata for file: {}.", e))
1050            })?
1051            .permissions();
1052        permissions.set_mode(0o600); // just for owner
1053        fs::set_permissions(&self._path, permissions).map_err(|e| {
1054            KeyFileError::PermissionError(format!("Failed to set permissions: {}.", e))
1055        })?;
1056        Ok(())
1057    }
1058
1059    /// Saves the key's password to the associated local environment variable.
1060    ///
1061    /// ```text
1062    ///     Arguments:
1063    ///         password (Option<String>): The password to save. If ``None``, asks for user input.
1064    ///     Returns:
1065    ///         encrypted_password_base64 (str): The base64-encoded encrypted password.
1066    /// ```
1067    pub fn save_password_to_env(&self, password: Option<String>) -> Result<String, KeyFileError> {
1068        // checking the password
1069        let password = match password {
1070            Some(pwd) => pwd,
1071            None => match ask_password(true) {
1072                Ok(pwd) => pwd,
1073                Err(e) => {
1074                    utils::print(format!("Error asking password: {:?}.\n", e));
1075                    return Ok("".to_string());
1076                }
1077            },
1078        };
1079        // saving password
1080        let env_var_name = self.env_var_name()?;
1081        // encrypt password
1082        let encrypted_password = encrypt_password(&env_var_name, &password);
1083        let encrypted_password_base64 = general_purpose::STANDARD.encode(&encrypted_password);
1084        // store encrypted password
1085        env::set_var(&env_var_name, &encrypted_password_base64);
1086        Ok(encrypted_password_base64)
1087    }
1088
1089    /// Removes the password associated with the Keyfile from the local environment.
1090    pub fn remove_password_from_env(&self) -> Result<bool, KeyFileError> {
1091        let env_var_name = self.env_var_name()?;
1092
1093        if env::var(&env_var_name).is_ok() {
1094            env::remove_var(&env_var_name);
1095            let message = format!("Environment variable '{}' removed.\n", env_var_name);
1096            utils::print(message);
1097            Ok(true)
1098        } else {
1099            let message = format!("Environment variable '{}' does not exist.\n", env_var_name);
1100            utils::print(message);
1101            Ok(false)
1102        }
1103    }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108    use super::*;
1109    use crate::constants::{CRYPTO_ED25519, CRYPTO_SR25519};
1110    use crate::keypair::Keypair;
1111
1112    fn test_mnemonic() -> String {
1113        "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
1114            .to_string()
1115    }
1116
1117    #[test]
1118    fn test_ed25519_keyfile_roundtrip() {
1119        let original = Keypair::create_from_mnemonic(&test_mnemonic(), CRYPTO_ED25519).unwrap();
1120        let data = serialized_keypair_to_keyfile_data(&original).unwrap();
1121        let restored = deserialize_keypair_from_keyfile_data(&data).unwrap();
1122
1123        assert_eq!(restored.crypto_type(), CRYPTO_ED25519);
1124        assert_eq!(restored.ss58_address(), original.ss58_address());
1125
1126        let sig = restored.sign(b"test".to_vec()).unwrap();
1127        assert!(restored.verify(b"test".to_vec(), sig).unwrap());
1128    }
1129
1130    #[test]
1131    fn test_sr25519_keyfile_roundtrip() {
1132        let original = Keypair::create_from_mnemonic(&test_mnemonic(), CRYPTO_SR25519).unwrap();
1133        let data = serialized_keypair_to_keyfile_data(&original).unwrap();
1134        let restored = deserialize_keypair_from_keyfile_data(&data).unwrap();
1135
1136        assert_eq!(restored.crypto_type(), CRYPTO_SR25519);
1137        assert_eq!(restored.ss58_address(), original.ss58_address());
1138
1139        let sig = restored.sign(b"test".to_vec()).unwrap();
1140        assert!(restored.verify(b"test".to_vec(), sig).unwrap());
1141    }
1142
1143    #[test]
1144    fn test_legacy_keyfile_without_crypto_type_defaults_sr25519() {
1145        let json = r#"{"secretPhrase":"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about","ss58Address":"5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY"}"#;
1146        let kp = deserialize_keypair_from_keyfile_data(json.as_bytes()).unwrap();
1147        assert_eq!(kp.crypto_type(), CRYPTO_SR25519);
1148    }
1149
1150    #[test]
1151    fn test_keyfile_json_contains_crypto_type() {
1152        let ed = Keypair::create_from_mnemonic(&test_mnemonic(), CRYPTO_ED25519).unwrap();
1153        let ed_data = serialized_keypair_to_keyfile_data(&ed).unwrap();
1154        let ed_str = std::str::from_utf8(&ed_data).unwrap();
1155        assert!(ed_str.contains("\"cryptoType\":0"));
1156
1157        let sr = Keypair::create_from_mnemonic(&test_mnemonic(), CRYPTO_SR25519).unwrap();
1158        let sr_data = serialized_keypair_to_keyfile_data(&sr).unwrap();
1159        let sr_str = std::str::from_utf8(&sr_data).unwrap();
1160        assert!(sr_str.contains("\"cryptoType\":1"));
1161    }
1162
1163    #[test]
1164    fn test_ed25519_keyfile_roundtrip_via_seed() {
1165        let seed = [0xffu8; 32];
1166        let original = Keypair::create_from_seed(seed.to_vec(), CRYPTO_ED25519).unwrap();
1167        let data = serialized_keypair_to_keyfile_data(&original).unwrap();
1168        let restored = deserialize_keypair_from_keyfile_data(&data).unwrap();
1169
1170        assert_eq!(restored.crypto_type(), CRYPTO_ED25519);
1171        assert_eq!(restored.ss58_address(), original.ss58_address());
1172    }
1173
1174    #[test]
1175    fn test_keyfile_explicit_crypto_type_in_json() {
1176        let json = format!(r#"{{"secretPhrase":"{}","cryptoType":0}}"#, test_mnemonic());
1177        let kp = deserialize_keypair_from_keyfile_data(json.as_bytes()).unwrap();
1178        assert_eq!(kp.crypto_type(), CRYPTO_ED25519);
1179        assert!(kp.ss58_address().is_some());
1180    }
1181
1182    #[test]
1183    fn test_ed25519_keyfile_cross_verification_after_roundtrip() {
1184        let original = Keypair::create_from_mnemonic(&test_mnemonic(), CRYPTO_ED25519).unwrap();
1185        let data = serialized_keypair_to_keyfile_data(&original).unwrap();
1186        let restored = deserialize_keypair_from_keyfile_data(&data).unwrap();
1187
1188        let sig = restored.sign(b"cross-check".to_vec()).unwrap();
1189        assert!(original.verify(b"cross-check".to_vec(), sig).unwrap());
1190    }
1191}