openauth_plugins/custom_session/
mod.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use http::{header, StatusCode};
8use openauth_core::api::{ApiRequest, ApiResponse};
9use openauth_core::context::AuthContext;
10use openauth_core::error::OpenAuthError;
11use openauth_core::plugin::{AuthPlugin, PluginAfterHookAction, PluginAfterHookFuture};
12use serde::Serialize;
13use serde_json::Value;
14
15pub const UPSTREAM_PLUGIN_ID: &str = "custom-session";
16
17#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
19pub struct CustomSessionOptions {
20 pub should_mutate_list_device_sessions_endpoint: bool,
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub struct CustomSessionInput {
26 pub user: Value,
27 pub session: Value,
28}
29
30#[derive(Clone, Copy)]
32pub struct CustomSessionContext<'a> {
33 pub auth_context: &'a AuthContext,
34 pub request: &'a ApiRequest,
35}
36
37pub type CustomSessionFuture<'a> =
38 Pin<Box<dyn Future<Output = Result<Value, OpenAuthError>> + Send + 'a>>;
39
40type CustomSessionHandler = Arc<
41 dyn for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
42 + Send
43 + Sync,
44>;
45
46pub fn custom_session<F>(handler: F) -> AuthPlugin
48where
49 F: Fn(CustomSessionInput) -> CustomSessionFuture<'static> + Send + Sync + 'static,
50{
51 custom_session_with_options(handler, CustomSessionOptions::default())
52}
53
54pub fn custom_session_with_options<F>(handler: F, options: CustomSessionOptions) -> AuthPlugin
56where
57 F: Fn(CustomSessionInput) -> CustomSessionFuture<'static> + Send + Sync + 'static,
58{
59 custom_session_with_context_and_options(move |input, _context| handler(input), options)
60}
61
62pub fn custom_session_with_context<F>(handler: F) -> AuthPlugin
64where
65 F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
66 + Send
67 + Sync
68 + 'static,
69{
70 custom_session_with_context_and_options(handler, CustomSessionOptions::default())
71}
72
73pub fn custom_session_with_context_and_options<F>(
75 handler: F,
76 options: CustomSessionOptions,
77) -> AuthPlugin
78where
79 F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
80 + Send
81 + Sync
82 + 'static,
83{
84 let handler: CustomSessionHandler = Arc::new(handler);
85 let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
86 .with_version(env!("CARGO_PKG_VERSION"))
87 .with_options(serde_json::to_value(options).unwrap_or(Value::Null))
88 .with_async_after_hook("/get-session", {
89 let handler = Arc::clone(&handler);
90 move |context, request, response| {
91 transform_get_session_response(&handler, context, request, response)
92 }
93 });
94
95 if options.should_mutate_list_device_sessions_endpoint {
96 plugin = plugin.with_async_after_hook("/multi-session/list-device-sessions", {
97 let handler = Arc::clone(&handler);
98 move |context, request, response| {
99 transform_list_device_sessions_response(&handler, context, request, response)
100 }
101 });
102 }
103
104 plugin
105}
106
107fn transform_get_session_response<'a>(
108 handler: &CustomSessionHandler,
109 auth_context: &'a AuthContext,
110 request: &'a ApiRequest,
111 response: ApiResponse,
112) -> PluginAfterHookFuture<'a> {
113 let handler = Arc::clone(handler);
114 Box::pin(async move {
115 if response.status() != StatusCode::OK {
116 return Ok(PluginAfterHookAction::Continue(response));
117 }
118 let (parts, body) = response.into_parts();
119 let value = response_json(&body)?;
120 if value.is_null() {
121 return Ok(PluginAfterHookAction::Continue(ApiResponse::from_parts(
122 parts, body,
123 )));
124 }
125 let input = custom_session_input(value)?;
126 let custom = handler(
127 input,
128 CustomSessionContext {
129 auth_context,
130 request,
131 },
132 )
133 .await?;
134 Ok(PluginAfterHookAction::Continue(json_response(
135 parts, &custom,
136 )?))
137 })
138}
139
140fn transform_list_device_sessions_response<'a>(
141 handler: &CustomSessionHandler,
142 auth_context: &'a AuthContext,
143 request: &'a ApiRequest,
144 response: ApiResponse,
145) -> PluginAfterHookFuture<'a> {
146 let handler = Arc::clone(handler);
147 Box::pin(async move {
148 if response.status() != StatusCode::OK {
149 return Ok(PluginAfterHookAction::Continue(response));
150 }
151 let (parts, body) = response.into_parts();
152 let value = response_json(&body)?;
153 let Some(sessions) = value.as_array() else {
154 return Err(OpenAuthError::Api(
155 "custom-session expected list-device-sessions response to be an array".to_owned(),
156 ));
157 };
158 let mut custom_sessions = Vec::with_capacity(sessions.len());
159 for session in sessions {
160 let input = custom_session_input(session.clone())?;
161 custom_sessions.push(
162 handler(
163 input,
164 CustomSessionContext {
165 auth_context,
166 request,
167 },
168 )
169 .await?,
170 );
171 }
172 Ok(PluginAfterHookAction::Continue(json_response(
173 parts,
174 &Value::Array(custom_sessions),
175 )?))
176 })
177}
178
179fn custom_session_input(value: Value) -> Result<CustomSessionInput, OpenAuthError> {
180 let Value::Object(mut object) = value else {
181 return Err(OpenAuthError::Api(
182 "custom-session expected session response to be an object".to_owned(),
183 ));
184 };
185 let Some(user) = object.remove("user") else {
186 return Err(OpenAuthError::Api(
187 "custom-session expected session response to include user".to_owned(),
188 ));
189 };
190 let Some(session) = object.remove("session") else {
191 return Err(OpenAuthError::Api(
192 "custom-session expected session response to include session".to_owned(),
193 ));
194 };
195 Ok(CustomSessionInput { user, session })
196}
197
198fn response_json(body: &[u8]) -> Result<Value, OpenAuthError> {
199 serde_json::from_slice(body).map_err(|error| OpenAuthError::Api(error.to_string()))
200}
201
202fn json_response(
203 mut parts: http::response::Parts,
204 body: &Value,
205) -> Result<ApiResponse, OpenAuthError> {
206 parts.headers.insert(
207 header::CONTENT_TYPE,
208 http::HeaderValue::from_static("application/json"),
209 );
210 parts.headers.remove(header::CONTENT_LENGTH);
211 let body = serde_json::to_vec(body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
212 Ok(ApiResponse::from_parts(parts, body))
213}