matrix_http_rendezvous/
handlers.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use axum::{
16    body::HttpBody,
17    extract::{DefaultBodyLimit, Path, State},
18    http::{
19        header::{CONTENT_TYPE, ETAG, IF_MATCH, IF_NONE_MATCH, LOCATION},
20        StatusCode,
21    },
22    response::{IntoResponse, Response},
23    routing::{get, post},
24    Router, TypedHeader,
25};
26use bytes::Bytes;
27use headers::{ContentType, HeaderName, HeaderValue, IfMatch, IfNoneMatch};
28use tower_http::{
29    cors::{Any, CorsLayer},
30    limit::RequestBodyLimitLayer,
31    set_header::SetResponseHeaderLayer,
32};
33use ulid::Ulid;
34
35use crate::Sessions;
36
37async fn new_session(
38    State(sessions): State<Sessions>,
39    content_type: Option<TypedHeader<ContentType>>,
40    payload: Bytes,
41) -> impl IntoResponse {
42    let content_type =
43        content_type.map_or(mime::APPLICATION_OCTET_STREAM, |TypedHeader(c)| c.into());
44    let (id, session) = sessions.new_session(payload, content_type).await;
45    let headers = session.typed_headers();
46
47    let location = id.to_string();
48    let additional_headers = [(LOCATION, location)];
49    (StatusCode::CREATED, headers, additional_headers)
50}
51
52async fn delete_session(State(sessions): State<Sessions>, Path(id): Path<Ulid>) -> StatusCode {
53    if sessions.delete_session(id).await {
54        StatusCode::NO_CONTENT
55    } else {
56        StatusCode::NOT_FOUND
57    }
58}
59
60async fn update_session(
61    State(sessions): State<Sessions>,
62    Path(id): Path<Ulid>,
63    content_type: Option<TypedHeader<ContentType>>,
64    if_match: Option<TypedHeader<IfMatch>>,
65    payload: Bytes,
66) -> Response {
67    if let Some(mut session) = sessions.get_session_mut(id).await {
68        if let Some(TypedHeader(if_match)) = if_match {
69            if !if_match.precondition_passes(&session.etag()) {
70                return (StatusCode::PRECONDITION_FAILED, session.typed_headers()).into_response();
71            }
72        }
73
74        let content_type =
75            content_type.map_or(mime::APPLICATION_OCTET_STREAM, |TypedHeader(c)| c.into());
76
77        session.update(payload, content_type);
78        (StatusCode::ACCEPTED, session.typed_headers()).into_response()
79    } else {
80        StatusCode::NOT_FOUND.into_response()
81    }
82}
83
84async fn get_session(
85    State(sessions): State<Sessions>,
86    Path(id): Path<Ulid>,
87    if_none_match: Option<TypedHeader<IfNoneMatch>>,
88) -> Response {
89    let session = if let Some(session) = sessions.get_session(id).await {
90        session
91    } else {
92        return StatusCode::NOT_FOUND.into_response();
93    };
94
95    if let Some(TypedHeader(if_none_match)) = if_none_match {
96        if !if_none_match.precondition_passes(&session.etag()) {
97            return (StatusCode::NOT_MODIFIED, session.typed_headers()).into_response();
98        }
99    }
100
101    (
102        StatusCode::OK,
103        session.typed_headers(),
104        TypedHeader(session.content_type()),
105        session.data(),
106    )
107        .into_response()
108}
109
110#[must_use]
111pub fn router<B>(prefix: &str, sessions: Sessions, max_bytes: usize) -> Router<(), B>
112where
113    B: HttpBody + Send + 'static,
114    <B as HttpBody>::Data: Send,
115    <B as HttpBody>::Error: std::error::Error + Send + Sync,
116{
117    let router = Router::new()
118        .route("/", post(new_session))
119        .route(
120            "/:id",
121            get(get_session).put(update_session).delete(delete_session),
122        )
123        .layer(DefaultBodyLimit::disable())
124        .layer(RequestBodyLimitLayer::new(max_bytes))
125        .layer(SetResponseHeaderLayer::if_not_present(
126            HeaderName::from_static("x-max-bytes"),
127            HeaderValue::from_str(&max_bytes.to_string())
128                .expect("Could not construct x-max-bytes header value"),
129        ))
130        .layer(
131            CorsLayer::new()
132                .allow_origin(Any)
133                .allow_methods(Any)
134                .allow_headers([CONTENT_TYPE, IF_MATCH, IF_NONE_MATCH])
135                .expose_headers([ETAG, LOCATION, HeaderName::from_static("x-max-bytes")]),
136        );
137
138    Router::new().nest(prefix, router).with_state(sessions)
139}