1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5 payload: &[u8],
6 strategy: Strategy,
7 context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9 let max_size = match context {
10 InjectionContext::JsonString => 4 * 1024 * 1024,
11 InjectionContext::JsonNumber => 1024,
12 InjectionContext::XmlAttribute => 1024 * 1024,
13 InjectionContext::XmlCdata => 8 * 1024 * 1024,
14 InjectionContext::HeaderValue => 8 * 1024,
15 InjectionContext::CookieValue => 4 * 1024,
16 InjectionContext::MultipartFileName => 256,
17 _ => 8 * 1024 * 1024,
18 };
19
20 if payload.len() > max_size {
21 return Err(ContextualEncodeError::PayloadTooLarge {
22 context,
23 size: payload.len(),
24 max: max_size,
25 });
26 }
27
28 let base = match crate::encoding::encode(payload, strategy) {
29 Ok(s) => s,
30 Err(e) => {
31 return Err(match e {
32 crate::error::EncodeError::InvalidUtf8 => ContextualEncodeError::InvalidUtf8 { offset: 0 },
33 crate::error::EncodeError::PayloadTooLarge { max, actual } => {
34 ContextualEncodeError::PayloadTooLarge {
35 context,
36 size: actual,
37 max,
38 }
39 }
40 crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
41 ContextualEncodeError::PayloadTooLarge {
42 context,
43 size: actual,
44 max,
45 }
46 }
47 crate::error::EncodeError::InvalidContext { strategy: s, context: _ } => {
48 ContextualEncodeError::ContextIncompatible {
49 strategy: s.into(),
50 context,
51 reason: "strategy invalid for context".into(),
52 }
53 }
54 crate::error::EncodeError::InvalidConfig(msg) => {
55 ContextualEncodeError::ContextIncompatible {
56 strategy: "config".into(),
57 context,
58 reason: msg,
59 }
60 }
61 });
62 }
63 };
64
65 escape_for_context(&base, context)
66}
67
68pub fn escape_for_context(
69 input: &str,
70 context: InjectionContext,
71) -> Result<String, ContextualEncodeError> {
72 let escaped = match context {
73 InjectionContext::JsonString => {
74 let mut s = String::with_capacity(input.len() + 10);
75 for c in input.chars() {
76 match c {
77 '\\' => s.push_str("\\\\"),
78 '"' => s.push_str("\\\""),
79 '\n' => s.push_str("\\n"),
80 '\r' => s.push_str("\\r"),
81 '\t' => s.push_str("\\t"),
82 '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
83 _ => s.push(c),
84 }
85 }
86 s
87 }
88 InjectionContext::JsonNumber => {
89 if input.chars().any(|c| !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+') {
90 return Err(ContextualEncodeError::ContextIncompatible {
91 strategy: "escape".into(),
92 context,
93 reason: "not a valid JSON number".into(),
94 });
95 }
96 input.to_string()
97 }
98 InjectionContext::XmlAttribute => {
99 if input.contains('\x00') {
100 return Err(ContextualEncodeError::ContextIncompatible {
101 strategy: "escape".into(),
102 context,
103 reason: "null byte in xml attribute".into(),
104 });
105 }
106 input.replace('&', "&")
107 .replace('"', """)
108 .replace('<', "<")
109 .replace('>', ">")
110 }
111 InjectionContext::XmlCdata => {
112 if input.contains("]]>") {
113 return Err(ContextualEncodeError::ContextIncompatible {
114 strategy: "escape".into(),
115 context,
116 reason: "CDATA cannot contain ]]>".into(),
117 });
118 }
119 input.to_string()
120 }
121 InjectionContext::XmlText => {
122 input.replace('&', "&")
123 .replace('<', "<")
124 .replace('>', ">")
125 }
126 InjectionContext::HtmlAttribute => {
127 input.replace('&', "&")
128 .replace('"', """)
129 .replace('\'', "'")
130 .replace('<', "<")
131 }
132 InjectionContext::HtmlText => {
133 input.replace('&', "&")
134 .replace('<', "<")
135 }
136 InjectionContext::UrlQuery => {
137 urlencoding::encode(input).to_string()
138 }
139 InjectionContext::UrlPath => {
140 urlencoding::encode(input).to_string().replace("%2F", "/")
141 }
142 InjectionContext::UrlFragment => {
143 urlencoding::encode(input).to_string()
144 }
145 InjectionContext::HeaderValue => {
146 if input.contains('\r') || input.contains('\n') {
147 return Err(ContextualEncodeError::ContextIncompatible {
148 strategy: "escape".into(),
149 context,
150 reason: "CR/LF in header value".into(),
151 });
152 }
153 if input.contains('\x00') {
154 return Err(ContextualEncodeError::ContextIncompatible {
155 strategy: "escape".into(),
156 context,
157 reason: "null byte in header value".into(),
158 });
159 }
160 input.to_string()
161 }
162 InjectionContext::CookieValue => {
163 input.replace(';', "%3B")
164 .replace('=', "%3D")
165 .replace('\x00', "%00")
166 .replace('\r', "%0D")
167 .replace('\n', "%0A")
168 }
169 InjectionContext::MultipartField => {
170 if input.contains('\r') || input.contains('\n') {
171 return Err(ContextualEncodeError::ContextIncompatible {
172 strategy: "escape".into(),
173 context,
174 reason: "CR/LF would break multipart structure".into(),
175 });
176 }
177 input.to_string()
178 }
179 InjectionContext::MultipartFileName => {
180 if input.contains('"') {
181 return Err(ContextualEncodeError::ContextIncompatible {
182 strategy: "escape".into(),
183 context,
184 reason: "quote in filename".into(),
185 });
186 }
187 if input.contains('\r') || input.contains('\n') {
188 return Err(ContextualEncodeError::ContextIncompatible {
189 strategy: "escape".into(),
190 context,
191 reason: "CR/LF in filename".into(),
192 });
193 }
194 input.to_string()
195 }
196 InjectionContext::PlainBody => {
197 input.to_string()
198 }
199 _ => input.to_string(),
200 };
201 validate_in_context(&escaped, context)?;
202 Ok(escaped)
203}
204
205pub fn validate_in_context(
206 payload: &str,
207 context: InjectionContext,
208) -> Result<(), ContextualEncodeError> {
209 match context {
210 InjectionContext::JsonString => {
211 let mut chars = payload.chars().peekable();
212 while let Some(c) = chars.next() {
213 if c == '"' {
214 return Err(ContextualEncodeError::ContextIncompatible {
215 strategy: "validate".into(),
216 context,
217 reason: "unescaped double quote in JSON string".into(),
218 });
219 }
220 if c == '\\' {
221 let escaped = chars.next();
222 match escaped {
223 Some('\\') | Some('"') | Some('n') | Some('r') | Some('t')
224 | Some('b') | Some('f') | Some('/') => {}
225 Some('u') => {
226 for _ in 0..4 {
228 match chars.next() {
229 Some(c) if c.is_ascii_hexdigit() => {}
230 _ => {
231 return Err(ContextualEncodeError::ContextIncompatible {
232 strategy: "validate".into(),
233 context,
234 reason: "invalid Unicode escape in JSON string".into(),
235 });
236 }
237 }
238 }
239 }
240 Some(other) => {
241 return Err(ContextualEncodeError::ContextIncompatible {
242 strategy: "validate".into(),
243 context,
244 reason: format!("invalid JSON escape sequence: \\{other}"),
245 });
246 }
247 None => {
248 return Err(ContextualEncodeError::ContextIncompatible {
249 strategy: "validate".into(),
250 context,
251 reason: "trailing backslash in JSON string".into(),
252 });
253 }
254 }
255 }
256 }
257 }
258 InjectionContext::XmlAttribute => {
259 let mut chars = payload.chars();
260 while let Some(c) = chars.next() {
261 if c == '"' {
262 return Err(ContextualEncodeError::ContextIncompatible {
263 strategy: "validate".into(),
264 context,
265 reason: "unescaped double quote in XML attribute".into(),
266 });
267 }
268 if c == '&' {
269 let remainder: String = chars.by_ref().take(6).collect();
271 if !remainder.starts_with("quot;")
272 && !remainder.starts_with("amp;")
273 && !remainder.starts_with("lt;")
274 && !remainder.starts_with("gt;")
275 {
276 }
280 }
281 }
282 }
283 _ => {}
284 }
285 Ok(())
286}