1#![deny(warnings)]
4#![deny(clippy::pedantic, clippy::unwrap_used)]
5#![deny(missing_docs)]
6use std::{cmp::Ordering, fmt::Display, str::FromStr};
7
8use axum::{
9 http::{HeaderMap, StatusCode, header::ToStrError},
10 response::{IntoResponse, Response},
11};
12use mediatype::{MediaType, MediaTypeError, MediaTypeList, Name, ReadParams, names::_STAR};
13
14#[derive(Debug)]
16pub enum AcceptRejection {
17 InvalidHeader(ToStrError),
19 InvalidMediaType(usize, MediaTypeError),
21 InvalidQ(usize, <f64 as FromStr>::Err),
23 NoSupportedMediaTypeFound,
25}
26
27impl AcceptRejection {
28 #[must_use]
30 pub fn status_and_message(&self) -> (StatusCode, String) {
31 match self {
32 Self::InvalidHeader(e) => (
33 StatusCode::BAD_REQUEST,
34 format!("Invalid accept header: {e}"),
35 ),
36 Self::InvalidMediaType(i, e) => (
37 StatusCode::BAD_REQUEST,
38 format!("Invalid media type in accept header at index {i}: {e}"),
39 ),
40 Self::InvalidQ(i, e) => (
41 StatusCode::BAD_REQUEST,
42 format!("Invalid q parameter in accept header at index {i}: {e}"),
43 ),
44 Self::NoSupportedMediaTypeFound => (
45 StatusCode::NOT_ACCEPTABLE,
46 "Accept header does not contain supported media types".to_string(),
47 ),
48 }
49 }
50}
51
52impl IntoResponse for AcceptRejection {
53 fn into_response(self) -> Response {
54 self.status_and_message().into_response()
55 }
56}
57
58impl Display for AcceptRejection {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 let (_, message) = self.status_and_message();
61 write!(f, "{message}")
62 }
63}
64
65impl std::error::Error for AcceptRejection {}
66
67pub fn parse_mediatypes(headers: &HeaderMap) -> Result<Vec<MediaType<'_>>, AcceptRejection> {
73 let accept_header = headers
74 .get("accept")
75 .map(|header| header.to_str())
76 .transpose()
77 .map_err(AcceptRejection::InvalidHeader)?
78 .unwrap_or_default();
79
80 let Some(q_name) = Name::new("q") else {
81 unreachable!()
82 };
83
84 let mut list = MediaTypeList::new(accept_header)
85 .enumerate()
86 .map(|(i, mt)| match mt {
87 Ok(mt) => Ok(match mt.get_param(q_name) {
89 Some(q_str) => {
90 let q: f64 = q_str
91 .as_str()
92 .parse::<f64>()
93 .map_err(|e| AcceptRejection::InvalidQ(i, e))?
94 .clamp(0.0, 1.0);
95
96 #[allow(clippy::cast_possible_truncation)]
98 #[allow(clippy::cast_sign_loss)]
99 ((q * 1000.0) as u16, mt)
100 }
101 None => (1000, mt),
102 }),
103 Err(e) => Err(AcceptRejection::InvalidMediaType(i, e)),
104 })
105 .collect::<Result<Vec<(u16, MediaType)>, AcceptRejection>>()?;
106
107 list.sort_by(|(a_q, a_mt), (b_q, b_mt)| {
108 let ord = b_q.cmp(a_q);
109 match ord {
110 Ordering::Less | Ordering::Greater => ord,
111 Ordering::Equal => {
112 if (a_mt.ty, a_mt.subty) == (_STAR, _STAR) {
116 return Ordering::Greater;
117 } else if (b_mt.ty, b_mt.subty) == (_STAR, _STAR) {
118 return Ordering::Less;
119 }
120
121 if a_mt.subty != b_mt.subty {
123 if a_mt.subty == _STAR {
124 return Ordering::Greater;
125 } else if b_mt.subty == _STAR {
126 return Ordering::Less;
127 }
128 }
129
130 Ordering::Equal
131 }
132 }
133 });
134
135 Ok(list.into_iter().map(|(_, mt)| mt).collect())
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use mediatype::media_type;
142
143 #[test]
144 fn test_parse_mediatype_invisible_ascii() {
145 let mut headers = HeaderMap::new();
146 headers.insert("accept", " ".parse().unwrap()); match parse_mediatypes(&headers) {
148 Err(AcceptRejection::InvalidHeader(_)) => {}
149 _ => panic!("expected invalid header rejection"),
150 }
151 }
152
153 #[test]
154 fn test_parse_mediatype_invalid_media_type() {
155 let mut headers = HeaderMap::new();
156 headers.insert("accept", "lol".parse().unwrap());
157 match parse_mediatypes(&headers) {
158 Err(AcceptRejection::InvalidMediaType(i, _)) => assert_eq!(i, 0),
159 _ => panic!("expected invalid media type rejection"),
160 }
161 }
162
163 #[test]
164 fn test_parse_mediatype_invalid_q() {
165 let mut headers = HeaderMap::new();
166 headers.insert(
167 "accept",
168 "text/plain,application/json;q=lol".parse().unwrap(),
169 );
170 match parse_mediatypes(&headers) {
171 Err(AcceptRejection::InvalidQ(i, _)) => assert_eq!(i, 1),
172 _ => panic!("expected invalid q rejection"),
173 }
174 }
175
176 #[test]
177 fn test_parse_mediatype_valid_types() {
178 let mut headers = HeaderMap::new();
179 headers.insert("accept", "text/plain".parse().unwrap());
180 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
181 assert_eq!(vec![media_type!(TEXT / PLAIN)], list);
182
183 let mut headers = HeaderMap::new();
184 headers.insert("accept", "text/plain,application/json".parse().unwrap());
185 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
186 assert_eq!(
187 vec![media_type!(TEXT / PLAIN), media_type!(APPLICATION / JSON)],
188 list
189 );
190
191 let mut headers = HeaderMap::new();
192 headers.insert(
193 "accept",
194 "text/plain,application/json;q=0.9".parse().unwrap(),
195 );
196 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
197 assert_eq!(2, list.len());
198 assert_eq!(media_type!(TEXT / PLAIN), list[0]);
199 assert_eq!(media_type!(APPLICATION / JSON), list[1].essence());
200 }
201
202 #[test]
203 fn test_parse_mediatype_order() {
204 let mut headers = HeaderMap::new();
205 headers.insert(
206 "accept",
207 "text/plain;q=0.9,application/json".parse().unwrap(),
208 );
209 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
210 assert_eq!(2, list.len());
211 assert_eq!(media_type!(APPLICATION / JSON), list[0]);
212 assert_eq!(media_type!(TEXT / PLAIN), list[1].essence());
213
214 let mut headers = HeaderMap::new();
215 headers.insert(
216 "accept",
217 "text/*,text/plain,application/json".parse().unwrap(),
218 );
219 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
220 assert_eq!(
221 vec![
222 media_type!(TEXT / PLAIN),
223 media_type!(APPLICATION / JSON),
224 media_type!(TEXT / _STAR)
225 ],
226 list
227 );
228
229 let mut headers = HeaderMap::new();
230 headers.insert(
231 "accept",
232 "*/*,text/*,text/plain,application/json".parse().unwrap(),
233 );
234 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
235 assert_eq!(
236 vec![
237 media_type!(TEXT / PLAIN),
238 media_type!(APPLICATION / JSON),
239 media_type!(TEXT / _STAR),
240 media_type!(_STAR / _STAR)
241 ],
242 list
243 );
244 }
245}