1use cloakrs_core::{PiiEntity, Result, Scanner};
4use serde::{Deserialize, Serialize};
5use std::io::{BufRead, BufReader, Read, Write};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct SqlValueScanResult {
12 pub statement_number: usize,
14 pub value_index: usize,
16 pub findings: Vec<PiiEntity>,
18 pub masked_value: Option<String>,
20}
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub struct SqlScanResult {
25 pub values: Vec<SqlValueScanResult>,
27 pub masked_sql: String,
29}
30
31pub fn scan_sql_str(input: &str, scanner: &Scanner) -> Result<SqlScanResult> {
60 let mut output = Vec::new();
61 let values = mask_sql_reader(input.as_bytes(), &mut output, scanner)?;
62 let masked_sql = String::from_utf8(output)
63 .map_err(|error| cloakrs_core::CloakError::ConfigError(error.to_string()))?;
64 Ok(SqlScanResult { values, masked_sql })
65}
66
67pub fn mask_sql_reader<R, W>(
70 reader: R,
71 mut writer: W,
72 scanner: &Scanner,
73) -> Result<Vec<SqlValueScanResult>>
74where
75 R: Read,
76 W: Write,
77{
78 let mut reader = BufReader::new(reader);
79 let mut state = SqlStreamState::default();
80 let mut values = Vec::new();
81 let mut line = String::new();
82
83 loop {
84 line.clear();
85 let bytes = reader.read_line(&mut line)?;
86 if bytes == 0 {
87 break;
88 }
89 state.process_line(&line, &mut writer, scanner, &mut values)?;
90 }
91 state.finish(&mut writer)?;
92 Ok(values)
93}
94
95#[derive(Default)]
96struct SqlStreamState {
97 in_insert_values: bool,
98 in_string: bool,
99 statement_number: usize,
100 value_index: usize,
101 token: String,
102 raw_string: String,
103}
104
105impl SqlStreamState {
106 fn process_line<W>(
107 &mut self,
108 line: &str,
109 writer: &mut W,
110 scanner: &Scanner,
111 values: &mut Vec<SqlValueScanResult>,
112 ) -> Result<()>
113 where
114 W: Write,
115 {
116 let mut chars = line.chars().peekable();
117 while let Some(ch) = chars.next() {
118 if self.in_string {
119 self.process_string_char(ch, &mut chars, writer, scanner, values)?;
120 } else {
121 self.process_sql_char(ch, writer)?;
122 }
123 }
124 Ok(())
125 }
126
127 fn process_sql_char<W>(&mut self, ch: char, writer: &mut W) -> Result<()>
128 where
129 W: Write,
130 {
131 if ch == '\'' && self.in_insert_values {
132 self.in_string = true;
133 self.raw_string.clear();
134 return Ok(());
135 }
136
137 write_char(writer, ch)?;
138 if ch == ';' {
139 self.in_insert_values = false;
140 self.token.clear();
141 return Ok(());
142 }
143 self.update_sql_state(ch);
144 Ok(())
145 }
146
147 fn process_string_char<I, W>(
148 &mut self,
149 ch: char,
150 chars: &mut std::iter::Peekable<I>,
151 writer: &mut W,
152 scanner: &Scanner,
153 values: &mut Vec<SqlValueScanResult>,
154 ) -> Result<()>
155 where
156 I: Iterator<Item = char>,
157 W: Write,
158 {
159 if ch == '\'' {
160 if matches!(chars.peek(), Some('\'')) {
161 self.raw_string.push('\'');
162 self.raw_string.push('\'');
163 let _ = chars.next();
164 return Ok(());
165 }
166 self.write_masked_string(writer, scanner, values)?;
167 return Ok(());
168 }
169
170 if ch == '\\' {
171 self.raw_string.push(ch);
172 if let Some(escaped) = chars.next() {
173 self.raw_string.push(escaped);
174 }
175 return Ok(());
176 }
177
178 self.raw_string.push(ch);
179 Ok(())
180 }
181
182 fn write_masked_string<W>(
183 &mut self,
184 writer: &mut W,
185 scanner: &Scanner,
186 values: &mut Vec<SqlValueScanResult>,
187 ) -> Result<()>
188 where
189 W: Write,
190 {
191 let unescaped = unescape_sql_string(&self.raw_string);
192 let scan = scanner.scan(&unescaped)?;
193 let masked_value = if scan.findings.is_empty() {
194 None
195 } else {
196 scan.masked_text.clone()
197 };
198 let value_to_write = masked_value
199 .as_deref()
200 .map(escape_sql_string)
201 .unwrap_or_else(|| self.raw_string.clone());
202
203 writer.write_all(b"'")?;
204 writer.write_all(value_to_write.as_bytes())?;
205 writer.write_all(b"'")?;
206
207 if !scan.findings.is_empty() {
208 values.push(SqlValueScanResult {
209 statement_number: self.statement_number,
210 value_index: self.value_index,
211 findings: scan.findings,
212 masked_value,
213 });
214 }
215 self.value_index += 1;
216 self.in_string = false;
217 self.raw_string.clear();
218 self.token.clear();
219 Ok(())
220 }
221
222 fn update_sql_state(&mut self, ch: char) {
223 if ch.is_ascii_alphanumeric() || ch == '_' {
224 self.token.push(ch.to_ascii_uppercase());
225 return;
226 }
227
228 if self.token == "INSERT" {
229 self.statement_number += 1;
230 self.value_index = 0;
231 self.in_insert_values = false;
232 } else if self.token == "VALUES" && self.statement_number > 0 {
233 self.in_insert_values = true;
234 }
235 self.token.clear();
236 }
237
238 fn finish<W>(&mut self, writer: &mut W) -> Result<()>
239 where
240 W: Write,
241 {
242 if self.in_string {
243 writer.write_all(b"'")?;
244 writer.write_all(self.raw_string.as_bytes())?;
245 self.in_string = false;
246 self.raw_string.clear();
247 }
248 Ok(())
249 }
250}
251
252fn write_char<W>(writer: &mut W, ch: char) -> Result<()>
253where
254 W: Write,
255{
256 let mut buffer = [0; 4];
257 writer.write_all(ch.encode_utf8(&mut buffer).as_bytes())?;
258 Ok(())
259}
260
261fn unescape_sql_string(value: &str) -> String {
262 let mut output = String::with_capacity(value.len());
263 let mut chars = value.chars().peekable();
264 while let Some(ch) = chars.next() {
265 if ch == '\'' && matches!(chars.peek(), Some('\'')) {
266 output.push('\'');
267 let _ = chars.next();
268 } else if ch == '\\' {
269 if let Some(next) = chars.next() {
270 output.push(next);
271 } else {
272 output.push(ch);
273 }
274 } else {
275 output.push(ch);
276 }
277 }
278 output
279}
280
281fn escape_sql_string(value: &str) -> String {
282 value.replace('\'', "''")
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use cloakrs_core::Locale;
289 use cloakrs_patterns::default_registry;
290
291 fn scanner() -> Scanner {
292 default_registry()
293 .into_scanner_builder()
294 .locale(Locale::US)
295 .build()
296 .unwrap()
297 }
298
299 #[test]
300 fn test_scan_sql_str_insert_masks_quoted_string() {
301 let sql = "INSERT INTO users (email) VALUES ('jane@example.com');";
302 let result = scan_sql_str(sql, &scanner()).unwrap();
303 assert!(result.masked_sql.contains("'[EMAIL]'"));
304 assert_eq!(result.values.len(), 1);
305 }
306
307 #[test]
308 fn test_scan_sql_str_multi_row_insert_masks_each_row() {
309 let sql = "INSERT INTO users VALUES (1,'jane@example.com'),(2,'ops@example.com');";
310 let result = scan_sql_str(sql, &scanner()).unwrap();
311 assert_eq!(result.values.len(), 2);
312 assert_eq!(result.masked_sql.matches("'[EMAIL]'").count(), 2);
313 }
314
315 #[test]
316 fn test_scan_sql_str_escaped_quote_preserves_valid_sql() {
317 let sql = "INSERT INTO users VALUES ('Jane O''Neil <jane@example.com>');";
318 let result = scan_sql_str(sql, &scanner()).unwrap();
319 assert!(result.masked_sql.contains("'Jane O''Neil <[EMAIL]>'"));
320 }
321
322 #[test]
323 fn test_scan_sql_str_non_insert_is_not_masked() {
324 let sql = "UPDATE users SET email='jane@example.com';";
325 let result = scan_sql_str(sql, &scanner()).unwrap();
326 assert_eq!(result.values.len(), 0);
327 assert_eq!(result.masked_sql, sql);
328 }
329}