use std::any::Any;
use std::fmt::{Debug, Display};
use std::ops::Deref;
use crate::jwe::JweHeader;
use crate::jwk::Jwk;
use crate::{util, JoseError, JoseHeader, Map, Value};
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct JweHeaderSet {
protected: Map<String, Value>,
unprotected: Map<String, Value>,
}
impl JweHeaderSet {
pub fn new() -> Self {
Self {
protected: Map::new(),
unprotected: Map::new(),
}
}
pub fn set_algorithm(&mut self, value: impl Into<String>, protection: bool) {
let key = "alg";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn algorithm(&self) -> Option<&str> {
match self.claim("alg") {
Some(Value::String(val)) => Some(&val),
_ => None,
}
}
pub fn set_content_encryption(&mut self, value: impl Into<String>, protection: bool) {
let key = "enc";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn content_encryption(&self) -> Option<&str> {
match self.claim("enc") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_compression(&mut self, value: impl Into<String>) {
let value: String = value.into();
self.protected
.insert("zip".to_string(), Value::String(value));
}
pub fn compression(&self) -> Option<&str> {
match self.protected.get("zip") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_jwk_set_url(&mut self, value: impl Into<String>, protection: bool) {
let key = "jku";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn jwk_set_url(&self) -> Option<&str> {
match self.claim("jku") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_jwk(&mut self, value: Jwk, protection: bool) {
let key = "jwk";
let value: Map<String, Value> = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::Object(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::Object(value));
}
}
pub fn jwk(&self) -> Option<Jwk> {
match self.claim("jwk") {
Some(Value::Object(vals)) => match Jwk::from_map(vals.clone()) {
Ok(val) => Some(val),
Err(_) => None,
},
_ => None,
}
}
pub fn set_x509_url(&mut self, value: impl Into<String>, protection: bool) {
let key = "x5u";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn x509_url(&self) -> Option<&str> {
match self.claim("x5u") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_x509_certificate_chain(&mut self, values: &Vec<impl AsRef<[u8]>>, protection: bool) {
let key = "x5c";
let vec = values
.iter()
.map(|v| Value::String(util::encode_base64_standard(v.as_ref())))
.collect();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::Array(vec));
} else {
self.protected.remove(key);
self.unprotected.insert(key.to_string(), Value::Array(vec));
}
}
pub fn x509_certificate_chain(&self) -> Option<Vec<Vec<u8>>> {
match self.claim("x5c") {
Some(Value::Array(vals)) => {
let mut vec = Vec::with_capacity(vals.len());
for val in vals {
match val {
Value::String(val2) => match util::decode_base64_standard(val2) {
Ok(val3) => vec.push(val3.clone()),
Err(_) => return None,
},
_ => return None,
}
}
Some(vec)
}
_ => None,
}
}
pub fn set_x509_certificate_sha1_thumbprint(
&mut self,
value: impl AsRef<[u8]>,
protection: bool,
) {
let key = "x5t";
let value = util::encode_base64_urlsafe_nopad(value);
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn x509_certificate_sha1_thumbprint(&self) -> Option<Vec<u8>> {
match self.claim("x5t") {
Some(Value::String(val)) => match util::decode_base64_urlsafe_no_pad(val) {
Ok(val2) => Some(val2),
Err(_) => None,
},
_ => None,
}
}
pub fn set_x509_certificate_sha256_thumbprint(
&mut self,
value: impl AsRef<[u8]>,
protection: bool,
) {
let key = "x5t#S256";
let value = util::encode_base64_urlsafe_nopad(value);
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn x509_certificate_sha256_thumbprint(&self) -> Option<Vec<u8>> {
match self.claim("x5t#S256") {
Some(Value::String(val)) => match util::decode_base64_urlsafe_no_pad(val) {
Ok(val2) => Some(val2),
Err(_) => None,
},
_ => None,
}
}
pub fn set_key_id(&mut self, value: impl Into<String>, protection: bool) {
let key = "kid";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn key_id(&self) -> Option<&str> {
match self.claim("kid") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_token_type(&mut self, value: impl Into<String>, protection: bool) {
let key = "typ";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn token_type(&self) -> Option<&str> {
match self.claim("typ") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_content_type(&mut self, value: impl Into<String>, protection: bool) {
let key = "cty";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn content_type(&self) -> Option<&str> {
match self.claim("cty") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_critical(&mut self, values: &Vec<impl AsRef<str>>) {
let key = "crit";
let vec = values
.iter()
.map(|v| Value::String(v.as_ref().to_string()))
.collect();
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::Array(vec));
}
pub fn critical(&self) -> Option<Vec<&str>> {
match self.claim("crit") {
Some(Value::Array(vals)) => {
let mut vec = Vec::with_capacity(vals.len());
for val in vals {
match val {
Value::String(val2) => vec.push(val2.as_str()),
_ => return None,
}
}
Some(vec)
}
_ => None,
}
}
pub fn set_url(&mut self, value: impl Into<String>, protection: bool) {
let key = "url";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn url(&self) -> Option<&str> {
match self.claim("url") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_nonce(&mut self, value: impl AsRef<[u8]>, protection: bool) {
let key = "nonce";
let value = util::encode_base64_urlsafe_nopad(value);
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn nonce(&self) -> Option<Vec<u8>> {
match self.claim("nonce") {
Some(Value::String(val)) => match util::decode_base64_urlsafe_no_pad(val) {
Ok(val2) => Some(val2),
Err(_) => None,
},
_ => None,
}
}
pub fn set_agreement_partyuinfo(&mut self, value: impl AsRef<[u8]>, protection: bool) {
let key = "apu";
let value = util::encode_base64_urlsafe_nopad(&value);
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn agreement_partyuinfo(&self) -> Option<Vec<u8>> {
match self.claim("apu") {
Some(Value::String(val)) => match util::decode_base64_urlsafe_no_pad(val) {
Ok(val2) => Some(val2),
Err(_) => None,
},
_ => None,
}
}
pub fn set_agreement_partyvinfo(&mut self, value: impl AsRef<[u8]>, protection: bool) {
let key = "apv";
let value = util::encode_base64_urlsafe_nopad(&value);
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn agreement_partyvinfo(&self) -> Option<Vec<u8>> {
match self.claim("apv") {
Some(Value::String(val)) => match util::decode_base64_urlsafe_no_pad(val) {
Ok(val2) => Some(val2),
Err(_) => None,
},
_ => None,
}
}
pub fn set_issuer(&mut self, value: impl Into<String>, protection: bool) {
let key = "iss";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn issuer(&self) -> Option<&str> {
match self.claim("iss") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_subject(&mut self, value: impl Into<String>, protection: bool) {
let key = "sub";
let value: String = value.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
}
pub fn subject(&self) -> Option<&str> {
match self.claim("sub") {
Some(Value::String(val)) => Some(val),
_ => None,
}
}
pub fn set_audience(&mut self, values: Vec<impl Into<String>>, protection: bool) {
let key = "aud";
if values.len() == 1 {
for val in values {
let value = val.into();
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::String(value));
} else {
self.protected.remove(key);
self.unprotected
.insert(key.to_string(), Value::String(value));
}
break;
}
} else if values.len() > 1 {
let mut vec = Vec::with_capacity(values.len());
for val in values {
let val: String = val.into();
vec.push(Value::String(val.clone()));
}
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), Value::Array(vec));
} else {
self.protected.remove(key);
self.unprotected.insert(key.to_string(), Value::Array(vec));
}
}
}
pub fn audience(&self) -> Option<Vec<&str>> {
match self.claim("aud") {
Some(Value::Array(vals)) => {
let mut vec = Vec::with_capacity(vals.len());
for val in vals {
match val {
Value::String(val2) => {
vec.push(val2.as_str());
}
_ => return None,
}
}
Some(vec)
}
Some(Value::String(val)) => Some(vec![val]),
_ => None,
}
}
pub fn set_claim(
&mut self,
key: &str,
value: Option<Value>,
protection: bool,
) -> Result<(), JoseError> {
match value {
Some(val) => {
JweHeader::check_claim(key, &val)?;
if protection {
self.unprotected.remove(key);
self.protected.insert(key.to_string(), val);
} else {
self.protected.remove(key);
self.unprotected.insert(key.to_string(), val);
}
}
None => {
self.protected.remove(key);
self.unprotected.remove(key);
}
}
Ok(())
}
pub fn claims_set(&self, protection: bool) -> &Map<String, Value> {
if protection {
&self.protected
} else {
&self.unprotected
}
}
pub fn to_map(&self) -> Map<String, Value> {
let mut map = self.protected.clone();
for (key, value) in &self.unprotected {
map.insert(key.clone(), value.clone());
}
map
}
}
impl JoseHeader for JweHeaderSet {
fn len(&self) -> usize {
self.protected.len() + self.unprotected.len()
}
fn claim(&self, key: &str) -> Option<&Value> {
if let Some(val) = self.protected.get(key) {
Some(val)
} else {
self.unprotected.get(key)
}
}
fn box_clone(&self) -> Box<dyn JoseHeader> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
impl Display for JweHeaderSet {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let protected = serde_json::to_string(&self.protected).map_err(|_e| std::fmt::Error {})?;
let unprotected =
serde_json::to_string(&self.unprotected).map_err(|_e| std::fmt::Error {})?;
fmt.write_str("{\"protected\":")?;
fmt.write_str(&protected)?;
fmt.write_str(",\"unprotected\":")?;
fmt.write_str(&unprotected)?;
fmt.write_str("}")?;
Ok(())
}
}
impl Deref for JweHeaderSet {
type Target = dyn JoseHeader;
fn deref(&self) -> &Self::Target {
self
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use serde_json::json;
use crate::jwe::JweHeaderSet;
use crate::jwk::Jwk;
use crate::Value;
#[test]
fn test_new_jwe_header() -> Result<()> {
let mut header = JweHeaderSet::new();
let jwk = Jwk::new("oct");
header.set_algorithm("alg", true);
header.set_content_encryption("enc", true);
header.set_compression("zip");
header.set_jwk_set_url("jku", true);
header.set_jwk(jwk.clone(), true);
header.set_x509_url("x5u", true);
header.set_x509_certificate_chain(
&vec![
b"x5c0".to_vec(),
b"x5c1".to_vec(),
"@@~".as_bytes().to_vec(),
],
true,
);
header.set_x509_certificate_sha1_thumbprint(b"x5t@@~", true);
header.set_x509_certificate_sha256_thumbprint(b"x5t#S256 @@~", true);
header.set_key_id("kid", true);
header.set_token_type("typ", true);
header.set_content_type("cty", true);
header.set_critical(&vec!["crit0", "crit1"]);
header.set_url("url", true);
header.set_nonce(b"nonce", true);
header.set_agreement_partyuinfo(b"apu", true);
header.set_agreement_partyvinfo(b"apv", true);
header.set_issuer("iss", true);
header.set_subject("sub", true);
header.set_claim("header_claim", Some(json!("header_claim")), true)?;
assert_eq!(header.algorithm(), Some("alg"));
assert_eq!(header.content_encryption(), Some("enc"));
assert_eq!(header.compression(), Some("zip"));
assert_eq!(header.jwk_set_url(), Some("jku"));
assert_eq!(header.jwk(), Some(jwk));
assert_eq!(header.x509_url(), Some("x5u"));
assert_eq!(
header.x509_certificate_chain(),
Some(vec![
b"x5c0".to_vec(),
b"x5c1".to_vec(),
"@@~".as_bytes().to_vec()
])
);
assert_eq!(
header.claim("x5c"),
Some(&Value::Array(vec![
Value::String("eDVjMA==".to_string()),
Value::String("eDVjMQ==".to_string()),
Value::String("QEB+".to_string()),
]))
);
assert_eq!(
header.x509_certificate_sha1_thumbprint(),
Some(b"x5t@@~".to_vec())
);
assert_eq!(
header.claim("x5t"),
Some(&Value::String("eDV0QEB-".to_string()))
);
assert_eq!(
header.x509_certificate_sha256_thumbprint(),
Some(b"x5t#S256 @@~".to_vec())
);
assert_eq!(
header.claim("x5t#S256"),
Some(&Value::String("eDV0I1MyNTYgQEB-".to_string()))
);
assert_eq!(header.key_id(), Some("kid"));
assert_eq!(header.token_type(), Some("typ"));
assert_eq!(header.content_type(), Some("cty"));
assert_eq!(header.url(), Some("url"));
assert_eq!(header.nonce(), Some(b"nonce".to_vec()));
assert_eq!(header.agreement_partyuinfo(), Some(b"apu".to_vec()));
assert_eq!(header.agreement_partyvinfo(), Some(b"apv".to_vec()));
assert_eq!(header.issuer(), Some("iss"));
assert_eq!(header.subject(), Some("sub"));
assert_eq!(header.critical(), Some(vec!["crit0", "crit1"]));
assert_eq!(header.claim("header_claim"), Some(&json!("header_claim")));
Ok(())
}
}