Skip to main content

wp_data_fmt/
sql.rs

1#[allow(deprecated)]
2use crate::formatter::DataFormat;
3use crate::formatter::{RecordFormatter, ValueFormatter};
4use wp_model_core::model::fmt_def::TextFmt;
5use wp_model_core::model::{
6    DataRecord, DataType, FieldStorage, Value, data::record::RecordItem, types::value::ObjectValue,
7};
8
9pub struct SqlInsert {
10    pub table_name: String,
11    pub quote_identifiers: bool,
12    pub obj_formatter: crate::SqlFormat,
13}
14
15impl Default for SqlInsert {
16    fn default() -> Self {
17        Self {
18            table_name: String::new(),
19            quote_identifiers: true,
20            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
21        }
22    }
23}
24
25impl SqlInsert {
26    pub fn new_with_json<T: Into<String>>(table: T) -> Self {
27        Self {
28            table_name: table.into(),
29            quote_identifiers: true,
30            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
31        }
32    }
33    fn quote_identifier(&self, name: &str) -> String {
34        if self.quote_identifiers {
35            let escaped = name.replace('"', "\"\"");
36            format!("\"{}\"", escaped)
37        } else {
38            name.to_string()
39        }
40    }
41    fn escape_string(&self, value: &str) -> String {
42        value.replace('\'', "''")
43    }
44}
45
46#[allow(deprecated)]
47impl DataFormat for SqlInsert {
48    type Output = String;
49    fn format_null(&self) -> String {
50        "NULL".to_string()
51    }
52    fn format_bool(&self, value: &bool) -> String {
53        if *value { "TRUE" } else { "FALSE" }.to_string()
54    }
55    fn format_string(&self, value: &str) -> String {
56        format!("'{}'", self.escape_string(value))
57    }
58    fn format_i64(&self, value: &i64) -> String {
59        value.to_string()
60    }
61    fn format_f64(&self, value: &f64) -> String {
62        if value.is_nan() {
63            "NULL".into()
64        } else if value.is_infinite() {
65            if value.is_sign_positive() {
66                "'Infinity'".into()
67            } else {
68                "'-Infinity'".into()
69            }
70        } else {
71            value.to_string()
72        }
73    }
74    fn format_ip(&self, value: &std::net::IpAddr) -> String {
75        self.format_string(&value.to_string())
76    }
77    fn format_datetime(&self, value: &chrono::NaiveDateTime) -> String {
78        self.format_string(&value.to_string())
79    }
80    fn format_object(&self, value: &ObjectValue) -> String {
81        let inner = match &self.obj_formatter {
82            crate::SqlFormat::Json(f) => f.format_object(value),
83            crate::SqlFormat::Kv(f) => f.format_object(value),
84            crate::SqlFormat::Raw(f) => f.format_object(value),
85            crate::SqlFormat::ProtoText(f) => f.format_object(value),
86        };
87        format!("'{}'", self.escape_string(&inner))
88    }
89    fn format_array(&self, value: &[FieldStorage]) -> String {
90        let inner = match &self.obj_formatter {
91            crate::SqlFormat::Json(f) => f.format_array(value),
92            crate::SqlFormat::Kv(f) => f.format_array(value),
93            crate::SqlFormat::Raw(f) => f.format_array(value),
94            crate::SqlFormat::ProtoText(f) => f.format_array(value),
95        };
96        format!("'{}'", self.escape_string(&inner))
97    }
98    fn format_record(&self, record: &DataRecord) -> String {
99        let columns: Vec<String> = record
100            .items
101            .iter()
102            .filter(|f| *f.get_meta() != DataType::Ignore)
103            .map(|f| self.quote_identifier(f.get_name()))
104            .collect();
105        let values: Vec<String> = record
106            .items
107            .iter()
108            .filter(|f| *f.get_meta() != DataType::Ignore)
109            .map(|f| self.format_field(f))
110            .collect();
111        format!(
112            "INSERT INTO {} ({}) VALUES ({});",
113            self.quote_identifier(&self.table_name),
114            columns.join(", "),
115            values.join(", ")
116        )
117    }
118    fn format_field(&self, field: &FieldStorage) -> String {
119        if *field.get_meta() == DataType::Ignore {
120            String::new()
121        } else {
122            self.fmt_value(field.get_value())
123        }
124    }
125}
126
127impl SqlInsert {
128    #[allow(deprecated)]
129    pub fn format_batch(&self, records: &[DataRecord]) -> String {
130        if records.is_empty() {
131            return String::new();
132        }
133        let mut output = String::new();
134        let columns: Vec<String> = records[0]
135            .items
136            .iter()
137            .filter(|f| *f.get_meta() != DataType::Ignore)
138            .map(|f| self.quote_identifier(f.get_name()))
139            .collect();
140        use std::fmt::Write;
141        writeln!(
142            output,
143            "INSERT INTO {} ({}) VALUES",
144            self.quote_identifier(&self.table_name),
145            columns.join(", ")
146        )
147        .unwrap();
148        for (i, record) in records.iter().enumerate() {
149            if i > 0 {
150                output.push_str(",\n");
151            }
152            let values: Vec<String> = record
153                .items
154                .iter()
155                .filter(|f| *f.get_meta() != DataType::Ignore)
156                .map(|f| self.format_field(f))
157                .collect();
158            write!(output, "  ({})", values.join(", ")).unwrap();
159        }
160        output.push(';');
161        output
162    }
163    pub fn generate_create_table(&self, records: &[DataRecord]) -> String {
164        if records.is_empty() {
165            return String::new();
166        }
167        let mut columns = Vec::new();
168        for field in &records[0].items {
169            if *field.get_meta() == DataType::Ignore {
170                continue;
171            }
172            let sql_type = &match field.get_value() {
173                Value::Bool(_) => "BOOLEAN",
174                Value::Chars(_) => "TEXT",
175                Value::Digit(_) => "BIGINT",
176                Value::Float(_) => "DOUBLE PRECISION",
177                Value::Time(_) => "TIMESTAMP",
178                Value::IpAddr(_) => "INET",
179                Value::Obj(_) | Value::Array(_) => "JSONB",
180                _ => "TEXT",
181            };
182            columns.push(format!(
183                "  {} {}",
184                self.quote_identifier(field.get_name()),
185                sql_type
186            ));
187        }
188        format!(
189            "CREATE TABLE IF NOT EXISTS {} (\n{}\n);",
190            self.quote_identifier(&self.table_name),
191            columns.join(",\n")
192        )
193    }
194    #[allow(deprecated)]
195    pub fn format_upsert(&self, record: &DataRecord, conflict_columns: &[&str]) -> String {
196        let insert = self.format_record(record);
197        let mut update_parts = Vec::new();
198        for field in record
199            .items
200            .iter()
201            .filter(|f| *f.get_meta() != DataType::Ignore)
202        {
203            let name = field.get_name();
204            if !conflict_columns.contains(&name) {
205                let col = self.quote_identifier(name);
206                update_parts.push(format!("{} = EXCLUDED.{}", &col, &col));
207            }
208        }
209        if update_parts.is_empty() {
210            insert
211        } else {
212            let quoted_conflicts: Vec<String> = conflict_columns
213                .iter()
214                .map(|c| self.quote_identifier(c))
215                .collect();
216            format!(
217                "{} ON CONFLICT ({}) DO UPDATE SET {};",
218                insert.trim_end_matches(';'),
219                quoted_conflicts.join(", "),
220                update_parts.join(", ")
221            )
222        }
223    }
224}
225
226#[cfg(test)]
227#[allow(deprecated)]
228mod tests {
229    use super::*;
230    use crate::formatter::DataFormat;
231    use wp_model_core::model::{DataField, DataRecord};
232    #[test]
233    fn test_sql_basic() {
234        let f = SqlInsert {
235            table_name: "t".into(),
236            quote_identifiers: true,
237            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
238        };
239        let r = DataRecord {
240            id: Default::default(),
241            items: vec![
242                FieldStorage::Owned(DataField::from_chars("name", "Alice")),
243                FieldStorage::Owned(DataField::from_digit("age", 30)),
244            ],
245        };
246        let s = f.format_record(&r);
247        assert!(s.contains("INSERT INTO \"t\" (\"name\", \"age\") VALUES"));
248    }
249
250    #[test]
251    fn test_sql_default() {
252        let sql = SqlInsert::default();
253        assert_eq!(sql.table_name, "");
254        assert!(sql.quote_identifiers);
255    }
256
257    #[test]
258    fn test_sql_new_with_json() {
259        let sql = SqlInsert::new_with_json("users");
260        assert_eq!(sql.table_name, "users");
261        assert!(sql.quote_identifiers);
262    }
263
264    #[test]
265    fn test_format_null() {
266        let sql = SqlInsert::default();
267        assert_eq!(sql.format_value(&Value::Null), "NULL");
268    }
269
270    #[test]
271    fn test_format_bool() {
272        let sql = SqlInsert::default();
273        assert_eq!(sql.format_value(&Value::Bool(true)), "TRUE");
274        assert_eq!(sql.format_value(&Value::Bool(false)), "FALSE");
275    }
276
277    #[test]
278    fn test_format_string() {
279        let sql = SqlInsert::default();
280        assert_eq!(sql.format_value(&Value::Chars("hello".into())), "'hello'");
281        assert_eq!(sql.format_value(&Value::Chars("".into())), "''");
282    }
283
284    #[test]
285    fn test_format_string_escape() {
286        let sql = SqlInsert::default();
287        // Single quotes should be escaped by doubling
288        assert_eq!(sql.format_value(&Value::Chars("it's".into())), "'it''s'");
289        assert_eq!(
290            sql.format_value(&Value::Chars("say 'hi'".into())),
291            "'say ''hi'''"
292        );
293    }
294
295    #[test]
296    fn test_format_i64() {
297        let sql = SqlInsert::default();
298        assert_eq!(sql.format_value(&Value::Digit(0)), "0");
299        assert_eq!(sql.format_value(&Value::Digit(42)), "42");
300        assert_eq!(sql.format_value(&Value::Digit(-100)), "-100");
301    }
302
303    #[test]
304    fn test_format_f64_normal() {
305        let sql = SqlInsert::default();
306        assert_eq!(sql.format_value(&Value::Float(3.24)), "3.24");
307        assert_eq!(sql.format_value(&Value::Float(0.0)), "0");
308    }
309
310    #[test]
311    fn test_format_f64_special() {
312        let sql = SqlInsert::default();
313        assert_eq!(sql.format_value(&Value::Float(f64::NAN)), "NULL");
314        assert_eq!(sql.format_value(&Value::Float(f64::INFINITY)), "'Infinity'");
315        assert_eq!(
316            sql.format_value(&Value::Float(f64::NEG_INFINITY)),
317            "'-Infinity'"
318        );
319    }
320
321    #[test]
322    fn test_format_ip() {
323        use std::net::IpAddr;
324        use std::str::FromStr;
325        let sql = SqlInsert::default();
326        let ip = IpAddr::from_str("192.168.1.1").unwrap();
327        assert_eq!(sql.format_value(&Value::IpAddr(ip)), "'192.168.1.1'");
328    }
329
330    #[test]
331    fn test_format_datetime() {
332        let sql = SqlInsert::default();
333        let dt = chrono::NaiveDateTime::parse_from_str("2024-01-15 10:30:45", "%Y-%m-%d %H:%M:%S")
334            .unwrap();
335        let result = sql.format_value(&Value::Time(dt));
336        assert!(result.starts_with('\''));
337        assert!(result.ends_with('\''));
338        assert!(result.contains("2024"));
339    }
340
341    #[test]
342    fn test_quote_identifier() {
343        let sql = SqlInsert::new_with_json("t");
344        assert_eq!(sql.quote_identifier("name"), "\"name\"");
345        assert_eq!(sql.quote_identifier("user_id"), "\"user_id\"");
346    }
347
348    #[test]
349    fn test_quote_identifier_escape() {
350        let sql = SqlInsert::new_with_json("t");
351        // Double quotes in identifier should be escaped by doubling
352        assert_eq!(sql.quote_identifier("col\"name"), "\"col\"\"name\"");
353    }
354
355    #[test]
356    fn test_quote_identifier_disabled() {
357        let sql = SqlInsert {
358            table_name: "t".into(),
359            quote_identifiers: false,
360            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
361        };
362        assert_eq!(sql.quote_identifier("name"), "name");
363    }
364
365    #[test]
366    fn test_format_record() {
367        let sql = SqlInsert::new_with_json("users");
368        let record = DataRecord {
369            id: Default::default(),
370            items: vec![
371                FieldStorage::Owned(DataField::from_chars("name", "Alice")),
372                FieldStorage::Owned(DataField::from_digit("age", 30)),
373                FieldStorage::Owned(DataField::from_bool("active", true)),
374            ],
375        };
376        let result = sql.fmt_record(&record);
377        assert!(result.starts_with("INSERT INTO \"users\""));
378        assert!(result.contains("(\"name\", \"age\", \"active\")"));
379        assert!(result.contains("VALUES ('Alice', 30, TRUE)"));
380        assert!(result.ends_with(';'));
381    }
382
383    #[test]
384    fn test_format_batch_empty() {
385        let sql = SqlInsert::new_with_json("users");
386        let records: Vec<DataRecord> = vec![];
387        assert_eq!(sql.format_batch(&records), "");
388    }
389
390    #[test]
391    fn test_format_batch() {
392        let sql = SqlInsert::new_with_json("users");
393        let records = vec![
394            DataRecord {
395                id: Default::default(),
396                items: vec![
397                    FieldStorage::Owned(DataField::from_chars("name", "Alice")),
398                    FieldStorage::Owned(DataField::from_digit("age", 30)),
399                ],
400            },
401            DataRecord {
402                id: Default::default(),
403                items: vec![
404                    FieldStorage::Owned(DataField::from_chars("name", "Bob")),
405                    FieldStorage::Owned(DataField::from_digit("age", 25)),
406                ],
407            },
408        ];
409        let result = sql.format_batch(&records);
410        assert!(result.contains("INSERT INTO \"users\""));
411        assert!(result.contains("('Alice', 30)"));
412        assert!(result.contains("('Bob', 25)"));
413        assert!(result.ends_with(';'));
414    }
415
416    #[test]
417    fn test_generate_create_table_empty() {
418        let sql = SqlInsert::new_with_json("users");
419        let records: Vec<DataRecord> = vec![];
420        assert_eq!(sql.generate_create_table(&records), "");
421    }
422
423    #[test]
424    fn test_generate_create_table() {
425        let sql = SqlInsert::new_with_json("users");
426        let records = vec![DataRecord {
427            id: Default::default(),
428            items: vec![
429                FieldStorage::Owned(DataField::from_chars("name", "Alice")),
430                FieldStorage::Owned(DataField::from_digit("age", 30)),
431                FieldStorage::Owned(DataField::from_bool("active", true)),
432                FieldStorage::Owned(DataField::from_float("score", 95.5)),
433            ],
434        }];
435        let result = sql.generate_create_table(&records);
436        assert!(result.contains("CREATE TABLE IF NOT EXISTS \"users\""));
437        assert!(result.contains("\"name\" TEXT"));
438        assert!(result.contains("\"age\" BIGINT"));
439        assert!(result.contains("\"active\" BOOLEAN"));
440        assert!(result.contains("\"score\" DOUBLE PRECISION"));
441    }
442
443    #[test]
444    fn test_format_upsert() {
445        let sql = SqlInsert::new_with_json("users");
446        let record = DataRecord {
447            id: Default::default(),
448            items: vec![
449                FieldStorage::Owned(DataField::from_chars("id", "u1")),
450                FieldStorage::Owned(DataField::from_chars("name", "Alice")),
451                FieldStorage::Owned(DataField::from_digit("age", 30)),
452            ],
453        };
454        let result = sql.format_upsert(&record, &["id"]);
455        assert!(result.contains("INSERT INTO \"users\""));
456        assert!(result.contains("ON CONFLICT (\"id\")"));
457        assert!(result.contains("DO UPDATE SET"));
458        assert!(result.contains("\"name\" = EXCLUDED.\"name\""));
459        assert!(result.contains("\"age\" = EXCLUDED.\"age\""));
460    }
461
462    #[test]
463    fn test_format_upsert_no_update_columns() {
464        let sql = SqlInsert::new_with_json("users");
465        let record = DataRecord {
466            id: Default::default(),
467            items: vec![FieldStorage::Owned(DataField::from_chars("id", "u1"))],
468        };
469        // When all columns are conflict columns, no update is needed
470        let result = sql.format_upsert(&record, &["id"]);
471        // Should just be a regular insert with semicolon
472        assert!(result.contains("INSERT INTO"));
473        assert!(!result.contains("ON CONFLICT"));
474    }
475}
476
477// ============================================================================
478// 新 trait 实现:ValueFormatter + RecordFormatter
479// ============================================================================
480
481#[allow(clippy::items_after_test_module)]
482impl ValueFormatter for SqlInsert {
483    type Output = String;
484
485    fn format_value(&self, value: &Value) -> String {
486        match value {
487            Value::Null => "NULL".to_string(),
488            Value::Bool(v) => if *v { "TRUE" } else { "FALSE" }.to_string(),
489            Value::Chars(v) => format!("'{}'", self.escape_string(v)),
490            Value::Digit(v) => v.to_string(),
491            Value::Float(v) => {
492                if v.is_nan() {
493                    "NULL".into()
494                } else if v.is_infinite() {
495                    if v.is_sign_positive() {
496                        "'Infinity'".into()
497                    } else {
498                        "'-Infinity'".into()
499                    }
500                } else {
501                    v.to_string()
502                }
503            }
504            Value::IpAddr(v) => format!("'{}'", self.escape_string(&v.to_string())),
505            Value::Time(v) => format!("'{}'", self.escape_string(&v.to_string())),
506            Value::Obj(_obj) => {
507                let inner = match &self.obj_formatter {
508                    crate::SqlFormat::Json(f) => f.format_value(value),
509                    crate::SqlFormat::Kv(f) => f.format_value(value),
510                    crate::SqlFormat::Raw(f) => f.format_value(value),
511                    crate::SqlFormat::ProtoText(f) => f.format_value(value),
512                };
513                format!("'{}'", self.escape_string(&inner))
514            }
515            Value::Array(arr) => {
516                let inner = match &self.obj_formatter {
517                    crate::SqlFormat::Json(f) => {
518                        let items: Vec<String> = arr
519                            .iter()
520                            .map(|field| f.format_value(field.get_value()))
521                            .collect();
522                        format!("[{}]", items.join(","))
523                    }
524                    crate::SqlFormat::Kv(f) => {
525                        let mut output = String::new();
526                        output.push('[');
527                        for (i, field) in arr.iter().enumerate() {
528                            if i > 0 {
529                                output.push_str(", ");
530                            }
531                            output.push_str(&f.format_value(field.get_value()));
532                        }
533                        output.push(']');
534                        output
535                    }
536                    crate::SqlFormat::Raw(f) => {
537                        if arr.is_empty() {
538                            "[]".to_string()
539                        } else {
540                            let content: Vec<String> = arr
541                                .iter()
542                                .map(|field| f.format_value(field.get_value()))
543                                .collect();
544                            format!("[{}]", content.join(", "))
545                        }
546                    }
547                    crate::SqlFormat::ProtoText(f) => {
548                        let items: Vec<String> = arr
549                            .iter()
550                            .map(|field| f.format_value(field.get_value()))
551                            .collect();
552                        format!("[{}]", items.join(", "))
553                    }
554                };
555                format!("'{}'", self.escape_string(&inner))
556            }
557            _ => format!("'{}'", self.escape_string(&value.to_string())),
558        }
559    }
560}
561
562impl RecordFormatter for SqlInsert {
563    fn fmt_field(&self, field: &FieldStorage) -> String {
564        if *field.get_meta() == DataType::Ignore {
565            String::new()
566        } else {
567            self.format_value(field.get_value())
568        }
569    }
570
571    fn fmt_record(&self, record: &DataRecord) -> String {
572        let columns: Vec<String> = record
573            .items
574            .iter()
575            .filter(|f| *f.get_meta() != DataType::Ignore)
576            .map(|f| self.quote_identifier(f.get_name()))
577            .collect();
578        let values: Vec<String> = record
579            .items
580            .iter()
581            .filter(|f| *f.get_meta() != DataType::Ignore)
582            .map(|f| self.fmt_field(f))
583            .collect();
584        format!(
585            "INSERT INTO {} ({}) VALUES ({});",
586            self.quote_identifier(&self.table_name),
587            columns.join(", "),
588            values.join(", ")
589        )
590    }
591}