Skip to main content

pdk_core/policy_context/
authentication.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! Utils to access and share authentication data between filters.
6
7use std::collections::HashMap;
8use std::convert::Infallible;
9
10use crate::policy_context::AUTHENTICATION_PROPERTY;
11use classy::extract::context::FilterContext;
12use classy::extract::{extractability, Extract, FromContext};
13use classy::hl::StreamProperties;
14use classy::stream::PropertyAccessor;
15use log::warn;
16use pdk_script::{AuthenticationBinding, IntoValue, Value};
17use rmp_serde::Serializer;
18use serde::{Deserialize, Serialize};
19
20/// Trait to access and share authentication data between filters.
21pub trait AuthenticationHandler {
22    /// Get the current data regarding authentication.
23    fn authentication(&self) -> Option<AuthenticationData>;
24    /// Replace the authentication data.
25    fn set_authentication(&self, authentication: Option<&AuthenticationData>);
26}
27
28/// Default implementation of the [`AuthenticationHandler`] trait.
29pub struct Authentication {
30    property_accessor: Box<dyn PropertyAccessor>,
31}
32
33impl FromContext<FilterContext, extractability::Transitive> for Authentication {
34    type Error = Infallible;
35
36    fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
37        let stream: StreamProperties = context.extract()?;
38        Ok(Authentication::new(stream))
39    }
40}
41
42impl Authentication {
43    pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
44        Self {
45            property_accessor: Box::new(property_accessor),
46        }
47    }
48}
49
50impl AuthenticationHandler for Authentication {
51    fn authentication(&self) -> Option<AuthenticationData> {
52        let bytes = self
53            .property_accessor
54            .read_property(AUTHENTICATION_PROPERTY)?;
55        AuthenticationStreamSerializer::deserialize(bytes.as_slice())
56    }
57
58    fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
59        let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
60        self.property_accessor
61            .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
62    }
63}
64
65#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
66/// The data regarding the authentication.
67pub struct AuthenticationData {
68    /// The main value used for authenticated.
69    pub principal: Option<String>,
70    /// The client id associated to the authentication method used.
71    pub client_id: Option<String>,
72    /// The name id associated to the authentication method used.
73    pub client_name: Option<String>,
74    /// Additional properties of the authentication.
75    pub properties: Value,
76}
77
78impl AuthenticationData {
79    pub fn new<K: IntoValue>(
80        principal: Option<String>,
81        client_id: Option<String>,
82        client_name: Option<String>,
83        properties: K,
84    ) -> Self {
85        Self {
86            principal,
87            client_id,
88            client_name,
89            properties: properties.into_value(),
90        }
91    }
92}
93
94impl AuthenticationBinding for AuthenticationData {
95    fn client_id(&self) -> Option<String> {
96        self.client_id.clone()
97    }
98
99    fn client_name(&self) -> Option<String> {
100        self.client_name.clone()
101    }
102
103    fn principal(&self) -> Option<String> {
104        self.principal.clone()
105    }
106
107    fn properties(&self) -> Option<Value> {
108        Some(self.properties.clone())
109    }
110
111    fn properties_map(&self) -> HashMap<String, Value> {
112        self.properties()
113            .as_ref()
114            .and_then(Value::as_object)
115            .cloned()
116            .unwrap_or_default()
117    }
118}
119
120/// Serializes and deserializes Authentication objects so that can be propagated between policies.
121/// The chosen serialization format is MessagePack. Using a cross-language format allows to
122/// propagate the object between filters that were coded in any language
123struct AuthenticationStreamSerializer;
124
125impl AuthenticationStreamSerializer {
126    pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
127        match rmp_serde::decode::from_read(bytes) {
128            Ok(authentication) => Some(authentication),
129            Err(err) => {
130                warn!("Unexpected error deserializing Authentication object: {err}");
131                None
132            }
133        }
134    }
135
136    pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
137        let mut buf = Vec::new();
138        let result = authentication.serialize(&mut Serializer::new(&mut buf));
139
140        match result {
141            Ok(_) => Some(buf),
142            Err(err) => {
143                warn!("Unexpected error serializing Authentication object: {err}");
144                None
145            }
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::cell::RefCell;
153    use std::collections::HashMap;
154
155    use super::*;
156    use classy::proxy_wasm::types::Bytes;
157
158    const KEY_1: &str = "key1";
159    const KEY_2: &str = "key2";
160
161    const VALUE: &str = "value2";
162
163    const PRINCIPAL: &str = "principal";
164    const CLIENT_ID: &str = "client_id";
165    const CLIENT_NAME: &str = "client_name";
166
167    #[derive(Default)]
168    struct MockPropertyAccessor {
169        properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
170    }
171
172    impl PropertyAccessor for MockPropertyAccessor {
173        fn read_property(&self, path: &[&str]) -> Option<Bytes> {
174            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
175            self.properties
176                .take()
177                .get(&path)
178                .cloned()
179                .unwrap_or_default()
180        }
181
182        fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
183            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
184            let bytes = value.map(Bytes::from);
185            self.properties.borrow_mut().insert(path.to_vec(), bytes);
186        }
187    }
188
189    #[test]
190    fn serialize_and_deserialize_authentication_to_bytes() {
191        let auth = create_authentication();
192        let property_accessor = MockPropertyAccessor::default();
193        let auth_handler = Authentication::new(property_accessor);
194
195        auth_handler.set_authentication(Some(&auth));
196        let auth = auth_handler.authentication();
197
198        assert_authentication(auth.clone());
199        match auth.unwrap().properties {
200            Value::Object(obj) => assert_eq!(obj.len(), 2),
201            _ => panic!(),
202        }
203    }
204
205    #[test]
206    fn handler_get_empty() {
207        let property_accessor = MockPropertyAccessor::default();
208        let auth_handler = Authentication::new(property_accessor);
209
210        let auth = auth_handler.authentication();
211
212        assert!(auth.is_none())
213    }
214
215    #[test]
216    fn handler_new_authentication_creates_auth_when_no_previous_data() {
217        let property_accessor = MockPropertyAccessor::default();
218        let auth_handler = Authentication::new(property_accessor);
219
220        let new_auth = AuthenticationData::new(
221            Some(PRINCIPAL.to_string()),
222            Some(CLIENT_ID.to_string()),
223            Some(CLIENT_NAME.to_string()),
224            HashMap::from([
225                (KEY_1.to_string(), Value::Bool(true)),
226                (KEY_2.to_string(), Value::String(VALUE.to_string())),
227            ]),
228        );
229
230        auth_handler.set_authentication(Some(&new_auth));
231
232        let auth = auth_handler.authentication();
233
234        assert_authentication(auth.clone());
235        assert_eq!(new_auth, auth.unwrap());
236    }
237
238    fn assert_authentication(auth: Option<AuthenticationData>) {
239        assert!(auth.is_some());
240        let unwrapped = auth.unwrap();
241        assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
242        assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
243        assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
244        assert_eq!(
245            unwrapped.properties.as_object().unwrap().get(KEY_1),
246            Some(&Value::Bool(true))
247        );
248        assert_eq!(
249            unwrapped.properties.as_object().unwrap().get(KEY_2),
250            Some(&Value::String(VALUE.to_string()))
251        );
252    }
253
254    fn create_authentication() -> AuthenticationData {
255        AuthenticationData {
256            principal: Some(PRINCIPAL.to_string()),
257            client_id: Some(CLIENT_ID.to_string()),
258            client_name: Some(CLIENT_NAME.to_string()),
259            properties: HashMap::from([
260                (KEY_1.to_string(), Value::Bool(true)),
261                (KEY_2.to_string(), Value::String(VALUE.to_string())),
262            ])
263            .into_value(),
264        }
265    }
266}