Skip to main content

wafrift_encoding/
contextual.rs

1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5    payload: &[u8],
6    strategy: Strategy,
7    context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9    let max_size = match context {
10        InjectionContext::JsonString => 4 * 1024 * 1024,
11        InjectionContext::JsonNumber => 1024,
12        InjectionContext::XmlAttribute => 1024 * 1024,
13        InjectionContext::XmlCdata => 8 * 1024 * 1024,
14        InjectionContext::HeaderValue => 8 * 1024,
15        InjectionContext::CookieValue => 4 * 1024,
16        InjectionContext::MultipartFileName => 256,
17        _ => 8 * 1024 * 1024,
18    };
19
20    if payload.len() > max_size {
21        return Err(ContextualEncodeError::PayloadTooLarge {
22            context,
23            size: payload.len(),
24            max: max_size,
25        });
26    }
27
28    let base = match crate::encoding::encode(payload, strategy) {
29        Ok(s) => s,
30        Err(e) => {
31            return Err(match e {
32                crate::error::EncodeError::InvalidUtf8 => {
33                    ContextualEncodeError::InvalidUtf8 { offset: 0 }
34                }
35                crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36                    ContextualEncodeError::PayloadTooLarge {
37                        context,
38                        size: actual,
39                        max,
40                    }
41                }
42                crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43                    ContextualEncodeError::PayloadTooLarge {
44                        context,
45                        size: actual,
46                        max,
47                    }
48                }
49                crate::error::EncodeError::InvalidContext {
50                    strategy: s,
51                    context: _,
52                } => ContextualEncodeError::ContextIncompatible {
53                    strategy: s.into(),
54                    context,
55                    reason: "strategy invalid for context".into(),
56                },
57                crate::error::EncodeError::InvalidConfig(msg) => {
58                    ContextualEncodeError::ContextIncompatible {
59                        strategy: "config".into(),
60                        context,
61                        reason: msg,
62                    }
63                }
64            });
65        }
66    };
67
68    escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72    input: &str,
73    context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75    let escaped = match context {
76        InjectionContext::JsonString => {
77            let mut s = String::with_capacity(input.len() + 10);
78            for c in input.chars() {
79                match c {
80                    '\\' => s.push_str("\\\\"),
81                    '"' => s.push_str("\\\""),
82                    '\n' => s.push_str("\\n"),
83                    '\r' => s.push_str("\\r"),
84                    '\t' => s.push_str("\\t"),
85                    '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86                    // U+2028 LINE SEPARATOR and U+2029 PARAGRAPH SEPARATOR
87                    // are valid in JSON strings per RFC 8259 but are line
88                    // terminators in legacy ECMAScript / JSONP / eval
89                    // contexts. Pre-fix a payload-controlled value with
90                    // U+2028 inlined into <script>JSON</script> would
91                    // close the string literal and inject script. Escape
92                    // both for defence-in-depth even when shipping pure
93                    // JSON over the wire.
94                    '\u{2028}' => s.push_str("\\u2028"),
95                    '\u{2029}' => s.push_str("\\u2029"),
96                    _ => s.push(c),
97                }
98            }
99            s
100        }
101        InjectionContext::JsonNumber => {
102            if input.chars().any(|c| {
103                !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
104            }) {
105                return Err(ContextualEncodeError::ContextIncompatible {
106                    strategy: "escape".into(),
107                    context,
108                    reason: "not a valid JSON number".into(),
109                });
110            }
111            input.to_string()
112        }
113        InjectionContext::XmlAttribute => {
114            if input.contains('\x00') {
115                return Err(ContextualEncodeError::ContextIncompatible {
116                    strategy: "escape".into(),
117                    context,
118                    reason: "null byte in xml attribute".into(),
119                });
120            }
121            // XML allows single-quoted attributes; pre-fix only escaped
122            // `&"<>` and a payload with `'` would break out of an
123            // `<elem attr='...'>` form. Add &apos; escape.
124            input
125                .replace('&', "&amp;")
126                .replace('"', "&quot;")
127                .replace('\'', "&apos;")
128                .replace('<', "&lt;")
129                .replace('>', "&gt;")
130        }
131        InjectionContext::XmlCdata => {
132            if input.contains("]]>") {
133                return Err(ContextualEncodeError::ContextIncompatible {
134                    strategy: "escape".into(),
135                    context,
136                    reason: "CDATA cannot contain ]]>".into(),
137                });
138            }
139            input.to_string()
140        }
141        InjectionContext::XmlText => input
142            .replace('&', "&amp;")
143            .replace('<', "&lt;")
144            .replace('>', "&gt;"),
145        InjectionContext::HtmlAttribute => input
146            .replace('&', "&amp;")
147            .replace('"', "&quot;")
148            .replace('\'', "&#x27;")
149            .replace('<', "&lt;"),
150        InjectionContext::HtmlText => input.replace('&', "&amp;").replace('<', "&lt;"),
151        InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
152        InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
153        InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
154        InjectionContext::HeaderValue => {
155            if input.contains('\r') || input.contains('\n') {
156                return Err(ContextualEncodeError::ContextIncompatible {
157                    strategy: "escape".into(),
158                    context,
159                    reason: "CR/LF in header value".into(),
160                });
161            }
162            if input.contains('\x00') {
163                return Err(ContextualEncodeError::ContextIncompatible {
164                    strategy: "escape".into(),
165                    context,
166                    reason: "null byte in header value".into(),
167                });
168            }
169            input.to_string()
170        }
171        InjectionContext::CookieValue => input
172            // RFC 6265 §4.1.1 cookie-octet excludes space, ",", '"', `\\`
173            // in addition to ; = CTLs. Pre-fix the missing chars caused
174            // Chrome / Firefox / curl to truncate the cookie at the
175            // offending byte — making bypass probes silently lie about
176            // the value that actually reached the server.
177            .replace(';', "%3B")
178            .replace('=', "%3D")
179            .replace(' ', "%20")
180            .replace(',', "%2C")
181            .replace('"', "%22")
182            .replace('\\', "%5C")
183            .replace('\x00', "%00")
184            .replace('\r', "%0D")
185            .replace('\n', "%0A"),
186        InjectionContext::MultipartField => {
187            if input.contains('\r') || input.contains('\n') {
188                return Err(ContextualEncodeError::ContextIncompatible {
189                    strategy: "escape".into(),
190                    context,
191                    reason: "CR/LF would break multipart structure".into(),
192                });
193            }
194            input.to_string()
195        }
196        InjectionContext::MultipartFileName => {
197            if input.contains('"') {
198                return Err(ContextualEncodeError::ContextIncompatible {
199                    strategy: "escape".into(),
200                    context,
201                    reason: "quote in filename".into(),
202                });
203            }
204            if input.contains('\r') || input.contains('\n') {
205                return Err(ContextualEncodeError::ContextIncompatible {
206                    strategy: "escape".into(),
207                    context,
208                    reason: "CR/LF in filename".into(),
209                });
210            }
211            input.to_string()
212        }
213        InjectionContext::PlainBody => input.to_string(),
214        _ => input.to_string(),
215    };
216    validate_in_context(&escaped, context)?;
217    Ok(escaped)
218}
219
220pub fn validate_in_context(
221    payload: &str,
222    context: InjectionContext,
223) -> Result<(), ContextualEncodeError> {
224    match context {
225        InjectionContext::JsonString => {
226            let mut chars = payload.chars().peekable();
227            while let Some(c) = chars.next() {
228                if c == '"' {
229                    return Err(ContextualEncodeError::ContextIncompatible {
230                        strategy: "validate".into(),
231                        context,
232                        reason: "unescaped double quote in JSON string".into(),
233                    });
234                }
235                if c == '\\' {
236                    let escaped = chars.next();
237                    match escaped {
238                        Some('\\') | Some('"') | Some('n') | Some('r') | Some('t') | Some('b')
239                        | Some('f') | Some('/') => {}
240                        Some('u') => {
241                            // Validate exactly 4 hex digits after \u
242                            for _ in 0..4 {
243                                match chars.next() {
244                                    Some(c) if c.is_ascii_hexdigit() => {}
245                                    _ => {
246                                        return Err(ContextualEncodeError::ContextIncompatible {
247                                            strategy: "validate".into(),
248                                            context,
249                                            reason: "invalid Unicode escape in JSON string".into(),
250                                        });
251                                    }
252                                }
253                            }
254                        }
255                        Some(other) => {
256                            return Err(ContextualEncodeError::ContextIncompatible {
257                                strategy: "validate".into(),
258                                context,
259                                reason: format!("invalid JSON escape sequence: \\{other}"),
260                            });
261                        }
262                        None => {
263                            return Err(ContextualEncodeError::ContextIncompatible {
264                                strategy: "validate".into(),
265                                context,
266                                reason: "trailing backslash in JSON string".into(),
267                            });
268                        }
269                    }
270                }
271            }
272        }
273        InjectionContext::XmlAttribute => {
274            let mut chars = payload.chars();
275            while let Some(c) = chars.next() {
276                if c == '"' {
277                    return Err(ContextualEncodeError::ContextIncompatible {
278                        strategy: "validate".into(),
279                        context,
280                        reason: "unescaped double quote in XML attribute".into(),
281                    });
282                }
283                if c == '&' {
284                    // Allow known entity references; anything else starting with & is suspicious
285                    let remainder: String = chars.by_ref().take(6).collect();
286                    if !remainder.starts_with("quot;")
287                        && !remainder.starts_with("amp;")
288                        && !remainder.starts_with("lt;")
289                        && !remainder.starts_with("gt;")
290                    {
291                        // Not a known entity — could be an unescaped &
292                        // (We keep scanning rather than erroring, since & alone
293                        // is technically valid XML text if followed by whitespace.)
294                    }
295                }
296            }
297        }
298        // Contexts below have no validation rules yet. Adding an explicit
299        // arm for each ensures the compiler warns us when a new variant is
300        // added so we can decide whether it needs validation.
301        InjectionContext::PlainBody => {
302            // Plain body accepts any byte sequence; nothing to validate.
303        }
304        InjectionContext::XmlCdata
305            if payload.contains("]]>") => {
306                return Err(ContextualEncodeError::ContextIncompatible {
307                    strategy: "validate".into(),
308                    context,
309                    reason: "CDATA payload contains `]]>` (unterminated section)".into(),
310                });
311            }
312        InjectionContext::XmlText => {
313            if payload.contains('<') {
314                return Err(ContextualEncodeError::ContextIncompatible {
315                    strategy: "validate".into(),
316                    context,
317                    reason: "XML text payload contains unescaped `<`".into(),
318                });
319            }
320            reject_unescaped_ampersand(payload, context)?;
321        }
322        InjectionContext::HtmlAttribute => {
323            if payload.contains('<') {
324                return Err(ContextualEncodeError::ContextIncompatible {
325                    strategy: "validate".into(),
326                    context,
327                    reason: "HTML attribute contains unescaped `<` — would close the attribute"
328                        .into(),
329                });
330            }
331            if payload.contains('"') {
332                return Err(ContextualEncodeError::ContextIncompatible {
333                    strategy: "validate".into(),
334                    context,
335                    reason: "HTML attribute contains unescaped `\"` — attribute breakout".into(),
336                });
337            }
338            if payload.contains('\'') {
339                return Err(ContextualEncodeError::ContextIncompatible {
340                    strategy: "validate".into(),
341                    context,
342                    reason: "HTML attribute contains unescaped `'` — single-quoted attr breakout"
343                        .into(),
344                });
345            }
346            reject_unescaped_ampersand(payload, context)?;
347        }
348        InjectionContext::HtmlText => {
349            if payload.contains('<') {
350                return Err(ContextualEncodeError::ContextIncompatible {
351                    strategy: "validate".into(),
352                    context,
353                    reason: "HTML text contains unescaped `<` — would start a tag".into(),
354                });
355            }
356            reject_unescaped_ampersand(payload, context)?;
357        }
358        InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
359            // URL components are validated by percent-encoding step later;
360            // raw payload can contain any bytes here.
361        }
362        InjectionContext::HeaderValue => {
363            // Header values are validated by the header obfuscation layer;
364            // CRLF injection is guarded at the transport level.
365        }
366        InjectionContext::CookieValue => {
367            // Cookie values accept most printable ASCII; validation is
368            // handled by the cookie encoding layer.
369        }
370        InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
371            // Multipart boundaries are managed by the form encoder;
372            // individual field values have no additional constraints.
373        }
374        // InjectionContext is #[non_exhaustive]; future variants default to
375        // no validation until explicit rules are added.
376        _ => {}
377    }
378    Ok(())
379}
380
381/// Returns Err if `payload` contains an `&` that is NOT the start of a
382/// well-formed entity reference (`&name;`, `&#nnn;`, or `&#xHHH;`).
383///
384/// This is the cheap cousin of an HTML5 entity validator — it doesn't
385/// know which named entities are real (`&copy;` vs `&xyz;`), but it
386/// does enforce the lexical shape so a stray `&` cannot ride through
387/// `validate_in_context` for HTML/XML contexts.
388fn reject_unescaped_ampersand(
389    payload: &str,
390    context: InjectionContext,
391) -> Result<(), ContextualEncodeError> {
392    let bytes = payload.as_bytes();
393    let mut i = 0;
394    while i < bytes.len() {
395        if bytes[i] != b'&' {
396            i += 1;
397            continue;
398        }
399        // Walk forward to find the terminating `;` within a bounded
400        // window — real entities are short (max ~12 chars including
401        // the `;`). If we don't find one, the `&` is unescaped.
402        let mut j = i + 1;
403        let max = (i + 12).min(bytes.len());
404        let mut saw_semicolon = false;
405        let mut valid_shape = true;
406        let first = bytes.get(j).copied();
407        if first == Some(b'#') {
408            j += 1;
409            let hex = bytes.get(j).copied() == Some(b'x') || bytes.get(j).copied() == Some(b'X');
410            if hex {
411                j += 1;
412            }
413            let mut digit_count = 0;
414            while j < max {
415                let b = bytes[j];
416                if b == b';' {
417                    saw_semicolon = true;
418                    j += 1;
419                    break;
420                }
421                let ok = if hex { b.is_ascii_hexdigit() } else { b.is_ascii_digit() };
422                if !ok {
423                    valid_shape = false;
424                    break;
425                }
426                digit_count += 1;
427                j += 1;
428            }
429            if digit_count == 0 {
430                valid_shape = false;
431            }
432        } else if let Some(b) = first {
433            if !b.is_ascii_alphabetic() {
434                valid_shape = false;
435            } else {
436                while j < max {
437                    let b = bytes[j];
438                    if b == b';' {
439                        saw_semicolon = true;
440                        j += 1;
441                        break;
442                    }
443                    if !b.is_ascii_alphanumeric() {
444                        valid_shape = false;
445                        break;
446                    }
447                    j += 1;
448                }
449            }
450        } else {
451            valid_shape = false;
452        }
453        if !valid_shape || !saw_semicolon {
454            return Err(ContextualEncodeError::ContextIncompatible {
455                strategy: "validate".into(),
456                context,
457                reason: format!("unescaped `&` at byte {i} (no entity reference follows)"),
458            });
459        }
460        i = j;
461    }
462    Ok(())
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::encoding::Strategy;
469
470    #[test]
471    fn encode_error_mapping_payload_too_large() {
472        // PayloadTooLarge from encode maps to PayloadTooLarge contextual error
473        // We can't easily trigger this from encode(), but we verify the error path
474        // by checking that InvalidUtf8 is only returned for actual UTF-8 errors
475        let result = encode_in_context(
476            b"\x80",
477            Strategy::CaseAlternation,
478            InjectionContext::PlainBody,
479        );
480        // \x80 alone is invalid UTF-8, so encode should return InvalidUtf8
481        assert!(result.is_err());
482        let err = result.unwrap_err();
483        assert!(
484            err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
485            "error should mention invalid UTF-8, got: {}",
486            err
487        );
488    }
489
490    #[test]
491    fn json_string_validates_unescaped_quote() {
492        let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
493        assert!(err.to_string().contains("unescaped double quote"));
494    }
495
496    #[test]
497    fn json_string_validates_valid_escapes() {
498        assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
499        assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
500        assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
501        assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
502    }
503
504    #[test]
505    fn json_string_validates_unicode_escape() {
506        // Valid \u00e4
507        assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
508        // Invalid \u00g4 (non-hex)
509        let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
510        assert!(err.to_string().contains("invalid Unicode escape"));
511        // Too short \u00
512        let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
513        assert!(err.to_string().contains("invalid Unicode escape"));
514    }
515
516    #[test]
517    fn json_string_validates_invalid_escape() {
518        let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
519        assert!(err.to_string().contains("invalid JSON escape"));
520    }
521
522    #[test]
523    fn json_string_validates_trailing_backslash() {
524        let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
525        assert!(err.to_string().contains("trailing backslash"));
526    }
527
528    #[test]
529    fn xml_attribute_validates_unescaped_quote() {
530        let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
531        assert!(err.to_string().contains("unescaped double quote"));
532    }
533
534    #[test]
535    fn xml_attribute_allows_escaped_quote() {
536        // &quot; should be allowed (the validator doesn't fully validate entities,
537        // but it shouldn't error on well-formed entity references)
538        assert!(validate_in_context("hello&quot;world", InjectionContext::XmlAttribute).is_ok());
539    }
540
541    #[test]
542    fn header_value_validates_crlf() {
543        let err = encode_in_context(
544            b"hello\r\nworld",
545            Strategy::CaseAlternation,
546            InjectionContext::HeaderValue,
547        )
548        .unwrap_err();
549        assert!(err.to_string().contains("CR/LF"));
550    }
551
552    #[test]
553    fn cookie_value_escapes_crlf() {
554        let out = encode_in_context(
555            b"hello\r\nworld",
556            Strategy::CaseAlternation,
557            InjectionContext::CookieValue,
558        )
559        .unwrap();
560        assert!(out.contains("%0D") && out.contains("%0A"));
561    }
562
563    #[test]
564    fn multipart_field_validates_crlf() {
565        let err = encode_in_context(
566            b"hello\r\nworld",
567            Strategy::CaseAlternation,
568            InjectionContext::MultipartField,
569        )
570        .unwrap_err();
571        assert!(err.to_string().contains("CR/LF"));
572    }
573
574    #[test]
575    fn html_attribute_escapes_ampersand() {
576        let out = encode_in_context(
577            b"a&b",
578            Strategy::CaseAlternation,
579            InjectionContext::HtmlAttribute,
580        )
581        .unwrap();
582        assert!(out.contains("&amp;"));
583    }
584
585    #[test]
586    fn url_query_escapes_space() {
587        let out = encode_in_context(
588            b"hello world",
589            Strategy::CaseAlternation,
590            InjectionContext::UrlQuery,
591        )
592        .unwrap();
593        assert!(!out.contains(' '));
594    }
595
596    #[test]
597    fn url_path_preserves_slash() {
598        let out = encode_in_context(
599            b"/api/v1",
600            Strategy::CaseAlternation,
601            InjectionContext::UrlPath,
602        )
603        .unwrap();
604        assert!(out.contains('/'));
605    }
606
607    #[test]
608    fn plain_body_no_structural_escaping() {
609        // PlainBody doesn't add structural escaping, but the strategy still mutates
610        let out = encode_in_context(
611            b"<script>",
612            Strategy::CaseAlternation,
613            InjectionContext::PlainBody,
614        )
615        .unwrap();
616        assert_eq!(out, "<ScRiPt>");
617    }
618
619    #[test]
620    fn max_size_enforced() {
621        let big = vec![b'a'; 8 * 1024 * 1024 + 1];
622        let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
623            .unwrap_err();
624        assert!(err.to_string().contains("too large"));
625    }
626
627    #[test]
628    fn xml_cdata_rejects_termination_sequence() {
629        let err = encode_in_context(
630            b"hello]]>world",
631            Strategy::CaseAlternation,
632            InjectionContext::XmlCdata,
633        )
634        .unwrap_err();
635        assert!(err.to_string().contains("CDATA"));
636    }
637
638    #[test]
639    fn multipart_filename_rejects_quote() {
640        let err = encode_in_context(
641            b"file\"name.txt",
642            Strategy::CaseAlternation,
643            InjectionContext::MultipartFileName,
644        )
645        .unwrap_err();
646        assert!(err.to_string().contains("quote"));
647    }
648
649    #[test]
650    fn json_number_rejects_non_numeric() {
651        let err = encode_in_context(
652            b"abc",
653            Strategy::CaseAlternation,
654            InjectionContext::JsonNumber,
655        )
656        .unwrap_err();
657        assert!(err.to_string().contains("not a valid JSON number"));
658    }
659
660    #[test]
661    fn empty_payload_valid_in_all_contexts() {
662        for ctx in [
663            InjectionContext::PlainBody,
664            InjectionContext::JsonString,
665            InjectionContext::XmlAttribute,
666            InjectionContext::HeaderValue,
667            InjectionContext::CookieValue,
668        ] {
669            assert!(
670                encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
671                "empty payload should be valid in {ctx:?}"
672            );
673        }
674    }
675}