axum_accept_shared/
lib.rs

1//! This crate contains shared types and functions used by both axum-accept
2//! and axum-accept-derive.
3#![deny(warnings)]
4#![deny(clippy::pedantic, clippy::unwrap_used)]
5#![deny(missing_docs)]
6use std::{cmp::Ordering, fmt::Display, str::FromStr};
7
8use axum::{
9    http::{HeaderMap, StatusCode, header::ToStrError},
10    response::{IntoResponse, Response},
11};
12use mediatype::{MediaType, MediaTypeError, MediaTypeList, Name, ReadParams, names::_STAR};
13
14/// The error type returned in the `FromRequestParts` implementations.
15#[derive(Debug)]
16pub enum AcceptRejection {
17    /// The header could not be converted to a &str.
18    InvalidHeader(ToStrError),
19    /// The media type at index .0 could not be parsed.
20    InvalidMediaType(usize, MediaTypeError),
21    /// Invalid q parameter
22    InvalidQ(usize, <f64 as FromStr>::Err),
23    /// No supported media type was found.
24    NoSupportedMediaTypeFound,
25}
26
27impl AcceptRejection {
28    /// Get the status and message for an error.
29    #[must_use]
30    pub fn status_and_message(&self) -> (StatusCode, String) {
31        match self {
32            Self::InvalidHeader(e) => (
33                StatusCode::BAD_REQUEST,
34                format!("Invalid accept header: {e}"),
35            ),
36            Self::InvalidMediaType(i, e) => (
37                StatusCode::BAD_REQUEST,
38                format!("Invalid media type in accept header at index {i}: {e}"),
39            ),
40            Self::InvalidQ(i, e) => (
41                StatusCode::BAD_REQUEST,
42                format!("Invalid q parameter in accept header at index {i}: {e}"),
43            ),
44            Self::NoSupportedMediaTypeFound => (
45                StatusCode::NOT_ACCEPTABLE,
46                "Accept header does not contain supported media types".to_string(),
47            ),
48        }
49    }
50}
51
52impl IntoResponse for AcceptRejection {
53    fn into_response(self) -> Response {
54        self.status_and_message().into_response()
55    }
56}
57
58impl Display for AcceptRejection {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let (_, message) = self.status_and_message();
61        write!(f, "{message}")
62    }
63}
64
65impl std::error::Error for AcceptRejection {}
66
67/// Parse and process the media types from the accept header.
68///
69/// # Errors
70///
71/// Returns an error if the accept header is invalid or no match was found.
72pub fn parse_mediatypes(headers: &HeaderMap) -> Result<Vec<MediaType<'_>>, AcceptRejection> {
73    let accept_header = headers
74        .get("accept")
75        .map(|header| header.to_str())
76        .transpose()
77        .map_err(AcceptRejection::InvalidHeader)?
78        .unwrap_or_default();
79
80    let Some(q_name) = Name::new("q") else {
81        unreachable!()
82    };
83
84    let mut list = MediaTypeList::new(accept_header)
85        .enumerate()
86        .map(|(i, mt)| match mt {
87            // validate q parameter and add it as u16 for sorting
88            Ok(mt) => Ok(match mt.get_param(q_name) {
89                Some(q_str) => {
90                    let q: f64 = q_str
91                        .as_str()
92                        .parse::<f64>()
93                        .map_err(|e| AcceptRejection::InvalidQ(i, e))?
94                        .clamp(0.0, 1.0);
95
96                    // q is clamped to 0.0-1.0 so nothing can happen here
97                    #[allow(clippy::cast_possible_truncation)]
98                    #[allow(clippy::cast_sign_loss)]
99                    ((q * 1000.0) as u16, mt)
100                }
101                None => (1000, mt),
102            }),
103            Err(e) => Err(AcceptRejection::InvalidMediaType(i, e)),
104        })
105        .collect::<Result<Vec<(u16, MediaType)>, AcceptRejection>>()?;
106
107    list.sort_by(|(a_q, a_mt), (b_q, b_mt)| {
108        let ord = b_q.cmp(a_q);
109        match ord {
110            Ordering::Less | Ordering::Greater => ord,
111            Ordering::Equal => {
112                // both have the same q, order by specificity
113
114                // is one of them */*? these come last
115                if (a_mt.ty, a_mt.subty) == (_STAR, _STAR) {
116                    return Ordering::Greater;
117                } else if (b_mt.ty, b_mt.subty) == (_STAR, _STAR) {
118                    return Ordering::Less;
119                }
120
121                // now check the subtype
122                if a_mt.subty != b_mt.subty {
123                    if a_mt.subty == _STAR {
124                        return Ordering::Greater;
125                    } else if b_mt.subty == _STAR {
126                        return Ordering::Less;
127                    }
128                }
129
130                Ordering::Equal
131            }
132        }
133    });
134
135    Ok(list.into_iter().map(|(_, mt)| mt).collect())
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use mediatype::media_type;
142
143    #[test]
144    fn test_parse_mediatype_invisible_ascii() {
145        let mut headers = HeaderMap::new();
146        headers.insert("accept", "‎ ".parse().unwrap()); // invisible ascii is verboten
147        match parse_mediatypes(&headers) {
148            Err(AcceptRejection::InvalidHeader(_)) => {}
149            _ => panic!("expected invalid header rejection"),
150        }
151    }
152
153    #[test]
154    fn test_parse_mediatype_invalid_media_type() {
155        let mut headers = HeaderMap::new();
156        headers.insert("accept", "lol".parse().unwrap());
157        match parse_mediatypes(&headers) {
158            Err(AcceptRejection::InvalidMediaType(i, _)) => assert_eq!(i, 0),
159            _ => panic!("expected invalid media type rejection"),
160        }
161    }
162
163    #[test]
164    fn test_parse_mediatype_invalid_q() {
165        let mut headers = HeaderMap::new();
166        headers.insert(
167            "accept",
168            "text/plain,application/json;q=lol".parse().unwrap(),
169        );
170        match parse_mediatypes(&headers) {
171            Err(AcceptRejection::InvalidQ(i, _)) => assert_eq!(i, 1),
172            _ => panic!("expected invalid q rejection"),
173        }
174    }
175
176    #[test]
177    fn test_parse_mediatype_valid_types() {
178        let mut headers = HeaderMap::new();
179        headers.insert("accept", "text/plain".parse().unwrap());
180        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
181        assert_eq!(vec![media_type!(TEXT / PLAIN)], list);
182
183        let mut headers = HeaderMap::new();
184        headers.insert("accept", "text/plain,application/json".parse().unwrap());
185        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
186        assert_eq!(
187            vec![media_type!(TEXT / PLAIN), media_type!(APPLICATION / JSON)],
188            list
189        );
190
191        let mut headers = HeaderMap::new();
192        headers.insert(
193            "accept",
194            "text/plain,application/json;q=0.9".parse().unwrap(),
195        );
196        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
197        assert_eq!(2, list.len());
198        assert_eq!(media_type!(TEXT / PLAIN), list[0]);
199        assert_eq!(media_type!(APPLICATION / JSON), list[1].essence());
200    }
201
202    #[test]
203    fn test_parse_mediatype_order() {
204        let mut headers = HeaderMap::new();
205        headers.insert(
206            "accept",
207            "text/plain;q=0.9,application/json".parse().unwrap(),
208        );
209        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
210        assert_eq!(2, list.len());
211        assert_eq!(media_type!(APPLICATION / JSON), list[0]);
212        assert_eq!(media_type!(TEXT / PLAIN), list[1].essence());
213
214        let mut headers = HeaderMap::new();
215        headers.insert(
216            "accept",
217            "text/*,text/plain,application/json".parse().unwrap(),
218        );
219        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
220        assert_eq!(
221            vec![
222                media_type!(TEXT / PLAIN),
223                media_type!(APPLICATION / JSON),
224                media_type!(TEXT / _STAR)
225            ],
226            list
227        );
228
229        let mut headers = HeaderMap::new();
230        headers.insert(
231            "accept",
232            "*/*,text/*,text/plain,application/json".parse().unwrap(),
233        );
234        let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
235        assert_eq!(
236            vec![
237                media_type!(TEXT / PLAIN),
238                media_type!(APPLICATION / JSON),
239                media_type!(TEXT / _STAR),
240                media_type!(_STAR / _STAR)
241            ],
242            list
243        );
244    }
245}