use parking_lot::RwLock;
use smallvec::SmallVec;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use crate::sql::DatabaseType;
#[derive(Debug, Clone, Default)]
pub struct PreparedStatementStats {
pub hits: u64,
pub misses: u64,
pub cached_count: usize,
pub time_saved_ms: u64,
}
impl PreparedStatementStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
}
#[derive(Debug, Clone)]
pub struct CachedStatement {
pub sql: String,
pub name: String,
pub use_count: u64,
pub last_used: Instant,
pub prep_time_us: u64,
pub handle: Option<u64>,
}
pub struct PreparedStatementCache {
statements: RwLock<HashMap<String, CachedStatement>>,
capacity: usize,
hits: AtomicU64,
misses: AtomicU64,
time_saved_us: AtomicU64,
avg_prep_time_us: u64,
}
impl PreparedStatementCache {
pub fn new(capacity: usize) -> Self {
Self {
statements: RwLock::new(HashMap::with_capacity(capacity)),
capacity,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
time_saved_us: AtomicU64::new(0),
avg_prep_time_us: 500, }
}
pub fn get_or_create<F>(&self, name: &str, generator: F) -> CachedStatement
where
F: FnOnce() -> String,
{
{
let cache = self.statements.read();
if let Some(stmt) = cache.get(name) {
self.hits.fetch_add(1, Ordering::Relaxed);
self.time_saved_us
.fetch_add(stmt.prep_time_us, Ordering::Relaxed);
return stmt.clone();
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
let sql = generator();
let entry = CachedStatement {
sql,
name: name.to_string(),
use_count: 1,
last_used: Instant::now(),
prep_time_us: self.avg_prep_time_us,
handle: None,
};
let mut cache = self.statements.write();
if let Some(existing) = cache.get(name) {
self.hits.fetch_add(1, Ordering::Relaxed);
return existing.clone();
}
if cache.len() >= self.capacity {
self.evict_oldest(&mut cache);
}
cache.insert(name.to_string(), entry.clone());
entry
}
pub fn contains(&self, name: &str) -> bool {
self.statements.read().contains_key(name)
}
pub fn stats(&self) -> PreparedStatementStats {
let cache = self.statements.read();
PreparedStatementStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
cached_count: cache.len(),
time_saved_ms: self.time_saved_us.load(Ordering::Relaxed) / 1000,
}
}
pub fn clear(&self) {
self.statements.write().clear();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.time_saved_us.store(0, Ordering::Relaxed);
}
pub fn len(&self) -> usize {
self.statements.read().len()
}
pub fn is_empty(&self) -> bool {
self.statements.read().is_empty()
}
fn evict_oldest(&self, cache: &mut HashMap<String, CachedStatement>) {
if let Some((oldest_key, _)) = cache
.iter()
.min_by_key(|(_, v)| v.last_used)
.map(|(k, v)| (k.clone(), v.clone()))
{
cache.remove(&oldest_key);
}
}
pub fn record_use(&self, name: &str) {
if let Some(stmt) = self.statements.write().get_mut(name) {
stmt.use_count += 1;
stmt.last_used = Instant::now();
}
}
pub fn set_handle(&self, name: &str, handle: u64) {
if let Some(stmt) = self.statements.write().get_mut(name) {
stmt.handle = Some(handle);
}
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::new(256)
}
}
pub fn global_statement_cache() -> &'static PreparedStatementCache {
use std::sync::OnceLock;
static CACHE: OnceLock<PreparedStatementCache> = OnceLock::new();
CACHE.get_or_init(|| PreparedStatementCache::new(512))
}
#[derive(Debug, Clone, Copy)]
pub struct BatchConfig {
pub batch_size: usize,
pub max_payload_bytes: usize,
pub multi_row_insert: bool,
pub use_copy: bool,
pub parallelism: usize,
}
impl BatchConfig {
pub const fn default_config() -> Self {
Self {
batch_size: 1000,
max_payload_bytes: 16 * 1024 * 1024, multi_row_insert: true,
use_copy: false,
parallelism: 1,
}
}
pub fn for_database(db_type: DatabaseType) -> Self {
match db_type {
DatabaseType::PostgreSQL => Self {
batch_size: 1000,
max_payload_bytes: 64 * 1024 * 1024, multi_row_insert: true,
use_copy: true, parallelism: 4,
},
DatabaseType::MySQL => Self {
batch_size: 500, max_payload_bytes: 16 * 1024 * 1024, multi_row_insert: true,
use_copy: false,
parallelism: 2,
},
DatabaseType::SQLite => Self {
batch_size: 500,
max_payload_bytes: 1024 * 1024, multi_row_insert: true,
use_copy: false,
parallelism: 1, },
DatabaseType::MSSQL => Self {
batch_size: 1000,
max_payload_bytes: 32 * 1024 * 1024, multi_row_insert: true,
use_copy: false,
parallelism: 4,
},
}
}
pub fn auto_tune(db_type: DatabaseType, avg_row_size: usize, total_rows: usize) -> Self {
let mut config = Self::for_database(db_type);
let max_rows_by_payload = config
.max_payload_bytes
.checked_div(avg_row_size)
.unwrap_or(config.batch_size);
let optimal_batch = if total_rows < 100 {
total_rows } else if total_rows < 1000 {
(total_rows / 10).max(100)
} else {
let by_count = total_rows / 10;
by_count.min(max_rows_by_payload).clamp(100, 10_000)
};
config.batch_size = optimal_batch;
if total_rows < 1000 {
config.parallelism = 1;
} else if total_rows < 10_000 {
config.parallelism = config.parallelism.min(2);
}
if matches!(db_type, DatabaseType::PostgreSQL) && total_rows > 5000 {
config.use_copy = true;
}
config
}
pub fn batch_ranges(&self, total: usize) -> impl Iterator<Item = (usize, usize)> {
let batch_size = self.batch_size;
(0..total)
.step_by(batch_size)
.map(move |start| (start, (start + batch_size).min(total)))
}
pub fn batch_count(&self, total: usize) -> usize {
total.div_ceil(self.batch_size)
}
}
impl Default for BatchConfig {
fn default() -> Self {
Self::default_config()
}
}
#[derive(Debug, Clone, Default)]
pub struct MongoPipelineBuilder {
stages: Vec<PipelineStage>,
pub allow_disk_use: bool,
pub batch_size: Option<u32>,
pub max_time_ms: Option<u64>,
pub comment: Option<String>,
}
#[derive(Debug, Clone)]
pub enum PipelineStage {
Match(String),
Project(String),
Group { id: String, accumulators: String },
Sort(String),
Limit(u64),
Skip(u64),
Unwind { path: String, preserve_null: bool },
Lookup {
from: String,
local_field: String,
foreign_field: String,
r#as: String,
},
AddFields(String),
Set(String),
Unset(Vec<String>),
ReplaceRoot(String),
Count(String),
Facet(Vec<(String, Vec<PipelineStage>)>),
Bucket {
group_by: String,
boundaries: String,
default: Option<String>,
output: Option<String>,
},
Sample(u64),
Merge {
into: String,
on: Option<String>,
when_matched: Option<String>,
when_not_matched: Option<String>,
},
Out(String),
Raw(String),
}
impl MongoPipelineBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn match_stage(mut self, filter: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Match(filter.into()));
self
}
pub fn project(mut self, projection: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Project(projection.into()));
self
}
pub fn group(mut self, id: impl Into<String>, accumulators: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Group {
id: id.into(),
accumulators: accumulators.into(),
});
self
}
pub fn sort(mut self, sort: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Sort(sort.into()));
self
}
pub fn limit(mut self, n: u64) -> Self {
self.stages.push(PipelineStage::Limit(n));
self
}
pub fn skip(mut self, n: u64) -> Self {
self.stages.push(PipelineStage::Skip(n));
self
}
pub fn unwind(mut self, path: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Unwind {
path: path.into(),
preserve_null: false,
});
self
}
pub fn unwind_preserve_null(mut self, path: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Unwind {
path: path.into(),
preserve_null: true,
});
self
}
pub fn lookup(
mut self,
from: impl Into<String>,
local_field: impl Into<String>,
foreign_field: impl Into<String>,
r#as: impl Into<String>,
) -> Self {
self.stages.push(PipelineStage::Lookup {
from: from.into(),
local_field: local_field.into(),
foreign_field: foreign_field.into(),
r#as: r#as.into(),
});
self
}
pub fn add_fields(mut self, fields: impl Into<String>) -> Self {
self.stages.push(PipelineStage::AddFields(fields.into()));
self
}
pub fn set(mut self, fields: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Set(fields.into()));
self
}
pub fn unset<I, S>(mut self, fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.stages.push(PipelineStage::Unset(
fields.into_iter().map(Into::into).collect(),
));
self
}
pub fn replace_root(mut self, new_root: impl Into<String>) -> Self {
self.stages
.push(PipelineStage::ReplaceRoot(new_root.into()));
self
}
pub fn count(mut self, field: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Count(field.into()));
self
}
pub fn sample(mut self, size: u64) -> Self {
self.stages.push(PipelineStage::Sample(size));
self
}
pub fn merge_into(mut self, collection: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Merge {
into: collection.into(),
on: None,
when_matched: None,
when_not_matched: None,
});
self
}
pub fn merge(
mut self,
into: impl Into<String>,
on: Option<String>,
when_matched: Option<String>,
when_not_matched: Option<String>,
) -> Self {
self.stages.push(PipelineStage::Merge {
into: into.into(),
on,
when_matched,
when_not_matched,
});
self
}
pub fn out(mut self, collection: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Out(collection.into()));
self
}
pub fn raw(mut self, stage: impl Into<String>) -> Self {
self.stages.push(PipelineStage::Raw(stage.into()));
self
}
pub fn with_disk_use(mut self) -> Self {
self.allow_disk_use = true;
self
}
pub fn with_batch_size(mut self, size: u32) -> Self {
self.batch_size = Some(size);
self
}
pub fn with_max_time(mut self, ms: u64) -> Self {
self.max_time_ms = Some(ms);
self
}
pub fn with_comment(mut self, comment: impl Into<String>) -> Self {
self.comment = Some(comment.into());
self
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
pub fn build(&self) -> String {
let stages: Vec<String> = self.stages.iter().map(|s| s.to_json()).collect();
format!("[{}]", stages.join(", "))
}
pub fn stages(&self) -> &[PipelineStage] {
&self.stages
}
}
impl PipelineStage {
pub fn to_json(&self) -> String {
match self {
Self::Match(filter) => format!(r#"{{ "$match": {} }}"#, filter),
Self::Project(proj) => format!(r#"{{ "$project": {} }}"#, proj),
Self::Group { id, accumulators } => {
format!(r#"{{ "$group": {{ "_id": {}, {} }} }}"#, id, accumulators)
}
Self::Sort(sort) => format!(r#"{{ "$sort": {} }}"#, sort),
Self::Limit(n) => format!(r#"{{ "$limit": {} }}"#, n),
Self::Skip(n) => format!(r#"{{ "$skip": {} }}"#, n),
Self::Unwind {
path,
preserve_null,
} => {
if *preserve_null {
format!(
r#"{{ "$unwind": {{ "path": "{}", "preserveNullAndEmptyArrays": true }} }}"#,
path
)
} else {
format!(r#"{{ "$unwind": "{}" }}"#, path)
}
}
Self::Lookup {
from,
local_field,
foreign_field,
r#as,
} => {
format!(
r#"{{ "$lookup": {{ "from": "{}", "localField": "{}", "foreignField": "{}", "as": "{}" }} }}"#,
from, local_field, foreign_field, r#as
)
}
Self::AddFields(fields) => format!(r#"{{ "$addFields": {} }}"#, fields),
Self::Set(fields) => format!(r#"{{ "$set": {} }}"#, fields),
Self::Unset(fields) => {
let quoted: Vec<_> = fields.iter().map(|f| format!(r#""{}""#, f)).collect();
format!(r#"{{ "$unset": [{}] }}"#, quoted.join(", "))
}
Self::ReplaceRoot(root) => {
format!(r#"{{ "$replaceRoot": {{ "newRoot": {} }} }}"#, root)
}
Self::Count(field) => format!(r#"{{ "$count": "{}" }}"#, field),
Self::Facet(facets) => {
let facet_strs: Vec<_> = facets
.iter()
.map(|(name, stages)| {
let pipeline: Vec<_> = stages.iter().map(|s| s.to_json()).collect();
format!(r#""{}": [{}]"#, name, pipeline.join(", "))
})
.collect();
format!(r#"{{ "$facet": {{ {} }} }}"#, facet_strs.join(", "))
}
Self::Bucket {
group_by,
boundaries,
default,
output,
} => {
let mut parts = vec![
format!(r#""groupBy": {}"#, group_by),
format!(r#""boundaries": {}"#, boundaries),
];
if let Some(def) = default {
parts.push(format!(r#""default": {}"#, def));
}
if let Some(out) = output {
parts.push(format!(r#""output": {}"#, out));
}
format!(r#"{{ "$bucket": {{ {} }} }}"#, parts.join(", "))
}
Self::Sample(size) => format!(r#"{{ "$sample": {{ "size": {} }} }}"#, size),
Self::Merge {
into,
on,
when_matched,
when_not_matched,
} => {
let mut parts = vec![format!(r#""into": "{}""#, into)];
if let Some(on_field) = on {
parts.push(format!(r#""on": "{}""#, on_field));
}
if let Some(matched) = when_matched {
parts.push(format!(r#""whenMatched": "{}""#, matched));
}
if let Some(not_matched) = when_not_matched {
parts.push(format!(r#""whenNotMatched": "{}""#, not_matched));
}
format!(r#"{{ "$merge": {{ {} }} }}"#, parts.join(", "))
}
Self::Out(collection) => format!(r#"{{ "$out": "{}" }}"#, collection),
Self::Raw(stage) => stage.clone(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryHints {
pub indexes: SmallVec<[IndexHint; 4]>,
pub parallel_workers: Option<u32>,
pub join_hints: SmallVec<[JoinHint; 4]>,
pub no_seq_scan: bool,
pub no_index_scan: bool,
pub cte_materialized: Option<bool>,
pub timeout_ms: Option<u64>,
pub custom: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct IndexHint {
pub table: Option<String>,
pub index_name: String,
pub hint_type: IndexHintType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexHintType {
Use,
Ignore,
Prefer,
}
#[derive(Debug, Clone)]
pub struct JoinHint {
pub tables: Vec<String>,
pub method: JoinMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinMethod {
NestedLoop,
Hash,
Merge,
}
impl QueryHints {
pub fn new() -> Self {
Self::default()
}
pub fn index_hint(mut self, index_name: impl Into<String>) -> Self {
self.indexes.push(IndexHint {
table: None,
index_name: index_name.into(),
hint_type: IndexHintType::Use,
});
self
}
pub fn index_hint_for_table(
mut self,
table: impl Into<String>,
index_name: impl Into<String>,
) -> Self {
self.indexes.push(IndexHint {
table: Some(table.into()),
index_name: index_name.into(),
hint_type: IndexHintType::Use,
});
self
}
pub fn ignore_index(mut self, index_name: impl Into<String>) -> Self {
self.indexes.push(IndexHint {
table: None,
index_name: index_name.into(),
hint_type: IndexHintType::Ignore,
});
self
}
pub fn parallel(mut self, workers: u32) -> Self {
self.parallel_workers = Some(workers);
self
}
pub fn no_parallel(mut self) -> Self {
self.parallel_workers = Some(0);
self
}
pub fn no_seq_scan(mut self) -> Self {
self.no_seq_scan = true;
self
}
pub fn no_index_scan(mut self) -> Self {
self.no_index_scan = true;
self
}
pub fn cte_materialized(mut self, materialized: bool) -> Self {
self.cte_materialized = Some(materialized);
self
}
pub fn nested_loop_join(mut self, tables: Vec<String>) -> Self {
self.join_hints.push(JoinHint {
tables,
method: JoinMethod::NestedLoop,
});
self
}
pub fn hash_join(mut self, tables: Vec<String>) -> Self {
self.join_hints.push(JoinHint {
tables,
method: JoinMethod::Hash,
});
self
}
pub fn merge_join(mut self, tables: Vec<String>) -> Self {
self.join_hints.push(JoinHint {
tables,
method: JoinMethod::Merge,
});
self
}
pub fn timeout(mut self, ms: u64) -> Self {
self.timeout_ms = Some(ms);
self
}
pub fn custom_hint(mut self, hint: impl Into<String>) -> Self {
self.custom.push(hint.into());
self
}
pub fn to_sql_prefix(&self, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => self.to_postgres_prefix(),
DatabaseType::MySQL => self.to_mysql_prefix(),
DatabaseType::SQLite => self.to_sqlite_prefix(),
DatabaseType::MSSQL => self.to_mssql_prefix(),
}
}
pub fn to_sql_suffix(&self, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::MySQL => self.to_mysql_suffix(),
DatabaseType::MSSQL => self.to_mssql_suffix(),
_ => String::new(),
}
}
pub fn apply_to_query(&self, query: &str, db_type: DatabaseType) -> String {
let prefix = self.to_sql_prefix(db_type);
let suffix = self.to_sql_suffix(db_type);
if prefix.is_empty() && suffix.is_empty() {
return query.to_string();
}
let mut result = String::with_capacity(prefix.len() + query.len() + suffix.len() + 2);
if !prefix.is_empty() {
result.push_str(&prefix);
result.push('\n');
}
result.push_str(query);
if !suffix.is_empty() {
result.push(' ');
result.push_str(&suffix);
}
result
}
fn to_postgres_prefix(&self) -> String {
let mut settings: Vec<String> = Vec::new();
if self.no_seq_scan {
settings.push("SET LOCAL enable_seqscan = off;".to_string());
}
if self.no_index_scan {
settings.push("SET LOCAL enable_indexscan = off;".to_string());
}
if let Some(workers) = self.parallel_workers {
settings.push(format!(
"SET LOCAL max_parallel_workers_per_gather = {};",
workers
));
}
if let Some(ms) = self.timeout_ms {
settings.push(format!("SET LOCAL statement_timeout = {};", ms));
}
for hint in &self.join_hints {
match hint.method {
JoinMethod::NestedLoop => {
settings.push("SET LOCAL enable_hashjoin = off;".to_string());
settings.push("SET LOCAL enable_mergejoin = off;".to_string());
}
JoinMethod::Hash => {
settings.push("SET LOCAL enable_nestloop = off;".to_string());
settings.push("SET LOCAL enable_mergejoin = off;".to_string());
}
JoinMethod::Merge => {
settings.push("SET LOCAL enable_nestloop = off;".to_string());
settings.push("SET LOCAL enable_hashjoin = off;".to_string());
}
}
}
for hint in &self.custom {
settings.push(hint.clone());
}
settings.join("\n")
}
fn to_mysql_prefix(&self) -> String {
String::new()
}
fn to_mysql_suffix(&self) -> String {
let mut hints: Vec<String> = Vec::new();
for hint in &self.indexes {
let hint_type = match hint.hint_type {
IndexHintType::Use => "USE INDEX",
IndexHintType::Ignore => "IGNORE INDEX",
IndexHintType::Prefer => "FORCE INDEX",
};
if let Some(ref table) = hint.table {
hints.push(format!(
"/* {} FOR {} ({}) */",
hint_type, table, hint.index_name
));
} else {
hints.push(format!("/* {} ({}) */", hint_type, hint.index_name));
}
}
for hint in &self.join_hints {
let method = match hint.method {
JoinMethod::NestedLoop => "BNL",
JoinMethod::Hash => "HASH_JOIN",
JoinMethod::Merge => "MERGE",
};
hints.push(format!("/* {}({}) */", method, hint.tables.join(", ")));
}
hints.join(" ")
}
fn to_sqlite_prefix(&self) -> String {
String::new()
}
fn to_mssql_prefix(&self) -> String {
String::new()
}
fn to_mssql_suffix(&self) -> String {
let mut options: Vec<String> = Vec::new();
for hint in &self.indexes {
match hint.hint_type {
IndexHintType::Use => {
if let Some(ref table) = hint.table {
options.push(format!("TABLE HINT({}, INDEX({}))", table, hint.index_name));
}
}
IndexHintType::Ignore => {
}
IndexHintType::Prefer => {
if let Some(ref table) = hint.table {
options.push(format!(
"TABLE HINT({}, FORCESEEK({}))",
table, hint.index_name
));
}
}
}
}
if let Some(workers) = self.parallel_workers {
if workers == 0 {
options.push("MAXDOP 1".to_string());
} else {
options.push(format!("MAXDOP {}", workers));
}
}
for hint in &self.join_hints {
let method = match hint.method {
JoinMethod::NestedLoop => "LOOP JOIN",
JoinMethod::Hash => "HASH JOIN",
JoinMethod::Merge => "MERGE JOIN",
};
options.push(method.to_string());
}
if options.is_empty() {
String::new()
} else {
format!("OPTION ({})", options.join(", "))
}
}
pub fn has_hints(&self) -> bool {
!self.indexes.is_empty()
|| self.parallel_workers.is_some()
|| !self.join_hints.is_empty()
|| self.no_seq_scan
|| self.no_index_scan
|| self.cte_materialized.is_some()
|| self.timeout_ms.is_some()
|| !self.custom.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepared_statement_cache() {
let cache = PreparedStatementCache::new(10);
let stmt1 = cache.get_or_create("test", || "SELECT * FROM users".to_string());
assert_eq!(stmt1.sql, "SELECT * FROM users");
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
let stmt2 = cache.get_or_create("test", || panic!("Should not be called"));
assert_eq!(stmt2.sql, "SELECT * FROM users");
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
assert!(stats.hit_rate() > 0.0);
}
#[test]
fn test_batch_config_auto_tune() {
let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 100, 50);
assert_eq!(config.batch_size, 50);
let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 500, 5000);
assert!(config.batch_size >= 100);
assert!(config.batch_size <= 5000);
let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 200, 100_000);
assert!(config.use_copy); assert!(config.batch_size >= 100);
}
#[test]
fn test_batch_ranges() {
let config = BatchConfig {
batch_size: 100,
..Default::default()
};
let ranges: Vec<_> = config.batch_ranges(250).collect();
assert_eq!(ranges.len(), 3);
assert_eq!(ranges[0], (0, 100));
assert_eq!(ranges[1], (100, 200));
assert_eq!(ranges[2], (200, 250));
}
#[test]
fn test_mongo_pipeline_builder() {
let pipeline = MongoPipelineBuilder::new()
.match_stage(r#"{ "status": "active" }"#)
.lookup("orders", "user_id", "_id", "user_orders")
.unwind("$user_orders")
.group(r#""$user_id""#, r#""total": { "$sum": "$amount" }"#)
.sort(r#"{ "total": -1 }"#)
.limit(10)
.build();
assert!(pipeline.contains("$match"));
assert!(pipeline.contains("$lookup"));
assert!(pipeline.contains("$unwind"));
assert!(pipeline.contains("$group"));
assert!(pipeline.contains("$sort"));
assert!(pipeline.contains("$limit"));
}
#[test]
fn test_query_hints_postgres() {
let hints = QueryHints::new().no_seq_scan().parallel(4).timeout(5000);
let prefix = hints.to_sql_prefix(DatabaseType::PostgreSQL);
assert!(prefix.contains("enable_seqscan = off"));
assert!(prefix.contains("max_parallel_workers_per_gather = 4"));
assert!(prefix.contains("statement_timeout = 5000"));
}
#[test]
fn test_query_hints_mssql() {
let hints = QueryHints::new()
.parallel(2)
.hash_join(vec!["users".to_string(), "orders".to_string()]);
let suffix = hints.to_sql_suffix(DatabaseType::MSSQL);
assert!(suffix.contains("MAXDOP 2"));
assert!(suffix.contains("HASH JOIN"));
}
#[test]
fn test_query_hints_apply() {
let hints = QueryHints::new().no_seq_scan();
let query = "SELECT * FROM users WHERE id = $1";
let result = hints.apply_to_query(query, DatabaseType::PostgreSQL);
assert!(result.contains("enable_seqscan = off"));
assert!(result.contains("SELECT * FROM users"));
}
}