1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::OpenAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, OpenAuthError>;
20pub type EndpointFuture<'a> =
21 Pin<Box<dyn Future<Output = Result<ApiResponse, OpenAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23 Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25 Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, OpenAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27 dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Clone)]
31pub struct EndpointMiddleware {
32 pub handler: EndpointMiddlewareHandler,
33}
34
35impl EndpointMiddleware {
36 pub fn new<F>(handler: F) -> Self
37 where
38 F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
39 + Send
40 + Sync
41 + 'static,
42 {
43 Self {
44 handler: Arc::new(handler),
45 }
46 }
47}
48
49#[derive(Clone, Default)]
50pub struct AuthEndpointOptions {
51 pub operation_id: Option<String>,
52 pub allowed_media_types: Vec<String>,
53 pub body_schema: Option<BodySchema>,
54 pub middlewares: Vec<EndpointMiddleware>,
55 pub openapi: Option<OpenApiOperation>,
56 pub server_only: bool,
57 pub hide_from_openapi: bool,
58 pub bypass_origin_security: bool,
59}
60
61impl AuthEndpointOptions {
62 pub fn new() -> Self {
63 Self::default()
64 }
65
66 #[must_use]
67 pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
68 self.operation_id = Some(operation_id.into());
69 self
70 }
71
72 #[must_use]
73 pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
74 where
75 I: IntoIterator<Item = S>,
76 S: Into<String>,
77 {
78 self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
79 self
80 }
81
82 #[must_use]
83 pub fn body_schema(mut self, schema: BodySchema) -> Self {
84 self.body_schema = Some(schema);
85 self
86 }
87
88 #[must_use]
89 pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
90 self.middlewares.push(middleware);
91 self
92 }
93
94 #[must_use]
95 pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
96 self.openapi = Some(operation);
97 self
98 }
99
100 #[must_use]
101 pub fn server_only(mut self) -> Self {
102 self.server_only = true;
103 self
104 }
105
106 #[must_use]
107 pub fn hide_from_openapi(mut self) -> Self {
108 self.hide_from_openapi = true;
109 self
110 }
111
112 #[must_use]
113 pub fn bypass_origin_security(mut self) -> Self {
114 self.bypass_origin_security = true;
115 self
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum EndpointKind {
121 Sync,
122 Async,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
126pub struct EndpointInfo {
127 pub path: String,
128 pub method: http::Method,
129 pub kind: EndpointKind,
130 pub operation_id: Option<String>,
131 pub allowed_media_types: Vec<String>,
132}
133
134#[derive(Clone)]
135pub struct AuthEndpoint {
136 pub path: String,
137 pub method: http::Method,
138 pub handler: EndpointHandler,
139}
140
141#[derive(Clone)]
142pub struct AsyncAuthEndpoint {
143 pub path: String,
144 pub method: http::Method,
145 pub handler: AsyncEndpointHandler,
146 pub options: AuthEndpointOptions,
147}
148
149impl AsyncAuthEndpoint {
150 pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
151 where
152 F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
153 {
154 Self {
155 path: path.into(),
156 method,
157 handler: Arc::new(handler),
158 options: AuthEndpointOptions::default(),
159 }
160 }
161}
162
163pub fn create_auth_endpoint<F>(
164 path: impl Into<String>,
165 method: http::Method,
166 options: AuthEndpointOptions,
167 handler: F,
168) -> AsyncAuthEndpoint
169where
170 F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
171{
172 AsyncAuthEndpoint {
173 path: path.into(),
174 method,
175 handler: Arc::new(handler),
176 options,
177 }
178}
179
180pub(super) fn validate_async_endpoint_request(
181 endpoint: &AsyncAuthEndpoint,
182 request: &ApiRequest,
183) -> Result<Option<ApiResponse>, OpenAuthError> {
184 if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
185 return Ok(None);
186 }
187
188 let content_type = request
189 .headers()
190 .get(header::CONTENT_TYPE)
191 .and_then(|value| value.to_str().ok())
192 .and_then(|value| value.split(';').next())
193 .map(str::trim)
194 .filter(|value| !value.is_empty());
195
196 if !endpoint.options.allowed_media_types.is_empty() {
197 let Some(content_type) = content_type else {
198 return invalid_request_response(
199 StatusCode::UNSUPPORTED_MEDIA_TYPE,
200 "UNSUPPORTED_MEDIA_TYPE",
201 "Missing Content-Type",
202 )
203 .map(Some);
204 };
205 if !endpoint
206 .options
207 .allowed_media_types
208 .iter()
209 .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
210 {
211 return invalid_request_response(
212 StatusCode::UNSUPPORTED_MEDIA_TYPE,
213 "UNSUPPORTED_MEDIA_TYPE",
214 "Unsupported Content-Type",
215 )
216 .map(Some);
217 }
218 }
219
220 if let Some(schema) = &endpoint.options.body_schema {
221 let body = match parse_request_body::<Value>(request) {
222 Ok(body) => body,
223 Err(error) => {
224 return invalid_request_response(
225 StatusCode::BAD_REQUEST,
226 "INVALID_REQUEST_BODY",
227 &error.to_string(),
228 )
229 .map(Some);
230 }
231 };
232 if let Err(message) = schema.validate(&body) {
233 return invalid_request_response(
234 StatusCode::BAD_REQUEST,
235 "INVALID_REQUEST_BODY",
236 &message,
237 )
238 .map(Some);
239 }
240 }
241
242 Ok(None)
243}
244
245pub(super) async fn run_endpoint_middlewares(
246 context: &AuthContext,
247 endpoint: &AsyncAuthEndpoint,
248 request: &ApiRequest,
249) -> Result<Option<ApiResponse>, OpenAuthError> {
250 for middleware in &endpoint.options.middlewares {
251 if let Some(response) = (middleware.handler)(context, request).await? {
252 return Ok(Some(response));
253 }
254 }
255 Ok(None)
256}
257
258fn invalid_request_response(
259 status: StatusCode,
260 code: &str,
261 message: &str,
262) -> Result<ApiResponse, OpenAuthError> {
263 let body = serde_json::to_vec(&ApiErrorResponse {
264 code: code.to_owned(),
265 message: message.to_owned(),
266 original_message: None,
267 })
268 .map_err(|error| OpenAuthError::Api(error.to_string()))?;
269
270 Response::builder()
271 .status(status)
272 .header(header::CONTENT_TYPE, "application/json")
273 .body(body)
274 .map_err(|error| OpenAuthError::Api(error.to_string()))
275}