use crate::crypto_algorithm::CryptoAlgorithm;
use crate::crypto_key::CryptoKey;
use crate::dat::DatPayload;
use crate::dat_records::DatRecords;
use crate::error::DatError;
use crate::signature_algorithm::SignatureAlgorithm;
use crate::signature_key::{SignatureKey, SignatureKeyOutOption};
use crate::util::{decode_base64_url_no_pad, encode_base64_url_no_pad, encode_base64_url_no_pad_out, now_unix_timestamp};
use crate::VERSION_DAT;
use std::cmp::PartialEq;
use std::fmt::Display;
use std::fmt::Write;
use std::str::FromStr;
pub trait Kid: PartialEq + Display + FromStr + Clone {}
impl<T> Kid for T where T: PartialEq + Display + FromStr + Clone {}
pub struct DatKey<T: Kid> {
pub(crate) kid: T,
pub(crate) signature_key: SignatureKey,
pub(crate) crypto_key: CryptoKey,
pub(crate) issue_begin: u64,
pub(crate) issue_end: u64,
pub(crate) token_ttl: u64,
}
impl <T: Kid> DatKey<T> {
pub fn generate(kid: T, signature_algorithm: SignatureAlgorithm, crypto_algorithm: CryptoAlgorithm, issue_begin: u64, issue_end: u64, token_ttl: u64) -> Result<Self, DatError> {
Self::from(kid, SignatureKey::generate(signature_algorithm), CryptoKey::generate(crypto_algorithm), issue_begin, issue_end, token_ttl)
}
pub fn from(kid: T, signature_key: SignatureKey, crypto_key: CryptoKey, issue_begin: u64, issue_end: u64, token_ttl: u64) -> Result<Self, DatError> {
if kid.to_string().contains(['.', '\r', '\n']) {
return Err(DatError::InvalidDatKidFormat);
}
Ok(DatKey { kid, signature_key, crypto_key, issue_begin, issue_end, token_ttl })
}
pub fn kid(&self) -> T { self.kid.clone() }
pub fn signature_key(&self) -> SignatureKey { self.signature_key.clone() }
pub fn crypto_key(&self) -> CryptoKey { self.crypto_key.clone() }
pub fn issue_begin(&self) -> u64 { self.issue_begin }
pub fn issue_end(&self) -> u64 { self.issue_end }
pub fn token_ttl(&self) -> u64 { self.token_ttl }
pub fn key_expire(&self) -> u64 { self.issue_end + self.token_ttl }
pub fn format(&self, signature_key_out_option: SignatureKeyOutOption) -> Result<String, DatError> {
let kid = self.kid.to_string();
let signature_algorithm = self.signature_key.algorithm();
let (sk, vk) = self.signature_key.to_bytes();
if sk.is_empty() && signature_key_out_option != SignatureKeyOutOption::VERIFYING {
return Err(DatError::VerifyOnlyKey)
}
let signature_key = match signature_key_out_option {
SignatureKeyOutOption::FULL => format!("{}~{}", encode_base64_url_no_pad(sk), encode_base64_url_no_pad(vk)),
SignatureKeyOutOption::SIGNING => encode_base64_url_no_pad(sk),
SignatureKeyOutOption::VERIFYING => format!("~{}", encode_base64_url_no_pad(vk)),
};
let crypto_algorithm = self.crypto_key.algorithm();
let crypto_key = encode_base64_url_no_pad(self.crypto_key.to_bytes());
let issue_begin = self.issue_begin;
let issue_end = self.issue_end;
let token_ttl = self.token_ttl;
Ok(format!("{VERSION_DAT}.{kid}.{signature_algorithm}.{signature_key}.{crypto_algorithm}.{crypto_key}.{issue_begin}.{issue_end}.{token_ttl}"))
}
pub fn to_payload(&self, dat_records: &DatRecords<T>) -> Result<DatPayload, DatError> {
if self.signature_key.verify(dat_records.body_str().as_bytes(), &*decode_base64_url_no_pad(dat_records.sign_base64())?).is_err() {
return Err(DatError::InvalidDatFormat)
}
self.to_payload_without_verify(dat_records)
}
pub fn to_payload_without_verify(&self, dat_records: &DatRecords<T>) -> Result<DatPayload, DatError> {
Ok(DatPayload {
expire: dat_records.expire(),
plain_bytes: decode_base64_url_no_pad(dat_records.plain_base64())?,
secure_bytes: self.crypto_key.decrypt(&*decode_base64_url_no_pad(dat_records.secure_base64())?)?,
})
}
pub fn to_dat<U: AsRef<[u8]>>(&self, plain: U, secure: U) -> Result<String, DatError> {
let sk = &self.signature_key;
let mut dat = String::with_capacity(100 + ((plain.as_ref().len() + secure.as_ref().len() + sk.signature_size()) * 4 / 3));
write!(dat, "{}.{}.", now_unix_timestamp() + self.token_ttl, self.kid).unwrap();
encode_base64_url_no_pad_out(plain.as_ref(), &mut dat); dat.push('.');
encode_base64_url_no_pad_out(self.crypto_key.encrypt(secure.as_ref())?, &mut dat); dat.push('.');
encode_base64_url_no_pad_out(sk.sign(dat[0..dat.len() - 1].as_bytes()), &mut dat); Ok(dat)
}
}
impl <T: Kid> FromStr for DatKey<T> {
type Err = DatError;
fn from_str(format: &str) -> Result<Self, Self::Err> {
let split = format.split(".").collect::<Vec<&str>>();
let count = split.len();
let version = split[0];
match version {
"2" | "1" => {
return if count == 9 {
let kid = split[1].parse::<T>().map_err(|_| DatError::InvalidDatKidFormat)?;
let signature_algorithm = SignatureAlgorithm::from_str(split[2])?;
let signature_key_str = split[3];
let signature_key = if let Some(pos) = signature_key_str.find('~') {
if pos == 0 { SignatureKey::from_bytes(signature_algorithm, &[], &*decode_base64_url_no_pad(signature_key_str[1..].as_bytes())?)
} else { SignatureKey::from_bytes(signature_algorithm, &*decode_base64_url_no_pad(signature_key_str[..pos].as_bytes())?, &*decode_base64_url_no_pad(signature_key_str[pos + 1..].as_bytes())?)
}
} else { SignatureKey::from_bytes(signature_algorithm, &*decode_base64_url_no_pad(signature_key_str)?, &[])
}?;
let crypto_algorithm = CryptoAlgorithm::from_str(split[4])?;
let crypto_key = CryptoKey::from_bytes(crypto_algorithm, &*decode_base64_url_no_pad(split[5])?)?;
let issue_begin = split[6].parse::<u64>().map_err(|_| DatError::InvalidDatKeyFormat)?;
let issue_end = split[7].parse::<u64>().map_err(|_| DatError::InvalidDatKeyFormat)?;
let token_ttl = split[8].parse::<u64>().map_err(|_| DatError::InvalidDatKeyFormat)?;
DatKey::from(kid, signature_key, crypto_key, issue_begin, issue_end, token_ttl)
} else {
Err(DatError::InvalidDatKeyFormat)
}
},
_ => {}
}
Err(DatError::UnSupportDatKeyVersion)
}
}
impl <T: Kid> PartialEq<DatKey<T>> for DatKey<T> {
fn eq(&self, other: &DatKey<T>) -> bool {
self.kid.eq(&other.kid)
}
}
impl <T: Kid> PartialEq<T> for DatKey<T> {
fn eq(&self, other: &T) -> bool {
self.kid.eq(other)
}
}
impl <T: Kid> Clone for DatKey<T> {
fn clone(&self) -> Self {
DatKey::<T> {
kid: self.kid.clone(),
signature_key: self.signature_key.clone(),
crypto_key: self.crypto_key.clone(),
issue_begin: self.issue_begin,
issue_end: self.issue_end,
token_ttl: self.token_ttl,
}
}
}