use crate::canon::{canonicalize_context, canonicalize_context_string};
use crate::error::AadError;
use crate::parse::{parse_aad, ParsedAad, CURRENT_VERSION};
use crate::types::{ExtensionValue, Extensions, FieldKey, Purpose, Resource, SafeInt, Tenant};
use serde::ser::SerializeMap;
use serde::{Serialize, Serializer};
use std::collections::BTreeMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AadContext {
version: SafeInt,
tenant: Tenant,
resource: Resource,
purpose: Purpose,
timestamp: Option<SafeInt>,
extensions: Extensions,
}
impl AadContext {
pub fn new(
tenant: impl Into<String>,
resource: impl Into<String>,
purpose: impl Into<String>,
) -> Result<Self, AadError> {
Ok(Self {
version: SafeInt::new(CURRENT_VERSION)?,
tenant: Tenant::new(tenant)?,
resource: Resource::new(resource)?,
purpose: Purpose::new(purpose)?,
timestamp: None,
extensions: BTreeMap::new(),
})
}
#[must_use]
pub fn builder() -> AadContextBuilder {
AadContextBuilder::new()
}
pub fn with_timestamp(mut self, ts: u64) -> Result<Self, AadError> {
self.timestamp = Some(SafeInt::new(ts)?);
Ok(self)
}
pub fn with_extension(
mut self,
key: impl Into<String>,
value: ExtensionValue,
) -> Result<Self, AadError> {
let key = FieldKey::new(key.into())?;
key.validate_as_extension()?;
self.extensions.insert(key, value);
Ok(self)
}
pub fn with_string_extension(
self,
key: impl Into<String>,
value: impl Into<String>,
) -> Result<Self, AadError> {
self.with_extension(key, ExtensionValue::string(value)?)
}
pub fn with_int_extension(self, key: impl Into<String>, value: u64) -> Result<Self, AadError> {
self.with_extension(key, ExtensionValue::integer(value)?)
}
#[must_use]
pub const fn version(&self) -> u64 {
self.version.value()
}
#[must_use]
pub fn tenant(&self) -> &str {
self.tenant.as_str()
}
#[must_use]
pub fn resource(&self) -> &str {
self.resource.as_str()
}
#[must_use]
pub fn purpose(&self) -> &str {
self.purpose.as_str()
}
#[must_use]
pub fn timestamp(&self) -> Option<u64> {
self.timestamp.map(|ts| ts.value())
}
#[must_use]
pub const fn extensions(&self) -> &Extensions {
&self.extensions
}
pub fn canonicalize(&self) -> Result<Vec<u8>, AadError> {
canonicalize_context(self)
}
pub fn canonicalize_string(&self) -> Result<String, AadError> {
canonicalize_context_string(self)
}
}
impl Serialize for AadContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 4;
if self.timestamp.is_some() {
field_count += 1;
}
field_count += self.extensions.len();
let mut map = serializer.serialize_map(Some(field_count))?;
map.serialize_entry("purpose", self.purpose.as_str())?;
map.serialize_entry("resource", self.resource.as_str())?;
map.serialize_entry("tenant", self.tenant.as_str())?;
if let Some(ts) = &self.timestamp {
map.serialize_entry("ts", &ts.value())?;
}
map.serialize_entry("v", &self.version.value())?;
for (key, value) in &self.extensions {
match value {
ExtensionValue::String(s) => map.serialize_entry(key.as_str(), s)?,
ExtensionValue::Integer(i) => map.serialize_entry(key.as_str(), &i.value())?,
}
}
map.end()
}
}
impl TryFrom<ParsedAad> for AadContext {
type Error = AadError;
fn try_from(parsed: ParsedAad) -> Result<Self, Self::Error> {
Ok(Self {
version: parsed.version,
tenant: parsed.tenant,
resource: parsed.resource,
purpose: parsed.purpose,
timestamp: parsed.timestamp,
extensions: parsed.extensions,
})
}
}
#[derive(Debug, Default)]
pub struct AadContextBuilder {
tenant: Option<String>,
resource: Option<String>,
purpose: Option<String>,
timestamp: Option<u64>,
extensions: Vec<(String, ExtensionValue)>,
}
impl AadContextBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn tenant(mut self, tenant: impl Into<String>) -> Self {
self.tenant = Some(tenant.into());
self
}
#[must_use]
pub fn resource(mut self, resource: impl Into<String>) -> Self {
self.resource = Some(resource.into());
self
}
#[must_use]
pub fn purpose(mut self, purpose: impl Into<String>) -> Self {
self.purpose = Some(purpose.into());
self
}
#[must_use]
pub const fn timestamp(mut self, ts: u64) -> Self {
self.timestamp = Some(ts);
self
}
#[must_use]
pub fn extension_string(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
if let Ok(v) = ExtensionValue::string(value) {
self.extensions.push((key.into(), v));
}
self
}
#[must_use]
pub fn extension_int(mut self, key: impl Into<String>, value: u64) -> Self {
if let Ok(v) = ExtensionValue::integer(value) {
self.extensions.push((key.into(), v));
}
self
}
pub fn build(self) -> Result<AadContext, AadError> {
let tenant = self.tenant.ok_or(AadError::MissingRequiredField { field: "tenant" })?;
let resource = self.resource.ok_or(AadError::MissingRequiredField { field: "resource" })?;
let purpose = self.purpose.ok_or(AadError::MissingRequiredField { field: "purpose" })?;
let mut ctx = AadContext::new(tenant, resource, purpose)?;
if let Some(ts) = self.timestamp {
ctx = ctx.with_timestamp(ts)?;
}
for (key, value) in self.extensions {
ctx = ctx.with_extension(key, value)?;
}
Ok(ctx)
}
}
pub fn parse(json: &str) -> Result<AadContext, AadError> {
let parsed = parse_aad(json)?;
AadContext::try_from(parsed)
}
pub fn validate(json: &str) -> Result<AadContext, AadError> {
parse(json)
}
pub fn canonicalize(json: &str) -> Result<Vec<u8>, AadError> {
let ctx = parse(json)?;
ctx.canonicalize()
}
pub fn canonicalize_string(json: &str) -> Result<String, AadError> {
let ctx = parse(json)?;
ctx.canonicalize_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_new() {
let ctx = AadContext::new("org_abc", "secrets/db", "encryption").unwrap();
assert_eq!(ctx.version(), 1);
assert_eq!(ctx.tenant(), "org_abc");
assert_eq!(ctx.resource(), "secrets/db");
assert_eq!(ctx.purpose(), "encryption");
assert!(ctx.timestamp().is_none());
assert!(ctx.extensions().is_empty());
}
#[test]
fn test_context_with_timestamp() {
let ctx =
AadContext::new("org", "res", "test").unwrap().with_timestamp(1_706_400_000).unwrap();
assert_eq!(ctx.timestamp(), Some(1_706_400_000));
}
#[test]
fn test_context_with_extension() {
let ctx = AadContext::new("org", "res", "test")
.unwrap()
.with_string_extension("x_vault_cluster", "us-east-1")
.unwrap();
assert_eq!(ctx.extensions().len(), 1);
}
#[test]
fn test_builder() {
let ctx = AadContext::builder()
.tenant("org_abc")
.resource("secrets/db")
.purpose("encryption")
.timestamp(1_706_400_000)
.extension_string("x_app_field", "value")
.build()
.unwrap();
assert_eq!(ctx.tenant(), "org_abc");
assert_eq!(ctx.timestamp(), Some(1_706_400_000));
assert_eq!(ctx.extensions().len(), 1);
}
#[test]
fn test_builder_missing_required() {
let result = AadContext::builder().tenant("org").resource("res").build();
assert!(matches!(result, Err(AadError::MissingRequiredField { field: "purpose" })));
}
#[test]
fn test_canonicalize_order() {
let ctx = AadContext::new("org_abc", "secrets/db", "encryption").unwrap();
let canonical = ctx.canonicalize_string().unwrap();
assert_eq!(
canonical,
r#"{"purpose":"encryption","resource":"secrets/db","tenant":"org_abc","v":1}"#
);
}
#[test]
fn test_parse_and_canonicalize() {
let json = r#"{"v":1,"tenant":"org_abc","resource":"secrets/db","purpose":"encryption"}"#;
let canonical = canonicalize_string(json).unwrap();
assert_eq!(
canonical,
r#"{"purpose":"encryption","resource":"secrets/db","tenant":"org_abc","v":1}"#
);
}
#[test]
fn test_parse_roundtrip() {
let original =
r#"{"purpose":"encryption","resource":"secrets/db","tenant":"org_abc","v":1}"#;
let ctx = parse(original).unwrap();
let canonical = ctx.canonicalize_string().unwrap();
assert_eq!(canonical, original);
}
}