Skip to main content

openauth_core/api/
endpoint.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::OpenAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, OpenAuthError>;
20pub type EndpointFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<ApiResponse, OpenAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25    Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, OpenAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Clone)]
31pub struct EndpointMiddleware {
32    pub handler: EndpointMiddlewareHandler,
33}
34
35impl EndpointMiddleware {
36    pub fn new<F>(handler: F) -> Self
37    where
38        F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
39            + Send
40            + Sync
41            + 'static,
42    {
43        Self {
44            handler: Arc::new(handler),
45        }
46    }
47}
48
49#[derive(Clone, Default)]
50pub struct AuthEndpointOptions {
51    pub operation_id: Option<String>,
52    pub allowed_media_types: Vec<String>,
53    pub body_schema: Option<BodySchema>,
54    pub middlewares: Vec<EndpointMiddleware>,
55    pub openapi: Option<OpenApiOperation>,
56    pub server_only: bool,
57    pub hide_from_openapi: bool,
58    pub bypass_origin_security: bool,
59}
60
61impl AuthEndpointOptions {
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    #[must_use]
67    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
68        self.operation_id = Some(operation_id.into());
69        self
70    }
71
72    #[must_use]
73    pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
74    where
75        I: IntoIterator<Item = S>,
76        S: Into<String>,
77    {
78        self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
79        self
80    }
81
82    #[must_use]
83    pub fn body_schema(mut self, schema: BodySchema) -> Self {
84        self.body_schema = Some(schema);
85        self
86    }
87
88    #[must_use]
89    pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
90        self.middlewares.push(middleware);
91        self
92    }
93
94    #[must_use]
95    pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
96        self.openapi = Some(operation);
97        self
98    }
99
100    #[must_use]
101    pub fn server_only(mut self) -> Self {
102        self.server_only = true;
103        self
104    }
105
106    #[must_use]
107    pub fn hide_from_openapi(mut self) -> Self {
108        self.hide_from_openapi = true;
109        self
110    }
111
112    #[must_use]
113    pub fn bypass_origin_security(mut self) -> Self {
114        self.bypass_origin_security = true;
115        self
116    }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum EndpointKind {
121    Sync,
122    Async,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
126pub struct EndpointInfo {
127    pub path: String,
128    pub method: http::Method,
129    pub kind: EndpointKind,
130    pub operation_id: Option<String>,
131    pub allowed_media_types: Vec<String>,
132}
133
134#[derive(Clone)]
135pub struct AuthEndpoint {
136    pub path: String,
137    pub method: http::Method,
138    pub handler: EndpointHandler,
139}
140
141#[derive(Clone)]
142pub struct AsyncAuthEndpoint {
143    pub path: String,
144    pub method: http::Method,
145    pub handler: AsyncEndpointHandler,
146    pub options: AuthEndpointOptions,
147}
148
149impl AsyncAuthEndpoint {
150    pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
151    where
152        F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
153    {
154        Self {
155            path: path.into(),
156            method,
157            handler: Arc::new(handler),
158            options: AuthEndpointOptions::default(),
159        }
160    }
161}
162
163pub fn create_auth_endpoint<F>(
164    path: impl Into<String>,
165    method: http::Method,
166    options: AuthEndpointOptions,
167    handler: F,
168) -> AsyncAuthEndpoint
169where
170    F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
171{
172    AsyncAuthEndpoint {
173        path: path.into(),
174        method,
175        handler: Arc::new(handler),
176        options,
177    }
178}
179
180pub(super) fn validate_async_endpoint_request(
181    endpoint: &AsyncAuthEndpoint,
182    request: &ApiRequest,
183) -> Result<Option<ApiResponse>, OpenAuthError> {
184    if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
185        return Ok(None);
186    }
187
188    let content_type = request
189        .headers()
190        .get(header::CONTENT_TYPE)
191        .and_then(|value| value.to_str().ok())
192        .and_then(|value| value.split(';').next())
193        .map(str::trim)
194        .filter(|value| !value.is_empty());
195
196    if !endpoint.options.allowed_media_types.is_empty() {
197        let Some(content_type) = content_type else {
198            return invalid_request_response(
199                StatusCode::UNSUPPORTED_MEDIA_TYPE,
200                "UNSUPPORTED_MEDIA_TYPE",
201                "Missing Content-Type",
202            )
203            .map(Some);
204        };
205        if !endpoint
206            .options
207            .allowed_media_types
208            .iter()
209            .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
210        {
211            return invalid_request_response(
212                StatusCode::UNSUPPORTED_MEDIA_TYPE,
213                "UNSUPPORTED_MEDIA_TYPE",
214                "Unsupported Content-Type",
215            )
216            .map(Some);
217        }
218    }
219
220    if let Some(schema) = &endpoint.options.body_schema {
221        let body = match parse_request_body::<Value>(request) {
222            Ok(body) => body,
223            Err(error) => {
224                return invalid_request_response(
225                    StatusCode::BAD_REQUEST,
226                    "INVALID_REQUEST_BODY",
227                    &error.to_string(),
228                )
229                .map(Some);
230            }
231        };
232        if let Err(message) = schema.validate(&body) {
233            return invalid_request_response(
234                StatusCode::BAD_REQUEST,
235                "INVALID_REQUEST_BODY",
236                &message,
237            )
238            .map(Some);
239        }
240    }
241
242    Ok(None)
243}
244
245pub(super) async fn run_endpoint_middlewares(
246    context: &AuthContext,
247    endpoint: &AsyncAuthEndpoint,
248    request: &ApiRequest,
249) -> Result<Option<ApiResponse>, OpenAuthError> {
250    for middleware in &endpoint.options.middlewares {
251        if let Some(response) = (middleware.handler)(context, request).await? {
252            return Ok(Some(response));
253        }
254    }
255    Ok(None)
256}
257
258fn invalid_request_response(
259    status: StatusCode,
260    code: &str,
261    message: &str,
262) -> Result<ApiResponse, OpenAuthError> {
263    let body = serde_json::to_vec(&ApiErrorResponse {
264        code: code.to_owned(),
265        message: message.to_owned(),
266        original_message: None,
267    })
268    .map_err(|error| OpenAuthError::Api(error.to_string()))?;
269
270    Response::builder()
271        .status(status)
272        .header(header::CONTENT_TYPE, "application/json")
273        .body(body)
274        .map_err(|error| OpenAuthError::Api(error.to_string()))
275}