1use crate::server::Encoding;
17use crate::server::JsonEncoding;
18use crate::server::SmileEncoding;
19use conjure_error::Error;
20use conjure_error::InvalidArgument;
21use http::header::ACCEPT;
22use http::header::CONTENT_TYPE;
23use http::HeaderMap;
24use mediatype::names;
25use mediatype::MediaType;
26use mediatype::MediaTypeList;
27use mediatype::ReadParams;
28
29pub struct ConjureRuntime {
31 encodings: Vec<Box<dyn Encoding + Sync + Send>>,
32}
33
34impl ConjureRuntime {
35 pub fn new() -> Self {
37 Self::builder().build()
38 }
39
40 pub fn builder() -> Builder {
42 Builder { encodings: vec![] }
43 }
44
45 pub fn request_body_encoding(
50 &self,
51 headers: &HeaderMap,
52 ) -> Result<&(dyn Encoding + Sync + Send), Error> {
53 let Some(content_type) = headers.get(CONTENT_TYPE) else {
54 return Err(Error::service_safe(
55 "Content-Type header missing from request",
56 InvalidArgument::new(),
57 ));
58 };
59
60 let content_type = content_type
61 .to_str()
62 .map_err(|e| Error::service_safe(e, InvalidArgument::new()))?;
63 let content_type = MediaType::parse(content_type)
64 .map_err(|e| Error::service_safe(e, InvalidArgument::new()))?;
65
66 self.encodings
67 .iter()
68 .map(|e| &**e)
69 .find(|e| mime_matches(&content_type, *e))
70 .ok_or_else(|| {
71 Error::service_safe(
72 "request Content-Type not accepted by any encoding",
73 InvalidArgument::new(),
74 )
75 })
76 }
77
78 pub fn response_body_encoding(
90 &self,
91 headers: &HeaderMap,
92 ) -> Result<&(dyn Encoding + Sync + Send), Error> {
93 let mut types = headers
94 .get_all(ACCEPT)
95 .iter()
96 .filter_map(|h| h.to_str().ok())
97 .flat_map(|h| MediaTypeList::new(h).filter_map(Result::ok))
98 .enumerate()
99 .map(|(idx, type_)| {
100 let quality = mime_quality(&type_);
101 (type_, quality, idx)
102 })
103 .collect::<Vec<_>>();
104
105 if types.is_empty() {
107 types.push((MediaType::new(names::_STAR, names::_STAR), 1000, 0));
108 }
109
110 types.sort_by(|(a, a_quality, a_idx), (b, b_quality, b_idx)| {
112 mime_specificity(a)
113 .cmp(&mime_specificity(b))
114 .reverse()
115 .then_with(|| a_quality.cmp(b_quality).reverse())
116 .then_with(|| a_idx.cmp(b_idx))
117 });
118
119 self.encodings
120 .iter()
121 .rev()
123 .filter_map(|encoding| {
125 types
126 .iter()
127 .find(|(type_, _, _)| accepts(type_, &**encoding))
128 .filter(|(_, quality, _)| *quality != 0)
130 .map(|(_, quality, idx)| (encoding, quality, idx))
131 })
132 .max_by(|(_, a_quality, a_idx), (_, b_quality, b_idx)| {
134 a_quality
135 .cmp(b_quality)
136 .then_with(|| a_idx.cmp(b_idx).reverse())
137 })
138 .map(|(encoding, _, _)| &**encoding)
139 .ok_or_else(|| {
140 Error::service_safe("request was not acceptable", InvalidArgument::new())
141 })
142 }
143}
144
145impl Default for ConjureRuntime {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151pub struct Builder {
153 encodings: Vec<Box<dyn Encoding + Sync + Send>>,
154}
155
156impl Builder {
157 pub fn encoding(mut self, encoding: impl Encoding + 'static + Sync + Send) -> Self {
161 self.encodings.push(Box::new(encoding));
162 self
163 }
164
165 pub fn build(mut self) -> ConjureRuntime {
167 if self.encodings.is_empty() {
168 self = self.encoding(JsonEncoding).encoding(SmileEncoding);
169 }
170
171 ConjureRuntime {
172 encodings: self.encodings,
173 }
174 }
175}
176
177fn mime_specificity(mime: &MediaType<'_>) -> impl Ord {
192 (
193 mime.ty != names::_STAR,
194 mime.subty != names::_STAR,
195 mime.params.iter().filter(|(k, _)| *k != names::Q).count(),
196 )
197}
198
199fn mime_quality(mime: &MediaType) -> u32 {
220 mime_quality_inner(mime).unwrap_or(1000)
221}
222
223fn mime_quality_inner(mime: &MediaType) -> Option<u32> {
224 let quality = mime.get_param(names::Q)?;
225
226 let mut value = 0;
227 let mut it = quality.as_str().chars();
228 match it.next() {
229 Some('1') => value = 1000,
230 Some('0') => {}
231 Some(_) | None => return None,
232 }
233 match it.next() {
234 Some('.') => {}
235 Some(_) => return None,
236 None => return Some(value),
237 }
238
239 if it.as_str().len() > 3 {
240 return None;
241 }
242
243 for (idx, ch) in it.enumerate() {
244 value += ch.to_digit(10)? * (10u32.pow(2 - idx as u32))
245 }
246
247 Some(value)
248}
249
250fn mime_matches(target_mime: &MediaType, encoding: &dyn Encoding) -> bool {
251 let encoding_type = encoding.content_type();
252 let Some(encoding_mime) = encoding_type
253 .to_str()
254 .ok()
255 .and_then(|t| MediaType::parse(t).ok())
256 else {
257 return false;
258 };
259
260 target_mime.essence() == encoding_mime.essence()
262}
263
264fn accepts(target_mime: &MediaType, encoding: &dyn Encoding) -> bool {
265 let encoding_type = encoding.content_type();
266 let Some(encoding_mime) = encoding_type
267 .to_str()
268 .ok()
269 .and_then(|t| MediaType::parse(t).ok())
270 else {
271 return false;
272 };
273
274 if target_mime.essence() == MediaType::new(names::_STAR, names::_STAR) {
275 return true;
276 }
277
278 if target_mime.ty == encoding_mime.ty && target_mime.subty == names::_STAR {
279 return true;
280 }
281
282 target_mime.essence() == encoding_mime.essence()
284}
285
286#[cfg(test)]
287mod test {
288 use super::*;
289 use http::HeaderValue;
290 use mediatype::MediaTypeBuf;
291
292 #[test]
293 fn request_encodings() {
294 let runtime = ConjureRuntime::builder()
295 .encoding(JsonEncoding)
296 .encoding(SmileEncoding)
297 .build();
298
299 let cases = [
300 (Some("application/json"), Ok("application/json")),
301 (
302 Some("application/json; charset=UTF-8"),
303 Ok("application/json"),
304 ),
305 (
306 Some("application/x-jackson-smile"),
307 Ok("application/x-jackson-smile"),
308 ),
309 (Some("text/plain"), Err(())),
310 (Some("application/*"), Err(())),
311 (Some("*/*"), Err(())),
312 (None, Err(())),
313 ];
314
315 for (content_type, result) in cases {
316 let mut headers = HeaderMap::new();
317 if let Some(content_type) = content_type {
318 headers.insert(CONTENT_TYPE, HeaderValue::from_str(content_type).unwrap());
319 }
320
321 match (result, runtime.request_body_encoding(&headers)) {
322 (Ok(expected), Ok(encoder)) => assert_eq!(expected, encoder.content_type()),
323 (Ok(expected), Err(e)) => panic!("expected Ok({expected}), got Err({e:?})"),
324 (Err(()), Err(_)) => {}
325 (Err(()), Ok(encoding)) => {
326 panic!("expected Err(), got Ok({:?})", encoding.content_type())
327 }
328 }
329 }
330 }
331
332 #[test]
333 fn response_encodings() {
334 let runtime = ConjureRuntime::builder()
335 .encoding(JsonEncoding)
336 .encoding(SmileEncoding)
337 .build();
338
339 let cases = [
340 (None, Ok("application/json")),
341 (Some("*/*"), Ok("application/json")),
342 (
343 Some("*/*, application/json; q=0.5"),
344 Ok("application/x-jackson-smile"),
345 ),
346 (
347 Some("*/*, application/json; q=0"),
348 Ok("application/x-jackson-smile"),
349 ),
350 (
351 Some("application/json; encoding=UTF-8"),
352 Ok("application/json"),
353 ),
354 (
355 Some("application/x-jackson-smile"),
356 Ok("application/x-jackson-smile"),
357 ),
358 (
359 Some("text/plain, application/json, application/x-jackson-smile"),
360 Ok("application/json"),
361 ),
362 (
363 Some("text/plain, application/x-jackson-smile, application/json"),
364 Ok("application/x-jackson-smile"),
365 ),
366 (
367 Some("application/json; q=0.5, application/x-jackson-smile"),
368 Ok("application/x-jackson-smile"),
369 ),
370 (
371 Some("text/html, image/gif, image/jpeg, */*; q=0.2"),
372 Ok("application/json"),
373 ),
374 (
375 Some("text/html, image/gif, image/jpeg, application/*; q=0.2"),
376 Ok("application/json"),
377 ),
378 (Some("text/plain"), Err(())),
379 (Some("application/json; q=0, text/plain"), Err(())),
380 ];
381
382 for (accept, result) in cases {
383 let mut headers = HeaderMap::new();
384 if let Some(accept) = accept {
385 headers.insert(ACCEPT, HeaderValue::from_str(accept).unwrap());
386 }
387
388 match (result, runtime.response_body_encoding(&headers)) {
389 (Ok(expected), Ok(encoding)) => assert_eq!(expected, encoding.content_type()),
390 (Ok(expected), Err(e)) => panic!("expected Ok({expected}), got Err({e:?})"),
391 (Err(()), Err(_)) => {}
392 (Err(()), Ok(encoding)) => {
393 panic!("expected Err(), got Ok({:?})", encoding.content_type())
394 }
395 }
396 }
397 }
398
399 #[test]
400 fn mime_quality() {
401 let cases = [
402 ("1", 1000),
403 ("0", 0),
404 ("1.", 1000),
405 ("0.", 0),
406 ("1.0", 1000),
407 ("0.0", 0),
408 ("1.00", 1000),
409 ("0.00", 0),
410 ("1.000", 1000),
411 ("0.000", 0),
412 ("0.2", 200),
413 ("0.02", 20),
414 ("0.002", 2),
415 ("0.123", 123),
416 ];
417
418 for (input, result) in cases {
419 let mime = format!("foo/bar; q={input}")
420 .parse::<MediaTypeBuf>()
421 .unwrap();
422 assert_eq!(result, super::mime_quality(&mime.to_ref()));
423 }
424 }
425}