use crate::mysql::operation::MysqlOperation;
use async_trait::async_trait;
use ciborium::Value as CborValue;
use indexmap::IndexMap;
use vantage_core::{Result, error};
use vantage_expressions::traits::associated_expressions::AssociatedExpression;
use vantage_expressions::traits::datasource::ExprDataSource;
use vantage_expressions::traits::expressive::ExpressiveEnum;
use vantage_expressions::{Expression, Expressive, Selectable};
use vantage_table::column::core::{Column, ColumnType};
use vantage_table::table::Table;
use vantage_table::traits::table_source::TableSource;
use vantage_types::{Entity, Record};
use crate::mysql::MysqlDB;
use crate::mysql::types::AnyMysqlType;
use crate::primitives::identifier::ident;
use vantage_expressions::expr_any;
fn id_value(id: &str) -> AnyMysqlType {
AnyMysqlType::from(id.to_string())
}
fn parse_rows(
result: AnyMysqlType,
id_field_name: &str,
) -> Result<IndexMap<String, Record<AnyMysqlType>>> {
let arr = match result.into_value() {
CborValue::Array(arr) => arr,
other => {
return Err(error!(
"expected array result",
details = format!("{:?}", other)
));
}
};
let mut records = IndexMap::new();
for item in arr {
let map = match item {
CborValue::Map(map) => map,
_ => continue,
};
let mut id = None;
let mut record = Record::new();
for (k, v) in map {
let key = match k {
CborValue::Text(s) => s,
_ => continue,
};
if key == id_field_name {
id = Some(match &v {
CborValue::Text(s) => s.clone(),
CborValue::Integer(i) => i128::from(*i).to_string(),
CborValue::Float(f) => f.to_string(),
_ => format!("{:?}", v),
});
}
record.insert(key, AnyMysqlType::untyped(v));
}
let id = id.ok_or_else(|| error!("row missing id field", field = id_field_name))?;
records.insert(id, record);
}
Ok(records)
}
#[async_trait]
impl TableSource for MysqlDB {
type Column<Type>
= Column<Type>
where
Type: ColumnType;
type AnyType = AnyMysqlType;
type Value = AnyMysqlType;
type Id = String;
type Condition = crate::condition::MysqlCondition;
fn create_column<Type: ColumnType>(&self, name: &str) -> Self::Column<Type> {
Column::new(name)
}
fn to_any_column<Type: ColumnType>(
&self,
column: Self::Column<Type>,
) -> Self::Column<Self::AnyType> {
Column::from_column(column)
}
fn convert_any_column<Type: ColumnType>(
&self,
any_column: Self::Column<Self::AnyType>,
) -> Option<Self::Column<Type>> {
Some(Column::from_column(any_column))
}
fn expr(
&self,
template: impl Into<String>,
parameters: Vec<ExpressiveEnum<Self::Value>>,
) -> Expression<Self::Value> {
Expression::new(template, parameters)
}
fn search_table_condition<E>(
&self,
table: &Table<Self, E>,
search_value: &str,
) -> Self::Condition
where
E: Entity<Self::Value>,
{
let escaped = search_value
.replace('$', "$$")
.replace('%', "$%")
.replace('_', "$_");
let pattern = format!("%{}%", escaped);
let conditions: Vec<Expression<AnyMysqlType>> = table
.columns()
.values()
.map(|col| {
let p = pattern.clone();
mysql_expr!("{} LIKE {} ESCAPE '$'", (ident(col.name())), p)
})
.collect();
if conditions.is_empty() {
return mysql_expr!("FALSE").into();
}
Expression::from_vec(conditions, " OR ").into()
}
async fn list_table_values<E>(
&self,
table: &Table<Self, E>,
) -> Result<IndexMap<Self::Id, Record<Self::Value>>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let select = table.select();
let result = self.execute(&select.expr()).await?;
parse_rows(result, &id_field_name)
}
async fn get_table_value<E>(
&self,
table: &Table<Self, E>,
id: &Self::Id,
) -> Result<Option<Record<Self::Value>>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let condition = {
let id_val = id_value(id);
mysql_expr!("{} = {}", (ident(&id_field_name)), id_val)
};
let select = table.select().with_condition(condition);
let result = self.execute(&select.expr()).await?;
let mut rows = parse_rows(result, &id_field_name)?;
Ok(rows.swap_remove(id))
}
async fn get_table_some_value<E>(
&self,
table: &Table<Self, E>,
) -> Result<Option<(Self::Id, Record<Self::Value>)>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let mut select = table.select();
select.set_limit(Some(1), None);
let result = self.execute(&select.expr()).await?;
let mut rows = parse_rows(result, &id_field_name)?;
Ok(rows.swap_remove_index(0))
}
async fn get_table_count<E>(&self, table: &Table<Self, E>) -> Result<i64>
where
E: Entity<Self::Value>,
{
let select = table.select();
let result = self.aggregate(&select, "count", mysql_expr!("*")).await?;
result.try_get::<i64>().ok_or_else(|| {
error!(
"get_table_count: expected i64",
result = format!("{}", result)
)
})
}
async fn get_table_sum<E>(
&self,
table: &Table<Self, E>,
column: &Self::Column<Self::AnyType>,
) -> Result<Self::Value>
where
E: Entity<Self::Value>,
{
self.aggregate(&table.select(), "sum", column.expr()).await
}
async fn get_table_max<E>(
&self,
table: &Table<Self, E>,
column: &Self::Column<Self::AnyType>,
) -> Result<Self::Value>
where
E: Entity<Self::Value>,
{
self.aggregate(&table.select(), "max", column.expr()).await
}
async fn get_table_min<E>(
&self,
table: &Table<Self, E>,
column: &Self::Column<Self::AnyType>,
) -> Result<Self::Value>
where
E: Entity<Self::Value>,
{
self.aggregate(&table.select(), "min", column.expr()).await
}
async fn insert_table_value<E>(
&self,
table: &Table<Self, E>,
id: &Self::Id,
record: &Record<Self::Value>,
) -> Result<Record<Self::Value>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let insert = crate::mysql::statements::MysqlInsert::new(table.table_name())
.with_field(&id_field_name, id_value(id))
.with_record(record);
self.execute(&insert.expr()).await?;
self.get_table_value(table, id)
.await?
.ok_or_else(|| error!("Inserted row disappeared", id = id.clone()))
}
async fn replace_table_value<E>(
&self,
table: &Table<Self, E>,
id: &Self::Id,
record: &Record<Self::Value>,
) -> Result<Record<Self::Value>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let insert = crate::mysql::statements::MysqlInsert::new(table.table_name())
.with_field(&id_field_name, id_value(id))
.with_record(record);
let base = insert.expr();
let set_parts: Vec<Expression<AnyMysqlType>> = if record.is_empty() {
vec![expr_any!(
"{} = {}",
(ident(&id_field_name)),
(ident(&id_field_name))
)]
} else {
record
.keys()
.map(|k| expr_any!("{} = VALUES({})", (ident(k)), (ident(k))))
.collect()
};
let conflict = Expression::from_vec(set_parts, ", ");
let upsert = expr_any!("{} ON DUPLICATE KEY UPDATE {}", (base), (conflict));
self.execute(&upsert).await?;
self.get_table_value(table, id)
.await?
.ok_or_else(|| error!("Row missing after upsert", id = id.clone()))
}
async fn patch_table_value<E>(
&self,
table: &Table<Self, E>,
id: &Self::Id,
partial: &Record<Self::Value>,
) -> Result<Record<Self::Value>>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let id_condition = {
let id_val = id_value(id);
mysql_expr!("{} = {}", (ident(&id_field_name)), id_val)
};
let update = crate::mysql::statements::MysqlUpdate::new(table.table_name())
.with_record(partial)
.with_condition(id_condition);
self.execute(&update.expr()).await?;
self.get_table_value(table, id)
.await?
.ok_or_else(|| error!("Row not found after patch", id = id.clone()))
}
async fn delete_table_value<E>(&self, table: &Table<Self, E>, id: &Self::Id) -> Result<()>
where
E: Entity<Self::Value>,
{
let id_field_name = table
.id_field()
.map(|c| c.name().to_string())
.unwrap_or_else(|| "id".to_string());
let id_condition = {
let id_val = id_value(id);
mysql_expr!("{} = {}", (ident(&id_field_name)), id_val)
};
let delete = crate::mysql::statements::MysqlDelete::new(table.table_name())
.with_condition(id_condition);
self.execute(&delete.expr()).await?;
Ok(())
}
async fn delete_table_all_values<E>(&self, table: &Table<Self, E>) -> Result<()>
where
E: Entity<Self::Value>,
{
let delete = crate::mysql::statements::MysqlDelete::new(table.table_name());
self.execute(&delete.expr()).await?;
Ok(())
}
async fn insert_table_return_id_value<E>(
&self,
table: &Table<Self, E>,
record: &Record<Self::Value>,
) -> Result<Self::Id>
where
E: Entity<Self::Value>,
{
let insert =
crate::mysql::statements::MysqlInsert::new(table.table_name()).with_record(record);
use crate::mysql::row::bind_mysql_value;
use vantage_expressions::{ExpressionFlattener, Flatten};
let expr = insert.expr();
let flattener = ExpressionFlattener::new();
let flattened = flattener.flatten(&expr);
let template_parts: Vec<&str> = flattened.template.split("{}").collect();
if template_parts.len() != flattened.parameters.len() + 1 {
return Err(error!(
"MySQL insert expression placeholder mismatch",
placeholders = (template_parts.len() - 1).to_string(),
parameters = flattened.parameters.len().to_string()
));
}
let mut sql = String::new();
let mut params = Vec::new();
sql.push_str(template_parts[0]);
for (i, param) in flattened.parameters.iter().enumerate() {
match param {
ExpressiveEnum::Scalar(value) => {
sql.push('?');
params.push(value.clone());
}
_ => {
return Err(error!(
"MySQL insert expression contains non-scalar parameter",
index = i.to_string()
));
}
}
sql.push_str(template_parts[i + 1]);
}
let mut conn = self
.pool()
.acquire()
.await
.map_err(|e| error!("MySQL acquire connection failed", details = e.to_string()))?;
let mut query = sqlx::query(&sql);
for value in ¶ms {
query = bind_mysql_value(query, value);
}
query
.execute(&mut *conn)
.await
.map_err(|e| error!("MySQL insert failed", details = e.to_string()))?;
let last_id_sql = "SELECT LAST_INSERT_ID() AS id";
let row = sqlx::query(last_id_sql)
.fetch_one(&mut *conn)
.await
.map_err(|e| error!("MySQL LAST_INSERT_ID failed", details = e.to_string()))?;
use sqlx::Row;
let id: u64 = row
.try_get("id")
.map_err(|e| error!("MySQL get id failed", details = e.to_string()))?;
Ok(id.to_string())
}
fn related_in_condition<SourceE: Entity<Self::Value> + 'static>(
&self,
target_field: &str,
source_table: &Table<Self, SourceE>,
source_column: &str,
) -> Self::Condition
where
Self: Sized,
{
let src_col = self.create_column::<Self::AnyType>(source_column);
let fk_values = self.column_table_values_expr(source_table, &src_col);
let tgt_col = self.create_column::<Self::AnyType>(target_field);
tgt_col.in_(fk_values.expr())
}
fn related_correlated_condition(
&self,
target_table: &str,
target_field: &str,
source_table: &str,
source_column: &str,
) -> Self::Condition {
mysql_expr!(
"{} = {}",
(ident(target_field).dot_of(target_table)),
(ident(source_column).dot_of(source_table))
)
.into()
}
fn column_table_values_expr<'a, E, Type: ColumnType>(
&'a self,
table: &Table<Self, E>,
column: &Self::Column<Type>,
) -> AssociatedExpression<'a, Self, Self::Value, Vec<Type>>
where
E: Entity<Self::Value> + 'static,
Self: ExprDataSource<Self::Value> + Sized,
{
let mut select = table.select();
select.clear_fields();
select.clear_order_by();
select.add_field(column.name());
let subquery = select.expr();
AssociatedExpression::new(subquery, self)
}
}