Skip to main content

pdk_core/policy_context/
authentication.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! Utils to access and share authentication data between filters.
6
7use std::collections::HashMap;
8use std::convert::Infallible;
9
10use crate::policy_context::AUTHENTICATION_PROPERTY;
11use classy::extract::context::FilterContext;
12use classy::extract::{Extract, FromContext};
13use classy::hl::context::{RequestContext, ResponseContext};
14use classy::hl::StreamProperties;
15use classy::stream::PropertyAccessor;
16use log::warn;
17use pdk_script::{AuthenticationBinding, IntoValue, Value};
18use rmp_serde::Serializer;
19use serde::{Deserialize, Serialize};
20
21/// Trait to access and share authentication data between filters.
22pub trait AuthenticationHandler {
23    /// Get the current data regarding authentication.
24    fn authentication(&self) -> Option<AuthenticationData>;
25    /// Replace the authentication data.
26    fn set_authentication(&self, authentication: Option<&AuthenticationData>);
27}
28
29/// Default implementation of the [`AuthenticationHandler`] trait.
30pub struct Authentication {
31    property_accessor: Box<dyn PropertyAccessor>,
32}
33
34impl FromContext<FilterContext> for Authentication {
35    type Error = Infallible;
36
37    fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
38        let stream: StreamProperties = context.extract()?;
39        Ok(Authentication::new(stream))
40    }
41}
42
43impl FromContext<RequestContext> for Authentication {
44    type Error = Infallible;
45
46    fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
47        let stream: StreamProperties = context.extract()?;
48        Ok(Authentication::new(stream))
49    }
50}
51
52impl<C> FromContext<ResponseContext<C>> for Authentication {
53    type Error = Infallible;
54
55    fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
56        let stream: StreamProperties = context.extract()?;
57        Ok(Authentication::new(stream))
58    }
59}
60
61impl Authentication {
62    pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
63        Self {
64            property_accessor: Box::new(property_accessor),
65        }
66    }
67}
68
69impl AuthenticationHandler for Authentication {
70    fn authentication(&self) -> Option<AuthenticationData> {
71        let bytes = self
72            .property_accessor
73            .read_property(AUTHENTICATION_PROPERTY)?;
74        AuthenticationStreamSerializer::deserialize(bytes.as_slice())
75    }
76
77    fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
78        let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
79        self.property_accessor
80            .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
81    }
82}
83
84#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
85/// The data regarding the authentication.
86pub struct AuthenticationData {
87    /// The main value used for authenticated.
88    pub principal: Option<String>,
89    /// The client id associated to the authentication method used.
90    pub client_id: Option<String>,
91    /// The name id associated to the authentication method used.
92    pub client_name: Option<String>,
93    /// Additional properties of the authentication.
94    pub properties: Value,
95}
96
97impl AuthenticationData {
98    pub fn new<K: IntoValue>(
99        principal: Option<String>,
100        client_id: Option<String>,
101        client_name: Option<String>,
102        properties: K,
103    ) -> Self {
104        Self {
105            principal,
106            client_id,
107            client_name,
108            properties: properties.into_value(),
109        }
110    }
111}
112
113impl AuthenticationBinding for AuthenticationData {
114    fn client_id(&self) -> Option<String> {
115        self.client_id.clone()
116    }
117
118    fn client_name(&self) -> Option<String> {
119        self.client_name.clone()
120    }
121
122    fn principal(&self) -> Option<String> {
123        self.principal.clone()
124    }
125
126    fn properties(&self) -> Option<Value> {
127        Some(self.properties.clone())
128    }
129
130    fn properties_map(&self) -> HashMap<String, Value> {
131        self.properties()
132            .as_ref()
133            .and_then(Value::as_object)
134            .cloned()
135            .unwrap_or_default()
136    }
137}
138
139/// Serializes and deserializes Authentication objects so that can be propagated between policies.
140/// The chosen serialization format is MessagePack. Using a cross-language format allows to
141/// propagate the object between filters that were coded in any language
142struct AuthenticationStreamSerializer;
143
144impl AuthenticationStreamSerializer {
145    pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
146        match rmp_serde::decode::from_read(bytes) {
147            Ok(authentication) => Some(authentication),
148            Err(err) => {
149                warn!("Unexpected error deserializing Authentication object: {err}");
150                None
151            }
152        }
153    }
154
155    pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
156        let mut buf = Vec::new();
157        let result = authentication.serialize(&mut Serializer::new(&mut buf));
158
159        match result {
160            Ok(_) => Some(buf),
161            Err(err) => {
162                warn!("Unexpected error serializing Authentication object: {err}");
163                None
164            }
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::cell::RefCell;
172    use std::collections::HashMap;
173
174    use super::*;
175    use classy::proxy_wasm::types::Bytes;
176
177    const KEY_1: &str = "key1";
178    const KEY_2: &str = "key2";
179
180    const VALUE: &str = "value2";
181
182    const PRINCIPAL: &str = "principal";
183    const CLIENT_ID: &str = "client_id";
184    const CLIENT_NAME: &str = "client_name";
185
186    #[derive(Default)]
187    struct MockPropertyAccessor {
188        properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
189    }
190
191    impl PropertyAccessor for MockPropertyAccessor {
192        fn read_property(&self, path: &[&str]) -> Option<Bytes> {
193            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
194            self.properties
195                .take()
196                .get(&path)
197                .cloned()
198                .unwrap_or_default()
199        }
200
201        fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
202            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
203            let bytes = value.map(Bytes::from);
204            self.properties.borrow_mut().insert(path.to_vec(), bytes);
205        }
206    }
207
208    #[test]
209    fn serialize_and_deserialize_authentication_to_bytes() {
210        let auth = create_authentication();
211        let property_accessor = MockPropertyAccessor::default();
212        let auth_handler = Authentication::new(property_accessor);
213
214        auth_handler.set_authentication(Some(&auth));
215        let auth = auth_handler.authentication();
216
217        assert_authentication(auth.clone());
218        match auth.unwrap().properties {
219            Value::Object(obj) => assert_eq!(obj.len(), 2),
220            _ => panic!(),
221        }
222    }
223
224    #[test]
225    fn handler_get_empty() {
226        let property_accessor = MockPropertyAccessor::default();
227        let auth_handler = Authentication::new(property_accessor);
228
229        let auth = auth_handler.authentication();
230
231        assert!(auth.is_none())
232    }
233
234    #[test]
235    fn handler_new_authentication_creates_auth_when_no_previous_data() {
236        let property_accessor = MockPropertyAccessor::default();
237        let auth_handler = Authentication::new(property_accessor);
238
239        let new_auth = AuthenticationData::new(
240            Some(PRINCIPAL.to_string()),
241            Some(CLIENT_ID.to_string()),
242            Some(CLIENT_NAME.to_string()),
243            HashMap::from([
244                (KEY_1.to_string(), Value::Bool(true)),
245                (KEY_2.to_string(), Value::String(VALUE.to_string())),
246            ]),
247        );
248
249        auth_handler.set_authentication(Some(&new_auth));
250
251        let auth = auth_handler.authentication();
252
253        assert_authentication(auth.clone());
254        assert_eq!(new_auth, auth.unwrap());
255    }
256
257    fn assert_authentication(auth: Option<AuthenticationData>) {
258        assert!(auth.is_some());
259        let unwrapped = auth.unwrap();
260        assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
261        assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
262        assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
263        assert_eq!(
264            unwrapped.properties.as_object().unwrap().get(KEY_1),
265            Some(&Value::Bool(true))
266        );
267        assert_eq!(
268            unwrapped.properties.as_object().unwrap().get(KEY_2),
269            Some(&Value::String(VALUE.to_string()))
270        );
271    }
272
273    fn create_authentication() -> AuthenticationData {
274        AuthenticationData {
275            principal: Some(PRINCIPAL.to_string()),
276            client_id: Some(CLIENT_ID.to_string()),
277            client_name: Some(CLIENT_NAME.to_string()),
278            properties: HashMap::from([
279                (KEY_1.to_string(), Value::Bool(true)),
280                (KEY_2.to_string(), Value::String(VALUE.to_string())),
281            ])
282            .into_value(),
283        }
284    }
285}