1use alloc::format;
18use alloc::string::{String, ToString};
19use alloc::vec::Vec;
20
21#[derive(Debug, PartialEq, Eq)]
23pub struct CopyFromSpec {
24 pub table: String,
27 pub columns: Option<Vec<String>>,
31}
32
33#[must_use]
41pub fn parse_copy_from_stdin_head(sql: &str) -> Option<CopyFromSpec> {
42 let trimmed = sql.trim();
43 let lower = trimmed.to_ascii_lowercase();
44 let rest = lower.strip_prefix("copy")?;
45 if !rest.starts_with(char::is_whitespace) {
46 return None;
47 }
48 let rest_orig = &trimmed[trimmed.len() - rest.len()..];
49 let bytes = rest.as_bytes();
50 let mut i = 0;
51 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
52 i += 1;
53 }
54 let t0 = i;
56 while i < bytes.len() && !bytes[i].is_ascii_whitespace() && bytes[i] != b'(' {
57 i += 1;
58 }
59 if i == t0 {
60 return None;
61 }
62 let raw_table = &rest_orig[t0..i];
63 let table = match raw_table.rsplit_once('.') {
64 Some((_, bare)) => bare,
65 None => raw_table,
66 }
67 .trim_matches('"')
68 .to_string();
69 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
70 i += 1;
71 }
72 let mut columns = None;
74 if bytes.get(i) == Some(&b'(') {
75 let cols_start = i + 1;
76 let mut depth = 1usize;
77 i += 1;
78 while i < bytes.len() && depth > 0 {
79 match bytes[i] {
80 b'(' => depth += 1,
81 b')' => depth -= 1,
82 _ => {}
83 }
84 i += 1;
85 }
86 let cols_str = &rest_orig[cols_start..i.saturating_sub(1)];
87 columns = Some(
88 cols_str
89 .split(',')
90 .map(|c| c.trim().trim_matches('"').to_string())
91 .filter(|c| !c.is_empty())
92 .collect::<Vec<_>>(),
93 );
94 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
95 i += 1;
96 }
97 }
98 let tail = &rest[i..];
100 let tail = tail.trim_start();
101 let tail = tail.strip_prefix("from")?;
102 if !tail.starts_with(char::is_whitespace) {
103 return None;
104 }
105 let tail = tail.trim_start();
106 if !(tail == "stdin" || tail.starts_with("stdin")) {
107 return None;
108 }
109 let after = tail["stdin".len()..].trim();
110 if after.contains("format") && !after.contains("text") {
112 return None;
113 }
114 Some(CopyFromSpec { table, columns })
115}
116
117#[must_use]
120pub fn decode_copy_text_row(line: &str) -> Vec<Option<String>> {
121 line.split('\t')
122 .map(|cell| {
123 if cell == "\\N" {
124 None
125 } else {
126 let mut out = String::with_capacity(cell.len());
127 let mut chars = cell.chars();
128 while let Some(c) = chars.next() {
129 if c == '\\'
130 && let Some(n) = chars.next()
131 {
132 out.push(match n {
133 'b' => '\u{08}',
134 'f' => '\u{0c}',
135 'n' => '\n',
136 'r' => '\r',
137 't' => '\t',
138 'v' => '\u{0b}',
139 '\\' => '\\',
140 other => other,
141 });
142 } else {
143 out.push(c);
144 }
145 }
146 Some(out)
147 }
148 })
149 .collect()
150}
151
152#[must_use]
157pub fn build_copy_insert(
158 table: &str,
159 columns: Option<&[String]>,
160 values: &[Option<String>],
161) -> String {
162 let mut sql = format!("INSERT INTO {table} ");
163 if let Some(cols) = columns {
164 sql.push('(');
165 for (i, c) in cols.iter().enumerate() {
166 if i > 0 {
167 sql.push_str(", ");
168 }
169 sql.push_str(c);
170 }
171 sql.push_str(") ");
172 }
173 sql.push_str("VALUES (");
174 for (i, v) in values.iter().enumerate() {
175 if i > 0 {
176 sql.push_str(", ");
177 }
178 match v {
179 None => sql.push_str("NULL"),
180 Some(s) => {
181 if copy_cell_looks_numeric(s)
182 || matches!(s.as_str(), "true" | "false" | "TRUE" | "FALSE")
183 {
184 sql.push_str(s);
185 } else {
186 sql.push('\'');
187 for ch in s.chars() {
188 if ch == '\'' {
189 sql.push('\'');
190 }
191 sql.push(ch);
192 }
193 sql.push('\'');
194 }
195 }
196 }
197 }
198 sql.push(')');
199 sql
200}
201
202fn copy_cell_looks_numeric(s: &str) -> bool {
206 if s.is_empty() {
207 return false;
208 }
209 let b = s.as_bytes();
210 let mut i = 0;
211 if b[0] == b'-' || b[0] == b'+' {
212 if b.len() == 1 {
213 return false;
214 }
215 i = 1;
216 }
217 let mut seen_dot = false;
218 let mut seen_digit = false;
219 while i < b.len() {
220 match b[i] {
221 b'0'..=b'9' => seen_digit = true,
222 b'.' if !seen_dot => seen_dot = true,
223 _ => return false,
224 }
225 i += 1;
226 }
227 if !seen_dot && s.trim_start_matches(['-', '+']).len() > 1 {
230 let digits = s.trim_start_matches(['-', '+']);
231 if digits.starts_with('0') {
232 return false;
233 }
234 }
235 seen_digit
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use alloc::string::ToString;
242 use alloc::vec;
243
244 #[test]
245 fn parses_pg_dump_copy_head() {
246 let spec =
247 parse_copy_from_stdin_head("COPY public.messages (id, subject, body) FROM stdin")
248 .unwrap();
249 assert_eq!(spec.table, "messages");
250 assert_eq!(
251 spec.columns.as_deref(),
252 Some(&["id".to_string(), "subject".to_string(), "body".to_string()][..])
253 );
254 let bare = parse_copy_from_stdin_head("copy t from stdin").unwrap();
256 assert_eq!(bare.table, "t");
257 assert_eq!(bare.columns, None);
258 assert!(parse_copy_from_stdin_head("COPY t TO stdout").is_none());
260 assert!(parse_copy_from_stdin_head("COPY t FROM '/tmp/f.csv'").is_none());
261 assert!(parse_copy_from_stdin_head("COPY t FROM stdin WITH (FORMAT csv)").is_none());
262 }
263
264 #[test]
265 fn decodes_text_rows() {
266 assert_eq!(
267 decode_copy_text_row("1\thello\t\\N\ta\\tb"),
268 vec![
269 Some("1".to_string()),
270 Some("hello".to_string()),
271 None,
272 Some("a\tb".to_string())
273 ]
274 );
275 }
276
277 #[test]
278 fn builds_inserts_with_column_list() {
279 let cols = vec!["id".to_string(), "note".to_string()];
280 let row = vec![Some("7".to_string()), Some("it's".to_string())];
281 assert_eq!(
282 build_copy_insert("t", Some(&cols), &row),
283 "INSERT INTO t (id, note) VALUES (7, 'it''s')"
284 );
285 assert_eq!(
286 build_copy_insert("t", None, &[None, Some("0042".to_string())]),
287 "INSERT INTO t VALUES (NULL, '0042')"
288 );
289 }
290}