pdk_core/policy_context/
authentication.rs1use std::collections::HashMap;
8use std::convert::Infallible;
9
10use crate::policy_context::AUTHENTICATION_PROPERTY;
11use classy::extract::context::FilterContext;
12use classy::extract::{Extract, FromContext};
13use classy::hl::context::{RequestContext, ResponseContext};
14use classy::hl::StreamProperties;
15use classy::stream::PropertyAccessor;
16use log::warn;
17use pdk_script::{AuthenticationBinding, IntoValue, Value};
18use rmp_serde::Serializer;
19use serde::{Deserialize, Serialize};
20
21pub trait AuthenticationHandler {
23 fn authentication(&self) -> Option<AuthenticationData>;
25 fn set_authentication(&self, authentication: Option<&AuthenticationData>);
27}
28
29pub struct Authentication {
31 property_accessor: Box<dyn PropertyAccessor>,
32}
33
34impl FromContext<FilterContext> for Authentication {
35 type Error = Infallible;
36
37 fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
38 let stream: StreamProperties = context.extract()?;
39 Ok(Authentication::new(stream))
40 }
41}
42
43impl FromContext<RequestContext> for Authentication {
44 type Error = Infallible;
45
46 fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
47 let stream: StreamProperties = context.extract()?;
48 Ok(Authentication::new(stream))
49 }
50}
51
52impl<C> FromContext<ResponseContext<C>> for Authentication {
53 type Error = Infallible;
54
55 fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
56 let stream: StreamProperties = context.extract()?;
57 Ok(Authentication::new(stream))
58 }
59}
60
61impl Authentication {
62 pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
63 Self {
64 property_accessor: Box::new(property_accessor),
65 }
66 }
67}
68
69impl AuthenticationHandler for Authentication {
70 fn authentication(&self) -> Option<AuthenticationData> {
71 let bytes = self
72 .property_accessor
73 .read_property(AUTHENTICATION_PROPERTY)?;
74 AuthenticationStreamSerializer::deserialize(bytes.as_slice())
75 }
76
77 fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
78 let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
79 self.property_accessor
80 .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
81 }
82}
83
84#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub struct AuthenticationData {
87 pub principal: Option<String>,
89 pub client_id: Option<String>,
91 pub client_name: Option<String>,
93 pub properties: Value,
95}
96
97impl AuthenticationData {
98 pub fn new<K: IntoValue>(
99 principal: Option<String>,
100 client_id: Option<String>,
101 client_name: Option<String>,
102 properties: K,
103 ) -> Self {
104 Self {
105 principal,
106 client_id,
107 client_name,
108 properties: properties.into_value(),
109 }
110 }
111}
112
113impl AuthenticationBinding for AuthenticationData {
114 fn client_id(&self) -> Option<String> {
115 self.client_id.clone()
116 }
117
118 fn client_name(&self) -> Option<String> {
119 self.client_name.clone()
120 }
121
122 fn principal(&self) -> Option<String> {
123 self.principal.clone()
124 }
125
126 fn properties(&self) -> Option<Value> {
127 Some(self.properties.clone())
128 }
129
130 fn properties_map(&self) -> HashMap<String, Value> {
131 self.properties()
132 .as_ref()
133 .and_then(Value::as_object)
134 .cloned()
135 .unwrap_or_default()
136 }
137}
138
139struct AuthenticationStreamSerializer;
143
144impl AuthenticationStreamSerializer {
145 pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
146 match rmp_serde::decode::from_read(bytes) {
147 Ok(authentication) => Some(authentication),
148 Err(err) => {
149 warn!("Unexpected error deserializing Authentication object: {err}");
150 None
151 }
152 }
153 }
154
155 pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
156 let mut buf = Vec::new();
157 let result = authentication.serialize(&mut Serializer::new(&mut buf));
158
159 match result {
160 Ok(_) => Some(buf),
161 Err(err) => {
162 warn!("Unexpected error serializing Authentication object: {err}");
163 None
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::cell::RefCell;
172 use std::collections::HashMap;
173
174 use super::*;
175 use classy::proxy_wasm::types::Bytes;
176
177 const KEY_1: &str = "key1";
178 const KEY_2: &str = "key2";
179
180 const VALUE: &str = "value2";
181
182 const PRINCIPAL: &str = "principal";
183 const CLIENT_ID: &str = "client_id";
184 const CLIENT_NAME: &str = "client_name";
185
186 #[derive(Default)]
187 struct MockPropertyAccessor {
188 properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
189 }
190
191 impl PropertyAccessor for MockPropertyAccessor {
192 fn read_property(&self, path: &[&str]) -> Option<Bytes> {
193 let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
194 self.properties
195 .take()
196 .get(&path)
197 .cloned()
198 .unwrap_or_default()
199 }
200
201 fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
202 let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
203 let bytes = value.map(Bytes::from);
204 self.properties.borrow_mut().insert(path.to_vec(), bytes);
205 }
206 }
207
208 #[test]
209 fn serialize_and_deserialize_authentication_to_bytes() {
210 let auth = create_authentication();
211 let property_accessor = MockPropertyAccessor::default();
212 let auth_handler = Authentication::new(property_accessor);
213
214 auth_handler.set_authentication(Some(&auth));
215 let auth = auth_handler.authentication();
216
217 assert_authentication(auth.clone());
218 match auth.unwrap().properties {
219 Value::Object(obj) => assert_eq!(obj.len(), 2),
220 _ => panic!(),
221 }
222 }
223
224 #[test]
225 fn handler_get_empty() {
226 let property_accessor = MockPropertyAccessor::default();
227 let auth_handler = Authentication::new(property_accessor);
228
229 let auth = auth_handler.authentication();
230
231 assert!(auth.is_none())
232 }
233
234 #[test]
235 fn handler_new_authentication_creates_auth_when_no_previous_data() {
236 let property_accessor = MockPropertyAccessor::default();
237 let auth_handler = Authentication::new(property_accessor);
238
239 let new_auth = AuthenticationData::new(
240 Some(PRINCIPAL.to_string()),
241 Some(CLIENT_ID.to_string()),
242 Some(CLIENT_NAME.to_string()),
243 HashMap::from([
244 (KEY_1.to_string(), Value::Bool(true)),
245 (KEY_2.to_string(), Value::String(VALUE.to_string())),
246 ]),
247 );
248
249 auth_handler.set_authentication(Some(&new_auth));
250
251 let auth = auth_handler.authentication();
252
253 assert_authentication(auth.clone());
254 assert_eq!(new_auth, auth.unwrap());
255 }
256
257 fn assert_authentication(auth: Option<AuthenticationData>) {
258 assert!(auth.is_some());
259 let unwrapped = auth.unwrap();
260 assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
261 assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
262 assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
263 assert_eq!(
264 unwrapped.properties.as_object().unwrap().get(KEY_1),
265 Some(&Value::Bool(true))
266 );
267 assert_eq!(
268 unwrapped.properties.as_object().unwrap().get(KEY_2),
269 Some(&Value::String(VALUE.to_string()))
270 );
271 }
272
273 fn create_authentication() -> AuthenticationData {
274 AuthenticationData {
275 principal: Some(PRINCIPAL.to_string()),
276 client_id: Some(CLIENT_ID.to_string()),
277 client_name: Some(CLIENT_NAME.to_string()),
278 properties: HashMap::from([
279 (KEY_1.to_string(), Value::Bool(true)),
280 (KEY_2.to_string(), Value::String(VALUE.to_string())),
281 ])
282 .into_value(),
283 }
284 }
285}