Skip to main content

wafrift_encoding/
contextual.rs

1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5    payload: &[u8],
6    strategy: Strategy,
7    context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9    let max_size = match context {
10        InjectionContext::JsonString => 4 * 1024 * 1024,
11        InjectionContext::JsonNumber => 1024,
12        InjectionContext::XmlAttribute => 1024 * 1024,
13        InjectionContext::XmlCdata => 8 * 1024 * 1024,
14        InjectionContext::HeaderValue => 8 * 1024,
15        InjectionContext::CookieValue => 4 * 1024,
16        InjectionContext::MultipartFileName => 256,
17        _ => 8 * 1024 * 1024,
18    };
19
20    if payload.len() > max_size {
21        return Err(ContextualEncodeError::PayloadTooLarge {
22            context,
23            size: payload.len(),
24            max: max_size,
25        });
26    }
27
28    let base = match crate::encoding::encode(payload, strategy) {
29        Ok(s) => s,
30        Err(e) => {
31            return Err(match e {
32                crate::error::EncodeError::InvalidUtf8 => {
33                    ContextualEncodeError::InvalidUtf8 { offset: 0 }
34                }
35                crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36                    ContextualEncodeError::PayloadTooLarge {
37                        context,
38                        size: actual,
39                        max,
40                    }
41                }
42                crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43                    ContextualEncodeError::PayloadTooLarge {
44                        context,
45                        size: actual,
46                        max,
47                    }
48                }
49                crate::error::EncodeError::InvalidContext {
50                    strategy: s,
51                    context: _,
52                } => ContextualEncodeError::ContextIncompatible {
53                    strategy: s.into(),
54                    context,
55                    reason: "strategy invalid for context".into(),
56                },
57                crate::error::EncodeError::InvalidConfig(msg) => {
58                    ContextualEncodeError::ContextIncompatible {
59                        strategy: "config".into(),
60                        context,
61                        reason: msg,
62                    }
63                }
64            });
65        }
66    };
67
68    escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72    input: &str,
73    context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75    let escaped = match context {
76        InjectionContext::JsonString => {
77            let mut s = String::with_capacity(input.len() + 10);
78            for c in input.chars() {
79                match c {
80                    '\\' => s.push_str("\\\\"),
81                    '"' => s.push_str("\\\""),
82                    '\n' => s.push_str("\\n"),
83                    '\r' => s.push_str("\\r"),
84                    '\t' => s.push_str("\\t"),
85                    '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86                    // U+2028 LINE SEPARATOR and U+2029 PARAGRAPH SEPARATOR
87                    // are valid in JSON strings per RFC 8259 but are line
88                    // terminators in legacy ECMAScript / JSONP / eval
89                    // contexts. Pre-fix a payload-controlled value with
90                    // U+2028 inlined into <script>JSON</script> would
91                    // close the string literal and inject script. Escape
92                    // both for defence-in-depth even when shipping pure
93                    // JSON over the wire.
94                    '\u{2028}' => s.push_str("\\u2028"),
95                    '\u{2029}' => s.push_str("\\u2029"),
96                    _ => s.push(c),
97                }
98            }
99            s
100        }
101        InjectionContext::JsonNumber => {
102            if input.chars().any(|c| {
103                !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
104            }) {
105                return Err(ContextualEncodeError::ContextIncompatible {
106                    strategy: "escape".into(),
107                    context,
108                    reason: "not a valid JSON number".into(),
109                });
110            }
111            input.to_string()
112        }
113        InjectionContext::XmlAttribute => {
114            if input.contains('\x00') {
115                return Err(ContextualEncodeError::ContextIncompatible {
116                    strategy: "escape".into(),
117                    context,
118                    reason: "null byte in xml attribute".into(),
119                });
120            }
121            // XML allows single-quoted attributes; pre-fix only escaped
122            // `&"<>` and a payload with `'` would break out of an
123            // `<elem attr='...'>` form. Add &apos; escape.
124            input
125                .replace('&', "&amp;")
126                .replace('"', "&quot;")
127                .replace('\'', "&apos;")
128                .replace('<', "&lt;")
129                .replace('>', "&gt;")
130        }
131        InjectionContext::XmlCdata => {
132            if input.contains("]]>") {
133                return Err(ContextualEncodeError::ContextIncompatible {
134                    strategy: "escape".into(),
135                    context,
136                    reason: "CDATA cannot contain ]]>".into(),
137                });
138            }
139            input.to_string()
140        }
141        InjectionContext::XmlText => input
142            .replace('&', "&amp;")
143            .replace('<', "&lt;")
144            .replace('>', "&gt;"),
145        InjectionContext::HtmlAttribute => input
146            .replace('&', "&amp;")
147            .replace('"', "&quot;")
148            .replace('\'', "&#x27;")
149            .replace('<', "&lt;"),
150        InjectionContext::HtmlText => input.replace('&', "&amp;").replace('<', "&lt;"),
151        InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
152        InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
153        InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
154        InjectionContext::HeaderValue => {
155            if input.contains('\r') || input.contains('\n') {
156                return Err(ContextualEncodeError::ContextIncompatible {
157                    strategy: "escape".into(),
158                    context,
159                    reason: "CR/LF in header value".into(),
160                });
161            }
162            if input.contains('\x00') {
163                return Err(ContextualEncodeError::ContextIncompatible {
164                    strategy: "escape".into(),
165                    context,
166                    reason: "null byte in header value".into(),
167                });
168            }
169            input.to_string()
170        }
171        InjectionContext::CookieValue => input
172            // RFC 6265 §4.1.1 cookie-octet excludes space, ",", '"', `\\`
173            // in addition to ; = CTLs. Pre-fix the missing chars caused
174            // Chrome / Firefox / curl to truncate the cookie at the
175            // offending byte — making bypass probes silently lie about
176            // the value that actually reached the server.
177            .replace(';', "%3B")
178            .replace('=', "%3D")
179            .replace(' ', "%20")
180            .replace(',', "%2C")
181            .replace('"', "%22")
182            .replace('\\', "%5C")
183            .replace('\x00', "%00")
184            .replace('\r', "%0D")
185            .replace('\n', "%0A"),
186        InjectionContext::MultipartField => {
187            if input.contains('\r') || input.contains('\n') {
188                return Err(ContextualEncodeError::ContextIncompatible {
189                    strategy: "escape".into(),
190                    context,
191                    reason: "CR/LF would break multipart structure".into(),
192                });
193            }
194            input.to_string()
195        }
196        InjectionContext::MultipartFileName => {
197            if input.contains('"') {
198                return Err(ContextualEncodeError::ContextIncompatible {
199                    strategy: "escape".into(),
200                    context,
201                    reason: "quote in filename".into(),
202                });
203            }
204            if input.contains('\r') || input.contains('\n') {
205                return Err(ContextualEncodeError::ContextIncompatible {
206                    strategy: "escape".into(),
207                    context,
208                    reason: "CR/LF in filename".into(),
209                });
210            }
211            input.to_string()
212        }
213        InjectionContext::PlainBody => input.to_string(),
214        _ => input.to_string(),
215    };
216    validate_in_context(&escaped, context)?;
217    Ok(escaped)
218}
219
220pub fn validate_in_context(
221    payload: &str,
222    context: InjectionContext,
223) -> Result<(), ContextualEncodeError> {
224    match context {
225        InjectionContext::JsonString => {
226            let mut chars = payload.chars().peekable();
227            while let Some(c) = chars.next() {
228                if c == '"' {
229                    return Err(ContextualEncodeError::ContextIncompatible {
230                        strategy: "validate".into(),
231                        context,
232                        reason: "unescaped double quote in JSON string".into(),
233                    });
234                }
235                if c == '\\' {
236                    let escaped = chars.next();
237                    match escaped {
238                        Some('\\' | '"' | 'n' | 'r' | 't' | 'b' | 'f' | '/') => {}
239                        Some('u') => {
240                            // Validate exactly 4 hex digits after \u
241                            for _ in 0..4 {
242                                match chars.next() {
243                                    Some(c) if c.is_ascii_hexdigit() => {}
244                                    _ => {
245                                        return Err(ContextualEncodeError::ContextIncompatible {
246                                            strategy: "validate".into(),
247                                            context,
248                                            reason: "invalid Unicode escape in JSON string".into(),
249                                        });
250                                    }
251                                }
252                            }
253                        }
254                        Some(other) => {
255                            return Err(ContextualEncodeError::ContextIncompatible {
256                                strategy: "validate".into(),
257                                context,
258                                reason: format!("invalid JSON escape sequence: \\{other}"),
259                            });
260                        }
261                        None => {
262                            return Err(ContextualEncodeError::ContextIncompatible {
263                                strategy: "validate".into(),
264                                context,
265                                reason: "trailing backslash in JSON string".into(),
266                            });
267                        }
268                    }
269                }
270            }
271        }
272        InjectionContext::XmlAttribute => {
273            let mut chars = payload.chars();
274            while let Some(c) = chars.next() {
275                if c == '"' {
276                    return Err(ContextualEncodeError::ContextIncompatible {
277                        strategy: "validate".into(),
278                        context,
279                        reason: "unescaped double quote in XML attribute".into(),
280                    });
281                }
282                if c == '&' {
283                    // Allow known entity references; anything else starting with & is suspicious
284                    let remainder: String = chars.by_ref().take(6).collect();
285                    if !remainder.starts_with("quot;")
286                        && !remainder.starts_with("amp;")
287                        && !remainder.starts_with("lt;")
288                        && !remainder.starts_with("gt;")
289                    {
290                        // Not a known entity — could be an unescaped &
291                        // (We keep scanning rather than erroring, since & alone
292                        // is technically valid XML text if followed by whitespace.)
293                    }
294                }
295            }
296        }
297        // Contexts below have no validation rules yet. Adding an explicit
298        // arm for each ensures the compiler warns us when a new variant is
299        // added so we can decide whether it needs validation.
300        InjectionContext::PlainBody => {
301            // Plain body accepts any byte sequence; nothing to validate.
302        }
303        InjectionContext::XmlCdata
304            if payload.contains("]]>") => {
305                return Err(ContextualEncodeError::ContextIncompatible {
306                    strategy: "validate".into(),
307                    context,
308                    reason: "CDATA payload contains `]]>` (unterminated section)".into(),
309                });
310            }
311        InjectionContext::XmlText => {
312            if payload.contains('<') {
313                return Err(ContextualEncodeError::ContextIncompatible {
314                    strategy: "validate".into(),
315                    context,
316                    reason: "XML text payload contains unescaped `<`".into(),
317                });
318            }
319            reject_unescaped_ampersand(payload, context)?;
320        }
321        InjectionContext::HtmlAttribute => {
322            if payload.contains('<') {
323                return Err(ContextualEncodeError::ContextIncompatible {
324                    strategy: "validate".into(),
325                    context,
326                    reason: "HTML attribute contains unescaped `<` — would close the attribute"
327                        .into(),
328                });
329            }
330            if payload.contains('"') {
331                return Err(ContextualEncodeError::ContextIncompatible {
332                    strategy: "validate".into(),
333                    context,
334                    reason: "HTML attribute contains unescaped `\"` — attribute breakout".into(),
335                });
336            }
337            if payload.contains('\'') {
338                return Err(ContextualEncodeError::ContextIncompatible {
339                    strategy: "validate".into(),
340                    context,
341                    reason: "HTML attribute contains unescaped `'` — single-quoted attr breakout"
342                        .into(),
343                });
344            }
345            reject_unescaped_ampersand(payload, context)?;
346        }
347        InjectionContext::HtmlText => {
348            if payload.contains('<') {
349                return Err(ContextualEncodeError::ContextIncompatible {
350                    strategy: "validate".into(),
351                    context,
352                    reason: "HTML text contains unescaped `<` — would start a tag".into(),
353                });
354            }
355            reject_unescaped_ampersand(payload, context)?;
356        }
357        InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
358            // URL components are validated by percent-encoding step later;
359            // raw payload can contain any bytes here.
360        }
361        InjectionContext::HeaderValue => {
362            // Header values are validated by the header obfuscation layer;
363            // CRLF injection is guarded at the transport level.
364        }
365        InjectionContext::CookieValue => {
366            // Cookie values accept most printable ASCII; validation is
367            // handled by the cookie encoding layer.
368        }
369        InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
370            // Multipart boundaries are managed by the form encoder;
371            // individual field values have no additional constraints.
372        }
373        // InjectionContext is #[non_exhaustive]; future variants default to
374        // no validation until explicit rules are added.
375        _ => {}
376    }
377    Ok(())
378}
379
380/// Returns Err if `payload` contains an `&` that is NOT the start of a
381/// well-formed entity reference (`&name;`, `&#nnn;`, or `&#xHHH;`).
382///
383/// This is the cheap cousin of an HTML5 entity validator — it doesn't
384/// know which named entities are real (`&copy;` vs `&xyz;`), but it
385/// does enforce the lexical shape so a stray `&` cannot ride through
386/// `validate_in_context` for HTML/XML contexts.
387fn reject_unescaped_ampersand(
388    payload: &str,
389    context: InjectionContext,
390) -> Result<(), ContextualEncodeError> {
391    let bytes = payload.as_bytes();
392    let mut i = 0;
393    while i < bytes.len() {
394        if bytes[i] != b'&' {
395            i += 1;
396            continue;
397        }
398        // Walk forward to find the terminating `;` within a bounded
399        // window — real entities are short (max ~12 chars including
400        // the `;`). If we don't find one, the `&` is unescaped.
401        let mut j = i + 1;
402        let max = (i + 12).min(bytes.len());
403        let mut saw_semicolon = false;
404        let mut valid_shape = true;
405        let first = bytes.get(j).copied();
406        if first == Some(b'#') {
407            j += 1;
408            let hex = bytes.get(j).copied() == Some(b'x') || bytes.get(j).copied() == Some(b'X');
409            if hex {
410                j += 1;
411            }
412            let mut digit_count = 0;
413            while j < max {
414                let b = bytes[j];
415                if b == b';' {
416                    saw_semicolon = true;
417                    j += 1;
418                    break;
419                }
420                let ok = if hex { b.is_ascii_hexdigit() } else { b.is_ascii_digit() };
421                if !ok {
422                    valid_shape = false;
423                    break;
424                }
425                digit_count += 1;
426                j += 1;
427            }
428            if digit_count == 0 {
429                valid_shape = false;
430            }
431        } else if let Some(b) = first {
432            if b.is_ascii_alphabetic() {
433                while j < max {
434                    let b = bytes[j];
435                    if b == b';' {
436                        saw_semicolon = true;
437                        j += 1;
438                        break;
439                    }
440                    if !b.is_ascii_alphanumeric() {
441                        valid_shape = false;
442                        break;
443                    }
444                    j += 1;
445                }
446            } else {
447                valid_shape = false;
448            }
449        } else {
450            valid_shape = false;
451        }
452        if !valid_shape || !saw_semicolon {
453            return Err(ContextualEncodeError::ContextIncompatible {
454                strategy: "validate".into(),
455                context,
456                reason: format!("unescaped `&` at byte {i} (no entity reference follows)"),
457            });
458        }
459        i = j;
460    }
461    Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::encoding::Strategy;
468
469    #[test]
470    fn encode_error_mapping_payload_too_large() {
471        // PayloadTooLarge from encode maps to PayloadTooLarge contextual error
472        // We can't easily trigger this from encode(), but we verify the error path
473        // by checking that InvalidUtf8 is only returned for actual UTF-8 errors
474        let result = encode_in_context(
475            b"\x80",
476            Strategy::CaseAlternation,
477            InjectionContext::PlainBody,
478        );
479        // \x80 alone is invalid UTF-8, so encode should return InvalidUtf8
480        assert!(result.is_err());
481        let err = result.unwrap_err();
482        assert!(
483            err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
484            "error should mention invalid UTF-8, got: {err}"
485        );
486    }
487
488    #[test]
489    fn json_string_validates_unescaped_quote() {
490        let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
491        assert!(err.to_string().contains("unescaped double quote"));
492    }
493
494    #[test]
495    fn json_string_validates_valid_escapes() {
496        assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
497        assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
498        assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
499        assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
500    }
501
502    #[test]
503    fn json_string_validates_unicode_escape() {
504        // Valid \u00e4
505        assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
506        // Invalid \u00g4 (non-hex)
507        let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
508        assert!(err.to_string().contains("invalid Unicode escape"));
509        // Too short \u00
510        let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
511        assert!(err.to_string().contains("invalid Unicode escape"));
512    }
513
514    #[test]
515    fn json_string_validates_invalid_escape() {
516        let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
517        assert!(err.to_string().contains("invalid JSON escape"));
518    }
519
520    #[test]
521    fn json_string_validates_trailing_backslash() {
522        let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
523        assert!(err.to_string().contains("trailing backslash"));
524    }
525
526    #[test]
527    fn xml_attribute_validates_unescaped_quote() {
528        let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
529        assert!(err.to_string().contains("unescaped double quote"));
530    }
531
532    #[test]
533    fn xml_attribute_allows_escaped_quote() {
534        // &quot; should be allowed (the validator doesn't fully validate entities,
535        // but it shouldn't error on well-formed entity references)
536        assert!(validate_in_context("hello&quot;world", InjectionContext::XmlAttribute).is_ok());
537    }
538
539    #[test]
540    fn header_value_validates_crlf() {
541        let err = encode_in_context(
542            b"hello\r\nworld",
543            Strategy::CaseAlternation,
544            InjectionContext::HeaderValue,
545        )
546        .unwrap_err();
547        assert!(err.to_string().contains("CR/LF"));
548    }
549
550    #[test]
551    fn cookie_value_escapes_crlf() {
552        let out = encode_in_context(
553            b"hello\r\nworld",
554            Strategy::CaseAlternation,
555            InjectionContext::CookieValue,
556        )
557        .unwrap();
558        assert!(out.contains("%0D") && out.contains("%0A"));
559    }
560
561    #[test]
562    fn multipart_field_validates_crlf() {
563        let err = encode_in_context(
564            b"hello\r\nworld",
565            Strategy::CaseAlternation,
566            InjectionContext::MultipartField,
567        )
568        .unwrap_err();
569        assert!(err.to_string().contains("CR/LF"));
570    }
571
572    #[test]
573    fn html_attribute_escapes_ampersand() {
574        let out = encode_in_context(
575            b"a&b",
576            Strategy::CaseAlternation,
577            InjectionContext::HtmlAttribute,
578        )
579        .unwrap();
580        assert!(out.contains("&amp;"));
581    }
582
583    #[test]
584    fn url_query_escapes_space() {
585        let out = encode_in_context(
586            b"hello world",
587            Strategy::CaseAlternation,
588            InjectionContext::UrlQuery,
589        )
590        .unwrap();
591        assert!(!out.contains(' '));
592    }
593
594    #[test]
595    fn url_path_preserves_slash() {
596        let out = encode_in_context(
597            b"/api/v1",
598            Strategy::CaseAlternation,
599            InjectionContext::UrlPath,
600        )
601        .unwrap();
602        assert!(out.contains('/'));
603    }
604
605    #[test]
606    fn plain_body_no_structural_escaping() {
607        // PlainBody doesn't add structural escaping, but the strategy still mutates
608        let out = encode_in_context(
609            b"<script>",
610            Strategy::CaseAlternation,
611            InjectionContext::PlainBody,
612        )
613        .unwrap();
614        assert_eq!(out, "<ScRiPt>");
615    }
616
617    #[test]
618    fn max_size_enforced() {
619        let big = vec![b'a'; 8 * 1024 * 1024 + 1];
620        let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
621            .unwrap_err();
622        assert!(err.to_string().contains("too large"));
623    }
624
625    #[test]
626    fn xml_cdata_rejects_termination_sequence() {
627        let err = encode_in_context(
628            b"hello]]>world",
629            Strategy::CaseAlternation,
630            InjectionContext::XmlCdata,
631        )
632        .unwrap_err();
633        assert!(err.to_string().contains("CDATA"));
634    }
635
636    #[test]
637    fn multipart_filename_rejects_quote() {
638        let err = encode_in_context(
639            b"file\"name.txt",
640            Strategy::CaseAlternation,
641            InjectionContext::MultipartFileName,
642        )
643        .unwrap_err();
644        assert!(err.to_string().contains("quote"));
645    }
646
647    #[test]
648    fn json_number_rejects_non_numeric() {
649        let err = encode_in_context(
650            b"abc",
651            Strategy::CaseAlternation,
652            InjectionContext::JsonNumber,
653        )
654        .unwrap_err();
655        assert!(err.to_string().contains("not a valid JSON number"));
656    }
657
658    #[test]
659    fn empty_payload_valid_in_all_contexts() {
660        for ctx in [
661            InjectionContext::PlainBody,
662            InjectionContext::JsonString,
663            InjectionContext::XmlAttribute,
664            InjectionContext::HeaderValue,
665            InjectionContext::CookieValue,
666        ] {
667            assert!(
668                encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
669                "empty payload should be valid in {ctx:?}"
670            );
671        }
672    }
673}