pdk_core/policy_context/
authentication.rs1use std::convert::Infallible;
8
9use crate::policy_context::AUTHENTICATION_PROPERTY;
10use classy::extract::context::FilterContext;
11use classy::extract::{Extract, FromContext};
12use classy::hl::context::{RequestContext, ResponseContext};
13use classy::hl::StreamProperties;
14use classy::stream::PropertyAccessor;
15use log::warn;
16use pdk_script::{AuthenticationBinding, IntoValue, Value};
17use rmp_serde::Serializer;
18use serde::{Deserialize, Serialize};
19
20pub trait AuthenticationHandler {
22 fn authentication(&self) -> Option<AuthenticationData>;
24 fn set_authentication(&self, authentication: Option<&AuthenticationData>);
26}
27
28pub struct Authentication {
30 property_accessor: Box<dyn PropertyAccessor>,
31}
32
33impl FromContext<FilterContext> for Authentication {
34 type Error = Infallible;
35
36 fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
37 let stream: StreamProperties = context.extract()?;
38 Ok(Authentication::new(stream))
39 }
40}
41
42impl FromContext<RequestContext> for Authentication {
43 type Error = Infallible;
44
45 fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
46 let stream: StreamProperties = context.extract()?;
47 Ok(Authentication::new(stream))
48 }
49}
50
51impl<C> FromContext<ResponseContext<C>> for Authentication {
52 type Error = Infallible;
53
54 fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
55 let stream: StreamProperties = context.extract()?;
56 Ok(Authentication::new(stream))
57 }
58}
59
60impl Authentication {
61 pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
62 Self {
63 property_accessor: Box::new(property_accessor),
64 }
65 }
66}
67
68impl AuthenticationHandler for Authentication {
69 fn authentication(&self) -> Option<AuthenticationData> {
70 let bytes = self
71 .property_accessor
72 .read_property(AUTHENTICATION_PROPERTY)?;
73 AuthenticationStreamSerializer::deserialize(bytes.as_slice())
74 }
75
76 fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
77 let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
78 self.property_accessor
79 .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
80 }
81}
82
83#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct AuthenticationData {
86 pub principal: Option<String>,
88 pub client_id: Option<String>,
90 pub client_name: Option<String>,
92 pub properties: Value,
94}
95
96impl AuthenticationData {
97 pub fn new<K: IntoValue>(
98 principal: Option<String>,
99 client_id: Option<String>,
100 client_name: Option<String>,
101 properties: K,
102 ) -> Self {
103 Self {
104 principal,
105 client_id,
106 client_name,
107 properties: properties.into_value(),
108 }
109 }
110}
111
112impl AuthenticationBinding for AuthenticationData {
113 fn client_id(&self) -> Option<String> {
114 self.client_id.clone()
115 }
116
117 fn client_name(&self) -> Option<String> {
118 self.client_name.clone()
119 }
120
121 fn principal(&self) -> Option<String> {
122 self.principal.clone()
123 }
124
125 fn properties(&self) -> Option<Value> {
126 Some(self.properties.clone())
127 }
128}
129
130struct AuthenticationStreamSerializer;
134
135impl AuthenticationStreamSerializer {
136 pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
137 match rmp_serde::decode::from_read(bytes) {
138 Ok(authentication) => Some(authentication),
139 Err(err) => {
140 warn!("Unexpected error deserializing Authentication object: {err}");
141 None
142 }
143 }
144 }
145
146 pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
147 let mut buf = Vec::new();
148 let result = authentication.serialize(&mut Serializer::new(&mut buf));
149
150 match result {
151 Ok(_) => Some(buf),
152 Err(err) => {
153 warn!("Unexpected error serializing Authentication object: {err}");
154 None
155 }
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use std::cell::RefCell;
163 use std::collections::HashMap;
164
165 use super::*;
166 use classy::proxy_wasm::types::Bytes;
167
168 const KEY_1: &str = "key1";
169 const KEY_2: &str = "key2";
170
171 const VALUE: &str = "value2";
172
173 const PRINCIPAL: &str = "principal";
174 const CLIENT_ID: &str = "client_id";
175 const CLIENT_NAME: &str = "client_name";
176
177 #[derive(Default)]
178 struct MockPropertyAccessor {
179 properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
180 }
181
182 impl PropertyAccessor for MockPropertyAccessor {
183 fn read_property(&self, path: &[&str]) -> Option<Bytes> {
184 let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
185 self.properties
186 .take()
187 .get(&path)
188 .cloned()
189 .unwrap_or_default()
190 }
191
192 fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
193 let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
194 let bytes = value.map(Bytes::from);
195 self.properties.borrow_mut().insert(path.to_vec(), bytes);
196 }
197 }
198
199 #[test]
200 fn serialize_and_deserialize_authentication_to_bytes() {
201 let auth = create_authentication();
202 let property_accessor = MockPropertyAccessor::default();
203 let auth_handler = Authentication::new(property_accessor);
204
205 auth_handler.set_authentication(Some(&auth));
206 let auth = auth_handler.authentication();
207
208 assert_authentication(auth.clone());
209 match auth.unwrap().properties {
210 Value::Object(obj) => assert_eq!(obj.len(), 2),
211 _ => panic!(),
212 }
213 }
214
215 #[test]
216 fn handler_get_empty() {
217 let property_accessor = MockPropertyAccessor::default();
218 let auth_handler = Authentication::new(property_accessor);
219
220 let auth = auth_handler.authentication();
221
222 assert!(auth.is_none())
223 }
224
225 #[test]
226 fn handler_new_authentication_creates_auth_when_no_previous_data() {
227 let property_accessor = MockPropertyAccessor::default();
228 let auth_handler = Authentication::new(property_accessor);
229
230 let new_auth = AuthenticationData::new(
231 Some(PRINCIPAL.to_string()),
232 Some(CLIENT_ID.to_string()),
233 Some(CLIENT_NAME.to_string()),
234 HashMap::from([
235 (KEY_1.to_string(), Value::Bool(true)),
236 (KEY_2.to_string(), Value::String(VALUE.to_string())),
237 ]),
238 );
239
240 auth_handler.set_authentication(Some(&new_auth));
241
242 let auth = auth_handler.authentication();
243
244 assert_authentication(auth.clone());
245 assert_eq!(new_auth, auth.unwrap());
246 }
247
248 fn assert_authentication(auth: Option<AuthenticationData>) {
249 assert!(auth.is_some());
250 let unwrapped = auth.unwrap();
251 assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
252 assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
253 assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
254 assert_eq!(
255 unwrapped.properties.as_object().unwrap().get(KEY_1),
256 Some(&Value::Bool(true))
257 );
258 assert_eq!(
259 unwrapped.properties.as_object().unwrap().get(KEY_2),
260 Some(&Value::String(VALUE.to_string()))
261 );
262 }
263
264 fn create_authentication() -> AuthenticationData {
265 AuthenticationData {
266 principal: Some(PRINCIPAL.to_string()),
267 client_id: Some(CLIENT_ID.to_string()),
268 client_name: Some(CLIENT_NAME.to_string()),
269 properties: HashMap::from([
270 (KEY_1.to_string(), Value::Bool(true)),
271 (KEY_2.to_string(), Value::String(VALUE.to_string())),
272 ])
273 .into_value(),
274 }
275 }
276}