pdk_core/policy_context/
authentication.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! Utils to access and share authentication data between filters.
6
7use std::convert::Infallible;
8
9use crate::policy_context::AUTHENTICATION_PROPERTY;
10use classy::extract::context::FilterContext;
11use classy::extract::{Extract, FromContext};
12use classy::hl::context::{RequestContext, ResponseContext};
13use classy::hl::StreamProperties;
14use classy::stream::PropertyAccessor;
15use log::warn;
16use pdk_script::{AuthenticationBinding, IntoValue, Value};
17use rmp_serde::Serializer;
18use serde::{Deserialize, Serialize};
19
20/// Trait to access and share authentication data between filters.
21pub trait AuthenticationHandler {
22    /// Get the current data regarding authentication.
23    fn authentication(&self) -> Option<AuthenticationData>;
24    /// Replace the authentication data.
25    fn set_authentication(&self, authentication: Option<&AuthenticationData>);
26}
27
28/// Default implementation of the [`AuthenticationHandler`] trait.
29pub struct Authentication {
30    property_accessor: Box<dyn PropertyAccessor>,
31}
32
33impl FromContext<FilterContext> for Authentication {
34    type Error = Infallible;
35
36    fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
37        let stream: StreamProperties = context.extract()?;
38        Ok(Authentication::new(stream))
39    }
40}
41
42impl FromContext<RequestContext> for Authentication {
43    type Error = Infallible;
44
45    fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
46        let stream: StreamProperties = context.extract()?;
47        Ok(Authentication::new(stream))
48    }
49}
50
51impl<C> FromContext<ResponseContext<C>> for Authentication {
52    type Error = Infallible;
53
54    fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
55        let stream: StreamProperties = context.extract()?;
56        Ok(Authentication::new(stream))
57    }
58}
59
60impl Authentication {
61    pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
62        Self {
63            property_accessor: Box::new(property_accessor),
64        }
65    }
66}
67
68impl AuthenticationHandler for Authentication {
69    fn authentication(&self) -> Option<AuthenticationData> {
70        let bytes = self
71            .property_accessor
72            .read_property(AUTHENTICATION_PROPERTY)?;
73        AuthenticationStreamSerializer::deserialize(bytes.as_slice())
74    }
75
76    fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
77        let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
78        self.property_accessor
79            .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
80    }
81}
82
83#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
84/// The data regarding the authentication.
85pub struct AuthenticationData {
86    /// The main value used for authenticated.
87    pub principal: Option<String>,
88    /// The client id associated to the authentication method used.
89    pub client_id: Option<String>,
90    /// The name id associated to the authentication method used.
91    pub client_name: Option<String>,
92    /// Additional properties of the authentication.
93    pub properties: Value,
94}
95
96impl AuthenticationData {
97    pub fn new<K: IntoValue>(
98        principal: Option<String>,
99        client_id: Option<String>,
100        client_name: Option<String>,
101        properties: K,
102    ) -> Self {
103        Self {
104            principal,
105            client_id,
106            client_name,
107            properties: properties.into_value(),
108        }
109    }
110}
111
112impl AuthenticationBinding for AuthenticationData {
113    fn client_id(&self) -> Option<String> {
114        self.client_id.clone()
115    }
116
117    fn client_name(&self) -> Option<String> {
118        self.client_name.clone()
119    }
120
121    fn principal(&self) -> Option<String> {
122        self.principal.clone()
123    }
124
125    fn properties(&self) -> Option<Value> {
126        Some(self.properties.clone())
127    }
128}
129
130/// Serializes and deserializes Authentication objects so that can be propagated between policies.
131/// The chosen serialization format is MessagePack. Using a cross-language format allows to
132/// propagate the object between filters that were coded in any language
133struct AuthenticationStreamSerializer;
134
135impl AuthenticationStreamSerializer {
136    pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
137        match rmp_serde::decode::from_read(bytes) {
138            Ok(authentication) => Some(authentication),
139            Err(err) => {
140                warn!("Unexpected error deserializing Authentication object: {err}");
141                None
142            }
143        }
144    }
145
146    pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
147        let mut buf = Vec::new();
148        let result = authentication.serialize(&mut Serializer::new(&mut buf));
149
150        match result {
151            Ok(_) => Some(buf),
152            Err(err) => {
153                warn!("Unexpected error serializing Authentication object: {err}");
154                None
155            }
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use std::cell::RefCell;
163    use std::collections::HashMap;
164
165    use super::*;
166    use classy::proxy_wasm::types::Bytes;
167
168    const KEY_1: &str = "key1";
169    const KEY_2: &str = "key2";
170
171    const VALUE: &str = "value2";
172
173    const PRINCIPAL: &str = "principal";
174    const CLIENT_ID: &str = "client_id";
175    const CLIENT_NAME: &str = "client_name";
176
177    #[derive(Default)]
178    struct MockPropertyAccessor {
179        properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
180    }
181
182    impl PropertyAccessor for MockPropertyAccessor {
183        fn read_property(&self, path: &[&str]) -> Option<Bytes> {
184            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
185            self.properties
186                .take()
187                .get(&path)
188                .cloned()
189                .unwrap_or_default()
190        }
191
192        fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
193            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
194            let bytes = value.map(Bytes::from);
195            self.properties.borrow_mut().insert(path.to_vec(), bytes);
196        }
197    }
198
199    #[test]
200    fn serialize_and_deserialize_authentication_to_bytes() {
201        let auth = create_authentication();
202        let property_accessor = MockPropertyAccessor::default();
203        let auth_handler = Authentication::new(property_accessor);
204
205        auth_handler.set_authentication(Some(&auth));
206        let auth = auth_handler.authentication();
207
208        assert_authentication(auth.clone());
209        match auth.unwrap().properties {
210            Value::Object(obj) => assert_eq!(obj.len(), 2),
211            _ => panic!(),
212        }
213    }
214
215    #[test]
216    fn handler_get_empty() {
217        let property_accessor = MockPropertyAccessor::default();
218        let auth_handler = Authentication::new(property_accessor);
219
220        let auth = auth_handler.authentication();
221
222        assert!(auth.is_none())
223    }
224
225    #[test]
226    fn handler_new_authentication_creates_auth_when_no_previous_data() {
227        let property_accessor = MockPropertyAccessor::default();
228        let auth_handler = Authentication::new(property_accessor);
229
230        let new_auth = AuthenticationData::new(
231            Some(PRINCIPAL.to_string()),
232            Some(CLIENT_ID.to_string()),
233            Some(CLIENT_NAME.to_string()),
234            HashMap::from([
235                (KEY_1.to_string(), Value::Bool(true)),
236                (KEY_2.to_string(), Value::String(VALUE.to_string())),
237            ]),
238        );
239
240        auth_handler.set_authentication(Some(&new_auth));
241
242        let auth = auth_handler.authentication();
243
244        assert_authentication(auth.clone());
245        assert_eq!(new_auth, auth.unwrap());
246    }
247
248    fn assert_authentication(auth: Option<AuthenticationData>) {
249        assert!(auth.is_some());
250        let unwrapped = auth.unwrap();
251        assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
252        assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
253        assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
254        assert_eq!(
255            unwrapped.properties.as_object().unwrap().get(KEY_1),
256            Some(&Value::Bool(true))
257        );
258        assert_eq!(
259            unwrapped.properties.as_object().unwrap().get(KEY_2),
260            Some(&Value::String(VALUE.to_string()))
261        );
262    }
263
264    fn create_authentication() -> AuthenticationData {
265        AuthenticationData {
266            principal: Some(PRINCIPAL.to_string()),
267            client_id: Some(CLIENT_ID.to_string()),
268            client_name: Some(CLIENT_NAME.to_string()),
269            properties: HashMap::from([
270                (KEY_1.to_string(), Value::Bool(true)),
271                (KEY_2.to_string(), Value::String(VALUE.to_string())),
272            ])
273            .into_value(),
274        }
275    }
276}