wp_data_fmt/
sql.rs

1use crate::formatter::DataFormat;
2use wp_model_core::model::fmt_def::TextFmt;
3use wp_model_core::model::{DataField, DataRecord, DataType, Value, types::value::ObjectValue};
4
5pub struct SqlInsert {
6    pub table_name: String,
7    pub quote_identifiers: bool,
8    pub obj_formatter: crate::SqlFormat,
9}
10
11impl Default for SqlInsert {
12    fn default() -> Self {
13        Self {
14            table_name: String::new(),
15            quote_identifiers: true,
16            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
17        }
18    }
19}
20
21impl SqlInsert {
22    pub fn new_with_json<T: Into<String>>(table: T) -> Self {
23        Self {
24            table_name: table.into(),
25            quote_identifiers: true,
26            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
27        }
28    }
29    fn quote_identifier(&self, name: &str) -> String {
30        if self.quote_identifiers {
31            let escaped = name.replace('"', "\"\"");
32            format!("\"{}\"", escaped)
33        } else {
34            name.to_string()
35        }
36    }
37    fn escape_string(&self, value: &str) -> String {
38        value.replace('\'', "''")
39    }
40}
41
42impl DataFormat for SqlInsert {
43    type Output = String;
44    fn format_null(&self) -> String {
45        "NULL".to_string()
46    }
47    fn format_bool(&self, value: &bool) -> String {
48        if *value { "TRUE" } else { "FALSE" }.to_string()
49    }
50    fn format_string(&self, value: &str) -> String {
51        format!("'{}'", self.escape_string(value))
52    }
53    fn format_i64(&self, value: &i64) -> String {
54        value.to_string()
55    }
56    fn format_f64(&self, value: &f64) -> String {
57        if value.is_nan() {
58            "NULL".into()
59        } else if value.is_infinite() {
60            if value.is_sign_positive() {
61                "'Infinity'".into()
62            } else {
63                "'-Infinity'".into()
64            }
65        } else {
66            value.to_string()
67        }
68    }
69    fn format_ip(&self, value: &std::net::IpAddr) -> String {
70        self.format_string(&value.to_string())
71    }
72    fn format_datetime(&self, value: &chrono::NaiveDateTime) -> String {
73        self.format_string(&value.to_string())
74    }
75    fn format_object(&self, value: &ObjectValue) -> String {
76        let inner = match &self.obj_formatter {
77            crate::SqlFormat::Json(f) => f.format_object(value),
78            crate::SqlFormat::Kv(f) => f.format_object(value),
79            crate::SqlFormat::Raw(f) => f.format_object(value),
80            crate::SqlFormat::ProtoText(f) => f.format_object(value),
81        };
82        format!("'{}'", self.escape_string(&inner))
83    }
84    fn format_array(&self, value: &[DataField]) -> String {
85        let inner = match &self.obj_formatter {
86            crate::SqlFormat::Json(f) => f.format_array(value),
87            crate::SqlFormat::Kv(f) => f.format_array(value),
88            crate::SqlFormat::Raw(f) => f.format_array(value),
89            crate::SqlFormat::ProtoText(f) => f.format_array(value),
90        };
91        format!("'{}'", self.escape_string(&inner))
92    }
93    fn format_record(&self, record: &DataRecord) -> String {
94        let columns: Vec<String> = record
95            .items
96            .iter()
97            .filter(|f| *f.get_meta() != DataType::Ignore)
98            .map(|f| self.quote_identifier(f.get_name()))
99            .collect();
100        let values: Vec<String> = record
101            .items
102            .iter()
103            .filter(|f| *f.get_meta() != DataType::Ignore)
104            .map(|f| self.format_field(f))
105            .collect();
106        format!(
107            "INSERT INTO {} ({}) VALUES ({});",
108            self.quote_identifier(&self.table_name),
109            columns.join(", "),
110            values.join(", ")
111        )
112    }
113    fn format_field(&self, field: &DataField) -> String {
114        if *field.get_meta() == DataType::Ignore {
115            String::new()
116        } else {
117            self.fmt_value(field.get_value())
118        }
119    }
120}
121
122impl SqlInsert {
123    pub fn format_batch(&self, records: &[DataRecord]) -> String {
124        if records.is_empty() {
125            return String::new();
126        }
127        let mut output = String::new();
128        let columns: Vec<String> = records[0]
129            .items
130            .iter()
131            .filter(|f| *f.get_meta() != DataType::Ignore)
132            .map(|f| self.quote_identifier(f.get_name()))
133            .collect();
134        use std::fmt::Write;
135        writeln!(
136            output,
137            "INSERT INTO {} ({}) VALUES",
138            self.quote_identifier(&self.table_name),
139            columns.join(", ")
140        )
141        .unwrap();
142        for (i, record) in records.iter().enumerate() {
143            if i > 0 {
144                output.push_str(",\n");
145            }
146            let values: Vec<String> = record
147                .items
148                .iter()
149                .filter(|f| *f.get_meta() != DataType::Ignore)
150                .map(|f| self.format_field(f))
151                .collect();
152            write!(output, "  ({})", values.join(", ")).unwrap();
153        }
154        output.push(';');
155        output
156    }
157    pub fn generate_create_table(&self, records: &[DataRecord]) -> String {
158        if records.is_empty() {
159            return String::new();
160        }
161        let mut columns = Vec::new();
162        for field in &records[0].items {
163            if *field.get_meta() == DataType::Ignore {
164                continue;
165            }
166            let sql_type = &match field.get_value() {
167                Value::Bool(_) => "BOOLEAN",
168                Value::Chars(_) => "TEXT",
169                Value::Digit(_) => "BIGINT",
170                Value::Float(_) => "DOUBLE PRECISION",
171                Value::Time(_) => "TIMESTAMP",
172                Value::IpAddr(_) => "INET",
173                Value::Obj(_) | Value::Array(_) => "JSONB",
174                _ => "TEXT",
175            };
176            columns.push(format!(
177                "  {} {}",
178                self.quote_identifier(field.get_name()),
179                sql_type
180            ));
181        }
182        format!(
183            "CREATE TABLE IF NOT EXISTS {} (\n{}\n);",
184            self.quote_identifier(&self.table_name),
185            columns.join(",\n")
186        )
187    }
188    pub fn format_upsert(&self, record: &DataRecord, conflict_columns: &[&str]) -> String {
189        let insert = self.format_record(record);
190        let mut update_parts = Vec::new();
191        for field in record
192            .items
193            .iter()
194            .filter(|f| *f.get_meta() != DataType::Ignore)
195        {
196            let name = field.get_name();
197            if !conflict_columns.contains(&name) {
198                let col = self.quote_identifier(name);
199                update_parts.push(format!("{} = EXCLUDED.{}", &col, &col));
200            }
201        }
202        if update_parts.is_empty() {
203            insert
204        } else {
205            let quoted_conflicts: Vec<String> = conflict_columns
206                .iter()
207                .map(|c| self.quote_identifier(c))
208                .collect();
209            format!(
210                "{} ON CONFLICT ({}) DO UPDATE SET {};",
211                insert.trim_end_matches(';'),
212                quoted_conflicts.join(", "),
213                update_parts.join(", ")
214            )
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::formatter::DataFormat;
223    use wp_model_core::model::{DataField, DataRecord};
224    #[test]
225    fn test_sql_basic() {
226        let f = SqlInsert {
227            table_name: "t".into(),
228            quote_identifiers: true,
229            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
230        };
231        let r = DataRecord {
232            items: vec![
233                DataField::from_chars("name", "Alice"),
234                DataField::from_digit("age", 30),
235            ],
236        };
237        let s = f.format_record(&r);
238        assert!(s.contains("INSERT INTO \"t\" (\"name\", \"age\") VALUES"));
239    }
240
241    #[test]
242    fn test_sql_default() {
243        let sql = SqlInsert::default();
244        assert_eq!(sql.table_name, "");
245        assert!(sql.quote_identifiers);
246    }
247
248    #[test]
249    fn test_sql_new_with_json() {
250        let sql = SqlInsert::new_with_json("users");
251        assert_eq!(sql.table_name, "users");
252        assert!(sql.quote_identifiers);
253    }
254
255    #[test]
256    fn test_format_null() {
257        let sql = SqlInsert::default();
258        assert_eq!(sql.format_null(), "NULL");
259    }
260
261    #[test]
262    fn test_format_bool() {
263        let sql = SqlInsert::default();
264        assert_eq!(sql.format_bool(&true), "TRUE");
265        assert_eq!(sql.format_bool(&false), "FALSE");
266    }
267
268    #[test]
269    fn test_format_string() {
270        let sql = SqlInsert::default();
271        assert_eq!(sql.format_string("hello"), "'hello'");
272        assert_eq!(sql.format_string(""), "''");
273    }
274
275    #[test]
276    fn test_format_string_escape() {
277        let sql = SqlInsert::default();
278        // Single quotes should be escaped by doubling
279        assert_eq!(sql.format_string("it's"), "'it''s'");
280        assert_eq!(sql.format_string("say 'hi'"), "'say ''hi'''");
281    }
282
283    #[test]
284    fn test_format_i64() {
285        let sql = SqlInsert::default();
286        assert_eq!(sql.format_i64(&0), "0");
287        assert_eq!(sql.format_i64(&42), "42");
288        assert_eq!(sql.format_i64(&-100), "-100");
289    }
290
291    #[test]
292    fn test_format_f64_normal() {
293        let sql = SqlInsert::default();
294        assert_eq!(sql.format_f64(&3.24), "3.24");
295        assert_eq!(sql.format_f64(&0.0), "0");
296    }
297
298    #[test]
299    fn test_format_f64_special() {
300        let sql = SqlInsert::default();
301        assert_eq!(sql.format_f64(&f64::NAN), "NULL");
302        assert_eq!(sql.format_f64(&f64::INFINITY), "'Infinity'");
303        assert_eq!(sql.format_f64(&f64::NEG_INFINITY), "'-Infinity'");
304    }
305
306    #[test]
307    fn test_format_ip() {
308        use std::net::IpAddr;
309        use std::str::FromStr;
310        let sql = SqlInsert::default();
311        let ip = IpAddr::from_str("192.168.1.1").unwrap();
312        assert_eq!(sql.format_ip(&ip), "'192.168.1.1'");
313    }
314
315    #[test]
316    fn test_format_datetime() {
317        let sql = SqlInsert::default();
318        let dt = chrono::NaiveDateTime::parse_from_str("2024-01-15 10:30:45", "%Y-%m-%d %H:%M:%S")
319            .unwrap();
320        let result = sql.format_datetime(&dt);
321        assert!(result.starts_with('\''));
322        assert!(result.ends_with('\''));
323        assert!(result.contains("2024"));
324    }
325
326    #[test]
327    fn test_quote_identifier() {
328        let sql = SqlInsert::new_with_json("t");
329        assert_eq!(sql.quote_identifier("name"), "\"name\"");
330        assert_eq!(sql.quote_identifier("user_id"), "\"user_id\"");
331    }
332
333    #[test]
334    fn test_quote_identifier_escape() {
335        let sql = SqlInsert::new_with_json("t");
336        // Double quotes in identifier should be escaped by doubling
337        assert_eq!(sql.quote_identifier("col\"name"), "\"col\"\"name\"");
338    }
339
340    #[test]
341    fn test_quote_identifier_disabled() {
342        let sql = SqlInsert {
343            table_name: "t".into(),
344            quote_identifiers: false,
345            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
346        };
347        assert_eq!(sql.quote_identifier("name"), "name");
348    }
349
350    #[test]
351    fn test_format_record() {
352        let sql = SqlInsert::new_with_json("users");
353        let record = DataRecord {
354            items: vec![
355                DataField::from_chars("name", "Alice"),
356                DataField::from_digit("age", 30),
357                DataField::from_bool("active", true),
358            ],
359        };
360        let result = sql.format_record(&record);
361        assert!(result.starts_with("INSERT INTO \"users\""));
362        assert!(result.contains("(\"name\", \"age\", \"active\")"));
363        assert!(result.contains("VALUES ('Alice', 30, TRUE)"));
364        assert!(result.ends_with(';'));
365    }
366
367    #[test]
368    fn test_format_batch_empty() {
369        let sql = SqlInsert::new_with_json("users");
370        let records: Vec<DataRecord> = vec![];
371        assert_eq!(sql.format_batch(&records), "");
372    }
373
374    #[test]
375    fn test_format_batch() {
376        let sql = SqlInsert::new_with_json("users");
377        let records = vec![
378            DataRecord {
379                items: vec![
380                    DataField::from_chars("name", "Alice"),
381                    DataField::from_digit("age", 30),
382                ],
383            },
384            DataRecord {
385                items: vec![
386                    DataField::from_chars("name", "Bob"),
387                    DataField::from_digit("age", 25),
388                ],
389            },
390        ];
391        let result = sql.format_batch(&records);
392        assert!(result.contains("INSERT INTO \"users\""));
393        assert!(result.contains("('Alice', 30)"));
394        assert!(result.contains("('Bob', 25)"));
395        assert!(result.ends_with(';'));
396    }
397
398    #[test]
399    fn test_generate_create_table_empty() {
400        let sql = SqlInsert::new_with_json("users");
401        let records: Vec<DataRecord> = vec![];
402        assert_eq!(sql.generate_create_table(&records), "");
403    }
404
405    #[test]
406    fn test_generate_create_table() {
407        let sql = SqlInsert::new_with_json("users");
408        let records = vec![DataRecord {
409            items: vec![
410                DataField::from_chars("name", "Alice"),
411                DataField::from_digit("age", 30),
412                DataField::from_bool("active", true),
413                DataField::from_float("score", 95.5),
414            ],
415        }];
416        let result = sql.generate_create_table(&records);
417        assert!(result.contains("CREATE TABLE IF NOT EXISTS \"users\""));
418        assert!(result.contains("\"name\" TEXT"));
419        assert!(result.contains("\"age\" BIGINT"));
420        assert!(result.contains("\"active\" BOOLEAN"));
421        assert!(result.contains("\"score\" DOUBLE PRECISION"));
422    }
423
424    #[test]
425    fn test_format_upsert() {
426        let sql = SqlInsert::new_with_json("users");
427        let record = DataRecord {
428            items: vec![
429                DataField::from_chars("id", "u1"),
430                DataField::from_chars("name", "Alice"),
431                DataField::from_digit("age", 30),
432            ],
433        };
434        let result = sql.format_upsert(&record, &["id"]);
435        assert!(result.contains("INSERT INTO \"users\""));
436        assert!(result.contains("ON CONFLICT (\"id\")"));
437        assert!(result.contains("DO UPDATE SET"));
438        assert!(result.contains("\"name\" = EXCLUDED.\"name\""));
439        assert!(result.contains("\"age\" = EXCLUDED.\"age\""));
440    }
441
442    #[test]
443    fn test_format_upsert_no_update_columns() {
444        let sql = SqlInsert::new_with_json("users");
445        let record = DataRecord {
446            items: vec![DataField::from_chars("id", "u1")],
447        };
448        // When all columns are conflict columns, no update is needed
449        let result = sql.format_upsert(&record, &["id"]);
450        // Should just be a regular insert with semicolon
451        assert!(result.contains("INSERT INTO"));
452        assert!(!result.contains("ON CONFLICT"));
453    }
454}