Skip to main content

d1_orm/
builder.rs

1use crate::error::Error;
2use crate::traits::{FieldMeta, FieldUpdate};
3use std::borrow::Cow;
4use std::fmt::Write;
5
6pub fn build_update_sql<T: FieldMeta + FieldUpdate>(
7    table: &str,
8    key_field: &str,
9    updates: &[T],
10) -> Result<Cow<'static, str>, Error> {
11    let valid = updates.iter().filter(|u| !u.is_primary_key());
12    let count = valid.clone().count();
13    if count == 0 {
14        return Err(Error::Build("Empty update fields".to_string()));
15    }
16
17    let mut sql = String::with_capacity(64 + table.len() + key_field.len() + count * 40);
18    write!(sql, "UPDATE {} SET ", table).unwrap();
19
20    for (i, u) in valid.enumerate() {
21        if i > 0 {
22            sql.push_str(", ");
23        }
24        write!(sql, "{} = ?", u.field()).unwrap();
25    }
26    write!(sql, " WHERE {} = ?", key_field).unwrap();
27
28    Ok(Cow::Owned(sql))
29}
30
31pub type ConflictResolution<'a> = dyn Fn(&str) -> Option<&'static str> + 'a;
32
33pub struct UpsertConfig<'a> {
34    pub table: &'a str,
35    pub primary_keys: &'a [&'a str],
36    pub custom_conflict_resolution: Option<&'a ConflictResolution<'a>>,
37}
38
39pub fn build_upsert_sql<T: FieldMeta + FieldUpdate>(
40    config: &UpsertConfig,
41    updates: &[T],
42) -> Result<Cow<'static, str>, Error> {
43    let valid = updates.iter().filter(|u| !u.is_primary_key());
44    let valid_count = valid.clone().count();
45    let pk_count = config.primary_keys.len();
46
47    // Must have at least one field to insert (PK or update field)
48    if pk_count == 0 && valid_count == 0 {
49        return Err(Error::Build("Empty fields for upsert".to_string()));
50    }
51
52    use std::fmt::Write;
53    let mut sql = String::with_capacity(128 + config.table.len() + (pk_count + valid_count) * 40);
54    write!(sql, "INSERT INTO {} (", config.table).unwrap();
55
56    // Columns: PKs then Update fields
57    let mut first = true;
58    for pk in config.primary_keys.iter() {
59        if !first {
60            sql.push_str(", ");
61        }
62        sql.push_str(pk);
63        first = false;
64    }
65    for u in valid.clone() {
66        if !first {
67            sql.push_str(", ");
68        }
69        write!(sql, "{}", u.field()).unwrap();
70        first = false;
71    }
72
73    sql.push_str(") VALUES (");
74
75    // Placeholders
76    first = true;
77    for _ in 0..pk_count {
78        if !first {
79            sql.push_str(", ");
80        }
81        sql.push('?');
82        first = false;
83    }
84    for _ in 0..valid_count {
85        if !first {
86            sql.push_str(", ");
87        }
88        sql.push('?');
89        first = false;
90    }
91    sql.push(')');
92
93    // ON CONFLICT clause
94    if pk_count > 0 {
95        sql.push_str(" ON CONFLICT(");
96        first = true;
97        for pk in config.primary_keys.iter() {
98            if !first {
99                sql.push_str(", ");
100            }
101            sql.push_str(pk);
102            first = false;
103        }
104        if valid_count > 0 {
105            sql.push_str(") DO UPDATE SET ");
106
107            first = true;
108            for u in valid {
109                if !first {
110                    sql.push_str(", ");
111                }
112                let f = u.field();
113
114                let mut resolved = false;
115                if let Some(custom_sql) = config.custom_conflict_resolution.and_then(|cf| cf(f)) {
116                    sql.push_str(custom_sql);
117                    resolved = true;
118                }
119                if !resolved {
120                    write!(sql, "{} = excluded.{}", f, f).unwrap();
121                }
122                first = false;
123            }
124        } else {
125            sql.push_str(") DO NOTHING");
126        }
127    }
128
129    Ok(Cow::Owned(sql))
130}