1use crate::formatter::DataFormat;
2use wp_model_core::model::fmt_def::TextFmt;
3use wp_model_core::model::{DataField, DataRecord, DataType, Value, types::value::ObjectValue};
4
5pub struct SqlInsert {
6 pub table_name: String,
7 pub quote_identifiers: bool,
8 pub obj_formatter: crate::SqlFormat,
9}
10
11impl Default for SqlInsert {
12 fn default() -> Self {
13 Self {
14 table_name: String::new(),
15 quote_identifiers: true,
16 obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
17 }
18 }
19}
20
21impl SqlInsert {
22 pub fn new_with_json<T: Into<String>>(table: T) -> Self {
23 Self {
24 table_name: table.into(),
25 quote_identifiers: true,
26 obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
27 }
28 }
29 fn quote_identifier(&self, name: &str) -> String {
30 if self.quote_identifiers {
31 let escaped = name.replace('"', "\"\"");
32 format!("\"{}\"", escaped)
33 } else {
34 name.to_string()
35 }
36 }
37 fn escape_string(&self, value: &str) -> String {
38 value.replace('\'', "''")
39 }
40}
41
42impl DataFormat for SqlInsert {
43 type Output = String;
44 fn format_null(&self) -> String {
45 "NULL".to_string()
46 }
47 fn format_bool(&self, value: &bool) -> String {
48 if *value { "TRUE" } else { "FALSE" }.to_string()
49 }
50 fn format_string(&self, value: &str) -> String {
51 format!("'{}'", self.escape_string(value))
52 }
53 fn format_i64(&self, value: &i64) -> String {
54 value.to_string()
55 }
56 fn format_f64(&self, value: &f64) -> String {
57 if value.is_nan() {
58 "NULL".into()
59 } else if value.is_infinite() {
60 if value.is_sign_positive() {
61 "'Infinity'".into()
62 } else {
63 "'-Infinity'".into()
64 }
65 } else {
66 value.to_string()
67 }
68 }
69 fn format_ip(&self, value: &std::net::IpAddr) -> String {
70 self.format_string(&value.to_string())
71 }
72 fn format_datetime(&self, value: &chrono::NaiveDateTime) -> String {
73 self.format_string(&value.to_string())
74 }
75 fn format_object(&self, value: &ObjectValue) -> String {
76 let inner = match &self.obj_formatter {
77 crate::SqlFormat::Json(f) => f.format_object(value),
78 crate::SqlFormat::Kv(f) => f.format_object(value),
79 crate::SqlFormat::Raw(f) => f.format_object(value),
80 crate::SqlFormat::ProtoText(f) => f.format_object(value),
81 };
82 format!("'{}'", self.escape_string(&inner))
83 }
84 fn format_array(&self, value: &[DataField]) -> String {
85 let inner = match &self.obj_formatter {
86 crate::SqlFormat::Json(f) => f.format_array(value),
87 crate::SqlFormat::Kv(f) => f.format_array(value),
88 crate::SqlFormat::Raw(f) => f.format_array(value),
89 crate::SqlFormat::ProtoText(f) => f.format_array(value),
90 };
91 format!("'{}'", self.escape_string(&inner))
92 }
93 fn format_record(&self, record: &DataRecord) -> String {
94 let columns: Vec<String> = record
95 .items
96 .iter()
97 .filter(|f| *f.get_meta() != DataType::Ignore)
98 .map(|f| self.quote_identifier(f.get_name()))
99 .collect();
100 let values: Vec<String> = record
101 .items
102 .iter()
103 .filter(|f| *f.get_meta() != DataType::Ignore)
104 .map(|f| self.format_field(f))
105 .collect();
106 format!(
107 "INSERT INTO {} ({}) VALUES ({});",
108 self.quote_identifier(&self.table_name),
109 columns.join(", "),
110 values.join(", ")
111 )
112 }
113 fn format_field(&self, field: &DataField) -> String {
114 if *field.get_meta() == DataType::Ignore {
115 String::new()
116 } else {
117 self.fmt_value(field.get_value())
118 }
119 }
120}
121
122impl SqlInsert {
123 pub fn format_batch(&self, records: &[DataRecord]) -> String {
124 if records.is_empty() {
125 return String::new();
126 }
127 let mut output = String::new();
128 let columns: Vec<String> = records[0]
129 .items
130 .iter()
131 .filter(|f| *f.get_meta() != DataType::Ignore)
132 .map(|f| self.quote_identifier(f.get_name()))
133 .collect();
134 use std::fmt::Write;
135 writeln!(
136 output,
137 "INSERT INTO {} ({}) VALUES",
138 self.quote_identifier(&self.table_name),
139 columns.join(", ")
140 )
141 .unwrap();
142 for (i, record) in records.iter().enumerate() {
143 if i > 0 {
144 output.push_str(",\n");
145 }
146 let values: Vec<String> = record
147 .items
148 .iter()
149 .filter(|f| *f.get_meta() != DataType::Ignore)
150 .map(|f| self.format_field(f))
151 .collect();
152 write!(output, " ({})", values.join(", ")).unwrap();
153 }
154 output.push(';');
155 output
156 }
157 pub fn generate_create_table(&self, records: &[DataRecord]) -> String {
158 if records.is_empty() {
159 return String::new();
160 }
161 let mut columns = Vec::new();
162 for field in &records[0].items {
163 if *field.get_meta() == DataType::Ignore {
164 continue;
165 }
166 let sql_type = &match field.get_value() {
167 Value::Bool(_) => "BOOLEAN",
168 Value::Chars(_) => "TEXT",
169 Value::Digit(_) => "BIGINT",
170 Value::Float(_) => "DOUBLE PRECISION",
171 Value::Time(_) => "TIMESTAMP",
172 Value::IpAddr(_) => "INET",
173 Value::Obj(_) | Value::Array(_) => "JSONB",
174 _ => "TEXT",
175 };
176 columns.push(format!(
177 " {} {}",
178 self.quote_identifier(field.get_name()),
179 sql_type
180 ));
181 }
182 format!(
183 "CREATE TABLE IF NOT EXISTS {} (\n{}\n);",
184 self.quote_identifier(&self.table_name),
185 columns.join(",\n")
186 )
187 }
188 pub fn format_upsert(&self, record: &DataRecord, conflict_columns: &[&str]) -> String {
189 let insert = self.format_record(record);
190 let mut update_parts = Vec::new();
191 for field in record
192 .items
193 .iter()
194 .filter(|f| *f.get_meta() != DataType::Ignore)
195 {
196 let name = field.get_name();
197 if !conflict_columns.contains(&name) {
198 let col = self.quote_identifier(name);
199 update_parts.push(format!("{} = EXCLUDED.{}", &col, &col));
200 }
201 }
202 if update_parts.is_empty() {
203 insert
204 } else {
205 let quoted_conflicts: Vec<String> = conflict_columns
206 .iter()
207 .map(|c| self.quote_identifier(c))
208 .collect();
209 format!(
210 "{} ON CONFLICT ({}) DO UPDATE SET {};",
211 insert.trim_end_matches(';'),
212 quoted_conflicts.join(", "),
213 update_parts.join(", ")
214 )
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::formatter::DataFormat;
223 use wp_model_core::model::{DataField, DataRecord};
224 #[test]
225 fn test_sql_basic() {
226 let f = SqlInsert {
227 table_name: "t".into(),
228 quote_identifiers: true,
229 obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
230 };
231 let r = DataRecord {
232 items: vec![
233 DataField::from_chars("name", "Alice"),
234 DataField::from_digit("age", 30),
235 ],
236 };
237 let s = f.format_record(&r);
238 assert!(s.contains("INSERT INTO \"t\" (\"name\", \"age\") VALUES"));
239 }
240
241 #[test]
242 fn test_sql_default() {
243 let sql = SqlInsert::default();
244 assert_eq!(sql.table_name, "");
245 assert!(sql.quote_identifiers);
246 }
247
248 #[test]
249 fn test_sql_new_with_json() {
250 let sql = SqlInsert::new_with_json("users");
251 assert_eq!(sql.table_name, "users");
252 assert!(sql.quote_identifiers);
253 }
254
255 #[test]
256 fn test_format_null() {
257 let sql = SqlInsert::default();
258 assert_eq!(sql.format_null(), "NULL");
259 }
260
261 #[test]
262 fn test_format_bool() {
263 let sql = SqlInsert::default();
264 assert_eq!(sql.format_bool(&true), "TRUE");
265 assert_eq!(sql.format_bool(&false), "FALSE");
266 }
267
268 #[test]
269 fn test_format_string() {
270 let sql = SqlInsert::default();
271 assert_eq!(sql.format_string("hello"), "'hello'");
272 assert_eq!(sql.format_string(""), "''");
273 }
274
275 #[test]
276 fn test_format_string_escape() {
277 let sql = SqlInsert::default();
278 assert_eq!(sql.format_string("it's"), "'it''s'");
280 assert_eq!(sql.format_string("say 'hi'"), "'say ''hi'''");
281 }
282
283 #[test]
284 fn test_format_i64() {
285 let sql = SqlInsert::default();
286 assert_eq!(sql.format_i64(&0), "0");
287 assert_eq!(sql.format_i64(&42), "42");
288 assert_eq!(sql.format_i64(&-100), "-100");
289 }
290
291 #[test]
292 fn test_format_f64_normal() {
293 let sql = SqlInsert::default();
294 assert_eq!(sql.format_f64(&3.24), "3.24");
295 assert_eq!(sql.format_f64(&0.0), "0");
296 }
297
298 #[test]
299 fn test_format_f64_special() {
300 let sql = SqlInsert::default();
301 assert_eq!(sql.format_f64(&f64::NAN), "NULL");
302 assert_eq!(sql.format_f64(&f64::INFINITY), "'Infinity'");
303 assert_eq!(sql.format_f64(&f64::NEG_INFINITY), "'-Infinity'");
304 }
305
306 #[test]
307 fn test_format_ip() {
308 use std::net::IpAddr;
309 use std::str::FromStr;
310 let sql = SqlInsert::default();
311 let ip = IpAddr::from_str("192.168.1.1").unwrap();
312 assert_eq!(sql.format_ip(&ip), "'192.168.1.1'");
313 }
314
315 #[test]
316 fn test_format_datetime() {
317 let sql = SqlInsert::default();
318 let dt = chrono::NaiveDateTime::parse_from_str("2024-01-15 10:30:45", "%Y-%m-%d %H:%M:%S")
319 .unwrap();
320 let result = sql.format_datetime(&dt);
321 assert!(result.starts_with('\''));
322 assert!(result.ends_with('\''));
323 assert!(result.contains("2024"));
324 }
325
326 #[test]
327 fn test_quote_identifier() {
328 let sql = SqlInsert::new_with_json("t");
329 assert_eq!(sql.quote_identifier("name"), "\"name\"");
330 assert_eq!(sql.quote_identifier("user_id"), "\"user_id\"");
331 }
332
333 #[test]
334 fn test_quote_identifier_escape() {
335 let sql = SqlInsert::new_with_json("t");
336 assert_eq!(sql.quote_identifier("col\"name"), "\"col\"\"name\"");
338 }
339
340 #[test]
341 fn test_quote_identifier_disabled() {
342 let sql = SqlInsert {
343 table_name: "t".into(),
344 quote_identifiers: false,
345 obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
346 };
347 assert_eq!(sql.quote_identifier("name"), "name");
348 }
349
350 #[test]
351 fn test_format_record() {
352 let sql = SqlInsert::new_with_json("users");
353 let record = DataRecord {
354 items: vec![
355 DataField::from_chars("name", "Alice"),
356 DataField::from_digit("age", 30),
357 DataField::from_bool("active", true),
358 ],
359 };
360 let result = sql.format_record(&record);
361 assert!(result.starts_with("INSERT INTO \"users\""));
362 assert!(result.contains("(\"name\", \"age\", \"active\")"));
363 assert!(result.contains("VALUES ('Alice', 30, TRUE)"));
364 assert!(result.ends_with(';'));
365 }
366
367 #[test]
368 fn test_format_batch_empty() {
369 let sql = SqlInsert::new_with_json("users");
370 let records: Vec<DataRecord> = vec![];
371 assert_eq!(sql.format_batch(&records), "");
372 }
373
374 #[test]
375 fn test_format_batch() {
376 let sql = SqlInsert::new_with_json("users");
377 let records = vec![
378 DataRecord {
379 items: vec![
380 DataField::from_chars("name", "Alice"),
381 DataField::from_digit("age", 30),
382 ],
383 },
384 DataRecord {
385 items: vec![
386 DataField::from_chars("name", "Bob"),
387 DataField::from_digit("age", 25),
388 ],
389 },
390 ];
391 let result = sql.format_batch(&records);
392 assert!(result.contains("INSERT INTO \"users\""));
393 assert!(result.contains("('Alice', 30)"));
394 assert!(result.contains("('Bob', 25)"));
395 assert!(result.ends_with(';'));
396 }
397
398 #[test]
399 fn test_generate_create_table_empty() {
400 let sql = SqlInsert::new_with_json("users");
401 let records: Vec<DataRecord> = vec![];
402 assert_eq!(sql.generate_create_table(&records), "");
403 }
404
405 #[test]
406 fn test_generate_create_table() {
407 let sql = SqlInsert::new_with_json("users");
408 let records = vec![DataRecord {
409 items: vec![
410 DataField::from_chars("name", "Alice"),
411 DataField::from_digit("age", 30),
412 DataField::from_bool("active", true),
413 DataField::from_float("score", 95.5),
414 ],
415 }];
416 let result = sql.generate_create_table(&records);
417 assert!(result.contains("CREATE TABLE IF NOT EXISTS \"users\""));
418 assert!(result.contains("\"name\" TEXT"));
419 assert!(result.contains("\"age\" BIGINT"));
420 assert!(result.contains("\"active\" BOOLEAN"));
421 assert!(result.contains("\"score\" DOUBLE PRECISION"));
422 }
423
424 #[test]
425 fn test_format_upsert() {
426 let sql = SqlInsert::new_with_json("users");
427 let record = DataRecord {
428 items: vec![
429 DataField::from_chars("id", "u1"),
430 DataField::from_chars("name", "Alice"),
431 DataField::from_digit("age", 30),
432 ],
433 };
434 let result = sql.format_upsert(&record, &["id"]);
435 assert!(result.contains("INSERT INTO \"users\""));
436 assert!(result.contains("ON CONFLICT (\"id\")"));
437 assert!(result.contains("DO UPDATE SET"));
438 assert!(result.contains("\"name\" = EXCLUDED.\"name\""));
439 assert!(result.contains("\"age\" = EXCLUDED.\"age\""));
440 }
441
442 #[test]
443 fn test_format_upsert_no_update_columns() {
444 let sql = SqlInsert::new_with_json("users");
445 let record = DataRecord {
446 items: vec![DataField::from_chars("id", "u1")],
447 };
448 let result = sql.format_upsert(&record, &["id"]);
450 assert!(result.contains("INSERT INTO"));
452 assert!(!result.contains("ON CONFLICT"));
453 }
454}