use crate::mysql::query_builder::SqlGenerator;
use std::collections::HashMap;
#[test]
fn test_build_update_batch_single_record_single_field() {
let records = vec![serde_json::json!({"id": 1, "name": "Alice"})];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
assert!(
sql.starts_with("UPDATE users SET "),
"SQL 应以 'UPDATE users SET ' 开头,实际: {}",
sql
);
assert!(
sql.contains("name = CASE WHEN id=? THEN ?"),
"SQL 应包含 CASE WHEN 结构,实际: {}",
sql
);
assert!(sql.contains("END"), "SQL 应包含 END 关键字,实际: {}", sql);
assert!(
sql.contains("WHERE id IN (?)"),
"SQL 应包含 WHERE id IN (?) 子句,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_multiple_records_single_field() {
let records = vec![
serde_json::json!({"id": 1, "name": "Alice"}),
serde_json::json!({"id": 2, "name": "Bob"}),
serde_json::json!({"id": 3, "name": "Charlie"}),
];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
let when_count = sql.matches("WHEN id=?").count();
assert_eq!(
when_count, 3,
"3 条记录应生成 3 个 WHEN 分支,实际: {}",
sql
);
assert!(
sql.contains("WHERE id IN (?,?,?)"),
"WHERE IN 应包含 3 个占位符,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_single_record_multiple_fields() {
let records = vec![serde_json::json!({"id": 1, "name": "Alice", "age": 25})];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
let case_count = sql.matches("CASE WHEN").count();
assert_eq!(
case_count, 2,
"2 个更新字段应生成 2 个 CASE WHEN 块,实际: {}",
sql
);
let end_count = sql.matches("END").count();
assert_eq!(
end_count, 2,
"2 个更新字段应生成 2 个 END 关键字,实际: {}",
sql
);
assert!(
sql.contains(", "),
"多个字段之间应用 ', ' 分隔,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_multiple_records_multiple_fields() {
let records = vec![
serde_json::json!({"id": 1, "name": "Alice", "age": 25}),
serde_json::json!({"id": 2, "name": "Bob", "age": 30}),
];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
assert!(
sql.starts_with("UPDATE users SET "),
"SQL 应以 'UPDATE users SET ' 开头,实际: {}",
sql
);
let case_count = sql.matches("CASE WHEN").count();
assert_eq!(
case_count, 2,
"2 个更新字段应生成 2 个 CASE WHEN 块,实际: {}",
sql
);
let when_count = sql.matches("WHEN id=?").count();
assert_eq!(
when_count, 4,
"2 个字段 × 2 条记录 = 4 个 WHEN 分支,实际: {}",
sql
);
assert!(
sql.contains("WHERE id IN (?,?)"),
"WHERE IN 应包含 2 个占位符,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_param_count() {
let records = vec![
serde_json::json!({"id": 1, "name": "Alice", "age": 25}),
serde_json::json!({"id": 2, "name": "Bob", "age": 30}),
];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let params = generator.get_params();
assert_eq!(
params.len(),
10,
"2 字段 × 2 记录 × 2 + 2 记录 = 10 个参数,实际: {}",
params.len()
);
}
#[test]
fn test_build_update_batch_empty_records_returns_error() {
let records: Vec<serde_json::Value> = vec![];
let mut generator = SqlGenerator::new();
let result = generator.build_update_batch("users", &records, "id", &HashMap::new());
assert!(result.is_err(), "空记录列表应返回错误");
assert!(
matches!(
result.unwrap_err(),
crate::error::DbError::SerializationError(_)
),
"应返回 SerializationError"
);
}
#[test]
fn test_build_update_batch_only_id_field_returns_error() {
let records = vec![serde_json::json!({"id": 1})];
let mut generator = SqlGenerator::new();
let result = generator.build_update_batch("users", &records, "id", &HashMap::new());
assert!(result.is_err(), "只有主键字段时应返回错误");
assert!(
matches!(
result.unwrap_err(),
crate::error::DbError::SerializationError(_)
),
"应返回 SerializationError"
);
}
#[test]
fn test_build_update_batch_custom_id_field() {
let records = vec![
serde_json::json!({"user_id": 10, "status": "active"}),
serde_json::json!({"user_id": 20, "status": "inactive"}),
];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "user_id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
assert!(
sql.contains("WHEN user_id=?"),
"应使用自定义主键字段名 user_id,实际: {}",
sql
);
assert!(
sql.contains("WHERE user_id IN"),
"WHERE IN 应使用自定义主键字段名,实际: {}",
sql
);
assert!(
sql.contains("status = CASE"),
"status 字段应有 CASE WHEN 结构,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_sql_matches_expected_format() {
let records = vec![
serde_json::json!({"id": 1, "name": "Alice"}),
serde_json::json!({"id": 2, "name": "Bob"}),
];
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
assert!(
sql.starts_with("UPDATE users SET "),
"SQL 头部格式不正确,实际: {}",
sql
);
assert!(
sql.contains("name = CASE WHEN id=? THEN ?"),
"CASE WHEN 格式不正确,实际: {}",
sql
);
assert!(
sql.ends_with("WHERE id IN (?,?)"),
"SQL 尾部格式不正确,实际: {}",
sql
);
}
#[test]
fn test_build_update_batch_large_batch() {
let records: Vec<serde_json::Value> = (1..=100)
.map(|i| serde_json::json!({"id": i, "name": format!("User{}", i), "score": i * 10}))
.collect();
let mut generator = SqlGenerator::new();
generator
.build_update_batch("users", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
let params = generator.get_params();
assert!(sql.starts_with("UPDATE users SET "), "SQL 头部格式不正确");
assert!(sql.contains("WHERE id IN ("), "SQL 应包含 WHERE IN 子句");
let when_count = sql.matches("WHEN id=?").count();
assert_eq!(
when_count, 200,
"2 字段 × 100 记录 = 200 个 WHEN 分支,实际: {}",
when_count
);
assert_eq!(
params.len(),
500,
"2 字段 × 100 记录 × 2 + 100 记录 = 500 个参数,实际: {}",
params.len()
);
}
#[test]
fn test_build_update_batch_allocation_level_verification() {
let records: Vec<serde_json::Value> = (1..=4)
.map(|i| {
serde_json::json!({
"id": i,
"field_a": format!("a{}", i),
"field_b": i * 2,
"field_c": i as f64 * 1.5
})
})
.collect();
let mut generator = SqlGenerator::new();
generator
.build_update_batch("test_table", &records, "id", &HashMap::new())
.unwrap();
let sql = generator.get_sql();
let params = generator.get_params();
let case_count = sql.matches("CASE WHEN").count();
assert_eq!(case_count, 3, "M=3 个字段应生成 3 个 CASE WHEN 块");
let when_count = sql.matches("WHEN id=?").count();
assert_eq!(when_count, 12, "M=3 × N=4 = 12 个 WHEN 分支");
assert_eq!(
params.len(),
28,
"M=3 × N=4 × 2 + N=4 = 28 个参数,实际: {}",
params.len()
);
assert!(
sql.contains("WHERE id IN (?,?,?,?)"),
"WHERE IN 应包含 4 个占位符,实际: {}",
sql
);
}