use crate::mysql::condition::{Condition, SqlValue};
use crate::mysql::field::{FieldType, JoinClause, OrderClause};
use sqlx::mysql::MySqlPool;
use std::collections::HashMap;
#[allow(dead_code)]
pub(crate) struct SqlGenerator {
sql: String,
params: Vec<SqlValue>,
}
#[allow(dead_code)]
impl SqlGenerator {
pub(crate) fn new() -> Self {
Self {
sql: String::new(),
params: Vec::new(),
}
}
pub(crate) fn get_sql(&self) -> &str {
&self.sql
}
pub(crate) fn get_params(&self) -> &[SqlValue] {
&self.params
}
fn append(&mut self, fragment: &str) {
self.sql.push_str(fragment);
}
fn add_param(&mut self, param: SqlValue) {
self.params.push(param);
}
fn clear(&mut self) {
self.sql.clear();
self.params.clear();
}
fn build_select(&mut self, builder: &QueryBuilder) -> Result<(), crate::error::DbError> {
self.clear();
self.append("SELECT ");
if builder.distinct {
self.append("DISTINCT ");
}
if builder.fields.is_empty() {
self.append("*");
} else {
self.append(&builder.fields.join(", "));
}
self.append(" FROM ");
self.append(&builder.table);
if !builder.joins.is_empty() {
self.build_joins(&builder.joins);
}
if !builder.conditions.is_empty() {
self.build_where(&builder.conditions)?;
}
if !builder.group_by.is_empty() {
self.build_group_by(&builder.group_by);
}
if !builder.order_by.is_empty() {
self.build_order_by(&builder.order_by);
}
if let Some(limit) = builder.limit {
self.append(&format!(" LIMIT {}", limit));
}
if let Some(offset) = builder.offset {
self.append(&format!(" OFFSET {}", offset));
}
Ok(())
}
fn build_where(&mut self, conditions: &[Condition]) -> Result<(), crate::error::DbError> {
if conditions.is_empty() {
return Ok(());
}
self.append(" WHERE ");
if conditions.len() == 1 {
let sql = crate::mysql::condition::condition_to_sql(&conditions[0], &mut self.params);
self.append(&sql);
} else {
let combined = Condition::And(conditions.to_vec());
let sql = crate::mysql::condition::condition_to_sql(&combined, &mut self.params);
self.append(&sql);
}
Ok(())
}
fn build_joins(&mut self, joins: &[JoinClause]) {
use crate::mysql::field::JoinType;
for join in joins {
let join_type_str = match join.join_type {
JoinType::Inner => " INNER JOIN ",
JoinType::Left => " LEFT JOIN ",
JoinType::Right => " RIGHT JOIN ",
};
self.append(join_type_str);
self.append(&join.table);
self.append(" ON ");
self.append(&join.on);
}
}
fn build_order_by(&mut self, orders: &[OrderClause]) {
if orders.is_empty() {
return;
}
self.append(" ORDER BY ");
let order_parts: Vec<String> = orders
.iter()
.map(|order| {
let direction = if order.asc { "ASC" } else { "DESC" };
format!("{} {}", order.field, direction)
})
.collect();
self.append(&order_parts.join(", "));
}
fn build_group_by(&mut self, groups: &[String]) {
if groups.is_empty() {
return;
}
self.append(" GROUP BY ");
self.append(&groups.join(", "));
}
pub(crate) fn build_insert(
&mut self,
table: &str,
data: &serde_json::Value,
field_types: &HashMap<String, FieldType>,
) -> Result<(), crate::error::DbError> {
self.clear();
let obj = data.as_object().ok_or_else(|| {
crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
})?;
if obj.is_empty() {
return Err(crate::error::DbError::SerializationError(
"插入数据不能为空".to_string(),
));
}
let mut fields = Vec::new();
let mut placeholders = Vec::new();
for (key, value) in obj.iter() {
fields.push(key.clone());
placeholders.push("?".to_string());
let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
self.add_param(sql_value);
}
self.append("INSERT INTO ");
self.append(table);
self.append(" (");
self.append(&fields.join(", "));
self.append(") VALUES (");
self.append(&placeholders.join(", "));
self.append(")");
Ok(())
}
pub(crate) fn build_insert_batch(
&mut self,
table: &str,
data_list: &[serde_json::Value],
field_types: &HashMap<String, FieldType>,
) -> Result<(), crate::error::DbError> {
self.clear();
if data_list.is_empty() {
return Err(crate::error::DbError::SerializationError(
"批量插入数据不能为空".to_string(),
));
}
let first_obj = data_list[0].as_object().ok_or_else(|| {
crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
})?;
if first_obj.is_empty() {
return Err(crate::error::DbError::SerializationError(
"插入数据不能为空".to_string(),
));
}
let fields: Vec<String> = first_obj.keys().cloned().collect();
self.append("INSERT INTO ");
self.append(table);
self.append(" (");
self.append(&fields.join(", "));
self.append(") VALUES ");
let mut value_clauses = Vec::new();
for data in data_list {
let obj = data.as_object().ok_or_else(|| {
crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
})?;
let mut placeholders = Vec::new();
for field in &fields {
placeholders.push("?".to_string());
let value = obj.get(field).unwrap_or(&serde_json::Value::Null);
let sql_value = self.json_value_to_sql_value(value, field_types.get(field))?;
self.add_param(sql_value);
}
value_clauses.push(format!("({})", placeholders.join(", ")));
}
self.append(&value_clauses.join(", "));
Ok(())
}
pub(crate) fn build_update(
&mut self,
table: &str,
data: &serde_json::Value,
field_types: &HashMap<String, FieldType>,
conditions: &[Condition],
) -> Result<(), crate::error::DbError> {
self.clear();
if conditions.is_empty() {
return Err(crate::error::DbError::MissingWhereClause);
}
let obj = data.as_object().ok_or_else(|| {
crate::error::DbError::SerializationError("更新数据必须是 JSON 对象".to_string())
})?;
if obj.is_empty() {
return Err(crate::error::DbError::SerializationError(
"更新数据不能为空".to_string(),
));
}
self.append("UPDATE ");
self.append(table);
self.append(" SET ");
let mut set_clauses = Vec::new();
for (key, value) in obj.iter() {
set_clauses.push(format!("{} = ?", key));
let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
self.add_param(sql_value);
}
self.append(&set_clauses.join(", "));
self.build_where(conditions)?;
Ok(())
}
pub(crate) fn build_delete(
&mut self,
table: &str,
conditions: &[Condition],
) -> Result<(), crate::error::DbError> {
self.clear();
if conditions.is_empty() {
return Err(crate::error::DbError::MissingWhereClause);
}
self.append("DELETE FROM ");
self.append(table);
self.build_where(conditions)?;
Ok(())
}
fn json_value_to_sql_value(
&self,
value: &serde_json::Value,
field_type: Option<&FieldType>,
) -> Result<SqlValue, crate::error::DbError> {
use serde_json::Value;
if let Some(ft) = field_type {
match ft {
FieldType::Json => {
return Ok(SqlValue::Json(value.clone()));
}
FieldType::DateTime => {
if let Some(s) = value.as_str() {
let dt = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S")
.map_err(|e| {
crate::error::DbError::TypeConversionError(format!(
"无法解析 DATETIME 字符串: {}",
e
))
})?;
return Ok(SqlValue::DateTime(dt));
}
}
FieldType::Timestamp => {
if let Some(i) = value.as_i64() {
return Ok(SqlValue::Timestamp(i));
}
}
FieldType::Decimal => {
if let Some(f) = value.as_f64() {
return Ok(SqlValue::Float(f));
} else if let Some(i) = value.as_i64() {
return Ok(SqlValue::Float(i as f64));
}
}
FieldType::Blob => {
if let Some(s) = value.as_str() {
use base64::Engine;
if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(s) {
return Ok(SqlValue::Bytes(bytes));
}
return Ok(SqlValue::Bytes(s.as_bytes().to_vec()));
}
}
FieldType::Text => {
if let Some(s) = value.as_str() {
return Ok(SqlValue::String(s.to_string()));
}
}
FieldType::Standard => {
}
}
}
match value {
Value::Null => Ok(SqlValue::Null),
Value::Bool(b) => Ok(SqlValue::Bool(*b)),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(SqlValue::Int(i))
} else if let Some(f) = n.as_f64() {
Ok(SqlValue::Float(f))
} else {
Err(crate::error::DbError::TypeConversionError(
"无法转换数字类型".to_string(),
))
}
}
Value::String(s) => Ok(SqlValue::String(s.clone())),
Value::Array(_) | Value::Object(_) => {
Ok(SqlValue::Json(value.clone()))
}
}
}
}
pub struct QueryBuilder<'a> {
#[allow(dead_code)]
pool: &'a MySqlPool,
table: String,
fields: Vec<String>,
#[allow(dead_code)]
conditions: Vec<Condition>,
#[allow(dead_code)]
joins: Vec<JoinClause>,
#[allow(dead_code)]
order_by: Vec<OrderClause>,
#[allow(dead_code)]
group_by: Vec<String>,
limit: Option<u64>,
offset: Option<u64>,
distinct: bool,
field_types: HashMap<String, FieldType>,
#[allow(dead_code)]
enable_logging: bool,
}
impl<'a> QueryBuilder<'a> {
pub(crate) fn new(pool: &'a MySqlPool, table_name: &str, enable_logging: bool) -> Self {
Self {
pool,
table: table_name.to_string(),
fields: Vec::new(),
conditions: Vec::new(),
joins: Vec::new(),
order_by: Vec::new(),
group_by: Vec::new(),
limit: None,
offset: None,
distinct: false,
field_types: HashMap::new(),
enable_logging,
}
}
pub fn field(mut self, field: &str) -> Self {
self.fields.push(field.to_string());
self
}
pub fn fields(mut self, fields: &[&str]) -> Self {
for field in fields {
self.fields.push(field.to_string());
}
self
}
pub fn json(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Json);
self
}
pub fn datetime(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::DateTime);
self
}
pub fn timestamp(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::Timestamp);
self
}
pub fn decimal(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::Decimal);
self
}
pub fn blob(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Blob);
self
}
pub fn text(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Text);
self
}
pub fn distinct(mut self) -> Self {
self.distinct = true;
self
}
pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
where
V: Into<crate::mysql::condition::SqlValue>,
{
use crate::mysql::condition::{Condition, SqlValue};
let sql_value = value.into();
let condition = match op {
"=" => Condition::Eq(field.to_string(), sql_value),
"!=" => Condition::Ne(field.to_string(), sql_value),
">" => Condition::Gt(field.to_string(), sql_value),
"<" => Condition::Lt(field.to_string(), sql_value),
">=" => Condition::Gte(field.to_string(), sql_value),
"<=" => Condition::Lte(field.to_string(), sql_value),
"like" | "LIKE" => {
if let SqlValue::String(s) = sql_value {
Condition::Like(field.to_string(), s)
} else {
Condition::Like(field.to_string(), format!("{:?}", sql_value))
}
}
_ => panic!("不支持的操作符: {}", op),
};
self.conditions.push(condition);
self
}
pub fn where_or<V>(mut self, field: &str, op: &str, value: V) -> Self
where
V: Into<crate::mysql::condition::SqlValue>,
{
use crate::mysql::condition::{Condition, SqlValue};
let sql_value = value.into();
let condition = match op {
"=" => Condition::Eq(field.to_string(), sql_value),
"!=" => Condition::Ne(field.to_string(), sql_value),
">" => Condition::Gt(field.to_string(), sql_value),
"<" => Condition::Lt(field.to_string(), sql_value),
">=" => Condition::Gte(field.to_string(), sql_value),
"<=" => Condition::Lte(field.to_string(), sql_value),
"like" | "LIKE" => {
if let SqlValue::String(s) = sql_value {
Condition::Like(field.to_string(), s)
} else {
Condition::Like(field.to_string(), format!("{:?}", sql_value))
}
}
_ => panic!("不支持的操作符: {}", op),
};
if !self.conditions.is_empty() {
let existing = std::mem::take(&mut self.conditions);
self.conditions.push(Condition::Or(vec![
if existing.len() == 1 {
existing.into_iter().next().unwrap()
} else {
Condition::And(existing)
},
condition,
]));
} else {
self.conditions.push(condition);
}
self
}
pub fn where_in<V>(mut self, field: &str, values: Vec<V>) -> Self
where
V: Into<crate::mysql::condition::SqlValue>,
{
use crate::mysql::condition::Condition;
let sql_values: Vec<_> = values.into_iter().map(|v| v.into()).collect();
self.conditions
.push(Condition::In(field.to_string(), sql_values));
self
}
pub fn where_between<V>(mut self, field: &str, start: V, end: V) -> Self
where
V: Into<crate::mysql::condition::SqlValue>,
{
use crate::mysql::condition::Condition;
self.conditions.push(Condition::Between(
field.to_string(),
start.into(),
end.into(),
));
self
}
pub fn join(mut self, table: &str, on: &str) -> Self {
use crate::mysql::field::{JoinClause, JoinType};
self.joins.push(JoinClause {
join_type: JoinType::Inner,
table: table.to_string(),
on: on.to_string(),
});
self
}
pub fn left_join(mut self, table: &str, on: &str) -> Self {
use crate::mysql::field::{JoinClause, JoinType};
self.joins.push(JoinClause {
join_type: JoinType::Left,
table: table.to_string(),
on: on.to_string(),
});
self
}
pub fn right_join(mut self, table: &str, on: &str) -> Self {
use crate::mysql::field::{JoinClause, JoinType};
self.joins.push(JoinClause {
join_type: JoinType::Right,
table: table.to_string(),
on: on.to_string(),
});
self
}
pub fn order(mut self, field: &str, asc: bool) -> Self {
use crate::mysql::field::OrderClause;
self.order_by.push(OrderClause {
field: field.to_string(),
asc,
});
self
}
pub fn group(mut self, field: &str) -> Self {
self.group_by.push(field.to_string());
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.offset = Some(offset);
self
}
pub fn to_sql(&self) -> String {
let mut generator = SqlGenerator::new();
match generator.build_select(self) {
Ok(_) => generator.get_sql().to_string(),
Err(_) => {
let fields_str = if self.fields.is_empty() {
"*".to_string()
} else {
self.fields.join(", ")
};
let distinct_str = if self.distinct { "DISTINCT " } else { "" };
format!("SELECT {}{} FROM {}", distinct_str, fields_str, self.table)
}
}
}
pub async fn find<T>(mut self) -> Result<Option<T>, crate::error::DbError>
where
T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
{
self.limit = Some(1);
let mut generator = SqlGenerator::new();
generator.build_select(&self)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 find() 查询: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query_as::<_, T>(sql);
for param in params {
query = bind_param(query, param);
}
let result = query.fetch_optional(self.pool).await;
match result {
Ok(row) => {
if self.enable_logging {
if row.is_some() {
log::debug!("find() 查询成功,返回 1 条记录");
} else {
log::debug!("find() 查询成功,未找到匹配记录");
}
}
Ok(row)
}
Err(e) => {
log::error!("find() 查询失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn select<T>(self) -> Result<Vec<T>, crate::error::DbError>
where
T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
{
let mut generator = SqlGenerator::new();
generator.build_select(&self)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 select() 查询: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query_as::<_, T>(sql);
for param in params {
query = bind_param(query, param);
}
let result = query.fetch_all(self.pool).await;
match result {
Ok(rows) => {
if self.enable_logging {
log::debug!("select() 查询成功,返回 {} 条记录", rows.len());
}
Ok(rows)
}
Err(e) => {
log::error!("select() 查询失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn value<T>(mut self, field: &str) -> Result<Option<T>, crate::error::DbError>
where
T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
{
self.fields.clear();
self.fields.push(field.to_string());
self.limit = Some(1);
let mut generator = SqlGenerator::new();
generator.build_select(&self)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 value() 查询: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query_scalar::<_, T>(sql);
for param in params {
query = bind_scalar_param(query, param);
}
let result = query.fetch_optional(self.pool).await;
match result {
Ok(value) => {
if self.enable_logging {
if value.is_some() {
log::debug!("value() 查询成功,返回字段值");
} else {
log::debug!("value() 查询成功,未找到匹配记录");
}
}
Ok(value)
}
Err(e) => {
log::error!("value() 查询失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn count(self) -> Result<i64, crate::error::DbError> {
if self.enable_logging {
log::debug!("执行 count() 查询");
}
let result = self.value::<i64>("COUNT(*)").await?;
Ok(result.unwrap_or(0))
}
pub async fn sum(self, field: &str) -> Result<Option<f64>, crate::error::DbError> {
if self.enable_logging {
log::debug!("执行 sum() 查询,字段: {}", field);
}
let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
let mut builder = self;
builder.fields.clear();
builder.fields.push(sum_expr.clone());
builder.limit = Some(1);
let mut generator = SqlGenerator::new();
generator.build_select(&builder)?;
let sql = generator.get_sql();
let params = generator.get_params();
if builder.enable_logging {
log::debug!("执行 sum() 查询: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query_scalar::<_, Option<f64>>(sql);
for param in params {
query = bind_scalar_param_option(query, param);
}
let result = query.fetch_optional(builder.pool).await;
match result {
Ok(Some(value)) => {
if builder.enable_logging {
if value.is_some() {
log::debug!("sum() 查询成功,返回总和");
} else {
log::debug!("sum() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
}
}
Ok(value)
}
Ok(None) => {
if builder.enable_logging {
log::debug!("sum() 查询成功,未找到匹配记录");
}
Ok(None)
}
Err(e) => {
log::error!("sum() 查询失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn insert<T>(self, data: &T) -> Result<u64, crate::error::DbError>
where
T: serde::Serialize,
{
if self.enable_logging {
log::debug!("执行 insert() 操作,表: {}", self.table);
}
let json_data = serde_json::to_value(data).map_err(|e| {
crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
})?;
let mut generator = SqlGenerator::new();
generator.build_insert(&self.table, &json_data, &self.field_types)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 insert() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
let result = query.execute(self.pool).await;
match result {
Ok(query_result) => {
let last_insert_id = query_result.last_insert_id();
if self.enable_logging {
log::debug!("insert() 成功,插入 ID: {}", last_insert_id);
}
Ok(last_insert_id)
}
Err(e) => {
log::error!("insert() 失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn insert_batch<T>(self, data: &[T]) -> Result<u64, crate::error::DbError>
where
T: serde::Serialize,
{
if self.enable_logging {
log::debug!(
"执行 insert_batch() 操作,表: {},记录数: {}",
self.table,
data.len()
);
}
if data.is_empty() {
return Err(crate::error::DbError::SerializationError(
"批量插入数据不能为空".to_string(),
));
}
let json_data_list: Result<Vec<_>, _> = data
.iter()
.map(|item| {
serde_json::to_value(item).map_err(|e| {
crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
})
})
.collect();
let json_data_list = json_data_list?;
let mut generator = SqlGenerator::new();
generator.build_insert_batch(&self.table, &json_data_list, &self.field_types)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 insert_batch() SQL: {}", sql);
log::debug!("参数数量: {}", params.len());
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
let result = query.execute(self.pool).await;
match result {
Ok(query_result) => {
let rows_affected = query_result.rows_affected();
if self.enable_logging {
log::debug!("insert_batch() 成功,影响 {} 行", rows_affected);
}
Ok(rows_affected)
}
Err(e) => {
log::error!("insert_batch() 失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn update<T>(self, data: &T) -> Result<u64, crate::error::DbError>
where
T: serde::Serialize,
{
if self.enable_logging {
log::debug!("执行 update() 操作,表: {}", self.table);
}
if self.conditions.is_empty() {
log::warn!("update() 操作缺少 WHERE 条件,禁止全表更新");
return Err(crate::error::DbError::MissingWhereClause);
}
let json_data = serde_json::to_value(data).map_err(|e| {
crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
})?;
let mut generator = SqlGenerator::new();
generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 update() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
let result = query.execute(self.pool).await;
match result {
Ok(query_result) => {
let rows_affected = query_result.rows_affected();
if self.enable_logging {
log::debug!("update() 成功,影响 {} 行", rows_affected);
}
Ok(rows_affected)
}
Err(e) => {
log::error!("update() 失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
pub async fn delete(self) -> Result<u64, crate::error::DbError> {
if self.enable_logging {
log::debug!("执行 delete() 操作,表: {}", self.table);
}
if self.conditions.is_empty() {
log::warn!("delete() 操作缺少 WHERE 条件,禁止全表删除");
return Err(crate::error::DbError::MissingWhereClause);
}
let mut generator = SqlGenerator::new();
generator.build_delete(&self.table, &self.conditions)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.enable_logging {
log::debug!("执行 delete() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
let result = query.execute(self.pool).await;
match result {
Ok(query_result) => {
let rows_affected = query_result.rows_affected();
if self.enable_logging {
log::debug!("delete() 成功,影响 {} 行", rows_affected);
}
Ok(rows_affected)
}
Err(e) => {
log::error!("delete() 失败: {}", e);
Err(crate::error::DbError::from(e))
}
}
}
}
fn bind_execute_param<'q>(
query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
param: &SqlValue,
) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
match param {
SqlValue::Null => query.bind(Option::<i32>::None),
SqlValue::Bool(b) => query.bind(*b),
SqlValue::Int(i) => query.bind(*i),
SqlValue::Float(f) => query.bind(*f),
SqlValue::String(s) => query.bind(s.clone()),
SqlValue::Bytes(b) => query.bind(b.clone()),
SqlValue::Json(j) => query.bind(j.to_string()),
SqlValue::DateTime(dt) => query.bind(*dt),
SqlValue::Timestamp(ts) => query.bind(*ts),
}
}
fn bind_param<'q, T>(
query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
param: &SqlValue,
) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
where
T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
{
match param {
SqlValue::Null => query.bind(Option::<i32>::None),
SqlValue::Bool(b) => query.bind(*b),
SqlValue::Int(i) => query.bind(*i),
SqlValue::Float(f) => query.bind(*f),
SqlValue::String(s) => query.bind(s.clone()),
SqlValue::Bytes(b) => query.bind(b.clone()),
SqlValue::Json(j) => query.bind(j.to_string()),
SqlValue::DateTime(dt) => query.bind(*dt),
SqlValue::Timestamp(ts) => query.bind(*ts),
}
}
fn bind_scalar_param<'q, T>(
query: sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
param: &SqlValue,
) -> sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
where
T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
{
match param {
SqlValue::Null => query.bind(Option::<i32>::None),
SqlValue::Bool(b) => query.bind(*b),
SqlValue::Int(i) => query.bind(*i),
SqlValue::Float(f) => query.bind(*f),
SqlValue::String(s) => query.bind(s.clone()),
SqlValue::Bytes(b) => query.bind(b.clone()),
SqlValue::Json(j) => query.bind(j.to_string()),
SqlValue::DateTime(dt) => query.bind(*dt),
SqlValue::Timestamp(ts) => query.bind(*ts),
}
}
fn bind_scalar_param_option<'q, T>(
query: sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>,
param: &SqlValue,
) -> sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>
where
T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
{
match param {
SqlValue::Null => query.bind(Option::<i32>::None),
SqlValue::Bool(b) => query.bind(*b),
SqlValue::Int(i) => query.bind(*i),
SqlValue::Float(f) => query.bind(*f),
SqlValue::String(s) => query.bind(s.clone()),
SqlValue::Bytes(b) => query.bind(b.clone()),
SqlValue::Json(j) => query.bind(j.to_string()),
SqlValue::DateTime(dt) => query.bind(*dt),
SqlValue::Timestamp(ts) => query.bind(*ts),
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::mysql::MySqlPoolOptions;
async fn create_test_pool() -> MySqlPool {
MySqlPoolOptions::new()
.max_connections(1)
.connect("mysql://root:111111@localhost:3306/test")
.await
.expect("无法连接到测试数据库")
}
#[tokio::test]
async fn test_table_name_in_sql() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false);
let sql = builder.to_sql();
assert!(sql.contains("FROM users"));
}
#[test]
fn test_sql_generator_new() {
let generator = SqlGenerator::new();
assert_eq!(generator.get_sql(), "");
assert_eq!(generator.get_params().len(), 0);
}
#[test]
fn test_sql_generator_append() {
let mut generator = SqlGenerator::new();
generator.append("SELECT * FROM users");
assert_eq!(generator.get_sql(), "SELECT * FROM users");
}
#[test]
fn test_sql_generator_add_param() {
let mut generator = SqlGenerator::new();
generator.add_param(SqlValue::Int(42));
generator.add_param(SqlValue::String("test".to_string()));
assert_eq!(generator.get_params().len(), 2);
}
#[test]
fn test_sql_generator_clear() {
let mut generator = SqlGenerator::new();
generator.append("SELECT * FROM users");
generator.add_param(SqlValue::Int(1));
generator.clear();
assert_eq!(generator.get_sql(), "");
assert_eq!(generator.get_params().len(), 0);
}
#[test]
fn test_sql_generator_multiple_operations() {
let mut generator = SqlGenerator::new();
generator.append("SELECT * FROM users WHERE id = ?");
generator.add_param(SqlValue::Int(1));
generator.append(" AND name = ?");
generator.add_param(SqlValue::String("test".to_string()));
assert_eq!(
generator.get_sql(),
"SELECT * FROM users WHERE id = ? AND name = ?"
);
assert_eq!(generator.get_params().len(), 2);
}
#[tokio::test]
async fn test_field_selection() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("id")
.field("name");
let sql = builder.to_sql();
assert!(sql.contains("id, name"));
}
#[tokio::test]
async fn test_fields_selection() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false).fields(&["id", "name", "email"]);
let sql = builder.to_sql();
assert!(sql.contains("id, name, email"));
}
#[tokio::test]
async fn test_distinct() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("name")
.distinct();
let sql = builder.to_sql();
assert!(sql.contains("SELECT DISTINCT"));
}
#[tokio::test]
async fn test_field_type_marking() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.json("data")
.datetime("created_at")
.timestamp("updated_at")
.decimal("price")
.blob("content")
.text("description");
assert_eq!(builder.field_types.get("data"), Some(&FieldType::Json));
assert_eq!(
builder.field_types.get("created_at"),
Some(&FieldType::DateTime)
);
assert_eq!(
builder.field_types.get("updated_at"),
Some(&FieldType::Timestamp)
);
assert_eq!(builder.field_types.get("price"), Some(&FieldType::Decimal));
assert_eq!(builder.field_types.get("content"), Some(&FieldType::Blob));
assert_eq!(
builder.field_types.get("description"),
Some(&FieldType::Text)
);
}
#[tokio::test]
async fn test_where_and() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.where_and("name", "=", "test")
.where_and("age", ">", 18);
assert_eq!(builder.conditions.len(), 2);
}
#[tokio::test]
async fn test_where_or() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.where_or("status", "=", 1)
.where_or("status", "=", 2);
assert_eq!(builder.conditions.len(), 1);
}
#[tokio::test]
async fn test_where_in() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false).where_in("id", vec![1, 2, 3]);
assert_eq!(builder.conditions.len(), 1);
}
#[tokio::test]
async fn test_where_between() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false).where_between("age", 18, 65);
assert_eq!(builder.conditions.len(), 1);
}
#[tokio::test]
async fn test_join() {
let pool = create_test_pool().await;
let builder =
QueryBuilder::new(&pool, "users", false).join("orders", "users.id = orders.user_id");
assert_eq!(builder.joins.len(), 1);
}
#[tokio::test]
async fn test_left_join() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.left_join("orders", "users.id = orders.user_id");
assert_eq!(builder.joins.len(), 1);
}
#[tokio::test]
async fn test_right_join() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.right_join("orders", "users.id = orders.user_id");
assert_eq!(builder.joins.len(), 1);
}
#[tokio::test]
async fn test_order() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.order("name", true)
.order("age", false);
assert_eq!(builder.order_by.len(), 2);
}
#[tokio::test]
async fn test_group() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.group("status")
.group("role");
assert_eq!(builder.group_by.len(), 2);
}
#[tokio::test]
async fn test_select_with_where() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("id")
.field("name")
.where_and("status", "=", 1);
let sql = builder.to_sql();
assert!(sql.contains("SELECT id, name FROM users"));
assert!(sql.contains("WHERE"));
}
#[tokio::test]
async fn test_select_with_join() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("users.id")
.field("orders.total")
.join("orders", "users.id = orders.user_id");
let sql = builder.to_sql();
assert!(sql.contains("SELECT users.id, orders.total FROM users"));
assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
}
#[tokio::test]
async fn test_select_with_order_by() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("name")
.order("name", true)
.order("age", false);
let sql = builder.to_sql();
assert!(sql.contains("ORDER BY name ASC, age DESC"));
}
#[tokio::test]
async fn test_select_with_group_by() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("status")
.group("status");
let sql = builder.to_sql();
assert!(sql.contains("GROUP BY status"));
}
#[tokio::test]
async fn test_select_with_limit_offset() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("id")
.limit(10)
.offset(20);
let sql = builder.to_sql();
assert!(sql.contains("LIMIT 10"));
assert!(sql.contains("OFFSET 20"));
}
#[tokio::test]
async fn test_select_complex_query() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("users.id")
.field("users.name")
.field("orders.total")
.distinct()
.join("orders", "users.id = orders.user_id")
.where_and("users.status", "=", 1)
.where_and("orders.total", ">", 100)
.group("users.id")
.order("orders.total", false)
.limit(50);
let sql = builder.to_sql();
assert!(sql.contains("SELECT DISTINCT"));
assert!(sql.contains("users.id, users.name, orders.total"));
assert!(sql.contains("FROM users"));
assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
assert!(sql.contains("WHERE"));
assert!(sql.contains("GROUP BY users.id"));
assert!(sql.contains("ORDER BY orders.total DESC"));
assert!(sql.contains("LIMIT 50"));
}
#[tokio::test]
async fn test_select_with_multiple_joins() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("users.name")
.field("orders.total")
.field("products.name")
.join("orders", "users.id = orders.user_id")
.left_join("products", "orders.product_id = products.id");
let sql = builder.to_sql();
assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
assert!(sql.contains("LEFT JOIN products ON orders.product_id = products.id"));
}
#[tokio::test]
async fn test_select_with_in_condition() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("name")
.where_in("id", vec![1, 2, 3, 4, 5]);
let sql = builder.to_sql();
assert!(sql.contains("WHERE"));
assert!(sql.contains("IN"));
}
#[tokio::test]
async fn test_select_with_between_condition() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("name")
.where_between("age", 18, 65);
let sql = builder.to_sql();
assert!(sql.contains("WHERE"));
assert!(sql.contains("BETWEEN"));
}
#[tokio::test]
async fn test_sql_generator_build_select_basic() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("id")
.field("name");
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
assert_eq!(generator.get_sql(), "SELECT id, name FROM users");
}
#[tokio::test]
async fn test_sql_generator_build_select_with_distinct() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("name")
.distinct();
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
assert_eq!(generator.get_sql(), "SELECT DISTINCT name FROM users");
}
#[tokio::test]
async fn test_sql_generator_build_select_all_fields() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false);
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
assert_eq!(generator.get_sql(), "SELECT * FROM users");
}
#[tokio::test]
async fn test_sql_generator_build_where() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.where_and("status", "=", 1)
.where_and("age", ">", 18);
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("WHERE"));
assert!(sql.contains("status"));
assert!(sql.contains("age"));
}
#[tokio::test]
async fn test_sql_generator_build_joins() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.join("orders", "users.id = orders.user_id")
.left_join("profiles", "users.id = profiles.user_id");
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
assert!(sql.contains("LEFT JOIN profiles ON users.id = profiles.user_id"));
}
#[tokio::test]
async fn test_sql_generator_build_order_by() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.order("name", true)
.order("created_at", false);
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("ORDER BY name ASC, created_at DESC"));
}
#[tokio::test]
async fn test_sql_generator_build_group_by() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.group("status")
.group("role");
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("GROUP BY status, role"));
}
#[tokio::test]
async fn test_sql_generator_build_limit_offset() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.limit(10)
.offset(20);
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("LIMIT 10"));
assert!(sql.contains("OFFSET 20"));
}
#[tokio::test]
async fn test_sql_generator_complex_query() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("users.id")
.field("users.name")
.field("COUNT(orders.id) as order_count")
.distinct()
.join("orders", "users.id = orders.user_id")
.where_and("users.status", "=", 1)
.where_and("orders.total", ">", 100)
.group("users.id")
.group("users.name")
.order("order_count", false)
.limit(20)
.offset(10);
let mut generator = SqlGenerator::new();
let result = generator.build_select(&builder);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.starts_with("SELECT DISTINCT"));
assert!(sql.contains("users.id, users.name, COUNT(orders.id) as order_count"));
assert!(sql.contains("FROM users"));
assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
assert!(sql.contains("WHERE"));
assert!(sql.contains("GROUP BY users.id, users.name"));
assert!(sql.contains("ORDER BY order_count DESC"));
assert!(sql.contains("LIMIT 20"));
assert!(sql.contains("OFFSET 10"));
}
#[tokio::test]
async fn test_find_adds_limit_one() {
let pool = create_test_pool().await;
let builder = QueryBuilder::new(&pool, "users", false)
.field("id")
.field("name")
.where_and("id", "=", 1);
assert_eq!(builder.limit, None);
let builder_with_limit = QueryBuilder::new(&pool, "users", false)
.field("id")
.field("name")
.where_and("id", "=", 1)
.limit(1);
let sql = builder_with_limit.to_sql();
assert!(sql.contains("LIMIT 1"), "find() 应该自动添加 LIMIT 1");
}
#[test]
fn test_sql_generator_build_insert_basic() {
let mut generator = SqlGenerator::new();
let data = serde_json::json!({
"name": "张三",
"age": 25,
"email": "zhangsan@example.com"
});
let field_types = HashMap::new();
let result = generator.build_insert("users", &data, &field_types);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.starts_with("INSERT INTO users"));
assert!(sql.contains("name"));
assert!(sql.contains("age"));
assert!(sql.contains("email"));
assert!(sql.contains("VALUES"));
assert_eq!(generator.get_params().len(), 3);
}
#[test]
fn test_sql_generator_build_insert_with_json_field() {
let mut generator = SqlGenerator::new();
let data = serde_json::json!({
"name": "测试用户",
"data": {"role": "admin", "permissions": ["read", "write"]}
});
let mut field_types = HashMap::new();
field_types.insert("data".to_string(), FieldType::Json);
let result = generator.build_insert("users", &data, &field_types);
assert!(result.is_ok());
let sql = generator.get_sql();
assert!(sql.contains("INSERT INTO users"));
assert!(sql.contains("name"));
assert!(sql.contains("data"));
assert_eq!(generator.get_params().len(), 2);
let params = generator.get_params();
let has_json = params.iter().any(|p| matches!(p, SqlValue::Json(_)));
assert!(has_json, "应该包含 JSON 类型的参数");
}
#[test]
fn test_sql_generator_build_insert_empty_data() {
let mut generator = SqlGenerator::new();
let data = serde_json::json!({});
let field_types = HashMap::new();
let result = generator.build_insert("users", &data, &field_types);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::DbError::SerializationError(_)
));
}
#[test]
fn test_sql_generator_build_insert_not_object() {
let mut generator = SqlGenerator::new();
let data = serde_json::json!([1, 2, 3]); let field_types = HashMap::new();
let result = generator.build_insert("users", &data, &field_types);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::DbError::SerializationError(_)
));
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
use sqlx::mysql::MySqlPoolOptions;
fn table_name_strategy() -> impl Strategy<Value = String> {
"[a-z][a-z0-9_]{0,30}"
}
fn field_name_strategy() -> impl Strategy<Value = String> {
"[a-z][a-z0-9_]{0,30}"
}
fn create_test_pool_sync() -> MySqlPool {
tokio::runtime::Runtime::new().unwrap().block_on(async {
MySqlPoolOptions::new()
.max_connections(1)
.connect("mysql://root:111111@localhost:3306/test")
.await
.expect("无法连接到测试数据库")
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_table_name_in_sql(table_name in table_name_strategy()) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false);
let sql = builder.to_sql();
let expected = format!("FROM {}", table_name);
prop_assert!(sql.contains(&expected));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_table_name_override(
table_name1 in table_name_strategy(),
table_name2 in table_name_strategy()
) {
prop_assume!(table_name1 != table_name2);
let pool = create_test_pool_sync();
let builder1 = QueryBuilder::new(&pool, &table_name1, false);
let sql1 = builder1.to_sql();
let expected1 = format!("FROM {}", table_name1);
prop_assert!(sql1.contains(&expected1));
let builder2 = QueryBuilder::new(&pool, &table_name2, false);
let sql2 = builder2.to_sql();
let expected2 = format!("FROM {}", table_name2);
prop_assert!(sql2.contains(&expected2));
let pattern1 = format!("FROM {} ", table_name1);
let pattern1_alt = format!("FROM {}\n", table_name1);
prop_assert!(!sql2.contains(&pattern1) && !sql2.contains(&pattern1_alt));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_field_selection(
table_name in table_name_strategy(),
fields in prop::collection::vec(field_name_strategy(), 1..10)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for field in &fields {
builder = builder.field(field);
}
let sql = builder.to_sql();
for field in &fields {
prop_assert!(sql.contains(field));
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_distinct_keyword(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.field(&field)
.distinct();
let sql = builder.to_sql();
prop_assert!(sql.contains("SELECT DISTINCT"));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_special_field_type_marking(
table_name in table_name_strategy(),
json_field in field_name_strategy(),
datetime_field in field_name_strategy(),
timestamp_field in field_name_strategy(),
decimal_field in field_name_strategy(),
blob_field in field_name_strategy(),
text_field in field_name_strategy()
) {
prop_assume!(json_field != datetime_field);
prop_assume!(json_field != timestamp_field);
prop_assume!(json_field != decimal_field);
prop_assume!(json_field != blob_field);
prop_assume!(json_field != text_field);
prop_assume!(datetime_field != timestamp_field);
prop_assume!(datetime_field != decimal_field);
prop_assume!(datetime_field != blob_field);
prop_assume!(datetime_field != text_field);
prop_assume!(timestamp_field != decimal_field);
prop_assume!(timestamp_field != blob_field);
prop_assume!(timestamp_field != text_field);
prop_assume!(decimal_field != blob_field);
prop_assume!(decimal_field != text_field);
prop_assume!(blob_field != text_field);
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.json(&json_field)
.datetime(&datetime_field)
.timestamp(×tamp_field)
.decimal(&decimal_field)
.blob(&blob_field)
.text(&text_field);
prop_assert_eq!(builder.field_types.get(&json_field), Some(&FieldType::Json));
prop_assert_eq!(builder.field_types.get(&datetime_field), Some(&FieldType::DateTime));
prop_assert_eq!(builder.field_types.get(×tamp_field), Some(&FieldType::Timestamp));
prop_assert_eq!(builder.field_types.get(&decimal_field), Some(&FieldType::Decimal));
prop_assert_eq!(builder.field_types.get(&blob_field), Some(&FieldType::Blob));
prop_assert_eq!(builder.field_types.get(&text_field), Some(&FieldType::Text));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_where_and_condition_added(
table_name in table_name_strategy(),
field in field_name_strategy(),
value in any::<i32>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", value);
prop_assert_eq!(builder.conditions.len(), 1);
}
#[test]
fn prop_where_or_condition_added(
table_name in table_name_strategy(),
field in field_name_strategy(),
value1 in any::<i32>(),
value2 in any::<i32>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_or(&field, "=", value1)
.where_or(&field, "=", value2);
prop_assert_eq!(builder.conditions.len(), 1);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_in_operator_array_support(
table_name in table_name_strategy(),
field in field_name_strategy(),
values in prop::collection::vec(any::<i32>(), 1..10)
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_in(&field, values);
prop_assert_eq!(builder.conditions.len(), 1);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_between_operator_boundary_support(
table_name in table_name_strategy(),
field in field_name_strategy(),
start in any::<i32>(),
end in any::<i32>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_between(&field, start, end);
prop_assert_eq!(builder.conditions.len(), 1);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_multiple_and_conditions(
table_name in table_name_strategy(),
field in field_name_strategy(),
values in prop::collection::vec(any::<i32>(), 2..5)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for value in &values {
builder = builder.where_and(&field, "=", *value);
}
prop_assert_eq!(builder.conditions.len(), values.len());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_join_clause_generation(
table_name in table_name_strategy(),
join_table in table_name_strategy(),
on_condition in "[a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20} = [a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20}"
) {
let pool = create_test_pool_sync();
let builder_inner = QueryBuilder::new(&pool, &table_name, false)
.join(&join_table, &on_condition);
prop_assert_eq!(builder_inner.joins.len(), 1);
let builder_left = QueryBuilder::new(&pool, &table_name, false)
.left_join(&join_table, &on_condition);
prop_assert_eq!(builder_left.joins.len(), 1);
let builder_right = QueryBuilder::new(&pool, &table_name, false)
.right_join(&join_table, &on_condition);
prop_assert_eq!(builder_right.joins.len(), 1);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_multiple_join_support(
table_name in table_name_strategy(),
join_tables in prop::collection::vec(table_name_strategy(), 1..5)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for join_table in &join_tables {
let on_condition = format!("{}.id = {}.id", table_name, join_table);
builder = builder.join(join_table, &on_condition);
}
prop_assert_eq!(builder.joins.len(), join_tables.len());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_table_alias_support(
base_table in table_name_strategy(),
join_table in table_name_strategy(),
base_alias in "[a-z][a-z0-9]{0,5}",
join_alias in "[a-z][a-z0-9]{0,5}"
) {
prop_assume!(base_table != join_table);
prop_assume!(base_alias != join_alias);
let pool = create_test_pool_sync();
let base_table_with_alias = format!("{} AS {}", base_table, base_alias);
let join_table_with_alias = format!("{} AS {}", join_table, join_alias);
let on_condition = format!("{}.id = {}.id", base_alias, join_alias);
let builder = QueryBuilder::new(&pool, &base_table_with_alias, false)
.field(&format!("{}.id", base_alias))
.field(&format!("{}.name", base_alias))
.join(&join_table_with_alias, &on_condition);
let sql = builder.to_sql();
prop_assert!(sql.contains(&format!("FROM {}", base_table_with_alias)),
"SQL 应该包含带别名的主表: FROM {}", base_table_with_alias);
prop_assert!(sql.contains(&join_table_with_alias),
"SQL 应该包含带别名的 JOIN 表: {}", join_table_with_alias);
prop_assert!(sql.contains(&on_condition),
"SQL 应该包含使用别名的 ON 条件: {}", on_condition);
prop_assert!(sql.contains(&format!("{}.id", base_alias)),
"SQL 应该包含使用别名的字段: {}.id", base_alias);
prop_assert!(sql.contains(&format!("{}.name", base_alias)),
"SQL 应该包含使用别名的字段: {}.name", base_alias);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_order_by_clause_generation(
table_name in table_name_strategy(),
field in field_name_strategy(),
asc in any::<bool>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.order(&field, asc);
prop_assert_eq!(builder.order_by.len(), 1);
prop_assert_eq!(&builder.order_by[0].field, &field);
prop_assert_eq!(builder.order_by[0].asc, asc);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_multiple_order_by_support(
table_name in table_name_strategy(),
fields in prop::collection::vec(field_name_strategy(), 1..5)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for field in &fields {
builder = builder.order(field, true);
}
prop_assert_eq!(builder.order_by.len(), fields.len());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_group_by_clause_generation(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.group(&field);
prop_assert_eq!(builder.group_by.len(), 1);
prop_assert_eq!(&builder.group_by[0], &field);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_multiple_group_by_support(
table_name in table_name_strategy(),
fields in prop::collection::vec(field_name_strategy(), 1..5)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for field in &fields {
builder = builder.group(field);
}
prop_assert_eq!(builder.group_by.len(), fields.len());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_to_sql_returns_valid_sql(
table_name in table_name_strategy(),
fields in prop::collection::vec(field_name_strategy(), 0..5),
use_distinct in any::<bool>(),
limit_opt in prop::option::of(1u64..100),
offset_opt in prop::option::of(0u64..100)
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for field in &fields {
builder = builder.field(field);
}
if use_distinct {
builder = builder.distinct();
}
if let Some(limit) = limit_opt {
builder = builder.limit(limit);
}
if let Some(offset) = offset_opt {
builder = builder.offset(offset);
}
let sql = builder.to_sql();
prop_assert!(!sql.is_empty(), "SQL 字符串不应为空");
prop_assert!(sql.contains("SELECT"), "SQL 应包含 SELECT 关键字");
prop_assert!(sql.contains("FROM"), "SQL 应包含 FROM 关键字");
prop_assert!(sql.contains(&table_name), "SQL 应包含表名");
if use_distinct {
prop_assert!(sql.contains("DISTINCT"), "SQL 应包含 DISTINCT 关键字");
}
if let Some(limit) = limit_opt {
prop_assert!(sql.contains("LIMIT"), "SQL 应包含 LIMIT 关键字");
prop_assert!(sql.contains(&limit.to_string()), "SQL 应包含 LIMIT 值");
}
if let Some(offset) = offset_opt {
prop_assert!(sql.contains("OFFSET"), "SQL 应包含 OFFSET 关键字");
prop_assert!(sql.contains(&offset.to_string()), "SQL 应包含 OFFSET 值");
}
if !fields.is_empty() {
for field in &fields {
prop_assert!(sql.contains(field), "SQL 应包含字段 {}", field);
}
} else {
prop_assert!(sql.contains("*"), "SQL 应包含 * 表示所有字段");
}
}
#[test]
fn prop_to_sql_with_conditions(
table_name in table_name_strategy(),
field in field_name_strategy(),
value in any::<i32>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", value);
let sql = builder.to_sql();
prop_assert!(!sql.is_empty());
prop_assert!(sql.contains("SELECT"));
prop_assert!(sql.contains("FROM"));
prop_assert!(sql.contains(&table_name));
prop_assert!(sql.contains("WHERE"), "SQL 应包含 WHERE 关键字");
}
#[test]
fn prop_to_sql_with_joins(
table_name in table_name_strategy(),
join_table in table_name_strategy(),
on_field1 in field_name_strategy(),
on_field2 in field_name_strategy()
) {
let pool = create_test_pool_sync();
let on_condition = format!("{}.{} = {}.{}", table_name, on_field1, join_table, on_field2);
let builder = QueryBuilder::new(&pool, &table_name, false)
.join(&join_table, &on_condition);
let sql = builder.to_sql();
prop_assert!(!sql.is_empty());
prop_assert!(sql.contains("SELECT"));
prop_assert!(sql.contains("FROM"));
prop_assert!(sql.contains("JOIN"), "SQL 应包含 JOIN 关键字");
prop_assert!(sql.contains(&join_table), "SQL 应包含连接的表名");
}
#[test]
fn prop_to_sql_with_order_and_group(
table_name in table_name_strategy(),
order_field in field_name_strategy(),
group_field in field_name_strategy(),
asc in any::<bool>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.order(&order_field, asc)
.group(&group_field);
let sql = builder.to_sql();
prop_assert!(!sql.is_empty());
prop_assert!(sql.contains("SELECT"));
prop_assert!(sql.contains("FROM"));
prop_assert!(sql.contains("ORDER BY"), "SQL 应包含 ORDER BY 关键字");
prop_assert!(sql.contains("GROUP BY"), "SQL 应包含 GROUP BY 关键字");
prop_assert!(sql.contains(&order_field), "SQL 应包含排序字段");
prop_assert!(sql.contains(&group_field), "SQL 应包含分组字段");
}
#[test]
fn prop_to_sql_complex_query(
table_name in table_name_strategy(),
fields in prop::collection::vec(field_name_strategy(), 1..3),
join_table in table_name_strategy(),
where_field in field_name_strategy(),
order_field in field_name_strategy(),
group_field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let mut builder = QueryBuilder::new(&pool, &table_name, false);
for field in &fields {
builder = builder.field(field);
}
let on_condition = format!("{}.id = {}.id", table_name, join_table);
builder = builder.join(&join_table, &on_condition);
builder = builder.where_and(&where_field, "=", 1);
builder = builder.order(&order_field, true);
builder = builder.group(&group_field);
builder = builder.limit(10);
let sql = builder.to_sql();
prop_assert!(!sql.is_empty());
prop_assert!(sql.contains("SELECT"));
prop_assert!(sql.contains("FROM"));
prop_assert!(sql.contains(&table_name));
prop_assert!(sql.contains("JOIN"));
prop_assert!(sql.contains("WHERE"));
prop_assert!(sql.contains("ORDER BY"));
prop_assert!(sql.contains("GROUP BY"));
prop_assert!(sql.contains("LIMIT"));
let select_pos = sql.find("SELECT").unwrap();
let from_pos = sql.find("FROM").unwrap();
let join_pos = sql.find("JOIN").unwrap();
let where_pos = sql.find("WHERE").unwrap();
let group_pos = sql.find("GROUP BY").unwrap();
let order_pos = sql.find("ORDER BY").unwrap();
let limit_pos = sql.find("LIMIT").unwrap();
prop_assert!(select_pos < from_pos, "SELECT 应在 FROM 之前");
prop_assert!(from_pos < join_pos, "FROM 应在 JOIN 之前");
prop_assert!(join_pos < where_pos, "JOIN 应在 WHERE 之前");
prop_assert!(where_pos < group_pos, "WHERE 应在 GROUP BY 之前");
prop_assert!(group_pos < order_pos, "GROUP BY 应在 ORDER BY 之前");
prop_assert!(order_pos < limit_pos, "ORDER BY 应在 LIMIT 之前");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sql_injection_prevention_single_quote(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_input in ".*'.*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询(? 占位符)");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(&malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_semicolon(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_input in ".*;.*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(&malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_comment(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_input in ".*--.*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(&malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_drop_table(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let malicious_input = "'; DROP TABLE users; --";
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input);
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
prop_assert!(!sql.to_uppercase().contains("DROP TABLE"),
"SQL 不应该包含 DROP TABLE 语句");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_union_select(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let malicious_input = "' UNION SELECT * FROM passwords --";
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input);
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let sql_upper = sql.to_uppercase();
let union_count = sql_upper.matches("UNION").count();
prop_assert_eq!(union_count, 0, "SQL 不应该包含 UNION 注入");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_or_always_true(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let malicious_input = "' OR '1'='1";
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input);
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
let or_count = where_clause.matches(" OR ").count();
prop_assert_eq!(or_count, 0, "不应该因为用户输入而产生 OR 条件");
}
#[test]
fn prop_sql_injection_prevention_multiple_special_chars(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_input in "[a-z0-9]*[';\"\\-][a-z0-9]*[';\"\\-][a-z0-9]*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "=", malicious_input.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(&malicious_input),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_in_operator(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_values in prop::collection::vec(".*[';].*", 1..5)
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_in(&field, malicious_values.clone());
let sql = builder.to_sql();
prop_assert!(sql.contains("IN"), "SQL 应该包含 IN 操作符");
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let placeholder_count = sql.matches("?").count();
prop_assert!(placeholder_count >= malicious_values.len(),
"每个 IN 值都应该有对应的参数占位符");
for malicious_value in &malicious_values {
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(malicious_value),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
}
#[test]
fn prop_sql_injection_prevention_like_operator(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_pattern in ".*[';].*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field, "like", malicious_pattern.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("LIKE"), "SQL 应该包含 LIKE 操作符");
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
prop_assert!(!where_clause.contains(&malicious_pattern),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
#[test]
fn prop_sql_injection_prevention_between_operator(
table_name in table_name_strategy(),
field in field_name_strategy(),
malicious_start in ".*[';].*",
malicious_end in ".*[';].*"
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_between(&field, malicious_start.as_str(), malicious_end.as_str());
let sql = builder.to_sql();
prop_assert!(sql.contains("BETWEEN"), "SQL 应该包含 BETWEEN 操作符");
prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
let placeholder_count = where_clause.matches("?").count();
prop_assert!(placeholder_count >= 2, "BETWEEN 应该有两个参数占位符");
prop_assert!(!where_clause.contains(&malicious_start),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
prop_assert!(!where_clause.contains(&malicious_end),
"WHERE 子句不应该直接包含用户输入的恶意字符串");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_find_adds_limit_one(
table_name in table_name_strategy(),
field in field_name_strategy(),
value in any::<i32>()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.field(&field)
.where_and(&field, "=", value)
.limit(1);
let sql = builder.to_sql();
prop_assert!(sql.contains("LIMIT 1"),
"find() 方法应该自动添加 LIMIT 1 到查询中");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_count_aggregation_function(
table_name in table_name_strategy()
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.field("COUNT(*)");
let sql = builder.to_sql();
prop_assert!(
sql.contains("COUNT(*)") || sql.contains("COUNT("),
"count() 方法应该生成包含 COUNT(*) 或 COUNT(field) 的 SQL 语句,实际 SQL: {}",
sql
);
prop_assert!(
sql.to_uppercase().contains("SELECT"),
"count() 方法应该生成 SELECT 语句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&format!("FROM {}", table_name)),
"count() 方法应该包含正确的表名,实际 SQL: {}",
sql
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_count_with_where_condition(
table_name in table_name_strategy(),
field_name in field_name_strategy(),
field_value in 1i32..1000i32,
) {
let pool = create_test_pool_sync();
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&field_name, "=", field_value)
.field("COUNT(*)");
let sql = builder.to_sql();
prop_assert!(
sql.contains("COUNT(*)"),
"带条件的 count() 查询应该包含 COUNT(*),实际 SQL: {}",
sql
);
prop_assert!(
sql.to_uppercase().contains("WHERE"),
"带条件的 count() 查询应该包含 WHERE 子句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&format!("FROM {}", table_name)),
"count() 方法应该包含正确的表名,实际 SQL: {}",
sql
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_count_specific_field(
table_name in table_name_strategy(),
field_name in field_name_strategy(),
) {
let pool = create_test_pool_sync();
let count_expr = format!("COUNT({})", field_name);
let builder = QueryBuilder::new(&pool, &table_name, false)
.field(&count_expr);
let sql = builder.to_sql();
prop_assert!(
sql.contains(&count_expr),
"COUNT 特定字段应该包含 COUNT(field_name),实际 SQL: {}",
sql
);
prop_assert!(
sql.to_uppercase().contains("SELECT"),
"COUNT 查询应该是 SELECT 语句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&format!("FROM {}", table_name)),
"COUNT 查询应该包含正确的表名,实际 SQL: {}",
sql
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sum_aggregation_function(
table_name in table_name_strategy(),
field in field_name_strategy()
) {
let pool = create_test_pool_sync();
let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
let builder = QueryBuilder::new(&pool, &table_name, false)
.field(&sum_expr);
let sql = builder.to_sql();
prop_assert!(
sql.contains("SUM("),
"sum() 方法应该生成包含 SUM(field) 的 SQL 语句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&field),
"sum() 方法生成的 SQL 应该包含指定的字段名 {},实际 SQL: {}",
field,
sql
);
prop_assert!(
sql.to_uppercase().contains("SELECT"),
"sum() 方法应该生成 SELECT 语句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&format!("FROM {}", table_name)),
"sum() 方法应该包含正确的表名,实际 SQL: {}",
sql
);
prop_assert!(
sql.to_uppercase().contains("CAST"),
"sum() 方法应该使用 CAST 转换结果为 DOUBLE,实际 SQL: {}",
sql
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sum_with_where_condition(
table_name in table_name_strategy(),
sum_field in field_name_strategy(),
where_field in field_name_strategy(),
where_value in 1i32..1000i32,
) {
prop_assume!(sum_field != where_field);
let pool = create_test_pool_sync();
let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&where_field, "=", where_value)
.field(&sum_expr);
let sql = builder.to_sql();
prop_assert!(
sql.contains("SUM("),
"带条件的 sum() 查询应该包含 SUM(field),实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&sum_field),
"sum() 方法应该包含求和字段名 {},实际 SQL: {}",
sum_field,
sql
);
prop_assert!(
sql.to_uppercase().contains("WHERE"),
"带条件的 sum() 查询应该包含 WHERE 子句,实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&format!("FROM {}", table_name)),
"sum() 方法应该包含正确的表名,实际 SQL: {}",
sql
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sum_with_multiple_conditions(
table_name in table_name_strategy(),
sum_field in field_name_strategy(),
where_field1 in field_name_strategy(),
where_field2 in field_name_strategy(),
value1 in 1i32..1000i32,
value2 in 1i32..1000i32,
) {
prop_assume!(sum_field != where_field1);
prop_assume!(sum_field != where_field2);
prop_assume!(where_field1 != where_field2);
let pool = create_test_pool_sync();
let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
let builder = QueryBuilder::new(&pool, &table_name, false)
.where_and(&where_field1, "=", value1)
.where_and(&where_field2, ">", value2)
.field(&sum_expr);
let sql = builder.to_sql();
prop_assert!(
sql.contains("SUM("),
"多条件 sum() 查询应该包含 SUM(field),实际 SQL: {}",
sql
);
prop_assert!(
sql.contains(&sum_field),
"sum() 方法应该包含求和字段名 {},实际 SQL: {}",
sum_field,
sql
);
prop_assert!(
sql.to_uppercase().contains("WHERE"),
"多条件查询应该包含 WHERE 子句,实际 SQL: {}",
sql
);
prop_assert!(
sql.to_uppercase().contains(" AND "),
"多个 where_and 条件应该用 AND 连接,实际 SQL: {}",
sql
);
}
}
}