use crate::{Error, Result};
use core::{fmt, str};
use encoding::Label;
#[cfg(feature = "alloc")]
use {
alloc::vec::Vec,
sha2::{Digest, Sha256, Sha512},
};
const BCRYPT: &str = "bcrypt";
const CERT_DSA: &str = "ssh-dss-cert-v01@openssh.com";
const CERT_ECDSA_SHA2_P256: &str = "ecdsa-sha2-nistp256-cert-v01@openssh.com";
const CERT_ECDSA_SHA2_P384: &str = "ecdsa-sha2-nistp384-cert-v01@openssh.com";
const CERT_ECDSA_SHA2_P521: &str = "ecdsa-sha2-nistp521-cert-v01@openssh.com";
const CERT_ED25519: &str = "ssh-ed25519-cert-v01@openssh.com";
const CERT_RSA: &str = "ssh-rsa-cert-v01@openssh.com";
const CERT_SK_ECDSA_SHA2_P256: &str = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com";
const CERT_SK_SSH_ED25519: &str = "sk-ssh-ed25519-cert-v01@openssh.com";
const ECDSA_SHA2_P256: &str = "ecdsa-sha2-nistp256";
const ECDSA_SHA2_P384: &str = "ecdsa-sha2-nistp384";
const ECDSA_SHA2_P521: &str = "ecdsa-sha2-nistp521";
const NONE: &str = "none";
const RSA_SHA2_256: &str = "rsa-sha2-256";
const RSA_SHA2_512: &str = "rsa-sha2-512";
const SHA256: &str = "sha256";
const SHA512: &str = "sha512";
const SSH_DSA: &str = "ssh-dss";
const SSH_ED25519: &str = "ssh-ed25519";
const SSH_RSA: &str = "ssh-rsa";
const SK_ECDSA_SHA2_P256: &str = "sk-ecdsa-sha2-nistp256@openssh.com";
const SK_SSH_ED25519: &str = "sk-ssh-ed25519@openssh.com";
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum Algorithm {
Dsa,
Ecdsa {
curve: EcdsaCurve,
},
Ed25519,
Rsa {
hash: Option<HashAlg>,
},
SkEcdsaSha2NistP256,
SkEd25519,
}
impl Algorithm {
pub fn new(id: &str) -> Result<Self> {
match id {
SSH_DSA => Ok(Algorithm::Dsa),
ECDSA_SHA2_P256 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP256,
}),
ECDSA_SHA2_P384 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP384,
}),
ECDSA_SHA2_P521 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP521,
}),
RSA_SHA2_256 => Ok(Algorithm::Rsa {
hash: Some(HashAlg::Sha256),
}),
RSA_SHA2_512 => Ok(Algorithm::Rsa {
hash: Some(HashAlg::Sha512),
}),
SSH_ED25519 => Ok(Algorithm::Ed25519),
SSH_RSA => Ok(Algorithm::Rsa { hash: None }),
SK_ECDSA_SHA2_P256 => Ok(Algorithm::SkEcdsaSha2NistP256),
SK_SSH_ED25519 => Ok(Algorithm::SkEd25519),
_ => Err(Error::Algorithm),
}
}
pub fn new_certificate(id: &str) -> Result<Self> {
match id {
CERT_DSA => Ok(Algorithm::Dsa),
CERT_ECDSA_SHA2_P256 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP256,
}),
CERT_ECDSA_SHA2_P384 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP384,
}),
CERT_ECDSA_SHA2_P521 => Ok(Algorithm::Ecdsa {
curve: EcdsaCurve::NistP521,
}),
CERT_ED25519 => Ok(Algorithm::Ed25519),
CERT_RSA => Ok(Algorithm::Rsa { hash: None }),
CERT_SK_ECDSA_SHA2_P256 => Ok(Algorithm::SkEcdsaSha2NistP256),
CERT_SK_SSH_ED25519 => Ok(Algorithm::SkEd25519),
_ => Err(Error::Algorithm),
}
}
pub fn as_str(self) -> &'static str {
match self {
Algorithm::Dsa => SSH_DSA,
Algorithm::Ecdsa { curve } => match curve {
EcdsaCurve::NistP256 => ECDSA_SHA2_P256,
EcdsaCurve::NistP384 => ECDSA_SHA2_P384,
EcdsaCurve::NistP521 => ECDSA_SHA2_P521,
},
Algorithm::Ed25519 => SSH_ED25519,
Algorithm::Rsa { hash } => match hash {
None => SSH_RSA,
Some(HashAlg::Sha256) => RSA_SHA2_256,
Some(HashAlg::Sha512) => RSA_SHA2_512,
},
Algorithm::SkEcdsaSha2NistP256 => SK_ECDSA_SHA2_P256,
Algorithm::SkEd25519 => SK_SSH_ED25519,
}
}
pub fn as_certificate_str(self) -> &'static str {
match self {
Algorithm::Dsa => CERT_DSA,
Algorithm::Ecdsa { curve } => match curve {
EcdsaCurve::NistP256 => CERT_ECDSA_SHA2_P256,
EcdsaCurve::NistP384 => CERT_ECDSA_SHA2_P384,
EcdsaCurve::NistP521 => CERT_ECDSA_SHA2_P521,
},
Algorithm::Ed25519 => CERT_ED25519,
Algorithm::Rsa { .. } => CERT_RSA,
Algorithm::SkEcdsaSha2NistP256 => CERT_SK_ECDSA_SHA2_P256,
Algorithm::SkEd25519 => CERT_SK_SSH_ED25519,
}
}
pub fn is_dsa(self) -> bool {
self == Algorithm::Dsa
}
pub fn is_ecdsa(self) -> bool {
matches!(self, Algorithm::Ecdsa { .. })
}
pub fn is_ed25519(self) -> bool {
self == Algorithm::Ed25519
}
pub fn is_rsa(self) -> bool {
matches!(self, Algorithm::Rsa { .. })
}
}
impl AsRef<str> for Algorithm {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Label for Algorithm {
type Error = Error;
}
impl Default for Algorithm {
fn default() -> Algorithm {
Algorithm::Ed25519
}
}
impl fmt::Display for Algorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl str::FromStr for Algorithm {
type Err = Error;
fn from_str(id: &str) -> Result<Self> {
Self::new(id)
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub enum EcdsaCurve {
NistP256,
NistP384,
NistP521,
}
impl EcdsaCurve {
pub fn new(id: &str) -> Result<Self> {
match id {
"nistp256" => Ok(EcdsaCurve::NistP256),
"nistp384" => Ok(EcdsaCurve::NistP384),
"nistp521" => Ok(EcdsaCurve::NistP521),
_ => Err(Error::Algorithm),
}
}
pub fn as_str(self) -> &'static str {
match self {
EcdsaCurve::NistP256 => "nistp256",
EcdsaCurve::NistP384 => "nistp384",
EcdsaCurve::NistP521 => "nistp521",
}
}
#[cfg(feature = "alloc")]
pub(crate) const fn field_size(self) -> usize {
match self {
EcdsaCurve::NistP256 => 32,
EcdsaCurve::NistP384 => 48,
EcdsaCurve::NistP521 => 66,
}
}
}
impl AsRef<str> for EcdsaCurve {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Label for EcdsaCurve {
type Error = Error;
}
impl fmt::Display for EcdsaCurve {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl str::FromStr for EcdsaCurve {
type Err = Error;
fn from_str(id: &str) -> Result<Self> {
EcdsaCurve::new(id)
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum HashAlg {
Sha256,
Sha512,
}
impl HashAlg {
pub fn new(id: &str) -> Result<Self> {
match id {
SHA256 => Ok(HashAlg::Sha256),
SHA512 => Ok(HashAlg::Sha512),
_ => Err(Error::Algorithm),
}
}
pub fn as_str(self) -> &'static str {
match self {
HashAlg::Sha256 => SHA256,
HashAlg::Sha512 => SHA512,
}
}
pub const fn digest_size(self) -> usize {
match self {
HashAlg::Sha256 => 32,
HashAlg::Sha512 => 64,
}
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub fn digest(self, msg: &[u8]) -> Vec<u8> {
match self {
HashAlg::Sha256 => Sha256::digest(msg).to_vec(),
HashAlg::Sha512 => Sha512::digest(msg).to_vec(),
}
}
}
impl Label for HashAlg {
type Error = Error;
}
impl AsRef<str> for HashAlg {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Default for HashAlg {
fn default() -> Self {
HashAlg::Sha256
}
}
impl fmt::Display for HashAlg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl str::FromStr for HashAlg {
type Err = Error;
fn from_str(id: &str) -> Result<Self> {
HashAlg::new(id)
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum KdfAlg {
None,
Bcrypt,
}
impl KdfAlg {
pub fn new(kdfname: &str) -> Result<Self> {
match kdfname {
NONE => Ok(Self::None),
BCRYPT => Ok(Self::Bcrypt),
_ => Err(Error::Algorithm),
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::None => NONE,
Self::Bcrypt => BCRYPT,
}
}
pub fn is_none(self) -> bool {
self == Self::None
}
}
impl Label for KdfAlg {
type Error = Error;
}
impl AsRef<str> for KdfAlg {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Default for KdfAlg {
fn default() -> KdfAlg {
KdfAlg::Bcrypt
}
}
impl fmt::Display for KdfAlg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl str::FromStr for KdfAlg {
type Err = Error;
fn from_str(id: &str) -> Result<Self> {
Self::new(id)
}
}