pg_upsert/
types.rs

1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
2use sqlx::Arguments;
3use sqlx::postgres::PgArguments;
4
5#[derive(Debug, Clone)]
6pub enum FieldValue {
7    Int32(i32),
8    Int64(i64),
9    Float32(f32),
10    Float64(f64),
11    Bool(bool),
12    String(String),
13    Numeric(String),
14    Bytes(Vec<u8>),
15    Date(NaiveDate),
16    Time(NaiveTime),
17    DateTime(NaiveDateTime),
18    DateTimeUtc(DateTime<Utc>),
19    Null,
20}
21
22impl FieldValue {
23    pub fn bind_to(&self, args: &mut PgArguments) {
24        match self {
25            FieldValue::Int32(v) => args.add(v).unwrap(),
26            FieldValue::Int64(v) => args.add(v).unwrap(),
27            FieldValue::Float32(v) => args.add(v).unwrap(),
28            FieldValue::Float64(v) => args.add(v).unwrap(),
29            FieldValue::Bool(v) => args.add(v).unwrap(),
30            FieldValue::String(v) => args.add(v).unwrap(),
31            FieldValue::Numeric(v) => args.add(v).unwrap(),
32            FieldValue::Bytes(v) => args.add(v).unwrap(),
33            FieldValue::Date(v) => args.add(v).unwrap(),
34            FieldValue::Time(v) => args.add(v).unwrap(),
35            FieldValue::DateTime(v) => args.add(v).unwrap(),
36            FieldValue::DateTimeUtc(v) => args.add(v).unwrap(),
37            FieldValue::Null => args.add(None::<i32>).unwrap(),
38        }
39    }
40}
41
42impl From<i32> for FieldValue {
43    fn from(v: i32) -> Self {
44        FieldValue::Int32(v)
45    }
46}
47
48impl From<i64> for FieldValue {
49    fn from(v: i64) -> Self {
50        FieldValue::Int64(v)
51    }
52}
53
54impl From<f32> for FieldValue {
55    fn from(v: f32) -> Self {
56        FieldValue::Float32(v)
57    }
58}
59
60impl From<f64> for FieldValue {
61    fn from(v: f64) -> Self {
62        FieldValue::Float64(v)
63    }
64}
65
66impl From<bool> for FieldValue {
67    fn from(v: bool) -> Self {
68        FieldValue::Bool(v)
69    }
70}
71
72impl From<String> for FieldValue {
73    fn from(v: String) -> Self {
74        FieldValue::String(v)
75    }
76}
77
78impl From<&str> for FieldValue {
79    fn from(v: &str) -> Self {
80        FieldValue::String(v.to_owned())
81    }
82}
83
84impl From<Vec<u8>> for FieldValue {
85    fn from(v: Vec<u8>) -> Self {
86        FieldValue::Bytes(v)
87    }
88}
89
90impl From<NaiveDate> for FieldValue {
91    fn from(v: NaiveDate) -> Self {
92        FieldValue::Date(v)
93    }
94}
95
96impl From<NaiveTime> for FieldValue {
97    fn from(v: NaiveTime) -> Self {
98        FieldValue::Time(v)
99    }
100}
101
102impl From<NaiveDateTime> for FieldValue {
103    fn from(v: NaiveDateTime) -> Self {
104        FieldValue::DateTime(v)
105    }
106}
107
108impl From<DateTime<Utc>> for FieldValue {
109    fn from(v: DateTime<Utc>) -> Self {
110        FieldValue::DateTimeUtc(v)
111    }
112}
113
114impl<T: Into<FieldValue>> From<Option<T>> for FieldValue {
115    fn from(v: Option<T>) -> Self {
116        match v {
117            Some(val) => val.into(),
118            None => FieldValue::Null,
119        }
120    }
121}
122
123#[derive(Debug, Clone)]
124pub struct Field {
125    pub name: String,
126    pub value: FieldValue,
127}
128
129impl Field {
130    pub fn new(name: impl Into<String>, value: impl Into<FieldValue>) -> Self {
131        Self {
132            name: name.into(),
133            value: value.into(),
134        }
135    }
136}
137
138#[derive(Debug, Clone, Default)]
139pub struct UpsertOptions {
140    pub version_field: Option<String>,
141    pub do_nothing_on_conflict: bool,
142}
143
144impl UpsertOptions {
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    pub fn with_version_field(mut self, field: impl Into<String>) -> Self {
150        self.version_field = Some(field.into());
151        self
152    }
153
154    pub fn with_do_nothing_on_conflict(mut self, do_nothing: bool) -> Self {
155        self.do_nothing_on_conflict = do_nothing;
156        self
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_date_field_conversion() {
166        let date = NaiveDate::from_ymd_opt(2025, 12, 26).unwrap();
167        let field_value: FieldValue = date.into();
168        assert!(matches!(field_value, FieldValue::Date(_)));
169    }
170
171    #[test]
172    fn test_time_field_conversion() {
173        let time = NaiveTime::from_hms_opt(14, 30, 0).unwrap();
174        let field_value: FieldValue = time.into();
175        assert!(matches!(field_value, FieldValue::Time(_)));
176    }
177
178    #[test]
179    fn test_datetime_field_conversion() {
180        let datetime = NaiveDate::from_ymd_opt(2025, 12, 26)
181            .unwrap()
182            .and_hms_opt(14, 30, 0)
183            .unwrap();
184        let field_value: FieldValue = datetime.into();
185        assert!(matches!(field_value, FieldValue::DateTime(_)));
186    }
187
188    #[test]
189    fn test_datetime_utc_field_conversion() {
190        let datetime = DateTime::<Utc>::from_timestamp(1735225800, 0).unwrap();
191        let field_value: FieldValue = datetime.into();
192        assert!(matches!(field_value, FieldValue::DateTimeUtc(_)));
193    }
194
195    #[test]
196    fn test_optional_date_field() {
197        let some_date: Option<NaiveDate> = Some(NaiveDate::from_ymd_opt(2025, 12, 26).unwrap());
198        let field_value: FieldValue = some_date.into();
199        assert!(matches!(field_value, FieldValue::Date(_)));
200
201        let none_date: Option<NaiveDate> = None;
202        let field_value: FieldValue = none_date.into();
203        assert!(matches!(field_value, FieldValue::Null));
204    }
205
206    #[test]
207    fn test_numeric_field_from_string() {
208        let numeric_value = String::from("123.456789");
209        let field_value = FieldValue::Numeric(numeric_value.clone());
210        assert!(matches!(field_value, FieldValue::Numeric(_)));
211        if let FieldValue::Numeric(v) = field_value {
212            assert_eq!(v, "123.456789");
213        }
214    }
215
216    #[test]
217    fn test_numeric_field_high_precision() {
218        let high_precision = "99999999999999999999999999.999999999999";
219        let field_value = FieldValue::Numeric(high_precision.to_string());
220        if let FieldValue::Numeric(v) = field_value {
221            assert_eq!(v, high_precision);
222        }
223    }
224}