use std::convert::Infallible;
use crate::policy_context::AUTHENTICATION_PROPERTY;
use classy::extract::context::FilterContext;
use classy::extract::{Extract, FromContext};
use classy::hl::context::{RequestContext, ResponseContext};
use classy::hl::StreamProperties;
use classy::stream::PropertyAccessor;
use log::warn;
use pdk_script::{AuthenticationBinding, IntoValue, Value};
use rmp_serde::Serializer;
use serde::{Deserialize, Serialize};
pub trait AuthenticationHandler {
fn authentication(&self) -> Option<AuthenticationData>;
fn set_authentication(&self, authentication: Option<&AuthenticationData>);
}
pub struct Authentication {
property_accessor: Box<dyn PropertyAccessor>,
}
impl FromContext<FilterContext> for Authentication {
type Error = Infallible;
fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
let stream: StreamProperties = context.extract()?;
Ok(Authentication::new(stream))
}
}
impl FromContext<RequestContext> for Authentication {
type Error = Infallible;
fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
let stream: StreamProperties = context.extract()?;
Ok(Authentication::new(stream))
}
}
impl<C> FromContext<ResponseContext<C>> for Authentication {
type Error = Infallible;
fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
let stream: StreamProperties = context.extract()?;
Ok(Authentication::new(stream))
}
}
impl Authentication {
pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
Self {
property_accessor: Box::new(property_accessor),
}
}
}
impl AuthenticationHandler for Authentication {
fn authentication(&self) -> Option<AuthenticationData> {
let bytes = self
.property_accessor
.read_property(AUTHENTICATION_PROPERTY)?;
AuthenticationStreamSerializer::deserialize(bytes.as_slice())
}
fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
self.property_accessor
.set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
}
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AuthenticationData {
pub principal: Option<String>,
pub client_id: Option<String>,
pub client_name: Option<String>,
pub properties: Value,
}
impl AuthenticationData {
pub fn new<K: IntoValue>(
principal: Option<String>,
client_id: Option<String>,
client_name: Option<String>,
properties: K,
) -> Self {
Self {
principal,
client_id,
client_name,
properties: properties.into_value(),
}
}
}
impl AuthenticationBinding for AuthenticationData {
fn client_id(&self) -> Option<String> {
self.client_id.clone()
}
fn client_name(&self) -> Option<String> {
self.client_name.clone()
}
fn principal(&self) -> Option<String> {
self.principal.clone()
}
fn properties(&self) -> Option<Value> {
Some(self.properties.clone())
}
}
struct AuthenticationStreamSerializer;
impl AuthenticationStreamSerializer {
pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
match rmp_serde::decode::from_read(bytes) {
Ok(authentication) => Some(authentication),
Err(err) => {
warn!("Unexpected error deserializing Authentication object: {err}");
None
}
}
}
pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
let mut buf = Vec::new();
let result = authentication.serialize(&mut Serializer::new(&mut buf));
match result {
Ok(_) => Some(buf),
Err(err) => {
warn!("Unexpected error serializing Authentication object: {err}");
None
}
}
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::collections::HashMap;
use super::*;
use classy::proxy_wasm::types::Bytes;
const KEY_1: &str = "key1";
const KEY_2: &str = "key2";
const VALUE: &str = "value2";
const PRINCIPAL: &str = "principal";
const CLIENT_ID: &str = "client_id";
const CLIENT_NAME: &str = "client_name";
#[derive(Default)]
struct MockPropertyAccessor {
properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
}
impl PropertyAccessor for MockPropertyAccessor {
fn read_property(&self, path: &[&str]) -> Option<Bytes> {
let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
self.properties
.take()
.get(&path)
.cloned()
.unwrap_or_default()
}
fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
let bytes = value.map(Bytes::from);
self.properties.borrow_mut().insert(path.to_vec(), bytes);
}
}
#[test]
fn serialize_and_deserialize_authentication_to_bytes() {
let auth = create_authentication();
let property_accessor = MockPropertyAccessor::default();
let auth_handler = Authentication::new(property_accessor);
auth_handler.set_authentication(Some(&auth));
let auth = auth_handler.authentication();
assert_authentication(auth.clone());
match auth.unwrap().properties {
Value::Object(obj) => assert_eq!(obj.len(), 2),
_ => panic!(),
}
}
#[test]
fn handler_get_empty() {
let property_accessor = MockPropertyAccessor::default();
let auth_handler = Authentication::new(property_accessor);
let auth = auth_handler.authentication();
assert!(auth.is_none())
}
#[test]
fn handler_new_authentication_creates_auth_when_no_previous_data() {
let property_accessor = MockPropertyAccessor::default();
let auth_handler = Authentication::new(property_accessor);
let new_auth = AuthenticationData::new(
Some(PRINCIPAL.to_string()),
Some(CLIENT_ID.to_string()),
Some(CLIENT_NAME.to_string()),
HashMap::from([
(KEY_1.to_string(), Value::Bool(true)),
(KEY_2.to_string(), Value::String(VALUE.to_string())),
]),
);
auth_handler.set_authentication(Some(&new_auth));
let auth = auth_handler.authentication();
assert_authentication(auth.clone());
assert_eq!(new_auth, auth.unwrap());
}
fn assert_authentication(auth: Option<AuthenticationData>) {
assert!(auth.is_some());
let unwrapped = auth.unwrap();
assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
assert_eq!(
unwrapped.properties.as_object().unwrap().get(KEY_1),
Some(&Value::Bool(true))
);
assert_eq!(
unwrapped.properties.as_object().unwrap().get(KEY_2),
Some(&Value::String(VALUE.to_string()))
);
}
fn create_authentication() -> AuthenticationData {
AuthenticationData {
principal: Some(PRINCIPAL.to_string()),
client_id: Some(CLIENT_ID.to_string()),
client_name: Some(CLIENT_NAME.to_string()),
properties: HashMap::from([
(KEY_1.to_string(), Value::Bool(true)),
(KEY_2.to_string(), Value::String(VALUE.to_string())),
])
.into_value(),
}
}
}