use std::str::FromStr as _;
use ssh_encoding::{self, CheckedSum, Decode, Encode, Reader, Writer};
use ssh_key::public::KeyData;
use ssh_key::{certificate::Certificate, private::KeypairData, Algorithm};
use crate::proto::{Error, PrivateKeyData, Result};
#[derive(Clone, PartialEq, Debug)]
pub enum PrivateCredential {
Key {
privkey: KeypairData,
comment: String,
},
Cert {
algorithm: Algorithm,
certificate: Box<Certificate>,
privkey: PrivateKeyData,
comment: String,
},
}
impl Decode for PrivateCredential {
type Error = Error;
fn decode(reader: &mut impl Reader) -> Result<Self> {
let alg = String::decode(reader)?;
let cert_alg = Algorithm::new_certificate(&alg);
if let Ok(algorithm) = cert_alg {
let certificate = reader
.read_prefixed(|reader| {
let cert = Certificate::decode(reader)?;
Ok::<_, Error>(cert)
})?
.into();
let privkey = PrivateKeyData::decode_as(reader, algorithm.clone())?;
let comment = String::decode(reader)?;
Ok(PrivateCredential::Cert {
algorithm,
certificate,
privkey,
comment,
})
} else {
let algorithm = Algorithm::from_str(&alg).map_err(ssh_encoding::Error::from)?;
let privkey = KeypairData::decode_as(reader, algorithm)?;
let comment = String::decode(reader)?;
Ok(PrivateCredential::Key { privkey, comment })
}
}
}
impl Encode for PrivateCredential {
fn encoded_len(&self) -> ssh_encoding::Result<usize> {
match self {
Self::Key { privkey, comment } => {
[privkey.encoded_len()?, comment.encoded_len()?].checked_sum()
}
Self::Cert {
algorithm,
certificate,
privkey,
comment,
} => [
algorithm.to_certificate_type().encoded_len()?,
certificate.encoded_len_prefixed()?,
privkey.encoded_len()?,
comment.encoded_len()?,
]
.checked_sum(),
}
}
fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> {
match self {
Self::Key { privkey, comment } => {
privkey.encode(writer)?;
comment.encode(writer)
}
Self::Cert {
algorithm,
certificate,
privkey,
comment,
} => {
algorithm.to_certificate_type().encode(writer)?;
certificate.encode_prefixed(writer)?;
privkey.encode(writer)?;
comment.encode(writer)
}
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum PublicCredential {
Key(KeyData),
Cert(Box<Certificate>),
}
impl PublicCredential {
pub fn key_data(&self) -> &KeyData {
match self {
Self::Key(key) => key,
Self::Cert(cert) => cert.public_key(),
}
}
}
impl Decode for PublicCredential {
type Error = Error;
fn decode(reader: &mut impl Reader) -> core::result::Result<Self, Self::Error> {
let alg = String::decode(reader)?;
let remaining_len = reader.remaining_len();
let mut buf = Vec::with_capacity(4 + alg.len() + remaining_len);
alg.encode(&mut buf)?;
let mut tail = vec![0u8; remaining_len];
reader.read(&mut tail)?;
buf.extend_from_slice(&tail);
if Algorithm::new_certificate(&alg).is_ok() {
let cert = Certificate::decode(&mut &buf[..])?;
Ok(Self::Cert(Box::new(cert)))
} else {
let key = KeyData::decode(&mut &buf[..])?;
Ok(Self::Key(key))
}
}
}
impl Encode for PublicCredential {
fn encoded_len(&self) -> std::result::Result<usize, ssh_encoding::Error> {
match self {
Self::Key(pubkey) => pubkey.encoded_len(),
Self::Cert(certificate) => certificate.encoded_len(),
}
}
fn encode(
&self,
writer: &mut impl ssh_encoding::Writer,
) -> std::result::Result<(), ssh_encoding::Error> {
match self {
Self::Key(pubkey) => pubkey.encode(writer),
Self::Cert(certificate) => certificate.encode(writer),
}
}
}
impl From<KeyData> for PublicCredential {
fn from(value: KeyData) -> Self {
Self::Key(value)
}
}
impl From<Certificate> for PublicCredential {
fn from(value: Certificate) -> Self {
Self::Cert(value.into())
}
}