use crate::cipher::Cipher;
use crate::error::*;
use crate::keys::{dsa::*, ecdsa::*, ed25519::*, rsa::*, KeyPair, PublicParts};
use crate::sshbuf::{SshBuf, SshReadExt, SshWriteExt};
use bcrypt_pbkdf::bcrypt_pbkdf;
use byteorder::WriteBytesExt;
use cryptovec::CryptoVec;
use openssl::bn::BigNum;
use openssl::dsa::Dsa;
use openssl::rsa::Rsa;
use rand::prelude::*;
use rand::rngs::StdRng;
use std::io::{Cursor, Read, Write};
use std::str::FromStr;
use zeroize::Zeroizing;
const KEY_MAGIC: &[u8] = b"openssh-key-v1\0";
const KDF_BCRYPT: &str = "bcrypt";
const KDF_NONE: &str = "none";
const DEFAULT_ROUNDS: u32 = 16;
const SALT_LEN: usize = 16;
pub fn decode_ossh_priv(keydata: &[u8], passphrase: Option<&str>) -> OsshResult<KeyPair> {
if keydata.len() >= 16 && &keydata[0..15] == KEY_MAGIC {
let mut reader = Cursor::new(keydata);
reader.set_position(15);
let ciphername = reader.read_utf8()?;
let kdfname = reader.read_utf8()?;
let kdf = reader.read_string()?;
let nkeys = reader.read_uint32()?;
if nkeys != 1 {
return Err(ErrorKind::InvalidKeyFormat.into());
}
reader.read_string()?; let encrypted = reader.read_string()?;
let mut secret_reader =
decrypt_ossh_priv(&encrypted, passphrase, &ciphername, &kdfname, &kdf)?;
let checksum0 = Zeroizing::new(secret_reader.read_uint32()?);
let checksum1 = Zeroizing::new(secret_reader.read_uint32()?);
if *checksum0 != *checksum1 {
return Err(ErrorKind::IncorrectPass.into());
}
let mut keypair: KeyPair = decode_key(&mut secret_reader)?;
*keypair.comment_mut() = secret_reader.read_utf8()?;
for (i, pad) in secret_reader.bytes().enumerate() {
if ((i + 1) & 0xff) as u8 != pad? {
return Err(ErrorKind::InvalidKeyFormat.into());
}
}
Ok(keypair)
} else {
Err(ErrorKind::InvalidKeyFormat.into())
}
}
pub fn decrypt_ossh_priv(
privkey_data: &[u8],
passphrase: Option<&str>,
ciphername: &str,
kdfname: &str,
kdf: &[u8],
) -> OsshResult<SshBuf> {
let cipher = Cipher::from_str(ciphername)?;
if (!passphrase.map_or(false, |pass| !pass.is_empty())) && !cipher.is_null() {
return Err(ErrorKind::IncorrectPass.into());
}
if kdfname != "none" && kdfname != "bcrypt" {
return Err(ErrorKind::UnsupportCipher.into());
}
if kdfname == "none" && !cipher.is_null() {
return Err(ErrorKind::InvalidKeyFormat.into());
}
let blocksize = cipher.block_size();
if privkey_data.len() < blocksize || privkey_data.len() % blocksize != 0 {
return Err(ErrorKind::InvalidKeyFormat.into());
}
if !cipher.is_null() {
let keyder = match kdfname {
"bcrypt" => {
if let Some(pass) = passphrase {
let mut kdfreader = Cursor::new(kdf);
let salt = kdfreader.read_string()?;
let round = kdfreader.read_uint32()?;
let mut output = Zeroizing::new(vec![0u8; cipher.key_len() + cipher.iv_len()]);
bcrypt_pbkdf(pass, &salt, round, &mut output)?;
output
} else {
return Err(ErrorKind::Unknown.into());
}
}
_ => {
return Err(ErrorKind::UnsupportCipher.into());
}
};
let key = &keyder[..cipher.key_len()];
let iv = &keyder[cipher.key_len()..];
let mut cvec = CryptoVec::new();
cvec.resize(cipher.calc_buffer_len(privkey_data.len()));
let n = cipher.decrypt_to(&mut cvec, privkey_data, key, iv)?;
cvec.resize(n);
Ok(SshBuf::with_vec(cvec))
} else {
let cvec = CryptoVec::from_slice(privkey_data);
Ok(SshBuf::with_vec(cvec))
}
}
#[allow(clippy::many_single_char_names)]
fn decode_key(reader: &mut SshBuf) -> OsshResult<KeyPair> {
let keystring = Zeroizing::new(reader.read_utf8()?);
let keyname: &str = keystring.as_str();
let key = match keyname {
RSA_NAME | RSA_SHA256_NAME | RSA_SHA512_NAME => {
let n = reader.read_mpint()?;
let e = reader.read_mpint()?;
let d = reader.read_mpint()?;
let iqmp = reader.read_mpint()?;
let p = reader.read_mpint()?;
let q = reader.read_mpint()?;
let one = BigNum::from_u32(1)?;
let dmp1 = &d % &(&p - &one);
let dmq1 = &d % &(&q - &one);
let rsa = Rsa::from_private_components(n, e, d, p, q, dmp1, dmq1, iqmp)?;
match keyname {
RSA_NAME => RsaKeyPair::from_ossl_rsa(rsa, RsaSignature::SHA1),
RSA_SHA256_NAME => RsaKeyPair::from_ossl_rsa(rsa, RsaSignature::SHA2_256),
RSA_SHA512_NAME => RsaKeyPair::from_ossl_rsa(rsa, RsaSignature::SHA2_512),
_ => unreachable!(),
}?
.into()
}
DSA_NAME => {
let p = reader.read_mpint()?;
let q = reader.read_mpint()?;
let g = reader.read_mpint()?;
let pubkey = reader.read_mpint()?;
let privkey = reader.read_mpint()?;
let dsa = Dsa::from_private_components(p, q, g, privkey, pubkey)?;
DsaKeyPair::from_ossl_dsa(dsa).into()
}
NIST_P256_NAME | NIST_P384_NAME | NIST_P521_NAME => {
let curvename = Zeroizing::new(reader.read_utf8()?);
let curvehint = EcCurve::from_name(keyname)?;
let curve = EcCurve::from_str(&curvename)?;
if curve != curvehint {
return Err(ErrorKind::TypeNotMatch.into());
}
let pubkey = Zeroizing::new(reader.read_string()?);
let mut privkey = reader.read_mpint()?;
let keypair = EcDsaKeyPair::from_bytes(curve, &pubkey, &privkey)?.into();
privkey.clear(); keypair
}
ED25519_NAME => {
let pk = Zeroizing::new(reader.read_string()?);
let sk = Zeroizing::new(reader.read_string()?); Ed25519KeyPair::from_bytes(&pk, &sk)?.into()
}
_ => return Err(ErrorKind::UnsupportType.into()),
};
Ok(key)
}
pub fn serialize_ossh_privkey(
key: &KeyPair,
passphrase: &str,
cipher: Cipher,
kdf_rounds: u32,
) -> OsshResult<String> {
let buf = encode_ossh_priv(key, passphrase, cipher, kdf_rounds)?;
let mut keystr = String::new();
keystr.push_str("-----BEGIN OPENSSH PRIVATE KEY-----\n");
let b64str = base64::encode(&buf);
keystr.extend(b64str.chars().enumerate().flat_map(|(i, c)| {
if i > 0 && i % 70 == 0 {
Some('\n')
} else {
None
}
.into_iter()
.chain(std::iter::once(c))
}));
keystr.push_str("\n-----END OPENSSH PRIVATE KEY-----\n");
Ok(keystr)
}
pub fn encode_ossh_priv(
key: &KeyPair,
passphrase: &str,
cipher: Cipher,
kdf_rounds: u32,
) -> OsshResult<Vec<u8>> {
if cipher.is_some() && passphrase.is_empty() {
return Err(ErrorKind::IncorrectPass.into());
}
let rounds = if kdf_rounds > 0 {
kdf_rounds
} else {
DEFAULT_ROUNDS
};
let mut salt = Zeroizing::from([0u8; SALT_LEN]);
let ciphername = cipher.name();
let mut buf = Vec::new();
buf.write_all(KEY_MAGIC)?;
buf.write_utf8(ciphername)?;
if cipher.is_some() {
buf.write_utf8(KDF_BCRYPT)?;
let mut rng = StdRng::from_entropy();
rng.fill_bytes(&mut *salt);
let mut kdfbuf = Vec::with_capacity(salt.len() + 8);
kdfbuf.write_string(&*salt)?;
kdfbuf.write_uint32(rounds)?;
buf.write_string(&kdfbuf)?;
} else {
buf.write_utf8(KDF_NONE)?;
buf.write_string(&[0; 0])?;
}
buf.write_uint32(1)?; buf.write_string(&key.blob()?)?;
let mut privbuf = SshBuf::new();
let mut rng = StdRng::from_entropy();
let checksum: u32 = rng.gen();
privbuf.write_uint32(checksum)?;
privbuf.write_uint32(checksum)?;
encode_key(key, &mut privbuf)?;
privbuf.write_utf8(key.comment())?;
let mut i = 0;
while privbuf.len() % cipher.block_size() != 0 {
i += 1;
privbuf.write_u8((i & 0xff) as u8)?;
}
if cipher.is_some() {
let encrypted = encrypt_ossh_priv(privbuf.as_slice(), passphrase, cipher, rounds, &*salt)?;
buf.write_string(&encrypted)?;
} else {
buf.write_string(privbuf.as_slice())?;
};
Ok(buf)
}
pub fn encrypt_ossh_priv(
privkey: &[u8],
passphrase: &str,
cipher: Cipher,
kdf_rounds: u32,
salt: &[u8],
) -> OsshResult<Vec<u8>> {
if passphrase.is_empty() {
return Err(ErrorKind::IncorrectPass.into());
}
let mut keyder = Zeroizing::new(vec![0u8; cipher.key_len() + cipher.iv_len()]);
bcrypt_pbkdf(passphrase, salt, kdf_rounds, &mut keyder)?;
let key = &keyder[..cipher.key_len()];
let iv = &keyder[cipher.key_len()..];
let encrypted = cipher.encrypt(privkey, key, iv)?;
Ok(encrypted)
}
#[allow(clippy::many_single_char_names)]
fn encode_key<W: Write + ?Sized>(key: &KeyPair, buf: &mut W) -> OsshResult<()> {
use crate::keys::Key;
use crate::keys::KeyPairType;
use openssl::bn::BigNumContext;
use openssl::ec::PointConversionForm;
buf.write_utf8(key.keyname())?;
match &key.key {
KeyPairType::RSA(rsa) => {
let inner = rsa.ossl_rsa();
buf.write_mpint(inner.n())?;
buf.write_mpint(inner.e())?;
buf.write_mpint(inner.d())?;
buf.write_mpint(inner.iqmp().unwrap())?;
buf.write_mpint(inner.p().unwrap())?;
buf.write_mpint(inner.q().unwrap())?;
}
KeyPairType::DSA(dsa) => {
let inner = dsa.ossl_dsa();
buf.write_mpint(inner.p())?;
buf.write_mpint(inner.q())?;
buf.write_mpint(inner.g())?;
buf.write_mpint(inner.pub_key())?;
buf.write_mpint(inner.priv_key())?;
}
KeyPairType::ECDSA(ecdsa) => {
buf.write_utf8(ecdsa.curve().ident())?;
let inner = ecdsa.ossl_ec();
let mut bn_ctx = BigNumContext::new()?;
buf.write_string(&inner.public_key().to_bytes(
inner.group(),
PointConversionForm::UNCOMPRESSED,
&mut bn_ctx,
)?)?;
buf.write_mpint(inner.private_key())?;
}
KeyPairType::ED25519(ed25519) => {
buf.write_string(&ed25519.key.public.to_bytes())?;
buf.write_string(&ed25519.key.to_bytes())?; }
}
Ok(())
}