1#![deny(warnings)]
4#![deny(clippy::pedantic, clippy::unwrap_used)]
5#![deny(missing_docs)]
6use std::{cmp::Ordering, str::FromStr};
7
8use axum::{
9 http::{HeaderMap, StatusCode, header::ToStrError},
10 response::{IntoResponse, Response},
11};
12use mediatype::{MediaType, MediaTypeError, MediaTypeList, Name, ReadParams, names::_STAR};
13
14#[derive(Debug)]
16pub enum AcceptRejection {
17 InvalidHeader(ToStrError),
19 InvalidMediaType(usize, MediaTypeError),
21 InvalidQ(usize, <f64 as FromStr>::Err),
23 NoSupportedMediaTypeFound,
25}
26
27impl AcceptRejection {
28 #[must_use]
30 pub fn status_and_message(&self) -> (StatusCode, String) {
31 match self {
32 Self::InvalidHeader(e) => (
33 StatusCode::BAD_REQUEST,
34 format!("Invalid accept header: {e}"),
35 ),
36 Self::InvalidMediaType(i, e) => (
37 StatusCode::BAD_REQUEST,
38 format!("Invalid media type in accept header at index {i}: {e}"),
39 ),
40 Self::InvalidQ(i, e) => (
41 StatusCode::BAD_REQUEST,
42 format!("Invalid q parameter in accept header at index {i}: {e}"),
43 ),
44 Self::NoSupportedMediaTypeFound => (
45 StatusCode::NOT_ACCEPTABLE,
46 "Accept header does not contain supported media types".to_string(),
47 ),
48 }
49 }
50}
51
52impl IntoResponse for AcceptRejection {
53 fn into_response(self) -> Response {
54 self.status_and_message().into_response()
55 }
56}
57
58pub fn parse_mediatypes(headers: &HeaderMap) -> Result<Vec<MediaType<'_>>, AcceptRejection> {
64 let accept_header = headers
65 .get("accept")
66 .map(|header| header.to_str())
67 .transpose()
68 .map_err(AcceptRejection::InvalidHeader)?
69 .unwrap_or_default();
70
71 let Some(q_name) = Name::new("q") else {
72 unreachable!()
73 };
74
75 let mut list = MediaTypeList::new(accept_header)
76 .enumerate()
77 .map(|(i, mt)| match mt {
78 Ok(mt) => Ok(match mt.get_param(q_name) {
80 Some(q_str) => {
81 let q: f64 = q_str
82 .as_str()
83 .parse::<f64>()
84 .map_err(|e| AcceptRejection::InvalidQ(i, e))?
85 .clamp(0.0, 1.0);
86
87 #[allow(clippy::cast_possible_truncation)]
89 #[allow(clippy::cast_sign_loss)]
90 ((q * 1000.0) as u16, mt)
91 }
92 None => (1000, mt),
93 }),
94 Err(e) => Err(AcceptRejection::InvalidMediaType(i, e)),
95 })
96 .collect::<Result<Vec<(u16, MediaType)>, AcceptRejection>>()?;
97
98 list.sort_by(|(a_q, a_mt), (b_q, b_mt)| {
99 if a_q == b_q {
100 if (a_mt.ty, a_mt.subty) == (_STAR, _STAR) {
104 return Ordering::Greater;
105 } else if (b_mt.ty, b_mt.subty) == (_STAR, _STAR) {
106 return Ordering::Less;
107 }
108
109 if a_mt.subty != b_mt.subty {
111 if a_mt.subty == _STAR {
112 return Ordering::Greater;
113 } else if b_mt.subty == _STAR {
114 return Ordering::Less;
115 }
116 }
117 }
118
119 b_q.cmp(a_q)
120 });
121
122 Ok(list.into_iter().map(|(_, mt)| mt).collect())
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use mediatype::media_type;
129
130 #[test]
131 fn test_parse_mediatype_invisible_ascii() {
132 let mut headers = HeaderMap::new();
133 headers.insert("accept", " ".parse().unwrap()); match parse_mediatypes(&headers) {
135 Err(AcceptRejection::InvalidHeader(_)) => {}
136 _ => panic!("expected invalid header rejection"),
137 }
138 }
139
140 #[test]
141 fn test_parse_mediatype_invalid_media_type() {
142 let mut headers = HeaderMap::new();
143 headers.insert("accept", "lol".parse().unwrap());
144 match parse_mediatypes(&headers) {
145 Err(AcceptRejection::InvalidMediaType(i, _)) => assert_eq!(i, 0),
146 _ => panic!("expected invalid media type rejection"),
147 }
148 }
149
150 #[test]
151 fn test_parse_mediatype_invalid_q() {
152 let mut headers = HeaderMap::new();
153 headers.insert(
154 "accept",
155 "text/plain,application/json;q=lol".parse().unwrap(),
156 );
157 match parse_mediatypes(&headers) {
158 Err(AcceptRejection::InvalidQ(i, _)) => assert_eq!(i, 1),
159 _ => panic!("expected invalid q rejection"),
160 }
161 }
162
163 #[test]
164 fn test_parse_mediatype_valid_types() {
165 let mut headers = HeaderMap::new();
166 headers.insert("accept", "text/plain".parse().unwrap());
167 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
168 assert_eq!(vec![media_type!(TEXT / PLAIN)], list);
169
170 let mut headers = HeaderMap::new();
171 headers.insert("accept", "text/plain,application/json".parse().unwrap());
172 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
173 assert_eq!(
174 vec![media_type!(TEXT / PLAIN), media_type!(APPLICATION / JSON)],
175 list
176 );
177
178 let mut headers = HeaderMap::new();
179 headers.insert(
180 "accept",
181 "text/plain,application/json;q=0.9".parse().unwrap(),
182 );
183 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
184 assert_eq!(2, list.len());
185 assert_eq!(media_type!(TEXT / PLAIN), list[0]);
186 assert_eq!(media_type!(APPLICATION / JSON), list[1].essence());
187 }
188
189 #[test]
190 fn test_parse_mediatype_order() {
191 let mut headers = HeaderMap::new();
192 headers.insert(
193 "accept",
194 "text/plain;q=0.9,application/json".parse().unwrap(),
195 );
196 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
197 assert_eq!(2, list.len());
198 assert_eq!(media_type!(APPLICATION / JSON), list[0]);
199 assert_eq!(media_type!(TEXT / PLAIN), list[1].essence());
200
201 let mut headers = HeaderMap::new();
202 headers.insert(
203 "accept",
204 "text/*,text/plain,application/json".parse().unwrap(),
205 );
206 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
207 assert_eq!(
208 vec![
209 media_type!(TEXT / PLAIN),
210 media_type!(APPLICATION / JSON),
211 media_type!(TEXT / _STAR)
212 ],
213 list
214 );
215
216 let mut headers = HeaderMap::new();
217 headers.insert(
218 "accept",
219 "*/*,text/*,text/plain,application/json".parse().unwrap(),
220 );
221 let list = parse_mediatypes(&headers).expect("Accept header should've parsed correctly");
222 assert_eq!(
223 vec![
224 media_type!(TEXT / PLAIN),
225 media_type!(APPLICATION / JSON),
226 media_type!(TEXT / _STAR),
227 media_type!(_STAR / _STAR)
228 ],
229 list
230 );
231 }
232}