use crate::crypt::{Cipher, CryptAlgorithm, CryptKey};
use crate::dat::DatPayload;
use crate::error::DatError;
use crate::sign::{SignAlgorithm, SignKey, VerifyKey};
use crate::util::{decode_base64_url_no_pad, encode_base64_url_no_pad, now_unix_timestamp, to_kid};
use std::fmt::{Display, Write};
use std::str::FromStr;
pub type DatSplit<'a> = [&'a str; 5];
pub trait Kid: PartialEq + Display + FromStr + Clone {}
impl<T> Kid for T where T: PartialEq + Display + FromStr + Clone {}
pub const CONV_VERSION: &str = "1";
#[derive(Debug, Clone)]
pub struct DatKeySet<T: Kid> {
pub(crate) kid: T,
pub(crate) sign_alg: SignAlgorithm,
pub(crate) sign_key: Box<[u8]>,
pub(crate) verify_key: Box<[u8]>,
pub(crate) crypt_alg: CryptAlgorithm,
pub(crate) crypt_key: Box<[u8]>,
pub(crate) issue_begin: i64,
pub(crate) issue_end: i64,
pub(crate) token_ttl: i64,
}
#[allow(dead_code)]
pub struct DatKeySetActive<T: Kid> {
pub(crate) kid: T,
pub(crate) sign_key: Option<SignKey>,
pub(crate) verify_key: VerifyKey,
pub(crate) cipher: Cipher,
pub(crate) issue_begin: i64,
pub(crate) issue_end: i64,
pub(crate) token_ttl: i64,
}
impl <T: Kid> DatKeySet<T> {
pub fn generate(kid: T, sign_alg: SignAlgorithm, crypt_alg: CryptAlgorithm, issue_begin: i64, issue_end: i64, token_ttl: i64) -> Result<Self, DatError> {
let sign_key = SignKey::generate(sign_alg).to_bytes();
let crypt_key = CryptKey::generate(crypt_alg).to_bytes();
Self::from(kid, sign_alg, sign_key, Box::new([]), crypt_alg, crypt_key, issue_begin, issue_end, token_ttl)
}
pub fn from(kid: T, sign_alg: SignAlgorithm, sign_key: Box<[u8]>, verify_key: Box<[u8]>, crypt_alg: CryptAlgorithm, crypt_key: Box<[u8]>, issue_begin: i64, issue_end: i64, token_ttl: i64) -> Result<Self, DatError> {
if kid.to_string().contains(['.', '\r', '\n']) {
return Err(DatError::KeyError("kid contains invalid characters".to_string()));
}
Ok(DatKeySet {
kid,
sign_alg,
sign_key,
verify_key,
crypt_alg,
crypt_key,
issue_begin,
issue_end,
token_ttl,
})
}
pub fn parse(format: &str) -> Result<Self, DatError> {
let split = format.split(".").collect::<Vec<&str>>();
match split[0] {
"1" => {
if split.len() == 9 {
let kid = to_kid(split[1])?;
let sign_alg = SignAlgorithm::from_str(split[2])?;
let sign_or_verify_key = split[3];
let (sign_key, verify_key) = if sign_or_verify_key.starts_with('~') {
(Box::new([]) as Box<[u8]>, decode_base64_url_no_pad(sign_or_verify_key[1..].as_ref() as &str)?.into_boxed_slice())
} else {
(decode_base64_url_no_pad(sign_or_verify_key)?.into_boxed_slice(), Box::new([]) as Box<[u8]>)
};
let crypt_alg = CryptAlgorithm::from_str(split[4])?;
let crypt_key = decode_base64_url_no_pad(split[5])?.into_boxed_slice();
let issue_begin = split[6].parse::<i64>().map_err(|e| DatError::KeyError(format!("parse error: {e}")))?;
let issue_end = split[7].parse::<i64>().map_err(|e| DatError::KeyError(format!("parse error: {e}")))?;
let token_ttl = split[8].parse::<i64>().map_err(|e| DatError::KeyError(format!("parse error: {e}")))?;
return DatKeySet::from(kid, sign_alg, sign_key, verify_key, crypt_alg, crypt_key, issue_begin, issue_end, token_ttl)
}
},
_ => {}
}
Err(DatError::KeyError("invalid version".to_string()))
}
pub fn get_base64_sign_key(&self) -> Result<String, DatError> {
if self.sign_key.len() > 0 {
Ok(encode_base64_url_no_pad(self.sign_key.as_ref()))
} else {
Err(DatError::KeyError("sign key not available".to_string()))
}
}
pub fn get_base64_verify_key(&self) -> Result<String, DatError> {
if self.verify_key.len() > 0 {
Ok(encode_base64_url_no_pad(self.verify_key.as_ref()))
} else if self.sign_key.len() > 0 {
let verify_key = SignKey::from_bytes(self.sign_alg, self.sign_key.as_ref())?.to_verify_key().to_bytes();
Ok(encode_base64_url_no_pad(verify_key))
} else {
return Err(DatError::KeyError("sign and verify key not available".to_string()));
}
}
pub fn get_key_expire(&self) -> i64 {
self.issue_end + self.token_ttl
}
pub fn format(&self, verify_only: bool) -> Result<String, DatError> {
let kid = self.kid.to_string();
let sign_alg = self.sign_alg.to_str();
let sign_key = if verify_only {
format!("~{}", self.get_base64_verify_key()?)
} else {
self.get_base64_sign_key()?
};
let crypt_alg = self.crypt_alg.to_str();
let crypt_key = encode_base64_url_no_pad(self.crypt_key.as_ref());
let issue_begin = self.issue_begin;
let issue_end = self.issue_end;
let token_ttl = self.token_ttl;
Ok(format!("{CONV_VERSION}.{kid}.{sign_alg}.{sign_key}.{crypt_alg}.{crypt_key}.{issue_begin}.{issue_end}.{token_ttl}"))
}
pub fn into_active(self) -> Result<DatKeySetActive<T>, DatError> {
let sign_key = if self.sign_key.len() > 0 {
Some(SignKey::from_bytes(self.sign_alg, &*self.sign_key)?)
} else {
None
};
let verify_key = if sign_key.is_some() {
sign_key.as_ref().unwrap().to_verify_key()
} else {
VerifyKey::from_bytes(self.sign_alg, &*self.verify_key)?
};
let crypt_key = CryptKey::from_bytes(self.crypt_alg, &*self.crypt_key)?;
Ok(DatKeySetActive {
kid: self.kid,
sign_key,
verify_key,
cipher: crypt_key.to_cipher(),
issue_begin: self.issue_begin,
issue_end: self.issue_end,
token_ttl: self.token_ttl,
})
}
}
impl <T: Kid> PartialEq for DatKeySet<T> {
fn eq(&self, other: &Self) -> bool {
self.kid.eq(&other.kid)
}
}
impl <T: Kid> DatKeySetActive<T> {
pub fn split(dat: &'_ str) -> Result<DatSplit<'_>, DatError> {
let mut ptr = dat.split(".");
let exp = ptr.next()
.ok_or_else(|| DatError::DatError("format error".to_string()))?
.parse::<i64>()
.map_err(|_| DatError::DatError("format error".to_string()))?;
if exp < now_unix_timestamp()? {
return Err(DatError::DatError("expired".to_string()));
}
let kid = ptr.next().ok_or_else(|| DatError::DatError("format error".to_string()))?;
let payload = ptr.next().ok_or_else(|| DatError::DatError("format error".to_string()))?;
let secure_payload = ptr.next().ok_or_else(|| DatError::DatError("format error".to_string()))?;
let sign = ptr.next().ok_or_else(|| DatError::DatError("format error".to_string()))?;
let sign_pos = sign.as_ptr() as usize - dat.as_ptr() as usize;
if ptr.next().is_some() {
return Err(DatError::DatError("format error".to_string()));
}
let body = &dat[0 ..sign_pos];
Ok([
kid,
payload,
secure_payload,
sign,
body,
])
}
pub fn verify(&self, dat: String) -> Result<DatPayload, DatError> {
self.verify_by_split_unsafe(Self::split(&dat)?)
}
pub fn verify_by_split_unsafe(&self, split: DatSplit) -> Result<DatPayload, DatError> {
if self.kid != to_kid(split[0])? {
return Err(DatError::DatError("kid is not equal".to_string()));
}
if self.verify_key.verify(split[4].as_bytes(), &*decode_base64_url_no_pad(split[3])?).is_err() {
return Err(DatError::DatError("invalid sign".to_string()))
}
let plain = to_utf8(decode_base64_url_no_pad(split[1])?)?;
let secure = to_utf8(self.cipher.decrypt(decode_base64_url_no_pad(split[2])?.as_slice())?)?;
Ok((plain, secure))
}
pub fn to_dat(&self, payload: DatPayload) -> Result<String, DatError> {
let expire = now_unix_timestamp()? + self.token_ttl;
let kid = &self.kid;
let (plain, secure) = payload;
let plain = encode_base64_url_no_pad(plain.as_bytes());
let secure = encode_base64_url_no_pad(self.cipher.encrypt(secure.as_bytes())?);
let mut body = String::with_capacity(100 + plain.len() + secure.len());
write!(body, "{expire}.{kid}.{plain}.{secure}.")
.map_err(|e| DatError::KeyError(format!("to_dat io error: {e}")))?;
let sign_ref = self.sign_key.as_ref()
.ok_or_else(|| DatError::KeyError("this key is for verification only".to_string()))?;
let sign = encode_base64_url_no_pad(sign_ref.sign(body.as_bytes()));
body.write_str(&sign)
.map_err(|e| DatError::KeyError(format!("to_dat io error: {e}")))?;
Ok(body)
}
}
fn to_utf8(vec: Vec<u8>) -> Result<String, DatError> {
String::from_utf8(vec).map_err(|e| DatError::KeyError(format!("to_string utf8 error: {e}")))
}