use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use sqlx::any::AnyArguments;
use sqlx::Arguments;
use uuid::Uuid;
use crate::{database::Drivers, temporal, Error};
pub trait ValueBinder {
fn bind_value(&mut self, value_str: &str, sql_type: &str, driver: &Drivers) -> Result<(), Error>;
fn bind_i32(&mut self, value: i32);
fn bind_i64(&mut self, value: i64);
fn bind_bool(&mut self, value: bool);
fn bind_f64(&mut self, value: f64);
fn bind_string(&mut self, value: String);
fn bind_uuid(&mut self, value: Uuid, driver: &Drivers);
fn bind_datetime_utc(&mut self, value: DateTime<Utc>, driver: &Drivers);
fn bind_datetime_fixed(&mut self, value: chrono::DateTime<chrono::FixedOffset>, driver: &Drivers);
fn bind_naive_datetime(&mut self, value: NaiveDateTime, driver: &Drivers);
fn bind_naive_date(&mut self, value: NaiveDate, driver: &Drivers);
fn bind_naive_time(&mut self, value: NaiveTime, driver: &Drivers);
}
impl ValueBinder for AnyArguments<'_> {
fn bind_value(&mut self, value_str: &str, sql_type: &str, driver: &Drivers) -> Result<(), Error> {
match sql_type {
"INTEGER" | "INT" | "SERIAL" | "serial" | "int4" => {
if let Ok(val) = value_str.parse::<i32>() {
self.bind_i32(val);
} else if let Ok(val) = value_str.parse::<u32>() {
self.bind_i64(val as i64); } else {
return Err(Error::Conversion(format!("Failed to parse integer: {}", value_str)));
}
Ok(())
}
"BIGINT" | "INT8" | "int8" | "BIGSERIAL" => {
if let Ok(val) = value_str.parse::<i64>() {
self.bind_i64(val);
} else if let Ok(_val) = value_str.parse::<u64>() {
let val = value_str
.parse::<i64>()
.map_err(|e| Error::Conversion(format!("Failed to parse i64: {}", e)))?;
self.bind_i64(val);
} else {
return Err(Error::Conversion(format!("Failed to parse i64: {}", value_str)));
}
Ok(())
}
"SMALLINT" | "INT2" | "int2" => {
let val: i16 =
value_str.parse().map_err(|e| Error::Conversion(format!("Failed to parse i16: {}", e)))?;
let _ = self.add(val);
Ok(())
}
"BOOLEAN" | "BOOL" | "bool" => {
let val: bool =
value_str.parse().map_err(|e| Error::Conversion(format!("Failed to parse bool: {}", e)))?;
self.bind_bool(val);
Ok(())
}
"DOUBLE PRECISION" | "FLOAT" | "float8" | "NUMERIC" | "DECIMAL" => {
let val: f64 =
value_str.parse().map_err(|e| Error::Conversion(format!("Failed to parse f64: {}", e)))?;
self.bind_f64(val);
Ok(())
}
"REAL" | "float4" => {
let val: f32 =
value_str.parse().map_err(|e| Error::Conversion(format!("Failed to parse f32: {}", e)))?;
let _ = self.add(val);
Ok(())
}
"JSON" | "JSONB" | "json" | "jsonb" => {
match driver {
Drivers::Postgres => {
self.bind_string(value_str.to_string());
}
_ => {
self.bind_string(value_str.to_string());
}
}
Ok(())
}
"UUID" => {
let val =
value_str.parse::<Uuid>().map_err(|e| Error::Conversion(format!("Failed to parse UUID: {}", e)))?;
self.bind_uuid(val, driver);
Ok(())
}
"TIMESTAMPTZ" | "DateTime" => {
if let Ok(val) = temporal::parse_datetime_utc(value_str) {
self.bind_datetime_utc(val, driver);
} else if let Ok(val) = temporal::parse_datetime_fixed(value_str) {
self.bind_datetime_fixed(val, driver);
} else {
return Err(Error::Conversion(format!("Failed to parse DateTime: {}", value_str)));
}
Ok(())
}
"TIMESTAMP" | "NaiveDateTime" => {
let val = temporal::parse_naive_datetime(value_str)?;
self.bind_naive_datetime(val, driver);
Ok(())
}
"DATE" | "NaiveDate" => {
let val = temporal::parse_naive_date(value_str)?;
self.bind_naive_date(val, driver);
Ok(())
}
"TIME" | "NaiveTime" => {
let val = temporal::parse_naive_time(value_str)?;
self.bind_naive_time(val, driver);
Ok(())
}
s if s.ends_with("[]") => {
match driver {
Drivers::Postgres => {
self.bind_string(value_str.to_string());
}
_ => {
self.bind_string(value_str.to_string());
}
}
Ok(())
}
"TEXT" | "VARCHAR" | "CHAR" | "STRING" | _ => {
self.bind_string(value_str.to_string());
Ok(())
}
}
}
fn bind_i32(&mut self, value: i32) {
let _ = self.add(value);
}
fn bind_i64(&mut self, value: i64) {
let _ = self.add(value);
}
fn bind_bool(&mut self, value: bool) {
let _ = self.add(value);
}
fn bind_f64(&mut self, value: f64) {
let _ = self.add(value);
}
fn bind_string(&mut self, value: String) {
let _ = self.add(value);
}
fn bind_uuid(&mut self, value: Uuid, driver: &Drivers) {
match driver {
Drivers::Postgres => {
let _ = self.add(value.hyphenated().to_string());
}
Drivers::MySQL => {
let _ = self.add(value.hyphenated().to_string());
}
Drivers::SQLite => {
let _ = self.add(value.hyphenated().to_string());
}
}
}
fn bind_datetime_utc(&mut self, value: DateTime<Utc>, driver: &Drivers) {
let formatted = temporal::format_datetime_for_driver(&value, driver);
let _ = self.add(formatted);
}
fn bind_datetime_fixed(&mut self, value: chrono::DateTime<chrono::FixedOffset>, driver: &Drivers) {
let formatted = temporal::format_datetime_fixed_for_driver(&value, driver);
let _ = self.add(formatted);
}
fn bind_naive_datetime(&mut self, value: NaiveDateTime, driver: &Drivers) {
let formatted = temporal::format_naive_datetime_for_driver(&value, driver);
let _ = self.add(formatted);
}
fn bind_naive_date(&mut self, value: NaiveDate, _driver: &Drivers) {
let formatted = value.format("%Y-%m-%d").to_string();
let _ = self.add(formatted);
}
fn bind_naive_time(&mut self, value: NaiveTime, _driver: &Drivers) {
let formatted = value.format("%H:%M:%S%.6f").to_string();
let _ = self.add(formatted);
}
}
pub fn bind_typed_value(
args: &mut AnyArguments<'_>,
value_str: &str,
sql_type: &str,
driver: &Drivers,
) -> Result<(), Error> {
args.bind_value(value_str, sql_type, driver)
}
pub fn bind_typed_value_or_string(args: &mut AnyArguments<'_>, value_str: &str, sql_type: &str, driver: &Drivers) {
if let Err(_) = args.bind_value(value_str, sql_type, driver) {
let _ = args.add(value_str.to_string());
}
}
pub fn requires_special_binding(sql_type: &str) -> bool {
matches!(
sql_type,
"UUID"
| "TIMESTAMPTZ"
| "DateTime"
| "TIMESTAMP"
| "NaiveDateTime"
| "DATE"
| "NaiveDate"
| "TIME"
| "NaiveTime"
)
}
pub fn is_numeric_type(sql_type: &str) -> bool {
matches!(
sql_type,
"INTEGER"
| "INT"
| "BIGINT"
| "INT8"
| "SERIAL"
| "BIGSERIAL"
| "SMALLINT"
| "DOUBLE PRECISION"
| "FLOAT"
| "REAL"
| "NUMERIC"
| "DECIMAL"
)
}
pub fn is_text_type(sql_type: &str) -> bool {
matches!(sql_type, "TEXT" | "VARCHAR" | "CHAR" | "STRING")
}