cotp 0.1.9

Trustworthy, encrypted, command-line TOTP/HOTP authenticator app with import functionality.
use std::{error, fmt};
use std::convert::TryInto;

use data_encoding::BASE64;
use sodiumoxide::crypto::pwhash;
use sodiumoxide::crypto::secretstream::{KEYBYTES, Stream, Tag};
use sodiumoxide::crypto::secretstream::xchacha20poly1305::{Header, Key};

const SIGNATURE: [u8; 4] = [0xC1, 0x0A, 0x4B, 0xED];

#[derive(Debug)]
struct CoreError {
    message: String,
}

impl CoreError {
    fn new(msg: &str) -> Self { CoreError { message: msg.to_string() } }
}

impl fmt::Display for CoreError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "Error: {}", self.message)
    }
}

impl error::Error for CoreError {}

fn argon_derive_key(key: &mut [u8; 32], password_bytes: &[u8], salt: &pwhash::argon2id13::Salt) -> Result<Key, String> {
    let result = pwhash::argon2id13::derive_key(key, password_bytes, salt,
                                                pwhash::argon2id13::OPSLIMIT_INTERACTIVE,
                                                pwhash::argon2id13::MEMLIMIT_INTERACTIVE);
    match result {
        Err(()) => Err(String::from("Failed to derive encryption key")),
        _ => Ok(Key(*key)),
    }
}

pub fn encrypt_string(plaintext: String, password: &str) -> String {
    let mut encrypted = String::new();
    encrypted.push_str(&BASE64.encode(&SIGNATURE));
    encrypted.push('|');
    let salt = pwhash::argon2id13::gen_salt();
    encrypted.push_str(&BASE64.encode(&salt.0));
    encrypted.push('|');
    let key = argon_derive_key(&mut [0u8; KEYBYTES], password.as_bytes(), &salt).unwrap();
    let (mut enc_stream, header) = Stream::init_push(&key).unwrap();

    encrypted.push_str(&BASE64.encode(&header.0));
    encrypted.push('|');

    let encrypted_string = enc_stream.push(plaintext.as_bytes(), None, Tag::Message).expect("Cannot encrypt");

    encrypted.push_str(&BASE64.encode(&encrypted_string));
    encrypted
}

pub fn decrypt_string(encrypted_text: &str, password: &str) -> Result<String, String> {
    let split = encrypted_text.split('|');
    let vec: Vec<&str> = split.collect();
    if vec.len() != 4 {
        return Err(String::from("Corrupted database file"));
    }
    let byte_salt = BASE64.decode(vec[1].as_bytes()).unwrap();
    let salt = pwhash::argon2id13::Salt(vec_to_arr(byte_salt));
    let byte_header = BASE64.decode(vec[2].as_bytes()).unwrap();
    let header = Header(vec_to_arr(byte_header));
    let cipher = BASE64.decode(vec[3].as_bytes()).unwrap();

    let mut key = [0u8; KEYBYTES];
    pwhash::argon2id13::derive_key(&mut key, password.as_bytes(), &salt,
                                   pwhash::argon2id13::OPSLIMIT_INTERACTIVE,
                                   pwhash::argon2id13::MEMLIMIT_INTERACTIVE)
        .map_err(|_| CoreError::new("Deriving key failed")).unwrap();
    let key = Key(key);

    let mut stream = Stream::init_pull(&header, &key)
        .map_err(|_| CoreError::new("init_pull failed")).unwrap();

    let (decrypted, _tag) = stream.pull(&cipher, None).unwrap_or((vec![0], Tag::Message));

    if decrypted == vec![0] {
        return Err(String::from("Wrong password"));
    }
    Ok(String::from_utf8(decrypted).unwrap())
}

fn vec_to_arr<T, const N: usize>(v: Vec<T>) -> [T; N] {
    v.try_into()
        .unwrap_or_else(|v: Vec<T>| panic!("Expected a Vec of length {} but it was {}", N, v.len()))
}

pub fn prompt_for_passwords(message: &str, minimum_password_length: usize, verify: bool) -> String {
    let mut password;
    loop {
        password = rpassword::prompt_password_stdout(message).unwrap();
        if verify {
            let verify_password = rpassword::prompt_password_stdout("Retype the same password: ").unwrap();
            if password != verify_password {
                println!("Passwords do not match");
                continue;
            }
            if password.chars().count() >= minimum_password_length {
                break;
            }
        } else if password.chars().count() >= minimum_password_length {
            break;
        }
        println!("Please insert a password with at least {} digits.", minimum_password_length);
    }
    password
}


#[cfg(test)]
mod tests {
    use super::{decrypt_string, encrypt_string};

    #[test]
    fn test_encryption() {
        assert_eq!(Ok(()), sodiumoxide::init());
        assert_eq!(
            String::from("Secret data@#[]ò"),
            decrypt_string(
                &mut encrypt_string(String::from("Secret data@#[]ò"), "pa$$w0rd"),
                "pa$$w0rd",
            ).unwrap()
        );
    }
}