Skip to main content

cloakrs_adapters/
sql.rs

1//! SQL dump adapter for masking quoted values in `INSERT` statements.
2
3use cloakrs_core::{PiiEntity, Result, Scanner};
4use serde::{Deserialize, Serialize};
5use std::io::{BufRead, BufReader, Read, Write};
6
7/// PII findings for one quoted SQL string value.
8///
9/// Spans are byte offsets relative to the unescaped SQL string value.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct SqlValueScanResult {
12    /// One-based `INSERT` statement number.
13    pub statement_number: usize,
14    /// Zero-based quoted string value index within that statement.
15    pub value_index: usize,
16    /// Findings detected in this string value.
17    pub findings: Vec<PiiEntity>,
18    /// Masked value when the scanner has masking enabled.
19    pub masked_value: Option<String>,
20}
21
22/// Result of scanning a SQL dump.
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub struct SqlScanResult {
25    /// Findings grouped by SQL quoted string value.
26    pub values: Vec<SqlValueScanResult>,
27    /// SQL output with quoted value contents masked.
28    pub masked_sql: String,
29}
30
31/// Scans SQL dump text and masks PII inside quoted `INSERT ... VALUES` strings.
32///
33/// # Examples
34///
35/// ```
36/// use cloakrs_adapters::scan_sql_str;
37/// use cloakrs_core::{Confidence, EntityType, Locale, PiiEntity, Recognizer, Scanner, Span};
38///
39/// struct Email;
40/// impl Recognizer for Email {
41///     fn id(&self) -> &str { "email_test" }
42///     fn entity_type(&self) -> EntityType { EntityType::Email }
43///     fn supported_locales(&self) -> &[Locale] { &[] }
44///     fn scan(&self, text: &str) -> Vec<PiiEntity> {
45///         text.find('@').map(|_| PiiEntity {
46///             entity_type: EntityType::Email,
47///             span: Span::new(0, text.len()),
48///             text: text.to_string(),
49///             confidence: Confidence::new(0.9).unwrap(),
50///             recognizer_id: self.id().to_string(),
51///         }).into_iter().collect()
52///     }
53/// }
54///
55/// let scanner = Scanner::builder().recognizer(Email).build().unwrap();
56/// let result = scan_sql_str("INSERT INTO users VALUES ('a@test');", &scanner).unwrap();
57/// assert!(result.masked_sql.contains("[EMAIL]"));
58/// ```
59pub fn scan_sql_str(input: &str, scanner: &Scanner) -> Result<SqlScanResult> {
60    let mut output = Vec::new();
61    let values = mask_sql_reader(input.as_bytes(), &mut output, scanner)?;
62    let masked_sql = String::from_utf8(output)
63        .map_err(|error| cloakrs_core::CloakError::ConfigError(error.to_string()))?;
64    Ok(SqlScanResult { values, masked_sql })
65}
66
67/// Streams SQL dump text from a reader to a writer while masking quoted
68/// strings in `INSERT ... VALUES` statements.
69pub fn mask_sql_reader<R, W>(
70    reader: R,
71    mut writer: W,
72    scanner: &Scanner,
73) -> Result<Vec<SqlValueScanResult>>
74where
75    R: Read,
76    W: Write,
77{
78    let mut reader = BufReader::new(reader);
79    let mut state = SqlStreamState::default();
80    let mut values = Vec::new();
81    let mut line = String::new();
82
83    loop {
84        line.clear();
85        let bytes = reader.read_line(&mut line)?;
86        if bytes == 0 {
87            break;
88        }
89        state.process_line(&line, &mut writer, scanner, &mut values)?;
90    }
91    state.finish(&mut writer)?;
92    Ok(values)
93}
94
95#[derive(Default)]
96struct SqlStreamState {
97    in_insert_values: bool,
98    in_string: bool,
99    statement_number: usize,
100    value_index: usize,
101    token: String,
102    raw_string: String,
103}
104
105impl SqlStreamState {
106    fn process_line<W>(
107        &mut self,
108        line: &str,
109        writer: &mut W,
110        scanner: &Scanner,
111        values: &mut Vec<SqlValueScanResult>,
112    ) -> Result<()>
113    where
114        W: Write,
115    {
116        let mut chars = line.chars().peekable();
117        while let Some(ch) = chars.next() {
118            if self.in_string {
119                self.process_string_char(ch, &mut chars, writer, scanner, values)?;
120            } else {
121                self.process_sql_char(ch, writer)?;
122            }
123        }
124        Ok(())
125    }
126
127    fn process_sql_char<W>(&mut self, ch: char, writer: &mut W) -> Result<()>
128    where
129        W: Write,
130    {
131        if ch == '\'' && self.in_insert_values {
132            self.in_string = true;
133            self.raw_string.clear();
134            return Ok(());
135        }
136
137        write_char(writer, ch)?;
138        if ch == ';' {
139            self.in_insert_values = false;
140            self.token.clear();
141            return Ok(());
142        }
143        self.update_sql_state(ch);
144        Ok(())
145    }
146
147    fn process_string_char<I, W>(
148        &mut self,
149        ch: char,
150        chars: &mut std::iter::Peekable<I>,
151        writer: &mut W,
152        scanner: &Scanner,
153        values: &mut Vec<SqlValueScanResult>,
154    ) -> Result<()>
155    where
156        I: Iterator<Item = char>,
157        W: Write,
158    {
159        if ch == '\'' {
160            if matches!(chars.peek(), Some('\'')) {
161                self.raw_string.push('\'');
162                self.raw_string.push('\'');
163                let _ = chars.next();
164                return Ok(());
165            }
166            self.write_masked_string(writer, scanner, values)?;
167            return Ok(());
168        }
169
170        if ch == '\\' {
171            self.raw_string.push(ch);
172            if let Some(escaped) = chars.next() {
173                self.raw_string.push(escaped);
174            }
175            return Ok(());
176        }
177
178        self.raw_string.push(ch);
179        Ok(())
180    }
181
182    fn write_masked_string<W>(
183        &mut self,
184        writer: &mut W,
185        scanner: &Scanner,
186        values: &mut Vec<SqlValueScanResult>,
187    ) -> Result<()>
188    where
189        W: Write,
190    {
191        let unescaped = unescape_sql_string(&self.raw_string);
192        let scan = scanner.scan(&unescaped)?;
193        let masked_value = if scan.findings.is_empty() {
194            None
195        } else {
196            scan.masked_text.clone()
197        };
198        let value_to_write = masked_value
199            .as_deref()
200            .map(escape_sql_string)
201            .unwrap_or_else(|| self.raw_string.clone());
202
203        writer.write_all(b"'")?;
204        writer.write_all(value_to_write.as_bytes())?;
205        writer.write_all(b"'")?;
206
207        if !scan.findings.is_empty() {
208            values.push(SqlValueScanResult {
209                statement_number: self.statement_number,
210                value_index: self.value_index,
211                findings: scan.findings,
212                masked_value,
213            });
214        }
215        self.value_index += 1;
216        self.in_string = false;
217        self.raw_string.clear();
218        self.token.clear();
219        Ok(())
220    }
221
222    fn update_sql_state(&mut self, ch: char) {
223        if ch.is_ascii_alphanumeric() || ch == '_' {
224            self.token.push(ch.to_ascii_uppercase());
225            return;
226        }
227
228        if self.token == "INSERT" {
229            self.statement_number += 1;
230            self.value_index = 0;
231            self.in_insert_values = false;
232        } else if self.token == "VALUES" && self.statement_number > 0 {
233            self.in_insert_values = true;
234        }
235        self.token.clear();
236    }
237
238    fn finish<W>(&mut self, writer: &mut W) -> Result<()>
239    where
240        W: Write,
241    {
242        if self.in_string {
243            writer.write_all(b"'")?;
244            writer.write_all(self.raw_string.as_bytes())?;
245            self.in_string = false;
246            self.raw_string.clear();
247        }
248        Ok(())
249    }
250}
251
252fn write_char<W>(writer: &mut W, ch: char) -> Result<()>
253where
254    W: Write,
255{
256    let mut buffer = [0; 4];
257    writer.write_all(ch.encode_utf8(&mut buffer).as_bytes())?;
258    Ok(())
259}
260
261fn unescape_sql_string(value: &str) -> String {
262    let mut output = String::with_capacity(value.len());
263    let mut chars = value.chars().peekable();
264    while let Some(ch) = chars.next() {
265        if ch == '\'' && matches!(chars.peek(), Some('\'')) {
266            output.push('\'');
267            let _ = chars.next();
268        } else if ch == '\\' {
269            if let Some(next) = chars.next() {
270                output.push(next);
271            } else {
272                output.push(ch);
273            }
274        } else {
275            output.push(ch);
276        }
277    }
278    output
279}
280
281fn escape_sql_string(value: &str) -> String {
282    value.replace('\'', "''")
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use cloakrs_core::Locale;
289    use cloakrs_patterns::default_registry;
290
291    fn scanner() -> Scanner {
292        default_registry()
293            .into_scanner_builder()
294            .locale(Locale::US)
295            .build()
296            .unwrap()
297    }
298
299    #[test]
300    fn test_scan_sql_str_insert_masks_quoted_string() {
301        let sql = "INSERT INTO users (email) VALUES ('jane@example.com');";
302        let result = scan_sql_str(sql, &scanner()).unwrap();
303        assert!(result.masked_sql.contains("'[EMAIL]'"));
304        assert_eq!(result.values.len(), 1);
305    }
306
307    #[test]
308    fn test_scan_sql_str_multi_row_insert_masks_each_row() {
309        let sql = "INSERT INTO users VALUES (1,'jane@example.com'),(2,'ops@example.com');";
310        let result = scan_sql_str(sql, &scanner()).unwrap();
311        assert_eq!(result.values.len(), 2);
312        assert_eq!(result.masked_sql.matches("'[EMAIL]'").count(), 2);
313    }
314
315    #[test]
316    fn test_scan_sql_str_escaped_quote_preserves_valid_sql() {
317        let sql = "INSERT INTO users VALUES ('Jane O''Neil <jane@example.com>');";
318        let result = scan_sql_str(sql, &scanner()).unwrap();
319        assert!(result.masked_sql.contains("'Jane O''Neil <[EMAIL]>'"));
320    }
321
322    #[test]
323    fn test_scan_sql_str_non_insert_is_not_masked() {
324        let sql = "UPDATE users SET email='jane@example.com';";
325        let result = scan_sql_str(sql, &scanner()).unwrap();
326        assert_eq!(result.values.len(), 0);
327        assert_eq!(result.masked_sql, sql);
328    }
329}