Skip to main content

openauth_core/api/
endpoint.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::OpenAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, OpenAuthError>;
20pub type EndpointFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<ApiResponse, OpenAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25    Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, OpenAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Clone)]
31pub struct EndpointMiddleware {
32    pub handler: EndpointMiddlewareHandler,
33}
34
35impl EndpointMiddleware {
36    pub fn new<F>(handler: F) -> Self
37    where
38        F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
39            + Send
40            + Sync
41            + 'static,
42    {
43        Self {
44            handler: Arc::new(handler),
45        }
46    }
47}
48
49#[derive(Clone, Default)]
50pub struct AuthEndpointOptions {
51    pub operation_id: Option<String>,
52    pub allowed_media_types: Vec<String>,
53    pub body_schema: Option<BodySchema>,
54    pub middlewares: Vec<EndpointMiddleware>,
55    pub openapi: Option<OpenApiOperation>,
56    pub server_only: bool,
57}
58
59impl AuthEndpointOptions {
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    #[must_use]
65    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
66        self.operation_id = Some(operation_id.into());
67        self
68    }
69
70    #[must_use]
71    pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
72    where
73        I: IntoIterator<Item = S>,
74        S: Into<String>,
75    {
76        self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
77        self
78    }
79
80    #[must_use]
81    pub fn body_schema(mut self, schema: BodySchema) -> Self {
82        self.body_schema = Some(schema);
83        self
84    }
85
86    #[must_use]
87    pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
88        self.middlewares.push(middleware);
89        self
90    }
91
92    #[must_use]
93    pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
94        self.openapi = Some(operation);
95        self
96    }
97
98    #[must_use]
99    pub fn server_only(mut self) -> Self {
100        self.server_only = true;
101        self
102    }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum EndpointKind {
107    Sync,
108    Async,
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct EndpointInfo {
113    pub path: String,
114    pub method: http::Method,
115    pub kind: EndpointKind,
116    pub operation_id: Option<String>,
117    pub allowed_media_types: Vec<String>,
118}
119
120#[derive(Clone)]
121pub struct AuthEndpoint {
122    pub path: String,
123    pub method: http::Method,
124    pub handler: EndpointHandler,
125}
126
127#[derive(Clone)]
128pub struct AsyncAuthEndpoint {
129    pub path: String,
130    pub method: http::Method,
131    pub handler: AsyncEndpointHandler,
132    pub options: AuthEndpointOptions,
133}
134
135impl AsyncAuthEndpoint {
136    pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
137    where
138        F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
139    {
140        Self {
141            path: path.into(),
142            method,
143            handler: Arc::new(handler),
144            options: AuthEndpointOptions::default(),
145        }
146    }
147}
148
149pub fn create_auth_endpoint<F>(
150    path: impl Into<String>,
151    method: http::Method,
152    options: AuthEndpointOptions,
153    handler: F,
154) -> AsyncAuthEndpoint
155where
156    F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
157{
158    AsyncAuthEndpoint {
159        path: path.into(),
160        method,
161        handler: Arc::new(handler),
162        options,
163    }
164}
165
166pub(super) fn validate_async_endpoint_request(
167    endpoint: &AsyncAuthEndpoint,
168    request: &ApiRequest,
169) -> Result<Option<ApiResponse>, OpenAuthError> {
170    if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
171        return Ok(None);
172    }
173
174    let content_type = request
175        .headers()
176        .get(header::CONTENT_TYPE)
177        .and_then(|value| value.to_str().ok())
178        .and_then(|value| value.split(';').next())
179        .map(str::trim)
180        .filter(|value| !value.is_empty());
181
182    if !endpoint.options.allowed_media_types.is_empty() {
183        let Some(content_type) = content_type else {
184            return invalid_request_response(
185                StatusCode::UNSUPPORTED_MEDIA_TYPE,
186                "UNSUPPORTED_MEDIA_TYPE",
187                "Missing Content-Type",
188            )
189            .map(Some);
190        };
191        if !endpoint
192            .options
193            .allowed_media_types
194            .iter()
195            .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
196        {
197            return invalid_request_response(
198                StatusCode::UNSUPPORTED_MEDIA_TYPE,
199                "UNSUPPORTED_MEDIA_TYPE",
200                "Unsupported Content-Type",
201            )
202            .map(Some);
203        }
204    }
205
206    if let Some(schema) = &endpoint.options.body_schema {
207        let body = match parse_request_body::<Value>(request) {
208            Ok(body) => body,
209            Err(error) => {
210                return invalid_request_response(
211                    StatusCode::BAD_REQUEST,
212                    "INVALID_REQUEST_BODY",
213                    &error.to_string(),
214                )
215                .map(Some);
216            }
217        };
218        if let Err(message) = schema.validate(&body) {
219            return invalid_request_response(
220                StatusCode::BAD_REQUEST,
221                "INVALID_REQUEST_BODY",
222                &message,
223            )
224            .map(Some);
225        }
226    }
227
228    Ok(None)
229}
230
231pub(super) async fn run_endpoint_middlewares(
232    context: &AuthContext,
233    endpoint: &AsyncAuthEndpoint,
234    request: &ApiRequest,
235) -> Result<Option<ApiResponse>, OpenAuthError> {
236    for middleware in &endpoint.options.middlewares {
237        if let Some(response) = (middleware.handler)(context, request).await? {
238            return Ok(Some(response));
239        }
240    }
241    Ok(None)
242}
243
244fn invalid_request_response(
245    status: StatusCode,
246    code: &str,
247    message: &str,
248) -> Result<ApiResponse, OpenAuthError> {
249    let body = serde_json::to_vec(&ApiErrorResponse {
250        code: code.to_owned(),
251        message: message.to_owned(),
252        original_message: None,
253    })
254    .map_err(|error| OpenAuthError::Api(error.to_string()))?;
255
256    Response::builder()
257        .status(status)
258        .header(header::CONTENT_TYPE, "application/json")
259        .body(body)
260        .map_err(|error| OpenAuthError::Api(error.to_string()))
261}