Skip to main content

openauth_core/api/
endpoint.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::OpenAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, OpenAuthError>;
20pub type EndpointFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<ApiResponse, OpenAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25    Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, OpenAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Clone)]
31pub struct EndpointMiddleware {
32    pub handler: EndpointMiddlewareHandler,
33}
34
35impl EndpointMiddleware {
36    pub fn new<F>(handler: F) -> Self
37    where
38        F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
39            + Send
40            + Sync
41            + 'static,
42    {
43        Self {
44            handler: Arc::new(handler),
45        }
46    }
47}
48
49#[derive(Clone, Default)]
50pub struct AuthEndpointOptions {
51    pub operation_id: Option<String>,
52    pub allowed_media_types: Vec<String>,
53    pub body_schema: Option<BodySchema>,
54    pub middlewares: Vec<EndpointMiddleware>,
55    pub openapi: Option<OpenApiOperation>,
56    pub server_only: bool,
57    pub hide_from_openapi: bool,
58}
59
60impl AuthEndpointOptions {
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    #[must_use]
66    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
67        self.operation_id = Some(operation_id.into());
68        self
69    }
70
71    #[must_use]
72    pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
73    where
74        I: IntoIterator<Item = S>,
75        S: Into<String>,
76    {
77        self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
78        self
79    }
80
81    #[must_use]
82    pub fn body_schema(mut self, schema: BodySchema) -> Self {
83        self.body_schema = Some(schema);
84        self
85    }
86
87    #[must_use]
88    pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
89        self.middlewares.push(middleware);
90        self
91    }
92
93    #[must_use]
94    pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
95        self.openapi = Some(operation);
96        self
97    }
98
99    #[must_use]
100    pub fn server_only(mut self) -> Self {
101        self.server_only = true;
102        self
103    }
104
105    #[must_use]
106    pub fn hide_from_openapi(mut self) -> Self {
107        self.hide_from_openapi = true;
108        self
109    }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum EndpointKind {
114    Sync,
115    Async,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct EndpointInfo {
120    pub path: String,
121    pub method: http::Method,
122    pub kind: EndpointKind,
123    pub operation_id: Option<String>,
124    pub allowed_media_types: Vec<String>,
125}
126
127#[derive(Clone)]
128pub struct AuthEndpoint {
129    pub path: String,
130    pub method: http::Method,
131    pub handler: EndpointHandler,
132}
133
134#[derive(Clone)]
135pub struct AsyncAuthEndpoint {
136    pub path: String,
137    pub method: http::Method,
138    pub handler: AsyncEndpointHandler,
139    pub options: AuthEndpointOptions,
140}
141
142impl AsyncAuthEndpoint {
143    pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
144    where
145        F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
146    {
147        Self {
148            path: path.into(),
149            method,
150            handler: Arc::new(handler),
151            options: AuthEndpointOptions::default(),
152        }
153    }
154}
155
156pub fn create_auth_endpoint<F>(
157    path: impl Into<String>,
158    method: http::Method,
159    options: AuthEndpointOptions,
160    handler: F,
161) -> AsyncAuthEndpoint
162where
163    F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
164{
165    AsyncAuthEndpoint {
166        path: path.into(),
167        method,
168        handler: Arc::new(handler),
169        options,
170    }
171}
172
173pub(super) fn validate_async_endpoint_request(
174    endpoint: &AsyncAuthEndpoint,
175    request: &ApiRequest,
176) -> Result<Option<ApiResponse>, OpenAuthError> {
177    if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
178        return Ok(None);
179    }
180
181    let content_type = request
182        .headers()
183        .get(header::CONTENT_TYPE)
184        .and_then(|value| value.to_str().ok())
185        .and_then(|value| value.split(';').next())
186        .map(str::trim)
187        .filter(|value| !value.is_empty());
188
189    if !endpoint.options.allowed_media_types.is_empty() {
190        let Some(content_type) = content_type else {
191            return invalid_request_response(
192                StatusCode::UNSUPPORTED_MEDIA_TYPE,
193                "UNSUPPORTED_MEDIA_TYPE",
194                "Missing Content-Type",
195            )
196            .map(Some);
197        };
198        if !endpoint
199            .options
200            .allowed_media_types
201            .iter()
202            .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
203        {
204            return invalid_request_response(
205                StatusCode::UNSUPPORTED_MEDIA_TYPE,
206                "UNSUPPORTED_MEDIA_TYPE",
207                "Unsupported Content-Type",
208            )
209            .map(Some);
210        }
211    }
212
213    if let Some(schema) = &endpoint.options.body_schema {
214        let body = match parse_request_body::<Value>(request) {
215            Ok(body) => body,
216            Err(error) => {
217                return invalid_request_response(
218                    StatusCode::BAD_REQUEST,
219                    "INVALID_REQUEST_BODY",
220                    &error.to_string(),
221                )
222                .map(Some);
223            }
224        };
225        if let Err(message) = schema.validate(&body) {
226            return invalid_request_response(
227                StatusCode::BAD_REQUEST,
228                "INVALID_REQUEST_BODY",
229                &message,
230            )
231            .map(Some);
232        }
233    }
234
235    Ok(None)
236}
237
238pub(super) async fn run_endpoint_middlewares(
239    context: &AuthContext,
240    endpoint: &AsyncAuthEndpoint,
241    request: &ApiRequest,
242) -> Result<Option<ApiResponse>, OpenAuthError> {
243    for middleware in &endpoint.options.middlewares {
244        if let Some(response) = (middleware.handler)(context, request).await? {
245            return Ok(Some(response));
246        }
247    }
248    Ok(None)
249}
250
251fn invalid_request_response(
252    status: StatusCode,
253    code: &str,
254    message: &str,
255) -> Result<ApiResponse, OpenAuthError> {
256    let body = serde_json::to_vec(&ApiErrorResponse {
257        code: code.to_owned(),
258        message: message.to_owned(),
259        original_message: None,
260    })
261    .map_err(|error| OpenAuthError::Api(error.to_string()))?;
262
263    Response::builder()
264        .status(status)
265        .header(header::CONTENT_TYPE, "application/json")
266        .body(body)
267        .map_err(|error| OpenAuthError::Api(error.to_string()))
268}