use ensemble_derive::Column;
use itertools::Itertools;
use rbs::Value;
use std::{fmt::Display, sync::mpsc};
use super::Schemable;
use crate::{
connection::{self, Database},
value,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
Json,
Uuid,
Text,
Boolean,
Timestamp,
BigInteger,
String(u32),
Enum(String, Vec<String>),
}
impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Json => f.write_str("json"),
Self::Uuid => f.write_str("uuid"),
Self::Text => f.write_str("text"),
Self::Boolean => f.write_str("boolean"),
Self::BigInteger => f.write_str("bigint"),
Self::Timestamp => f.write_str("timestamp"),
Self::String(size) => {
let value = format!("varchar({size})");
f.write_str(&value)
},
Self::Enum(name, values) => {
let value = match connection::which_db() {
Database::MySQL => format!(
"enum({})",
values
.iter()
.map(|v| format!("'{}'", v.replace('\'', "\\'")))
.join(", ")
),
Database::PostgreSQL => format!(
"varchar(255) check({name} in ({}))",
values
.iter()
.map(|v| format!("'{}'", v.replace('\'', "\\'")))
.join(", ")
),
};
f.write_str(&value)
},
}
}
}
#[derive(Debug, Clone, Column)]
#[allow(clippy::struct_excessive_bools, dead_code)]
pub struct Column {
#[builder(init)]
name: String,
#[builder(init)]
r#type: Type,
#[cfg(feature = "mysql")]
after: Option<String>,
#[builder(rename = "increments", type = Type::BigInteger, needs = [primary, unique])]
auto_increment: bool,
#[builder(type = Type::Uuid)]
uuid: bool,
comment: Option<String>,
#[builder(skip)]
default: Option<rbs::Value>,
index: Option<String>,
nullable: bool,
primary: bool,
unique: bool,
collation: Option<String>,
#[cfg(feature = "mysql")]
#[builder(type = Type::BigInteger)]
unsigned: bool,
#[builder(type = Type::Timestamp)]
use_current: bool,
#[cfg(feature = "mysql")]
#[builder(type = Type::Timestamp)]
use_current_on_update: bool,
#[builder(init)]
tx: Option<mpsc::Sender<Schemable>>,
}
impl Column {
pub fn default<T: serde::Serialize>(mut self, default: T) -> Self {
let value = if self.r#type == Type::Json {
Value::String(serde_json::to_string(&default).unwrap())
} else {
value::for_db(default).unwrap()
};
if let Type::Enum(_, values) = &self.r#type {
assert!(
values.contains(&value.as_str().unwrap_or_default().to_string()),
"default value must be one of the enum values"
);
}
self.default = Some(value);
self
}
pub(crate) fn to_sql(&self) -> String {
let db_type = if connection::which_db().is_postgres()
&& self.r#type == Type::BigInteger
&& self.auto_increment
{
"bigserial".to_string()
} else {
self.r#type.to_string()
};
let mut sql = format!("{} {db_type}", self.name);
#[cfg(feature = "mysql")]
if self.unsigned {
sql.push_str(" unsigned");
}
if self.nullable {
sql.push_str(" NULL");
} else {
sql.push_str(" NOT NULL");
}
#[cfg(feature = "mysql")]
if let Some(after) = &self.after {
sql.push_str(&format!(" AFTER {after}"));
}
if let Some(comment) = &self.comment {
sql.push_str(&format!(" COMMENT {comment}"));
}
if let Some(collation) = &self.collation {
sql.push_str(&format!(" COLLATE {collation}"));
}
if let Some(default) = &self.default {
if self.r#type == Type::Json {
sql.push_str(&format!(" DEFAULT '{}'", default.as_str().unwrap()));
} else {
sql.push_str(&format!(" DEFAULT {default}"));
}
}
if self.uuid {
assert!(
self.default.is_none(),
"cannot set a default valud and automatically generate UUIDs at the same time"
);
#[cfg(feature = "mysql")]
sql.push_str(" DEFAULT (UUID())");
#[cfg(feature = "postgres")]
sql.push_str(" DEFAULT (gen_random_uuid())");
}
if self.auto_increment {
#[cfg(feature = "mysql")]
sql.push_str(" AUTO_INCREMENT");
}
if let Some(index) = &self.index {
sql.push_str(&format!(" INDEX {index}"));
}
if self.primary {
sql.push_str(" PRIMARY KEY");
}
if self.unique {
sql.push_str(" UNIQUE");
}
if self.use_current {
#[cfg(feature = "mysql")]
sql.push_str(" DEFAULT CURRENT_TIMESTAMP");
#[cfg(feature = "postgres")]
sql.push_str(" DEFAULT now()");
}
#[cfg(feature = "mysql")]
if self.use_current_on_update {
sql.push_str(" ON UPDATE CURRENT_TIMESTAMP");
}
sql
}
}
impl Drop for Column {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
tx.send(Schemable::Column(self.clone())).unwrap();
drop(tx);
}
}
}