1use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::ElifRequest;
8use crate::response::headers::ElifHeaderValue;
9use crate::response::ElifResponse;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum ContentType {
15 Json,
16 Xml,
17 Html,
18 PlainText,
19 Csv,
20 MessagePack,
21 Yaml,
22 Custom(String),
23}
24
25impl ContentType {
26 pub fn from_mime_type(mime_type: &str) -> Option<Self> {
28 let mime_lower = mime_type.split(';').next()?.trim().to_lowercase();
29 match mime_lower.as_str() {
30 "application/json" => Some(Self::Json),
31 "application/xml" | "text/xml" => Some(Self::Xml),
32 "text/html" => Some(Self::Html),
33 "text/plain" => Some(Self::PlainText),
34 "text/csv" => Some(Self::Csv),
35 "application/msgpack" | "application/x-msgpack" => Some(Self::MessagePack),
36 "application/yaml" | "application/x-yaml" | "text/yaml" => Some(Self::Yaml),
37 _ => Some(Self::Custom(mime_lower)),
38 }
39 }
40
41 pub fn to_mime_type(&self) -> &'static str {
43 match self {
44 Self::Json => "application/json",
45 Self::Xml => "application/xml",
46 Self::Html => "text/html",
47 Self::PlainText => "text/plain",
48 Self::Csv => "text/csv",
49 Self::MessagePack => "application/msgpack",
50 Self::Yaml => "application/yaml",
51 Self::Custom(_) => "application/octet-stream", }
53 }
54
55 pub fn file_extension(&self) -> &'static str {
57 match self {
58 Self::Json => "json",
59 Self::Xml => "xml",
60 Self::Html => "html",
61 Self::PlainText => "txt",
62 Self::Csv => "csv",
63 Self::MessagePack => "msgpack",
64 Self::Yaml => "yaml",
65 Self::Custom(_) => "bin",
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct AcceptValue {
73 pub content_type: ContentType,
74 pub quality: f32,
75 pub params: HashMap<String, String>,
76}
77
78impl AcceptValue {
79 pub fn parse(value: &str) -> Option<Self> {
81 let parts: Vec<&str> = value.split(';').collect();
82 let mime_type = parts.first()?.trim();
83
84 let content_type = ContentType::from_mime_type(mime_type)?;
85 let mut quality = 1.0;
86 let mut params = HashMap::new();
87
88 for param in parts.iter().skip(1) {
90 let param = param.trim();
91 if let Some((key, val)) = param.split_once('=') {
92 let key = key.trim();
93 let val = val.trim();
94
95 if key == "q" {
96 quality = val.parse().unwrap_or(1.0);
97 } else {
98 params.insert(key.to_string(), val.to_string());
99 }
100 }
101 }
102
103 Some(Self {
104 content_type,
105 quality,
106 params,
107 })
108 }
109}
110
111pub struct ContentNegotiationConfig {
113 pub default_content_type: ContentType,
115 pub supported_types: Vec<ContentType>,
117 pub add_vary_header: bool,
119 pub converters: HashMap<
121 ContentType,
122 std::sync::Arc<dyn Fn(&serde_json::Value) -> Result<Vec<u8>, String> + Send + Sync>,
123 >,
124}
125
126impl Default for ContentNegotiationConfig {
127 fn default() -> Self {
128 let mut converters = HashMap::new();
129 converters.insert(
130 ContentType::Json,
131 std::sync::Arc::new(Self::convert_to_json)
132 as std::sync::Arc<
133 dyn Fn(&serde_json::Value) -> Result<Vec<u8>, String> + Send + Sync,
134 >,
135 );
136 converters.insert(
137 ContentType::PlainText,
138 std::sync::Arc::new(Self::convert_to_text)
139 as std::sync::Arc<
140 dyn Fn(&serde_json::Value) -> Result<Vec<u8>, String> + Send + Sync,
141 >,
142 );
143 converters.insert(
144 ContentType::Html,
145 std::sync::Arc::new(Self::convert_to_html)
146 as std::sync::Arc<
147 dyn Fn(&serde_json::Value) -> Result<Vec<u8>, String> + Send + Sync,
148 >,
149 );
150
151 Self {
152 default_content_type: ContentType::Json,
153 supported_types: vec![
154 ContentType::Json,
155 ContentType::PlainText,
156 ContentType::Html,
157 ContentType::Xml,
158 ],
159 add_vary_header: true,
160 converters,
161 }
162 }
163}
164
165impl ContentNegotiationConfig {
166 fn convert_to_json(value: &serde_json::Value) -> Result<Vec<u8>, String> {
168 serde_json::to_vec_pretty(value).map_err(|e| e.to_string())
169 }
170
171 fn convert_to_text(value: &serde_json::Value) -> Result<Vec<u8>, String> {
173 let text = match value {
174 serde_json::Value::String(s) => s.clone(),
175 serde_json::Value::Number(n) => n.to_string(),
176 serde_json::Value::Bool(b) => b.to_string(),
177 serde_json::Value::Null => "null".to_string(),
178 other => serde_json::to_string(other).map_err(|e| e.to_string())?,
179 };
180 Ok(text.into_bytes())
181 }
182
183 fn convert_to_html(value: &serde_json::Value) -> Result<Vec<u8>, String> {
185 let json_str = serde_json::to_string_pretty(value).map_err(|e| e.to_string())?;
186 let html = format!(
187 r#"<!DOCTYPE html>
188<html>
189<head>
190 <title>API Response</title>
191 <style>
192 body {{ font-family: monospace; padding: 20px; }}
193 pre {{ background: #f5f5f5; padding: 15px; border-radius: 5px; }}
194 </style>
195</head>
196<body>
197 <h1>API Response</h1>
198 <pre>{}</pre>
199</body>
200</html>"#,
201 html_escape::encode_text(&json_str)
202 );
203 Ok(html.into_bytes())
204 }
205}
206
207#[derive(Debug)]
209pub struct ContentNegotiationMiddleware {
210 config: ContentNegotiationConfig,
211}
212
213impl std::fmt::Debug for ContentNegotiationConfig {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("ContentNegotiationConfig")
216 .field("default_content_type", &self.default_content_type)
217 .field("supported_types", &self.supported_types)
218 .field("add_vary_header", &self.add_vary_header)
219 .field(
220 "converters",
221 &format!("<{} converters>", self.converters.len()),
222 )
223 .finish()
224 }
225}
226
227impl Clone for ContentNegotiationConfig {
228 fn clone(&self) -> Self {
229 Self {
230 default_content_type: self.default_content_type.clone(),
231 supported_types: self.supported_types.clone(),
232 add_vary_header: self.add_vary_header,
233 converters: self.converters.clone(), }
235 }
236}
237
238impl ContentNegotiationMiddleware {
239 pub fn new() -> Self {
241 Self {
242 config: ContentNegotiationConfig::default(),
243 }
244 }
245
246 pub fn with_config(config: ContentNegotiationConfig) -> Self {
248 Self { config }
249 }
250
251 pub fn default_type(mut self, content_type: ContentType) -> Self {
253 self.config.default_content_type = content_type;
254 self
255 }
256
257 pub fn support(mut self, content_type: ContentType) -> Self {
259 if !self.config.supported_types.contains(&content_type) {
260 self.config.supported_types.push(content_type);
261 }
262 self
263 }
264
265 pub fn converter<F>(mut self, content_type: ContentType, converter: F) -> Self
267 where
268 F: Fn(&serde_json::Value) -> Result<Vec<u8>, String> + Send + Sync + 'static,
269 {
270 self.config
271 .converters
272 .insert(content_type.clone(), std::sync::Arc::new(converter));
273 if !self.config.supported_types.contains(&content_type) {
275 self.config.supported_types.push(content_type);
276 }
277 self
278 }
279
280 pub fn no_vary_header(mut self) -> Self {
282 self.config.add_vary_header = false;
283 self
284 }
285
286 fn parse_accept_header(&self, accept_header: &str) -> Vec<AcceptValue> {
288 let mut accept_values = Vec::new();
289
290 for value in accept_header.split(',') {
291 if let Some(accept_value) = AcceptValue::parse(value.trim()) {
292 accept_values.push(accept_value);
293 }
294 }
295
296 accept_values.sort_by(|a, b| {
298 b.quality
299 .partial_cmp(&a.quality)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 });
302
303 accept_values
304 }
305
306 fn negotiate_content_type(&self, accept_header: Option<&ElifHeaderValue>) -> ContentType {
308 let accept_str = match accept_header.and_then(|h| h.to_str().ok()) {
309 Some(s) => s,
310 None => return self.config.default_content_type.clone(),
311 };
312
313 let accept_values = self.parse_accept_header(accept_str);
314
315 for accept_value in &accept_values {
317 if self
318 .config
319 .supported_types
320 .contains(&accept_value.content_type)
321 {
322 return accept_value.content_type.clone();
323 }
324
325 if let ContentType::Custom(mime) = &accept_value.content_type {
327 if mime == "*/*" {
328 return self.config.default_content_type.clone();
329 } else if mime.ends_with("/*") {
330 let category = &mime[..mime.len() - 2];
331 for supported in &self.config.supported_types {
333 if supported.to_mime_type().starts_with(category) {
334 return supported.clone();
335 }
336 }
337 }
338 }
339 }
340
341 self.config.default_content_type.clone()
342 }
343
344 fn extract_json_value(&self, response_body: &[u8]) -> Option<serde_json::Value> {
346 serde_json::from_slice(response_body).ok()
348 }
349
350 async fn convert_response(
352 &self,
353 response: ElifResponse,
354 target_type: ContentType,
355 ) -> ElifResponse {
356 let axum_response = response.into_axum_response();
357 let (parts, body) = axum_response.into_parts();
358
359 let current_content_type = parts
361 .headers
362 .get("content-type")
363 .and_then(|h| h.to_str().ok())
364 .and_then(ContentType::from_mime_type)
365 .unwrap_or(ContentType::Json);
366
367 if current_content_type == target_type {
369 let response = axum::response::Response::from_parts(parts, body);
370 return ElifResponse::from_axum_response(response).await;
371 }
372
373 let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
375 Ok(bytes) => bytes,
376 Err(_) => {
377 let response =
379 axum::response::Response::from_parts(parts, axum::body::Body::empty());
380 return ElifResponse::from_axum_response(response).await;
381 }
382 };
383
384 let json_value = match self.extract_json_value(&body_bytes) {
386 Some(value) => value,
387 None => {
388 let response =
390 axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes));
391 return ElifResponse::from_axum_response(response).await;
392 }
393 };
394
395 let converted_body =
397 match self.config.converters.get(&target_type) {
398 Some(converter) => match converter(&json_value) {
399 Ok(body) => body,
400 Err(_) => {
401 return ElifResponse::from_axum_response(
403 axum::response::Response::builder()
404 .status(axum::http::StatusCode::NOT_ACCEPTABLE)
405 .header("content-type", "application/json")
406 .body(axum::body::Body::from(
407 serde_json::to_vec(&serde_json::json!({
408 "error": {
409 "code": "not_acceptable",
410 "message": "Cannot convert response to requested format",
411 "hint": "Supported formats: JSON, Plain Text, HTML"
412 }
413 })).unwrap_or_default()
414 ))
415 .unwrap()
416 ).await;
417 }
418 },
419 None => {
420 return ElifResponse::from_axum_response(
422 axum::response::Response::builder()
423 .status(axum::http::StatusCode::NOT_ACCEPTABLE)
424 .header("content-type", "application/json")
425 .body(axum::body::Body::from(
426 serde_json::to_vec(&serde_json::json!({
427 "error": {
428 "code": "not_acceptable",
429 "message": "Requested format is not supported",
430 "hint": "Supported formats: JSON, Plain Text, HTML"
431 }
432 }))
433 .unwrap_or_default(),
434 ))
435 .unwrap(),
436 )
437 .await;
438 }
439 };
440
441 let mut new_parts = parts;
443 new_parts.headers.insert(
444 axum::http::HeaderName::from_static("content-type"),
445 axum::http::HeaderValue::from_static(target_type.to_mime_type()),
446 );
447
448 new_parts.headers.insert(
449 axum::http::HeaderName::from_static("content-length"),
450 axum::http::HeaderValue::try_from(converted_body.len().to_string()).unwrap(),
451 );
452
453 if self.config.add_vary_header {
454 new_parts.headers.insert(
455 axum::http::HeaderName::from_static("vary"),
456 axum::http::HeaderValue::from_static("Accept"),
457 );
458 }
459
460 let response =
461 axum::response::Response::from_parts(new_parts, axum::body::Body::from(converted_body));
462
463 ElifResponse::from_axum_response(response).await
464 }
465}
466
467impl Default for ContentNegotiationMiddleware {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473impl Middleware for ContentNegotiationMiddleware {
474 fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
475 let accept_header = request.header("accept").cloned();
476 let target_type = self.negotiate_content_type(accept_header.as_ref());
477 let config = self.config.clone();
478
479 Box::pin(async move {
480 let response = next.run(request).await;
481
482 let middleware = ContentNegotiationMiddleware { config };
483 middleware.convert_response(response, target_type).await
484 })
485 }
486
487 fn name(&self) -> &'static str {
488 "ContentNegotiationMiddleware"
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::request::ElifRequest;
496 use crate::response::headers::ElifHeaderMap;
497 use crate::response::headers::ElifHeaderName;
498 use crate::response::ElifResponse;
499
500 #[test]
501 fn test_content_type_parsing() {
502 assert_eq!(
503 ContentType::from_mime_type("application/json"),
504 Some(ContentType::Json)
505 );
506 assert_eq!(
507 ContentType::from_mime_type("application/xml"),
508 Some(ContentType::Xml)
509 );
510 assert_eq!(
511 ContentType::from_mime_type("text/html"),
512 Some(ContentType::Html)
513 );
514 assert_eq!(
515 ContentType::from_mime_type("text/plain"),
516 Some(ContentType::PlainText)
517 );
518 }
519
520 #[test]
521 fn test_accept_value_parsing() {
522 let accept = AcceptValue::parse("application/json;q=0.8").unwrap();
523 assert_eq!(accept.content_type, ContentType::Json);
524 assert_eq!(accept.quality, 0.8);
525
526 let accept = AcceptValue::parse("text/html").unwrap();
527 assert_eq!(accept.content_type, ContentType::Html);
528 assert_eq!(accept.quality, 1.0);
529
530 let accept = AcceptValue::parse("text/plain;q=0.5;charset=utf-8").unwrap();
531 assert_eq!(accept.content_type, ContentType::PlainText);
532 assert_eq!(accept.quality, 0.5);
533 assert_eq!(accept.params.get("charset"), Some(&"utf-8".to_string()));
534 }
535
536 #[test]
537 fn test_accept_header_parsing() {
538 let middleware = ContentNegotiationMiddleware::new();
539 let values =
540 middleware.parse_accept_header("text/html,application/json;q=0.9,text/plain;q=0.8");
541
542 assert_eq!(values.len(), 3);
543 assert_eq!(values[0].content_type, ContentType::Html);
545 assert_eq!(values[1].content_type, ContentType::Json);
546 assert_eq!(values[2].content_type, ContentType::PlainText);
547 }
548
549 #[test]
550 fn test_content_negotiation() {
551 let middleware = ContentNegotiationMiddleware::new();
552
553 let header = ElifHeaderValue::from_static("application/json");
555 assert_eq!(
556 middleware.negotiate_content_type(Some(&header)),
557 ContentType::Json
558 );
559
560 let header = ElifHeaderValue::from_static("text/html,application/json;q=0.9");
562 assert_eq!(
563 middleware.negotiate_content_type(Some(&header)),
564 ContentType::Html
565 );
566
567 let header = ElifHeaderValue::from_static("application/pdf");
569 assert_eq!(
570 middleware.negotiate_content_type(Some(&header)),
571 ContentType::Json );
573
574 let header = ElifHeaderValue::from_static("*/*");
576 assert_eq!(
577 middleware.negotiate_content_type(Some(&header)),
578 ContentType::Json );
580 }
581
582 #[tokio::test]
583 async fn test_json_to_text_conversion() {
584 let middleware = ContentNegotiationMiddleware::new();
585
586 let mut headers = ElifHeaderMap::new();
587 let accept_header = ElifHeaderName::from_str("accept").unwrap();
588 let plain_value = ElifHeaderValue::from_str("text/plain").unwrap();
589 headers.insert(accept_header, plain_value);
590 let request = ElifRequest::new(
591 crate::request::ElifMethod::GET,
592 "/api/data".parse().unwrap(),
593 headers,
594 );
595
596 let next = Next::new(|_req| {
597 Box::pin(async move {
598 ElifResponse::ok().json_value(serde_json::json!({
599 "message": "Hello, World!",
600 "count": 42
601 }))
602 })
603 });
604
605 let response = middleware.handle(request, next).await;
606 assert_eq!(
607 response.status_code(),
608 crate::response::status::ElifStatusCode::OK
609 );
610
611 let axum_response = response.into_axum_response();
613 let (parts, _) = axum_response.into_parts();
614 assert_eq!(parts.headers.get("content-type").unwrap(), "text/plain");
615 }
616
617 #[tokio::test]
618 async fn test_json_to_html_conversion() {
619 let middleware = ContentNegotiationMiddleware::new();
620
621 let mut headers = ElifHeaderMap::new();
622 let accept_header = ElifHeaderName::from_str("accept").unwrap();
623 let html_value = ElifHeaderValue::from_str("text/html").unwrap();
624 headers.insert(accept_header, html_value);
625 let request = ElifRequest::new(
626 crate::request::ElifMethod::GET,
627 "/api/data".parse().unwrap(),
628 headers,
629 );
630
631 let next = Next::new(|_req| {
632 Box::pin(async move {
633 ElifResponse::ok().json_value(serde_json::json!({
634 "message": "Hello, World!"
635 }))
636 })
637 });
638
639 let response = middleware.handle(request, next).await;
640 assert_eq!(
641 response.status_code(),
642 crate::response::status::ElifStatusCode::OK
643 );
644
645 let axum_response = response.into_axum_response();
646 let (parts, body) = axum_response.into_parts();
647 assert_eq!(parts.headers.get("content-type").unwrap(), "text/html");
648
649 let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
651 let html_content = String::from_utf8(body_bytes.to_vec()).unwrap();
652 assert!(html_content.contains("<!DOCTYPE html>"));
653 assert!(html_content.contains("Hello, World!"));
654 }
655
656 #[tokio::test]
657 async fn test_unsupported_format_406() {
658 let middleware = ContentNegotiationMiddleware::new();
659
660 let mut headers = ElifHeaderMap::new();
661 let accept_header = ElifHeaderName::from_str("accept").unwrap();
662 let pdf_value = ElifHeaderValue::from_str("application/pdf").unwrap();
663 headers.insert(accept_header, pdf_value);
664 let request = ElifRequest::new(
665 crate::request::ElifMethod::GET,
666 "/api/data".parse().unwrap(),
667 headers,
668 );
669
670 let next = Next::new(|_req| {
671 Box::pin(async move {
672 ElifResponse::ok().json_value(serde_json::json!({
673 "message": "Hello, World!"
674 }))
675 })
676 });
677
678 let response = middleware.handle(request, next).await;
679 assert_eq!(
681 response.status_code(),
682 crate::response::status::ElifStatusCode::OK
683 );
684 }
685
686 #[tokio::test]
687 async fn test_builder_pattern() {
688 let middleware = ContentNegotiationMiddleware::new()
689 .default_type(ContentType::Html)
690 .support(ContentType::Csv)
691 .no_vary_header();
692
693 assert_eq!(middleware.config.default_content_type, ContentType::Html);
694 assert!(middleware
695 .config
696 .supported_types
697 .contains(&ContentType::Csv));
698 assert!(!middleware.config.add_vary_header);
699 }
700
701 #[test]
702 fn test_content_type_mime_types() {
703 assert_eq!(ContentType::Json.to_mime_type(), "application/json");
704 assert_eq!(ContentType::Xml.to_mime_type(), "application/xml");
705 assert_eq!(ContentType::Html.to_mime_type(), "text/html");
706 assert_eq!(ContentType::PlainText.to_mime_type(), "text/plain");
707 assert_eq!(ContentType::Csv.to_mime_type(), "text/csv");
708 }
709
710 #[test]
711 fn test_json_conversion_functions() {
712 let json_val = serde_json::json!({
713 "name": "test",
714 "value": 42
715 });
716
717 let json_result = ContentNegotiationConfig::convert_to_json(&json_val).unwrap();
719 assert!(String::from_utf8(json_result).unwrap().contains("test"));
720
721 let text_val = serde_json::json!("Hello World");
723 let text_result = ContentNegotiationConfig::convert_to_text(&text_val).unwrap();
724 assert_eq!(String::from_utf8(text_result).unwrap(), "Hello World");
725
726 let html_result = ContentNegotiationConfig::convert_to_html(&json_val).unwrap();
728 let html_content = String::from_utf8(html_result).unwrap();
729 assert!(html_content.contains("<!DOCTYPE html>"));
730 assert!(html_content.contains("test"));
731 }
732
733 #[tokio::test]
734 async fn test_custom_converter_preservation_after_clone() {
735 let middleware =
737 ContentNegotiationMiddleware::new().converter(ContentType::Csv, |_json_value| {
738 Ok(b"custom,csv,data".to_vec())
740 });
741
742 let mut headers = ElifHeaderMap::new();
743 let accept_header = ElifHeaderName::from_str("accept").unwrap();
744 let csv_value = ElifHeaderValue::from_str("text/csv").unwrap();
745 headers.insert(accept_header, csv_value);
746 let request = ElifRequest::new(
747 crate::request::ElifMethod::GET,
748 "/api/data".parse().unwrap(),
749 headers,
750 );
751
752 let next = Next::new(|_req| {
753 Box::pin(async move {
754 ElifResponse::ok().json_value(serde_json::json!({
755 "test": "data"
756 }))
757 })
758 });
759
760 let response = middleware.handle(request, next).await;
762 assert_eq!(
763 response.status_code(),
764 crate::response::status::ElifStatusCode::OK
765 );
766
767 let axum_response = response.into_axum_response();
769 let (parts, _) = axum_response.into_parts();
770 assert_eq!(parts.headers.get("content-type").unwrap(), "text/csv");
771 }
772}