pgdump_toc_rewrite/
rewrite_sql.rs

1/*
2 * Copyright 2023, WiltonDB Software
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::collections::HashMap;
18
19use sqlparser::dialect::GenericDialect;
20use sqlparser::tokenizer::Token;
21use sqlparser::tokenizer::Tokenizer;
22use sqlparser::tokenizer::TokenWithLocation;
23
24use crate::toc_error::TocError;
25
26
27fn location_to_idx(lines: &Vec<&str>, twl: &TokenWithLocation) -> usize {
28    let TokenWithLocation{ token, location } = twl;
29    let mut res = 0usize;
30    for i in 0..location.line - 1 {
31        res += lines[i as usize].chars().count();
32    }
33    res += (location.line - 1) as usize; // EOLs
34    res += (location.column - 1) as usize;
35    if let Token::Word(word) = token {
36        if word.quote_style.is_some() {
37            res += 1;
38        }
39    } else if let Token::SingleQuotedString(_) = token {
40        res += 1;
41    }
42    res
43}
44
45fn rewrite_schema_in_sql_internal(schemas: &HashMap<String, String>,
46                                  sql: &str,
47                                  qualified_only: bool,
48                                  single_quoted_only: bool
49) -> Result<String, TocError> {
50    let dialect = GenericDialect {};
51    let lines: Vec<&str> = sql.split('\n').collect();
52    let tokens = match Tokenizer::new(&dialect, sql).tokenize_with_location() {
53        Ok(tokens) => tokens,
54        Err(e) => return Err(TocError::new(&format!(
55            "Tokenizer error: {}, sql: {}", e, sql)))
56    };
57    let mut to_replace: Vec<(&str, &str, usize)> = Vec::new();
58    for i in 0..tokens.len() {
59        if qualified_only && !single_quoted_only {
60            if i >= tokens.len() - 1 {
61                continue;
62            }
63            let TokenWithLocation{ token, .. } = &tokens[i + 1];
64            if let Token::Period = token {
65                // success
66            } else {
67                continue;
68            }
69        }
70        let twl = &tokens[i];
71        let loc_idx = location_to_idx(&lines, twl);
72        let TokenWithLocation{ token, .. } = twl;
73        if single_quoted_only {
74            if let Token::SingleQuotedString(st) = token {
75                let old_schema = if qualified_only {
76                    let idx = st.find('.').ok_or(TocError::new(&format!(
77                        "Unexpected unqualified single-quoted entry: {}", st)))?;
78                    &st[..idx]
79                } else {
80                    st
81                };
82                if let Some(schema) = schemas.get(old_schema) {
83                    to_replace.push((old_schema, schema, loc_idx));
84                }
85            }
86        } else {
87            if let Token::Word(word) = token {
88                if let Some(schema) = schemas.get(&word.value) {
89                    to_replace.push((&word.value, schema, loc_idx));
90                }
91            }
92        }
93    }
94
95    let orig: Vec<char> = sql.chars().collect();
96    let mut rewritten: Vec<char> = Vec::new();
97    let mut last_idx = 0;
98    for (schema_orig, schema_replaced, start_idx) in to_replace {
99        for i in last_idx..start_idx {
100            rewritten.push(orig[i]);
101        }
102        for ch in schema_replaced.chars() {
103            rewritten.push(ch);
104        }
105        let orig_check: String = orig.iter().skip(start_idx).take(schema_orig.chars().count()).collect();
106        if orig_check != *schema_orig {
107            return Err(TocError::new(&format!(
108                "Replace error, sql: {}, location: {}", sql, start_idx)))
109        }
110        last_idx = start_idx + schema_orig.chars().count();
111    }
112
113    // tail
114    for i in last_idx..orig.len() {
115        rewritten.push(orig[i]);
116    }
117
118    let res: String = rewritten.into_iter().collect();
119    Ok(res)
120}
121
122pub fn rewrite_schema_in_sql(schemas: &HashMap<String, String>, sql: &str) -> Result<String, TocError> {
123    rewrite_schema_in_sql_internal(schemas, sql, true, false)
124}
125
126pub fn rewrite_schema_in_sql_unqualified(schemas: &HashMap<String, String>, sql: &str) -> Result<String, TocError> {
127    rewrite_schema_in_sql_internal(schemas, sql, false, false)
128}
129
130pub fn rewrite_schema_in_sql_single_quoted(schemas: &HashMap<String, String>, sql: &str) -> Result<String, TocError> {
131    rewrite_schema_in_sql_internal(schemas, sql, false, true)
132}
133
134pub fn rewrite_schema_in_sql_qualified_single_quoted(schemas: &HashMap<String, String>, sql: &str) -> Result<String, TocError> {
135    rewrite_schema_in_sql_internal(schemas, sql, true, true)
136}