Skip to main content

wafrift_encoding/
contextual.rs

1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5    payload: &[u8],
6    strategy: Strategy,
7    context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9    let max_size = match context {
10        InjectionContext::JsonString => 4 * 1024 * 1024,
11        InjectionContext::JsonNumber => 1024,
12        InjectionContext::XmlAttribute => 1024 * 1024,
13        InjectionContext::XmlCdata => 8 * 1024 * 1024,
14        InjectionContext::HeaderValue => 8 * 1024,
15        InjectionContext::CookieValue => 4 * 1024,
16        InjectionContext::MultipartFileName => 256,
17        _ => 8 * 1024 * 1024,
18    };
19
20    if payload.len() > max_size {
21        return Err(ContextualEncodeError::PayloadTooLarge {
22            context,
23            size: payload.len(),
24            max: max_size,
25        });
26    }
27
28    let base = match crate::encoding::encode(payload, strategy) {
29        Ok(s) => s,
30        Err(e) => {
31            return Err(match e {
32                crate::error::EncodeError::InvalidUtf8 => {
33                    ContextualEncodeError::InvalidUtf8 { offset: 0 }
34                }
35                crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36                    ContextualEncodeError::PayloadTooLarge {
37                        context,
38                        size: actual,
39                        max,
40                    }
41                }
42                crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43                    ContextualEncodeError::PayloadTooLarge {
44                        context,
45                        size: actual,
46                        max,
47                    }
48                }
49                crate::error::EncodeError::InvalidContext {
50                    strategy: s,
51                    context: _,
52                } => ContextualEncodeError::ContextIncompatible {
53                    strategy: s.into(),
54                    context,
55                    reason: "strategy invalid for context".into(),
56                },
57                crate::error::EncodeError::InvalidConfig(msg) => {
58                    ContextualEncodeError::ContextIncompatible {
59                        strategy: "config".into(),
60                        context,
61                        reason: msg,
62                    }
63                }
64            });
65        }
66    };
67
68    escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72    input: &str,
73    context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75    let escaped = match context {
76        InjectionContext::JsonString => {
77            let mut s = String::with_capacity(input.len() + 10);
78            for c in input.chars() {
79                match c {
80                    '\\' => s.push_str("\\\\"),
81                    '"' => s.push_str("\\\""),
82                    '\n' => s.push_str("\\n"),
83                    '\r' => s.push_str("\\r"),
84                    '\t' => s.push_str("\\t"),
85                    '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86                    _ => s.push(c),
87                }
88            }
89            s
90        }
91        InjectionContext::JsonNumber => {
92            if input.chars().any(|c| {
93                !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
94            }) {
95                return Err(ContextualEncodeError::ContextIncompatible {
96                    strategy: "escape".into(),
97                    context,
98                    reason: "not a valid JSON number".into(),
99                });
100            }
101            input.to_string()
102        }
103        InjectionContext::XmlAttribute => {
104            if input.contains('\x00') {
105                return Err(ContextualEncodeError::ContextIncompatible {
106                    strategy: "escape".into(),
107                    context,
108                    reason: "null byte in xml attribute".into(),
109                });
110            }
111            input
112                .replace('&', "&amp;")
113                .replace('"', "&quot;")
114                .replace('<', "&lt;")
115                .replace('>', "&gt;")
116        }
117        InjectionContext::XmlCdata => {
118            if input.contains("]]>") {
119                return Err(ContextualEncodeError::ContextIncompatible {
120                    strategy: "escape".into(),
121                    context,
122                    reason: "CDATA cannot contain ]]>".into(),
123                });
124            }
125            input.to_string()
126        }
127        InjectionContext::XmlText => input
128            .replace('&', "&amp;")
129            .replace('<', "&lt;")
130            .replace('>', "&gt;"),
131        InjectionContext::HtmlAttribute => input
132            .replace('&', "&amp;")
133            .replace('"', "&quot;")
134            .replace('\'', "&#x27;")
135            .replace('<', "&lt;"),
136        InjectionContext::HtmlText => input.replace('&', "&amp;").replace('<', "&lt;"),
137        InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
138        InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
139        InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
140        InjectionContext::HeaderValue => {
141            if input.contains('\r') || input.contains('\n') {
142                return Err(ContextualEncodeError::ContextIncompatible {
143                    strategy: "escape".into(),
144                    context,
145                    reason: "CR/LF in header value".into(),
146                });
147            }
148            if input.contains('\x00') {
149                return Err(ContextualEncodeError::ContextIncompatible {
150                    strategy: "escape".into(),
151                    context,
152                    reason: "null byte in header value".into(),
153                });
154            }
155            input.to_string()
156        }
157        InjectionContext::CookieValue => input
158            .replace(';', "%3B")
159            .replace('=', "%3D")
160            .replace('\x00', "%00")
161            .replace('\r', "%0D")
162            .replace('\n', "%0A"),
163        InjectionContext::MultipartField => {
164            if input.contains('\r') || input.contains('\n') {
165                return Err(ContextualEncodeError::ContextIncompatible {
166                    strategy: "escape".into(),
167                    context,
168                    reason: "CR/LF would break multipart structure".into(),
169                });
170            }
171            input.to_string()
172        }
173        InjectionContext::MultipartFileName => {
174            if input.contains('"') {
175                return Err(ContextualEncodeError::ContextIncompatible {
176                    strategy: "escape".into(),
177                    context,
178                    reason: "quote in filename".into(),
179                });
180            }
181            if input.contains('\r') || input.contains('\n') {
182                return Err(ContextualEncodeError::ContextIncompatible {
183                    strategy: "escape".into(),
184                    context,
185                    reason: "CR/LF in filename".into(),
186                });
187            }
188            input.to_string()
189        }
190        InjectionContext::PlainBody => input.to_string(),
191        _ => input.to_string(),
192    };
193    validate_in_context(&escaped, context)?;
194    Ok(escaped)
195}
196
197pub fn validate_in_context(
198    payload: &str,
199    context: InjectionContext,
200) -> Result<(), ContextualEncodeError> {
201    match context {
202        InjectionContext::JsonString => {
203            let mut chars = payload.chars().peekable();
204            while let Some(c) = chars.next() {
205                if c == '"' {
206                    return Err(ContextualEncodeError::ContextIncompatible {
207                        strategy: "validate".into(),
208                        context,
209                        reason: "unescaped double quote in JSON string".into(),
210                    });
211                }
212                if c == '\\' {
213                    let escaped = chars.next();
214                    match escaped {
215                        Some('\\') | Some('"') | Some('n') | Some('r') | Some('t') | Some('b')
216                        | Some('f') | Some('/') => {}
217                        Some('u') => {
218                            // Validate exactly 4 hex digits after \u
219                            for _ in 0..4 {
220                                match chars.next() {
221                                    Some(c) if c.is_ascii_hexdigit() => {}
222                                    _ => {
223                                        return Err(ContextualEncodeError::ContextIncompatible {
224                                            strategy: "validate".into(),
225                                            context,
226                                            reason: "invalid Unicode escape in JSON string".into(),
227                                        });
228                                    }
229                                }
230                            }
231                        }
232                        Some(other) => {
233                            return Err(ContextualEncodeError::ContextIncompatible {
234                                strategy: "validate".into(),
235                                context,
236                                reason: format!("invalid JSON escape sequence: \\{other}"),
237                            });
238                        }
239                        None => {
240                            return Err(ContextualEncodeError::ContextIncompatible {
241                                strategy: "validate".into(),
242                                context,
243                                reason: "trailing backslash in JSON string".into(),
244                            });
245                        }
246                    }
247                }
248            }
249        }
250        InjectionContext::XmlAttribute => {
251            let mut chars = payload.chars();
252            while let Some(c) = chars.next() {
253                if c == '"' {
254                    return Err(ContextualEncodeError::ContextIncompatible {
255                        strategy: "validate".into(),
256                        context,
257                        reason: "unescaped double quote in XML attribute".into(),
258                    });
259                }
260                if c == '&' {
261                    // Allow known entity references; anything else starting with & is suspicious
262                    let remainder: String = chars.by_ref().take(6).collect();
263                    if !remainder.starts_with("quot;")
264                        && !remainder.starts_with("amp;")
265                        && !remainder.starts_with("lt;")
266                        && !remainder.starts_with("gt;")
267                    {
268                        // Not a known entity — could be an unescaped &
269                        // (We keep scanning rather than erroring, since & alone
270                        // is technically valid XML text if followed by whitespace.)
271                    }
272                }
273            }
274        }
275        // Contexts below have no validation rules yet. Adding an explicit
276        // arm for each ensures the compiler warns us when a new variant is
277        // added so we can decide whether it needs validation.
278        InjectionContext::PlainBody => {
279            // Plain body accepts any byte sequence; nothing to validate.
280        }
281        InjectionContext::XmlCdata => {
282            // TODO: validate that payload doesn't contain `]]>` which
283            // would terminate the CDATA section prematurely.
284        }
285        InjectionContext::XmlText => {
286            // TODO: validate that payload doesn't contain `<` or `&`
287            // unless they are proper entities.
288        }
289        InjectionContext::HtmlAttribute => {
290            // TODO: validate that payload doesn't contain unescaped quotes
291            // matching the attribute delimiter.
292        }
293        InjectionContext::HtmlText => {
294            // TODO: validate that payload doesn't contain `<` or `&`
295            // unless they are proper HTML entities.
296        }
297        InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
298            // URL components are validated by percent-encoding step later;
299            // raw payload can contain any bytes here.
300        }
301        InjectionContext::HeaderValue => {
302            // Header values are validated by the header obfuscation layer;
303            // CRLF injection is guarded at the transport level.
304        }
305        InjectionContext::CookieValue => {
306            // Cookie values accept most printable ASCII; validation is
307            // handled by the cookie encoding layer.
308        }
309        InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
310            // Multipart boundaries are managed by the form encoder;
311            // individual field values have no additional constraints.
312        }
313        // InjectionContext is #[non_exhaustive]; future variants default to
314        // no validation until explicit rules are added.
315        _ => {}
316    }
317    Ok(())
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::encoding::Strategy;
324
325    #[test]
326    fn encode_error_mapping_payload_too_large() {
327        // PayloadTooLarge from encode maps to PayloadTooLarge contextual error
328        // We can't easily trigger this from encode(), but we verify the error path
329        // by checking that InvalidUtf8 is only returned for actual UTF-8 errors
330        let result = encode_in_context(
331            b"\x80",
332            Strategy::CaseAlternation,
333            InjectionContext::PlainBody,
334        );
335        // \x80 alone is invalid UTF-8, so encode should return InvalidUtf8
336        assert!(result.is_err());
337        let err = result.unwrap_err();
338        assert!(
339            err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
340            "error should mention invalid UTF-8, got: {}",
341            err
342        );
343    }
344
345    #[test]
346    fn json_string_validates_unescaped_quote() {
347        let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
348        assert!(err.to_string().contains("unescaped double quote"));
349    }
350
351    #[test]
352    fn json_string_validates_valid_escapes() {
353        assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
354        assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
355        assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
356        assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
357    }
358
359    #[test]
360    fn json_string_validates_unicode_escape() {
361        // Valid \u00e4
362        assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
363        // Invalid \u00g4 (non-hex)
364        let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
365        assert!(err.to_string().contains("invalid Unicode escape"));
366        // Too short \u00
367        let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
368        assert!(err.to_string().contains("invalid Unicode escape"));
369    }
370
371    #[test]
372    fn json_string_validates_invalid_escape() {
373        let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
374        assert!(err.to_string().contains("invalid JSON escape"));
375    }
376
377    #[test]
378    fn json_string_validates_trailing_backslash() {
379        let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
380        assert!(err.to_string().contains("trailing backslash"));
381    }
382
383    #[test]
384    fn xml_attribute_validates_unescaped_quote() {
385        let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
386        assert!(err.to_string().contains("unescaped double quote"));
387    }
388
389    #[test]
390    fn xml_attribute_allows_escaped_quote() {
391        // &quot; should be allowed (the validator doesn't fully validate entities,
392        // but it shouldn't error on well-formed entity references)
393        assert!(validate_in_context("hello&quot;world", InjectionContext::XmlAttribute).is_ok());
394    }
395
396    #[test]
397    fn header_value_validates_crlf() {
398        let err = encode_in_context(
399            b"hello\r\nworld",
400            Strategy::CaseAlternation,
401            InjectionContext::HeaderValue,
402        )
403        .unwrap_err();
404        assert!(err.to_string().contains("CR/LF"));
405    }
406
407    #[test]
408    fn cookie_value_escapes_crlf() {
409        let out = encode_in_context(
410            b"hello\r\nworld",
411            Strategy::CaseAlternation,
412            InjectionContext::CookieValue,
413        )
414        .unwrap();
415        assert!(out.contains("%0D") && out.contains("%0A"));
416    }
417
418    #[test]
419    fn multipart_field_validates_crlf() {
420        let err = encode_in_context(
421            b"hello\r\nworld",
422            Strategy::CaseAlternation,
423            InjectionContext::MultipartField,
424        )
425        .unwrap_err();
426        assert!(err.to_string().contains("CR/LF"));
427    }
428
429    #[test]
430    fn html_attribute_escapes_ampersand() {
431        let out = encode_in_context(
432            b"a&b",
433            Strategy::CaseAlternation,
434            InjectionContext::HtmlAttribute,
435        )
436        .unwrap();
437        assert!(out.contains("&amp;"));
438    }
439
440    #[test]
441    fn url_query_escapes_space() {
442        let out = encode_in_context(
443            b"hello world",
444            Strategy::CaseAlternation,
445            InjectionContext::UrlQuery,
446        )
447        .unwrap();
448        assert!(!out.contains(' '));
449    }
450
451    #[test]
452    fn url_path_preserves_slash() {
453        let out = encode_in_context(
454            b"/api/v1",
455            Strategy::CaseAlternation,
456            InjectionContext::UrlPath,
457        )
458        .unwrap();
459        assert!(out.contains('/'));
460    }
461
462    #[test]
463    fn plain_body_no_structural_escaping() {
464        // PlainBody doesn't add structural escaping, but the strategy still mutates
465        let out = encode_in_context(
466            b"<script>",
467            Strategy::CaseAlternation,
468            InjectionContext::PlainBody,
469        )
470        .unwrap();
471        assert_eq!(out, "<ScRiPt>");
472    }
473
474    #[test]
475    fn max_size_enforced() {
476        let big = vec![b'a'; 8 * 1024 * 1024 + 1];
477        let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
478            .unwrap_err();
479        assert!(err.to_string().contains("too large"));
480    }
481
482    #[test]
483    fn xml_cdata_rejects_termination_sequence() {
484        let err = encode_in_context(
485            b"hello]]>world",
486            Strategy::CaseAlternation,
487            InjectionContext::XmlCdata,
488        )
489        .unwrap_err();
490        assert!(err.to_string().contains("CDATA"));
491    }
492
493    #[test]
494    fn multipart_filename_rejects_quote() {
495        let err = encode_in_context(
496            b"file\"name.txt",
497            Strategy::CaseAlternation,
498            InjectionContext::MultipartFileName,
499        )
500        .unwrap_err();
501        assert!(err.to_string().contains("quote"));
502    }
503
504    #[test]
505    fn json_number_rejects_non_numeric() {
506        let err = encode_in_context(
507            b"abc",
508            Strategy::CaseAlternation,
509            InjectionContext::JsonNumber,
510        )
511        .unwrap_err();
512        assert!(err.to_string().contains("not a valid JSON number"));
513    }
514
515    #[test]
516    fn empty_payload_valid_in_all_contexts() {
517        for ctx in [
518            InjectionContext::PlainBody,
519            InjectionContext::JsonString,
520            InjectionContext::XmlAttribute,
521            InjectionContext::HeaderValue,
522            InjectionContext::CookieValue,
523        ] {
524            assert!(
525                encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
526                "empty payload should be valid in {ctx:?}"
527            );
528        }
529    }
530}