use anyhow::{bail, Context, Error};
use rpassword::read_password_from_tty;
use serde_derive::{Deserialize, Serialize};
use zeroize::Zeroize;
use std::{env, str::FromStr};
pub const DEFAULT_MASTER_PASS_ENV_VAR: &str = "EXONUM_MASTER_PASS";
#[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Passphrase(String);
impl Drop for Passphrase {
fn drop(&mut self) {
self.0.zeroize()
}
}
impl Passphrase {
pub fn new(passphrase: String) -> Self {
Self(passphrase)
}
pub fn read_from_tty(prompt: &str) -> Result<Self, Error> {
Ok(Self(read_password_from_tty(Some(prompt))?))
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub enum PassInputMethod {
Terminal,
EnvVariable(Option<String>),
CmdLineParameter(Passphrase),
}
impl Default for PassInputMethod {
fn default() -> Self {
Self::Terminal
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PassphraseUsage {
SettingUp,
Using,
}
impl PassInputMethod {
pub fn get_passphrase(self, usage: PassphraseUsage) -> Result<Passphrase, Error> {
match self {
Self::Terminal => {
let prompt = "Enter master key passphrase: ";
match usage {
PassphraseUsage::SettingUp => prompt_passphrase(prompt),
PassphraseUsage::Using => Passphrase::read_from_tty(prompt),
}
}
Self::EnvVariable(name) => {
let variable_name = name.unwrap_or_else(|| DEFAULT_MASTER_PASS_ENV_VAR.to_string());
let passphrase = env::var(&variable_name).with_context(|| {
format!("Failed to get password from env variable {}", variable_name)
})?;
Ok(Passphrase(passphrase))
}
Self::CmdLineParameter(pass) => Ok(pass),
}
}
}
impl FromStr for PassInputMethod {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Ok(Self::default());
}
if s == "stdin" {
return Ok(Self::Terminal);
}
if s.starts_with("env") {
let env_var = s.split(':').nth(1).map(String::from);
return Ok(Self::EnvVariable(env_var));
}
if s.starts_with("pass") {
let pass = s.split(':').nth(1).unwrap_or_default();
return Ok(Self::CmdLineParameter(Passphrase(pass.to_owned())));
}
bail!("Failed to parse passphrase input method")
}
}
fn prompt_passphrase(prompt: &str) -> Result<Passphrase, Error> {
loop {
let password = Passphrase::read_from_tty(prompt)?;
if password.is_empty() {
eprintln!("Passphrase must not be empty. Try again.");
continue;
}
let confirmation = Passphrase::read_from_tty("Enter same passphrase again: ")?;
if password == confirmation {
return Ok(password);
} else {
eprintln!("Passphrases do not match. Try again.");
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::{PassInputMethod, Passphrase};
#[test]
fn test_pass_input_method_parse() {
let correct_cases = vec![
("", <PassInputMethod as Default>::default()),
("", PassInputMethod::Terminal),
("stdin", PassInputMethod::Terminal),
("env", PassInputMethod::EnvVariable(None)),
(
"env:VAR",
PassInputMethod::EnvVariable(Some("VAR".to_owned())),
),
(
"pass",
PassInputMethod::CmdLineParameter(Passphrase("".to_owned())),
),
(
"pass:PASS",
PassInputMethod::CmdLineParameter(Passphrase("PASS".to_owned())),
),
];
for (inp, out) in correct_cases {
let method = <PassInputMethod as FromStr>::from_str(inp);
assert!(method.is_ok());
assert_eq!(method.unwrap(), out)
}
}
}