axum_accept_shared/
lib.rs

1//! This crate contains shared types and functions used by both axum-accept
2//! and axum-accept-derive.
3#![deny(warnings)]
4#![deny(clippy::pedantic, clippy::unwrap_used)]
5#![deny(missing_docs)]
6use std::{cmp::Ordering, str::FromStr};
7
8use axum::{
9    http::{HeaderMap, StatusCode, header::ToStrError},
10    response::{IntoResponse, Response},
11};
12use mediatype::{MediaType, MediaTypeError, MediaTypeList, Name, ReadParams, names::_STAR};
13
14/// The error type returned in the `FromRequestParts` implementations.
15#[derive(Debug)]
16pub enum AcceptRejection {
17    /// The header could not be converted to a &str.
18    InvalidHeader(ToStrError),
19    /// The media type at index .0 could not be parsed.
20    InvalidMediaType(usize, MediaTypeError),
21    /// Invalid q parameter
22    InvalidQ(usize, <f64 as FromStr>::Err),
23    /// No supported media type was found.
24    NoSupportedMediaTypeFound,
25}
26
27impl AcceptRejection {
28    /// Get the status and message for an error.
29    #[must_use]
30    pub fn status_and_message(&self) -> (StatusCode, String) {
31        match self {
32            Self::InvalidHeader(e) => (
33                StatusCode::BAD_REQUEST,
34                format!("Invalid accept header: {e}"),
35            ),
36            Self::InvalidMediaType(i, e) => (
37                StatusCode::BAD_REQUEST,
38                format!("Invalid media type in accept header at index {i}: {e}"),
39            ),
40            Self::InvalidQ(i, e) => (
41                StatusCode::BAD_REQUEST,
42                format!("Invalid q parameter in accept header at index {i}: {e}"),
43            ),
44            Self::NoSupportedMediaTypeFound => (
45                StatusCode::NOT_ACCEPTABLE,
46                "Accept header does not contain supported media types".to_string(),
47            ),
48        }
49    }
50}
51
52impl IntoResponse for AcceptRejection {
53    fn into_response(self) -> Response {
54        self.status_and_message().into_response()
55    }
56}
57
58/// Parse and process the media types from the accept header.
59///
60/// # Errors
61///
62/// Returns an error if the accept header is invalid or no match was found.
63pub fn parse_mediatypes(headers: &HeaderMap) -> Result<Vec<MediaType<'_>>, AcceptRejection> {
64    let accept_header = headers
65        .get("accept")
66        .map(|header| header.to_str())
67        .transpose()
68        .map_err(AcceptRejection::InvalidHeader)?
69        .unwrap_or_default();
70
71    let Some(q_name) = Name::new("q") else {
72        unreachable!()
73    };
74
75    let mut list = MediaTypeList::new(accept_header)
76        .enumerate()
77        .map(|(i, mt)| match mt {
78            // validate q parameter and add it as u16 for sorting
79            Ok(mt) => Ok(match mt.get_param(q_name) {
80                Some(q_str) => {
81                    let q: f64 = q_str
82                        .as_str()
83                        .parse::<f64>()
84                        .map_err(|e| AcceptRejection::InvalidQ(i, e))?
85                        .clamp(0.0, 1.0);
86
87                    // q is clamped to 0.0-1.0 so nothing can happen here
88                    #[allow(clippy::cast_possible_truncation)]
89                    #[allow(clippy::cast_sign_loss)]
90                    ((q * 1000.0) as u16, mt)
91                }
92                None => (1000, mt),
93            }),
94            Err(e) => Err(AcceptRejection::InvalidMediaType(i, e)),
95        })
96        .collect::<Result<Vec<(u16, MediaType)>, AcceptRejection>>()?;
97
98    list.sort_by(|(a_q, a_mt), (b_q, b_mt)| {
99        if a_q == b_q {
100            // both have the same q, order by specificity
101
102            // is one of them */*? these come last
103            if (a_mt.ty, a_mt.subty) == (_STAR, _STAR) {
104                return Ordering::Greater;
105            } else if (b_mt.ty, b_mt.subty) == (_STAR, _STAR) {
106                return Ordering::Less;
107            }
108
109            // now check the subtype
110            if a_mt.subty != b_mt.subty {
111                if a_mt.subty == _STAR {
112                    return Ordering::Greater;
113                } else if b_mt.subty == _STAR {
114                    return Ordering::Less;
115                }
116            }
117        }
118
119        b_q.cmp(a_q)
120    });
121
122    Ok(list.into_iter().map(|(_, mt)| mt).collect())
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use mediatype::media_type;
129
130    #[test]
131    fn test_parse_mediatype_invisible_ascii() {
132        let mut headers = HeaderMap::new();
133        headers.insert("accept", "‎ ".parse().unwrap()); // invisible ascii is verboten
134        match parse_mediatypes(&headers) {
135            Err(AcceptRejection::InvalidHeader(_)) => {}
136            _ => panic!("expected invalid header rejection"),
137        }
138    }
139
140    #[test]
141    fn test_parse_mediatype_invalid_media_type() {
142        let mut headers = HeaderMap::new();
143        headers.insert("accept", "lol".parse().unwrap());
144        match parse_mediatypes(&headers) {
145            Err(AcceptRejection::InvalidMediaType(i, _)) => assert_eq!(i, 0),
146            _ => panic!("expected invalid media type rejection"),
147        }
148    }
149
150    #[test]
151    fn test_parse_mediatype_invalid_q() {
152        let mut headers = HeaderMap::new();
153        headers.insert(
154            "accept",
155            "text/plain,application/json;q=lol".parse().unwrap(),
156        );
157        match parse_mediatypes(&headers) {
158            Err(AcceptRejection::InvalidQ(i, _)) => assert_eq!(i, 1),
159            _ => panic!("expected invalid q rejection"),
160        }
161    }
162
163    #[test]
164    fn test_parse_mediatype_valid_types() {
165        let mut headers = HeaderMap::new();
166        headers.insert("accept", "text/plain".parse().unwrap());
167        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
168        assert_eq!(vec![media_type!(TEXT / PLAIN)], list);
169
170        let mut headers = HeaderMap::new();
171        headers.insert("accept", "text/plain,application/json".parse().unwrap());
172        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
173        assert_eq!(
174            vec![media_type!(TEXT / PLAIN), media_type!(APPLICATION / JSON)],
175            list
176        );
177
178        let mut headers = HeaderMap::new();
179        headers.insert(
180            "accept",
181            "text/plain,application/json;q=0.9".parse().unwrap(),
182        );
183        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
184        assert_eq!(2, list.len());
185        assert_eq!(media_type!(TEXT / PLAIN), list[0]);
186        assert_eq!(media_type!(APPLICATION / JSON), list[1].essence());
187    }
188
189    #[test]
190    fn test_parse_mediatype_order() {
191        let mut headers = HeaderMap::new();
192        headers.insert(
193            "accept",
194            "text/plain;q=0.9,application/json".parse().unwrap(),
195        );
196        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
197        assert_eq!(2, list.len());
198        assert_eq!(media_type!(APPLICATION / JSON), list[0]);
199        assert_eq!(media_type!(TEXT / PLAIN), list[1].essence());
200
201        let mut headers = HeaderMap::new();
202        headers.insert(
203            "accept",
204            "text/*,text/plain,application/json".parse().unwrap(),
205        );
206        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
207        assert_eq!(
208            vec![
209                media_type!(TEXT / PLAIN),
210                media_type!(APPLICATION / JSON),
211                media_type!(TEXT / _STAR)
212            ],
213            list
214        );
215
216        let mut headers = HeaderMap::new();
217        headers.insert(
218            "accept",
219            "*/*,text/*,text/plain,application/json".parse().unwrap(),
220        );
221        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
222        assert_eq!(
223            vec![
224                media_type!(TEXT / PLAIN),
225                media_type!(APPLICATION / JSON),
226                media_type!(TEXT / _STAR),
227                media_type!(_STAR / _STAR)
228            ],
229            list
230        );
231    }
232}