Skip to main content

openauth_plugins/custom_session/
mod.rs

1//! Custom session plugin.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use http::{header, StatusCode};
8use openauth_core::api::{ApiRequest, ApiResponse};
9use openauth_core::context::AuthContext;
10use openauth_core::error::OpenAuthError;
11use openauth_core::plugin::{AuthPlugin, PluginAfterHookAction, PluginAfterHookFuture};
12use serde::Serialize;
13use serde_json::Value;
14
15pub const UPSTREAM_PLUGIN_ID: &str = "custom-session";
16
17/// Options for the custom session plugin.
18#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
19pub struct CustomSessionOptions {
20    pub should_mutate_list_device_sessions_endpoint: bool,
21}
22
23/// Session payload passed to the custom session handler.
24#[derive(Debug, Clone, PartialEq)]
25pub struct CustomSessionInput {
26    pub user: Value,
27    pub session: Value,
28}
29
30/// Request context available to custom session handlers.
31#[derive(Clone, Copy)]
32pub struct CustomSessionContext<'a> {
33    pub auth_context: &'a AuthContext,
34    pub request: &'a ApiRequest,
35}
36
37pub type CustomSessionFuture<'a> =
38    Pin<Box<dyn Future<Output = Result<Value, OpenAuthError>> + Send + 'a>>;
39
40type CustomSessionHandler = Arc<
41    dyn for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
42        + Send
43        + Sync,
44>;
45
46/// Create a custom session plugin with default options.
47pub fn custom_session<F>(handler: F) -> AuthPlugin
48where
49    F: Fn(CustomSessionInput) -> CustomSessionFuture<'static> + Send + Sync + 'static,
50{
51    custom_session_with_options(handler, CustomSessionOptions::default())
52}
53
54/// Create a custom session plugin.
55pub fn custom_session_with_options<F>(handler: F, options: CustomSessionOptions) -> AuthPlugin
56where
57    F: Fn(CustomSessionInput) -> CustomSessionFuture<'static> + Send + Sync + 'static,
58{
59    custom_session_with_context_and_options(move |input, _context| handler(input), options)
60}
61
62/// Create a custom session plugin whose handler can inspect request context.
63pub fn custom_session_with_context<F>(handler: F) -> AuthPlugin
64where
65    F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
66        + Send
67        + Sync
68        + 'static,
69{
70    custom_session_with_context_and_options(handler, CustomSessionOptions::default())
71}
72
73/// Create a custom session plugin with options and request-aware handler.
74pub fn custom_session_with_context_and_options<F>(
75    handler: F,
76    options: CustomSessionOptions,
77) -> AuthPlugin
78where
79    F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
80        + Send
81        + Sync
82        + 'static,
83{
84    let handler: CustomSessionHandler = Arc::new(handler);
85    let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
86        .with_version(env!("CARGO_PKG_VERSION"))
87        .with_options(serde_json::to_value(options).unwrap_or(Value::Null))
88        .with_async_after_hook("/get-session", {
89            let handler = Arc::clone(&handler);
90            move |context, request, response| {
91                transform_get_session_response(&handler, context, request, response)
92            }
93        });
94
95    if options.should_mutate_list_device_sessions_endpoint {
96        plugin = plugin.with_async_after_hook("/multi-session/list-device-sessions", {
97            let handler = Arc::clone(&handler);
98            move |context, request, response| {
99                transform_list_device_sessions_response(&handler, context, request, response)
100            }
101        });
102    }
103
104    plugin
105}
106
107fn transform_get_session_response<'a>(
108    handler: &CustomSessionHandler,
109    auth_context: &'a AuthContext,
110    request: &'a ApiRequest,
111    response: ApiResponse,
112) -> PluginAfterHookFuture<'a> {
113    let handler = Arc::clone(handler);
114    Box::pin(async move {
115        if response.status() != StatusCode::OK {
116            return Ok(PluginAfterHookAction::Continue(response));
117        }
118        let (parts, body) = response.into_parts();
119        let value = response_json(&body)?;
120        if value.is_null() {
121            return Ok(PluginAfterHookAction::Continue(ApiResponse::from_parts(
122                parts, body,
123            )));
124        }
125        let input = custom_session_input(value)?;
126        let custom = handler(
127            input,
128            CustomSessionContext {
129                auth_context,
130                request,
131            },
132        )
133        .await?;
134        Ok(PluginAfterHookAction::Continue(json_response(
135            parts, &custom,
136        )?))
137    })
138}
139
140fn transform_list_device_sessions_response<'a>(
141    handler: &CustomSessionHandler,
142    auth_context: &'a AuthContext,
143    request: &'a ApiRequest,
144    response: ApiResponse,
145) -> PluginAfterHookFuture<'a> {
146    let handler = Arc::clone(handler);
147    Box::pin(async move {
148        if response.status() != StatusCode::OK {
149            return Ok(PluginAfterHookAction::Continue(response));
150        }
151        let (parts, body) = response.into_parts();
152        let value = response_json(&body)?;
153        let Some(sessions) = value.as_array() else {
154            return Err(OpenAuthError::Api(
155                "custom-session expected list-device-sessions response to be an array".to_owned(),
156            ));
157        };
158        let mut custom_sessions = Vec::with_capacity(sessions.len());
159        for session in sessions {
160            let input = custom_session_input(session.clone())?;
161            custom_sessions.push(
162                handler(
163                    input,
164                    CustomSessionContext {
165                        auth_context,
166                        request,
167                    },
168                )
169                .await?,
170            );
171        }
172        Ok(PluginAfterHookAction::Continue(json_response(
173            parts,
174            &Value::Array(custom_sessions),
175        )?))
176    })
177}
178
179fn custom_session_input(value: Value) -> Result<CustomSessionInput, OpenAuthError> {
180    let Value::Object(mut object) = value else {
181        return Err(OpenAuthError::Api(
182            "custom-session expected session response to be an object".to_owned(),
183        ));
184    };
185    let Some(user) = object.remove("user") else {
186        return Err(OpenAuthError::Api(
187            "custom-session expected session response to include user".to_owned(),
188        ));
189    };
190    let Some(session) = object.remove("session") else {
191        return Err(OpenAuthError::Api(
192            "custom-session expected session response to include session".to_owned(),
193        ));
194    };
195    Ok(CustomSessionInput { user, session })
196}
197
198fn response_json(body: &[u8]) -> Result<Value, OpenAuthError> {
199    serde_json::from_slice(body).map_err(|error| OpenAuthError::Api(error.to_string()))
200}
201
202fn json_response(
203    mut parts: http::response::Parts,
204    body: &Value,
205) -> Result<ApiResponse, OpenAuthError> {
206    parts.headers.insert(
207        header::CONTENT_TYPE,
208        http::HeaderValue::from_static("application/json"),
209    );
210    parts.headers.remove(header::CONTENT_LENGTH);
211    let body = serde_json::to_vec(body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
212    Ok(ApiResponse::from_parts(parts, body))
213}