Skip to main content

tower_request_guard/
response.rs

1use crate::violation::Violation;
2use http::Response;
3
4/// Escape a string for safe inclusion in a JSON string value.
5/// Handles quotes, backslashes, and control characters.
6pub(crate) fn escape_json_string(s: &str) -> String {
7    let mut escaped = String::with_capacity(s.len());
8    for ch in s.chars() {
9        match ch {
10            '"' => escaped.push_str(r#"\""#),
11            '\\' => escaped.push_str(r#"\\"#),
12            '\n' => escaped.push_str(r#"\n"#),
13            '\r' => escaped.push_str(r#"\r"#),
14            '\t' => escaped.push_str(r#"\t"#),
15            c if c.is_control() => {
16                escaped.push_str(&format!("\\u{:04x}", c as u32));
17            }
18            c => escaped.push(c),
19        }
20    }
21    escaped
22}
23
24/// Build an HTTP error response for a given violation.
25pub fn violation_response(violation: &Violation) -> Response<String> {
26    let status = violation.status_code();
27    let body = violation_json_body(violation);
28
29    Response::builder()
30        .status(status)
31        .header("Content-Type", "application/json")
32        .body(body)
33        .unwrap()
34}
35
36fn violation_json_body(violation: &Violation) -> String {
37    match violation {
38        Violation::BodyTooLarge { max, received } => {
39            format!(
40                r#"{{"error":"payload too large","violation":"body_too_large","max":{},"received":{}}}"#,
41                max, received
42            )
43        }
44        Violation::RequestTimeout { timeout_ms } => {
45            format!(
46                r#"{{"error":"request timeout","violation":"request_timeout","timeout_ms":{}}}"#,
47                timeout_ms
48            )
49        }
50        Violation::InvalidContentType { received, allowed } => {
51            let received_escaped = escape_json_string(received);
52            let allowed_json: Vec<String> = allowed
53                .iter()
54                .map(|a| format!(r#""{}""#, escape_json_string(a)))
55                .collect();
56            format!(
57                r#"{{"error":"unsupported content type","violation":"invalid_content_type","received":"{}","allowed":[{}]}}"#,
58                received_escaped,
59                allowed_json.join(",")
60            )
61        }
62        Violation::MissingHeader { header } => {
63            format!(
64                r#"{{"error":"missing required header","violation":"missing_header","header":"{}"}}"#,
65                escape_json_string(header)
66            )
67        }
68        Violation::JsonTooDeep {
69            max_depth,
70            found_depth,
71        } => {
72            format!(
73                r#"{{"error":"json depth exceeded","violation":"json_too_deep","max_depth":{},"found_depth":{}}}"#,
74                max_depth, found_depth
75            )
76        }
77        Violation::InvalidJson { detail } => {
78            format!(
79                r#"{{"error":"invalid json","violation":"invalid_json","detail":"{}"}}"#,
80                escape_json_string(detail)
81            )
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::violation::Violation;
90    use http::StatusCode;
91
92    #[test]
93    fn escape_json_string_handles_special_chars() {
94        assert_eq!(escape_json_string(r#"hello "world""#), r#"hello \"world\""#);
95        assert_eq!(escape_json_string("back\\slash"), r#"back\\slash"#);
96        assert_eq!(escape_json_string("new\nline"), r#"new\nline"#);
97        assert_eq!(escape_json_string("tab\there"), r#"tab\there"#);
98    }
99
100    #[test]
101    fn escape_json_string_passes_through_clean_input() {
102        assert_eq!(escape_json_string("application/json"), "application/json");
103        assert_eq!(escape_json_string("Authorization"), "Authorization");
104    }
105
106    #[test]
107    fn body_too_large_response() {
108        let v = Violation::BodyTooLarge {
109            max: 1024,
110            received: 2048,
111        };
112        let resp = violation_response(&v);
113        assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
114        let body = resp.into_body();
115        assert!(body.contains(r#""violation":"body_too_large""#));
116        assert!(body.contains(r#""max":1024"#));
117        assert!(body.contains(r#""received":2048"#));
118    }
119
120    #[test]
121    fn missing_header_response() {
122        let v = Violation::MissingHeader {
123            header: "Authorization".into(),
124        };
125        let resp = violation_response(&v);
126        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
127        let body = resp.into_body();
128        assert!(body.contains(r#""violation":"missing_header""#));
129        assert!(body.contains(r#""header":"Authorization""#));
130    }
131
132    #[test]
133    fn invalid_content_type_response() {
134        let v = Violation::InvalidContentType {
135            received: "text/xml".into(),
136            allowed: vec!["application/json".into(), "multipart/form-data".into()],
137        };
138        let resp = violation_response(&v);
139        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
140        let body = resp.into_body();
141        assert!(body.contains(r#""violation":"invalid_content_type""#));
142        assert!(body.contains(r#""received":"text/xml""#));
143        assert!(body.contains(r#""allowed":["application/json","multipart/form-data"]"#));
144    }
145
146    #[test]
147    fn timeout_response() {
148        let v = Violation::RequestTimeout { timeout_ms: 30000 };
149        let resp = violation_response(&v);
150        assert_eq!(resp.status(), StatusCode::GATEWAY_TIMEOUT);
151        let body = resp.into_body();
152        assert!(body.contains(r#""violation":"request_timeout""#));
153        assert!(body.contains(r#""timeout_ms":30000"#));
154    }
155
156    #[test]
157    fn json_too_deep_response() {
158        let v = Violation::JsonTooDeep {
159            max_depth: 32,
160            found_depth: 128,
161        };
162        let resp = violation_response(&v);
163        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
164        let body = resp.into_body();
165        assert!(body.contains(r#""violation":"json_too_deep""#));
166        assert!(body.contains(r#""max_depth":32"#));
167        assert!(body.contains(r#""found_depth":128"#));
168    }
169
170    #[test]
171    fn invalid_json_response() {
172        let v = Violation::InvalidJson {
173            detail: "unexpected EOF".into(),
174        };
175        let resp = violation_response(&v);
176        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
177        let body = resp.into_body();
178        assert!(body.contains(r#""violation":"invalid_json""#));
179        assert!(body.contains(r#""detail":"unexpected EOF""#));
180    }
181
182    #[test]
183    fn response_escapes_untrusted_input() {
184        let v = Violation::MissingHeader {
185            header: r#"X-Bad"Header"#.into(),
186        };
187        let resp = violation_response(&v);
188        let body = resp.into_body();
189        assert!(body.contains(r#""header":"X-Bad\"Header""#));
190    }
191}