Skip to main content

prax_query/
db_optimize.rs

1//! Database-specific optimizations.
2//!
3//! This module provides performance optimizations tailored to each database:
4//! - Prepared statement caching (PostgreSQL, MySQL, MSSQL)
5//! - Batch size tuning for bulk operations
6//! - MongoDB pipeline aggregation
7//! - Query plan hints for complex queries
8//!
9//! # Performance Characteristics
10//!
11//! | Database   | Optimization              | Typical Gain |
12//! |------------|---------------------------|--------------|
13//! | PostgreSQL | Prepared statement cache  | 30-50%       |
14//! | MySQL      | Multi-row INSERT batching | 40-60%       |
15//! | MongoDB    | Bulk write batching       | 50-70%       |
16//! | MSSQL      | Table-valued parameters   | 30-40%       |
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use prax_query::db_optimize::{PreparedStatementCache, BatchConfig, QueryHints};
22//!
23//! // Prepared statement caching
24//! let cache = PreparedStatementCache::new(100);
25//! let stmt = cache.get_or_prepare("find_user", || {
26//!     "SELECT * FROM users WHERE id = $1"
27//! });
28//!
29//! // Auto-tuned batching
30//! let config = BatchConfig::auto_tune(payload_size, row_count);
31//! for batch in data.chunks(config.batch_size) {
32//!     execute_batch(batch);
33//! }
34//!
35//! // Query hints
36//! let hints = QueryHints::new()
37//!     .parallel(4)
38//!     .index_hint("users_email_idx");
39//! ```
40
41use parking_lot::RwLock;
42use smallvec::SmallVec;
43use std::collections::HashMap;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::time::Instant;
46
47use crate::sql::DatabaseType;
48
49// ==============================================================================
50// Prepared Statement Cache
51// ==============================================================================
52
53/// Statistics for prepared statement cache.
54#[derive(Debug, Clone, Default)]
55pub struct PreparedStatementStats {
56    /// Number of cache hits.
57    pub hits: u64,
58    /// Number of cache misses.
59    pub misses: u64,
60    /// Number of statements currently cached.
61    pub cached_count: usize,
62    /// Total preparation time saved (estimated).
63    pub time_saved_ms: u64,
64}
65
66impl PreparedStatementStats {
67    /// Calculate hit rate as a percentage.
68    pub fn hit_rate(&self) -> f64 {
69        let total = self.hits + self.misses;
70        if total == 0 {
71            0.0
72        } else {
73            (self.hits as f64 / total as f64) * 100.0
74        }
75    }
76}
77
78/// A cached prepared statement entry.
79#[derive(Debug, Clone)]
80pub struct CachedStatement {
81    /// The SQL statement text.
82    pub sql: String,
83    /// Unique statement identifier/name.
84    pub name: String,
85    /// Number of times this statement was used.
86    pub use_count: u64,
87    /// When this statement was last used.
88    pub last_used: Instant,
89    /// Estimated preparation time in microseconds.
90    pub prep_time_us: u64,
91    /// Database-specific statement handle (opaque).
92    pub handle: Option<u64>,
93}
94
95/// A cache for prepared statements.
96///
97/// This cache stores prepared statement metadata and tracks usage patterns
98/// to optimize database interactions. The actual statement handles are
99/// managed by the database driver.
100///
101/// # Features
102///
103/// - LRU eviction when capacity is reached
104/// - Usage statistics for monitoring
105/// - Thread-safe with read-write locking
106/// - Automatic cleanup of stale entries
107///
108/// # Example
109///
110/// ```rust
111/// use prax_query::db_optimize::PreparedStatementCache;
112///
113/// let cache = PreparedStatementCache::new(100);
114///
115/// // Register a prepared statement
116/// let entry = cache.get_or_create("find_user_by_email", || {
117///     "SELECT * FROM users WHERE email = $1".to_string()
118/// });
119///
120/// // Check cache stats
121/// let stats = cache.stats();
122/// println!("Hit rate: {:.1}%", stats.hit_rate());
123/// ```
124pub struct PreparedStatementCache {
125    statements: RwLock<HashMap<String, CachedStatement>>,
126    capacity: usize,
127    hits: AtomicU64,
128    misses: AtomicU64,
129    time_saved_us: AtomicU64,
130    /// Average preparation time in microseconds (for estimation).
131    avg_prep_time_us: u64,
132}
133
134impl PreparedStatementCache {
135    /// Create a new cache with the specified capacity.
136    pub fn new(capacity: usize) -> Self {
137        Self {
138            statements: RwLock::new(HashMap::with_capacity(capacity)),
139            capacity,
140            hits: AtomicU64::new(0),
141            misses: AtomicU64::new(0),
142            time_saved_us: AtomicU64::new(0),
143            avg_prep_time_us: 500, // Default 500µs estimate
144        }
145    }
146
147    /// Get or create a prepared statement entry.
148    ///
149    /// If the statement is cached, returns the cached entry and increments hit count.
150    /// Otherwise, calls the generator function, caches the result, and returns it.
151    pub fn get_or_create<F>(&self, name: &str, generator: F) -> CachedStatement
152    where
153        F: FnOnce() -> String,
154    {
155        // Try read lock first (fast path)
156        {
157            let cache = self.statements.read();
158            if let Some(stmt) = cache.get(name) {
159                self.hits.fetch_add(1, Ordering::Relaxed);
160                self.time_saved_us
161                    .fetch_add(stmt.prep_time_us, Ordering::Relaxed);
162                return stmt.clone();
163            }
164        }
165
166        // Miss - need to create and cache
167        self.misses.fetch_add(1, Ordering::Relaxed);
168
169        let sql = generator();
170        let entry = CachedStatement {
171            sql,
172            name: name.to_string(),
173            use_count: 1,
174            last_used: Instant::now(),
175            prep_time_us: self.avg_prep_time_us,
176            handle: None,
177        };
178
179        // Upgrade to write lock
180        let mut cache = self.statements.write();
181
182        // Double-check after acquiring write lock
183        if let Some(existing) = cache.get(name) {
184            self.hits.fetch_add(1, Ordering::Relaxed);
185            return existing.clone();
186        }
187
188        // Evict if at capacity (simple LRU-like: remove oldest)
189        if cache.len() >= self.capacity {
190            self.evict_oldest(&mut cache);
191        }
192
193        cache.insert(name.to_string(), entry.clone());
194        entry
195    }
196
197    /// Check if a statement is cached.
198    pub fn contains(&self, name: &str) -> bool {
199        self.statements.read().contains_key(name)
200    }
201
202    /// Get cache statistics.
203    pub fn stats(&self) -> PreparedStatementStats {
204        let cache = self.statements.read();
205        PreparedStatementStats {
206            hits: self.hits.load(Ordering::Relaxed),
207            misses: self.misses.load(Ordering::Relaxed),
208            cached_count: cache.len(),
209            time_saved_ms: self.time_saved_us.load(Ordering::Relaxed) / 1000,
210        }
211    }
212
213    /// Clear the cache.
214    pub fn clear(&self) {
215        self.statements.write().clear();
216        self.hits.store(0, Ordering::Relaxed);
217        self.misses.store(0, Ordering::Relaxed);
218        self.time_saved_us.store(0, Ordering::Relaxed);
219    }
220
221    /// Get the number of cached statements.
222    pub fn len(&self) -> usize {
223        self.statements.read().len()
224    }
225
226    /// Check if the cache is empty.
227    pub fn is_empty(&self) -> bool {
228        self.statements.read().is_empty()
229    }
230
231    /// Evict the oldest entry.
232    fn evict_oldest(&self, cache: &mut HashMap<String, CachedStatement>) {
233        if let Some((oldest_key, _)) = cache
234            .iter()
235            .min_by_key(|(_, v)| v.last_used)
236            .map(|(k, v)| (k.clone(), v.clone()))
237        {
238            cache.remove(&oldest_key);
239        }
240    }
241
242    /// Update statement usage (call after executing).
243    pub fn record_use(&self, name: &str) {
244        if let Some(stmt) = self.statements.write().get_mut(name) {
245            stmt.use_count += 1;
246            stmt.last_used = Instant::now();
247        }
248    }
249
250    /// Set a database-specific handle for a statement.
251    pub fn set_handle(&self, name: &str, handle: u64) {
252        if let Some(stmt) = self.statements.write().get_mut(name) {
253            stmt.handle = Some(handle);
254        }
255    }
256}
257
258impl Default for PreparedStatementCache {
259    fn default() -> Self {
260        Self::new(256)
261    }
262}
263
264/// Global prepared statement cache.
265pub fn global_statement_cache() -> &'static PreparedStatementCache {
266    use std::sync::OnceLock;
267    static CACHE: OnceLock<PreparedStatementCache> = OnceLock::new();
268    CACHE.get_or_init(|| PreparedStatementCache::new(512))
269}
270
271// ==============================================================================
272// Batch Size Tuning
273// ==============================================================================
274
275/// Configuration for batch operations.
276#[derive(Debug, Clone, Copy)]
277pub struct BatchConfig {
278    /// Number of rows per batch.
279    pub batch_size: usize,
280    /// Maximum payload size in bytes.
281    pub max_payload_bytes: usize,
282    /// Whether to use multi-row INSERT syntax.
283    pub multi_row_insert: bool,
284    /// Whether to use COPY for bulk inserts (PostgreSQL).
285    pub use_copy: bool,
286    /// Parallelism level for bulk operations.
287    pub parallelism: usize,
288}
289
290impl BatchConfig {
291    /// Default batch configuration.
292    pub const fn default_config() -> Self {
293        Self {
294            batch_size: 1000,
295            max_payload_bytes: 16 * 1024 * 1024, // 16MB
296            multi_row_insert: true,
297            use_copy: false,
298            parallelism: 1,
299        }
300    }
301
302    /// Create configuration optimized for the given database.
303    pub fn for_database(db_type: DatabaseType) -> Self {
304        match db_type {
305            DatabaseType::PostgreSQL => Self {
306                batch_size: 1000,
307                max_payload_bytes: 64 * 1024 * 1024, // 64MB
308                multi_row_insert: true,
309                use_copy: true, // PostgreSQL COPY is very fast
310                parallelism: 4,
311            },
312            DatabaseType::MySQL => Self {
313                batch_size: 500,                     // MySQL has packet size limits
314                max_payload_bytes: 16 * 1024 * 1024, // 16MB (default max_allowed_packet)
315                multi_row_insert: true,
316                use_copy: false,
317                parallelism: 2,
318            },
319            DatabaseType::SQLite => Self {
320                batch_size: 500,
321                max_payload_bytes: 1024 * 1024, // 1MB (SQLite is single-threaded)
322                multi_row_insert: true,
323                use_copy: false,
324                parallelism: 1, // SQLite doesn't benefit from parallelism
325            },
326            DatabaseType::MSSQL => Self {
327                batch_size: 1000,
328                max_payload_bytes: 32 * 1024 * 1024, // 32MB
329                multi_row_insert: true,
330                use_copy: false,
331                parallelism: 4,
332            },
333        }
334    }
335
336    /// Auto-tune batch size based on row size and count.
337    ///
338    /// This calculates an optimal batch size that:
339    /// - Stays within the max payload size
340    /// - Balances memory usage vs round-trip overhead
341    /// - Adapts to row size variations
342    ///
343    /// # Example
344    ///
345    /// ```rust
346    /// use prax_query::db_optimize::BatchConfig;
347    /// use prax_query::sql::DatabaseType;
348    ///
349    /// // Auto-tune for 10,000 rows averaging 500 bytes each
350    /// let config = BatchConfig::auto_tune(
351    ///     DatabaseType::PostgreSQL,
352    ///     500,    // avg row size in bytes
353    ///     10_000, // total row count
354    /// );
355    /// println!("Optimal batch size: {}", config.batch_size);
356    /// ```
357    pub fn auto_tune(db_type: DatabaseType, avg_row_size: usize, total_rows: usize) -> Self {
358        let mut config = Self::for_database(db_type);
359
360        // Calculate batch size based on payload limit
361        let max_rows_by_payload = if avg_row_size > 0 {
362            config.max_payload_bytes / avg_row_size
363        } else {
364            config.batch_size
365        };
366
367        // Balance: smaller batches for small datasets, larger for big ones
368        let optimal_batch = if total_rows < 100 {
369            total_rows // No batching needed for small datasets
370        } else if total_rows < 1000 {
371            (total_rows / 10).max(100)
372        } else {
373            // For large datasets, use ~10 batches or max by payload
374            let by_count = total_rows / 10;
375            by_count.min(max_rows_by_payload).min(10_000).max(100)
376        };
377
378        config.batch_size = optimal_batch;
379
380        // Adjust parallelism based on dataset size
381        if total_rows < 1000 {
382            config.parallelism = 1;
383        } else if total_rows < 10_000 {
384            config.parallelism = config.parallelism.min(2);
385        }
386
387        // Use COPY for large PostgreSQL imports
388        if matches!(db_type, DatabaseType::PostgreSQL) && total_rows > 5000 {
389            config.use_copy = true;
390        }
391
392        config
393    }
394
395    /// Create an iterator that yields batch ranges.
396    ///
397    /// # Example
398    ///
399    /// ```rust
400    /// use prax_query::db_optimize::BatchConfig;
401    ///
402    /// let config = BatchConfig::default_config();
403    /// let total = 2500;
404    ///
405    /// for (start, end) in config.batch_ranges(total) {
406    ///     println!("Processing rows {} to {}", start, end);
407    /// }
408    /// ```
409    pub fn batch_ranges(&self, total: usize) -> impl Iterator<Item = (usize, usize)> {
410        let batch_size = self.batch_size;
411        (0..total)
412            .step_by(batch_size)
413            .map(move |start| (start, (start + batch_size).min(total)))
414    }
415
416    /// Calculate the number of batches for a given total.
417    pub fn batch_count(&self, total: usize) -> usize {
418        (total + self.batch_size - 1) / self.batch_size
419    }
420}
421
422impl Default for BatchConfig {
423    fn default() -> Self {
424        Self::default_config()
425    }
426}
427
428// ==============================================================================
429// MongoDB Pipeline Aggregation
430// ==============================================================================
431
432/// A builder for combining multiple MongoDB operations into a single pipeline.
433///
434/// This reduces round-trips by batching related operations.
435///
436/// # Example
437///
438/// ```rust,ignore
439/// use prax_query::db_optimize::MongoPipelineBuilder;
440///
441/// let pipeline = MongoPipelineBuilder::new()
442///     .match_stage(doc! { "status": "active" })
443///     .lookup("orders", "user_id", "_id", "user_orders")
444///     .unwind("$user_orders")
445///     .group("$user_id", doc! { "total": { "$sum": "$amount" } })
446///     .sort(doc! { "total": -1 })
447///     .limit(10)
448///     .build();
449/// ```
450#[derive(Debug, Clone, Default)]
451pub struct MongoPipelineBuilder {
452    stages: Vec<PipelineStage>,
453    /// Whether to allow disk use for large operations.
454    pub allow_disk_use: bool,
455    /// Batch size for cursor.
456    pub batch_size: Option<u32>,
457    /// Maximum time for operation in milliseconds.
458    pub max_time_ms: Option<u64>,
459    /// Comment for profiling.
460    pub comment: Option<String>,
461}
462
463/// A MongoDB aggregation pipeline stage.
464#[derive(Debug, Clone)]
465pub enum PipelineStage {
466    /// $match stage.
467    Match(String),
468    /// $project stage.
469    Project(String),
470    /// $group stage with _id and accumulators.
471    Group { id: String, accumulators: String },
472    /// $sort stage.
473    Sort(String),
474    /// $limit stage.
475    Limit(u64),
476    /// $skip stage.
477    Skip(u64),
478    /// $unwind stage.
479    Unwind { path: String, preserve_null: bool },
480    /// $lookup stage.
481    Lookup {
482        from: String,
483        local_field: String,
484        foreign_field: String,
485        r#as: String,
486    },
487    /// $addFields stage.
488    AddFields(String),
489    /// $set stage (alias for $addFields).
490    Set(String),
491    /// $unset stage.
492    Unset(Vec<String>),
493    /// $replaceRoot stage.
494    ReplaceRoot(String),
495    /// $count stage.
496    Count(String),
497    /// $facet stage for multiple pipelines.
498    Facet(Vec<(String, Vec<PipelineStage>)>),
499    /// $bucket stage.
500    Bucket {
501        group_by: String,
502        boundaries: String,
503        default: Option<String>,
504        output: Option<String>,
505    },
506    /// $sample stage.
507    Sample(u64),
508    /// $merge stage for output.
509    Merge {
510        into: String,
511        on: Option<String>,
512        when_matched: Option<String>,
513        when_not_matched: Option<String>,
514    },
515    /// $out stage.
516    Out(String),
517    /// Raw BSON stage.
518    Raw(String),
519}
520
521impl MongoPipelineBuilder {
522    /// Create a new empty pipeline builder.
523    pub fn new() -> Self {
524        Self::default()
525    }
526
527    /// Add a $match stage.
528    pub fn match_stage(mut self, filter: impl Into<String>) -> Self {
529        self.stages.push(PipelineStage::Match(filter.into()));
530        self
531    }
532
533    /// Add a $project stage.
534    pub fn project(mut self, projection: impl Into<String>) -> Self {
535        self.stages.push(PipelineStage::Project(projection.into()));
536        self
537    }
538
539    /// Add a $group stage.
540    pub fn group(mut self, id: impl Into<String>, accumulators: impl Into<String>) -> Self {
541        self.stages.push(PipelineStage::Group {
542            id: id.into(),
543            accumulators: accumulators.into(),
544        });
545        self
546    }
547
548    /// Add a $sort stage.
549    pub fn sort(mut self, sort: impl Into<String>) -> Self {
550        self.stages.push(PipelineStage::Sort(sort.into()));
551        self
552    }
553
554    /// Add a $limit stage.
555    pub fn limit(mut self, n: u64) -> Self {
556        self.stages.push(PipelineStage::Limit(n));
557        self
558    }
559
560    /// Add a $skip stage.
561    pub fn skip(mut self, n: u64) -> Self {
562        self.stages.push(PipelineStage::Skip(n));
563        self
564    }
565
566    /// Add a $unwind stage.
567    pub fn unwind(mut self, path: impl Into<String>) -> Self {
568        self.stages.push(PipelineStage::Unwind {
569            path: path.into(),
570            preserve_null: false,
571        });
572        self
573    }
574
575    /// Add a $unwind stage with null preservation.
576    pub fn unwind_preserve_null(mut self, path: impl Into<String>) -> Self {
577        self.stages.push(PipelineStage::Unwind {
578            path: path.into(),
579            preserve_null: true,
580        });
581        self
582    }
583
584    /// Add a $lookup stage.
585    pub fn lookup(
586        mut self,
587        from: impl Into<String>,
588        local_field: impl Into<String>,
589        foreign_field: impl Into<String>,
590        r#as: impl Into<String>,
591    ) -> Self {
592        self.stages.push(PipelineStage::Lookup {
593            from: from.into(),
594            local_field: local_field.into(),
595            foreign_field: foreign_field.into(),
596            r#as: r#as.into(),
597        });
598        self
599    }
600
601    /// Add a $addFields stage.
602    pub fn add_fields(mut self, fields: impl Into<String>) -> Self {
603        self.stages.push(PipelineStage::AddFields(fields.into()));
604        self
605    }
606
607    /// Add a $set stage.
608    pub fn set(mut self, fields: impl Into<String>) -> Self {
609        self.stages.push(PipelineStage::Set(fields.into()));
610        self
611    }
612
613    /// Add a $unset stage.
614    pub fn unset<I, S>(mut self, fields: I) -> Self
615    where
616        I: IntoIterator<Item = S>,
617        S: Into<String>,
618    {
619        self.stages.push(PipelineStage::Unset(
620            fields.into_iter().map(Into::into).collect(),
621        ));
622        self
623    }
624
625    /// Add a $replaceRoot stage.
626    pub fn replace_root(mut self, new_root: impl Into<String>) -> Self {
627        self.stages
628            .push(PipelineStage::ReplaceRoot(new_root.into()));
629        self
630    }
631
632    /// Add a $count stage.
633    pub fn count(mut self, field: impl Into<String>) -> Self {
634        self.stages.push(PipelineStage::Count(field.into()));
635        self
636    }
637
638    /// Add a $sample stage.
639    pub fn sample(mut self, size: u64) -> Self {
640        self.stages.push(PipelineStage::Sample(size));
641        self
642    }
643
644    /// Add a $merge output stage.
645    pub fn merge_into(mut self, collection: impl Into<String>) -> Self {
646        self.stages.push(PipelineStage::Merge {
647            into: collection.into(),
648            on: None,
649            when_matched: None,
650            when_not_matched: None,
651        });
652        self
653    }
654
655    /// Add a $merge output stage with options.
656    pub fn merge(
657        mut self,
658        into: impl Into<String>,
659        on: Option<String>,
660        when_matched: Option<String>,
661        when_not_matched: Option<String>,
662    ) -> Self {
663        self.stages.push(PipelineStage::Merge {
664            into: into.into(),
665            on,
666            when_matched,
667            when_not_matched,
668        });
669        self
670    }
671
672    /// Add a $out stage.
673    pub fn out(mut self, collection: impl Into<String>) -> Self {
674        self.stages.push(PipelineStage::Out(collection.into()));
675        self
676    }
677
678    /// Add a raw BSON stage.
679    pub fn raw(mut self, stage: impl Into<String>) -> Self {
680        self.stages.push(PipelineStage::Raw(stage.into()));
681        self
682    }
683
684    /// Enable disk use for large operations.
685    pub fn with_disk_use(mut self) -> Self {
686        self.allow_disk_use = true;
687        self
688    }
689
690    /// Set cursor batch size.
691    pub fn with_batch_size(mut self, size: u32) -> Self {
692        self.batch_size = Some(size);
693        self
694    }
695
696    /// Set maximum execution time.
697    pub fn with_max_time(mut self, ms: u64) -> Self {
698        self.max_time_ms = Some(ms);
699        self
700    }
701
702    /// Add a comment for profiling.
703    pub fn with_comment(mut self, comment: impl Into<String>) -> Self {
704        self.comment = Some(comment.into());
705        self
706    }
707
708    /// Get the number of stages.
709    pub fn stage_count(&self) -> usize {
710        self.stages.len()
711    }
712
713    /// Build the pipeline as a JSON array string.
714    pub fn build(&self) -> String {
715        let stages: Vec<String> = self.stages.iter().map(|s| s.to_json()).collect();
716        format!("[{}]", stages.join(", "))
717    }
718
719    /// Get the stages.
720    pub fn stages(&self) -> &[PipelineStage] {
721        &self.stages
722    }
723}
724
725impl PipelineStage {
726    /// Convert to JSON representation.
727    pub fn to_json(&self) -> String {
728        match self {
729            Self::Match(filter) => format!(r#"{{ "$match": {} }}"#, filter),
730            Self::Project(proj) => format!(r#"{{ "$project": {} }}"#, proj),
731            Self::Group { id, accumulators } => {
732                format!(r#"{{ "$group": {{ "_id": {}, {} }} }}"#, id, accumulators)
733            }
734            Self::Sort(sort) => format!(r#"{{ "$sort": {} }}"#, sort),
735            Self::Limit(n) => format!(r#"{{ "$limit": {} }}"#, n),
736            Self::Skip(n) => format!(r#"{{ "$skip": {} }}"#, n),
737            Self::Unwind {
738                path,
739                preserve_null,
740            } => {
741                if *preserve_null {
742                    format!(
743                        r#"{{ "$unwind": {{ "path": "{}", "preserveNullAndEmptyArrays": true }} }}"#,
744                        path
745                    )
746                } else {
747                    format!(r#"{{ "$unwind": "{}" }}"#, path)
748                }
749            }
750            Self::Lookup {
751                from,
752                local_field,
753                foreign_field,
754                r#as,
755            } => {
756                format!(
757                    r#"{{ "$lookup": {{ "from": "{}", "localField": "{}", "foreignField": "{}", "as": "{}" }} }}"#,
758                    from, local_field, foreign_field, r#as
759                )
760            }
761            Self::AddFields(fields) => format!(r#"{{ "$addFields": {} }}"#, fields),
762            Self::Set(fields) => format!(r#"{{ "$set": {} }}"#, fields),
763            Self::Unset(fields) => {
764                let quoted: Vec<_> = fields.iter().map(|f| format!(r#""{}""#, f)).collect();
765                format!(r#"{{ "$unset": [{}] }}"#, quoted.join(", "))
766            }
767            Self::ReplaceRoot(root) => {
768                format!(r#"{{ "$replaceRoot": {{ "newRoot": {} }} }}"#, root)
769            }
770            Self::Count(field) => format!(r#"{{ "$count": "{}" }}"#, field),
771            Self::Facet(facets) => {
772                let facet_strs: Vec<_> = facets
773                    .iter()
774                    .map(|(name, stages)| {
775                        let pipeline: Vec<_> = stages.iter().map(|s| s.to_json()).collect();
776                        format!(r#""{}": [{}]"#, name, pipeline.join(", "))
777                    })
778                    .collect();
779                format!(r#"{{ "$facet": {{ {} }} }}"#, facet_strs.join(", "))
780            }
781            Self::Bucket {
782                group_by,
783                boundaries,
784                default,
785                output,
786            } => {
787                let mut parts = vec![
788                    format!(r#""groupBy": {}"#, group_by),
789                    format!(r#""boundaries": {}"#, boundaries),
790                ];
791                if let Some(def) = default {
792                    parts.push(format!(r#""default": {}"#, def));
793                }
794                if let Some(out) = output {
795                    parts.push(format!(r#""output": {}"#, out));
796                }
797                format!(r#"{{ "$bucket": {{ {} }} }}"#, parts.join(", "))
798            }
799            Self::Sample(size) => format!(r#"{{ "$sample": {{ "size": {} }} }}"#, size),
800            Self::Merge {
801                into,
802                on,
803                when_matched,
804                when_not_matched,
805            } => {
806                let mut parts = vec![format!(r#""into": "{}""#, into)];
807                if let Some(on_field) = on {
808                    parts.push(format!(r#""on": "{}""#, on_field));
809                }
810                if let Some(matched) = when_matched {
811                    parts.push(format!(r#""whenMatched": "{}""#, matched));
812                }
813                if let Some(not_matched) = when_not_matched {
814                    parts.push(format!(r#""whenNotMatched": "{}""#, not_matched));
815                }
816                format!(r#"{{ "$merge": {{ {} }} }}"#, parts.join(", "))
817            }
818            Self::Out(collection) => format!(r#"{{ "$out": "{}" }}"#, collection),
819            Self::Raw(stage) => stage.clone(),
820        }
821    }
822}
823
824// ==============================================================================
825// Query Plan Hints
826// ==============================================================================
827
828/// Query plan hints for optimizing complex queries.
829///
830/// These hints are applied to queries to guide the query planner:
831/// - Index hints to force specific index usage
832/// - Parallelism settings
833/// - Join strategies
834/// - Materialization preferences
835///
836/// # Database Support
837///
838/// | Hint Type | PostgreSQL | MySQL | SQLite | MSSQL |
839/// |-----------|------------|-------|--------|-------|
840/// | Index     | ✅ (GUC)   | ✅    | ✅     | ✅    |
841/// | Parallel  | ✅         | ❌    | ❌     | ✅    |
842/// | Join      | ✅         | ✅    | ❌     | ✅    |
843/// | CTE Mat   | ✅         | ❌    | ❌     | ❌    |
844///
845/// # Example
846///
847/// ```rust
848/// use prax_query::db_optimize::QueryHints;
849/// use prax_query::sql::DatabaseType;
850///
851/// let hints = QueryHints::new()
852///     .index_hint("users_email_idx")
853///     .parallel(4)
854///     .no_seq_scan();
855///
856/// let sql = hints.apply_to_query("SELECT * FROM users WHERE email = $1", DatabaseType::PostgreSQL);
857/// ```
858#[derive(Debug, Clone, Default)]
859pub struct QueryHints {
860    /// Index hints.
861    pub indexes: SmallVec<[IndexHint; 4]>,
862    /// Parallelism level (0 = default, >0 = specific workers).
863    pub parallel_workers: Option<u32>,
864    /// Join method hints.
865    pub join_hints: SmallVec<[JoinHint; 4]>,
866    /// Whether to prevent sequential scans.
867    pub no_seq_scan: bool,
868    /// Whether to prevent index scans.
869    pub no_index_scan: bool,
870    /// CTE materialization preference.
871    pub cte_materialized: Option<bool>,
872    /// Query timeout in milliseconds.
873    pub timeout_ms: Option<u64>,
874    /// Custom database-specific hints.
875    pub custom: Vec<String>,
876}
877
878/// An index hint.
879#[derive(Debug, Clone)]
880pub struct IndexHint {
881    /// Table the index belongs to.
882    pub table: Option<String>,
883    /// Index name.
884    pub index_name: String,
885    /// Hint type.
886    pub hint_type: IndexHintType,
887}
888
889/// Type of index hint.
890#[derive(Debug, Clone, Copy, PartialEq, Eq)]
891pub enum IndexHintType {
892    /// Force use of this index.
893    Use,
894    /// Force ignore of this index.
895    Ignore,
896    /// Prefer this index if possible.
897    Prefer,
898}
899
900/// A join method hint.
901#[derive(Debug, Clone)]
902pub struct JoinHint {
903    /// Tables involved in the join.
904    pub tables: Vec<String>,
905    /// Join method to use.
906    pub method: JoinMethod,
907}
908
909/// Join methods.
910#[derive(Debug, Clone, Copy, PartialEq, Eq)]
911pub enum JoinMethod {
912    /// Nested loop join.
913    NestedLoop,
914    /// Hash join.
915    Hash,
916    /// Merge join.
917    Merge,
918}
919
920impl QueryHints {
921    /// Create new empty hints.
922    pub fn new() -> Self {
923        Self::default()
924    }
925
926    /// Add an index hint.
927    pub fn index_hint(mut self, index_name: impl Into<String>) -> Self {
928        self.indexes.push(IndexHint {
929            table: None,
930            index_name: index_name.into(),
931            hint_type: IndexHintType::Use,
932        });
933        self
934    }
935
936    /// Add an index hint for a specific table.
937    pub fn index_hint_for_table(
938        mut self,
939        table: impl Into<String>,
940        index_name: impl Into<String>,
941    ) -> Self {
942        self.indexes.push(IndexHint {
943            table: Some(table.into()),
944            index_name: index_name.into(),
945            hint_type: IndexHintType::Use,
946        });
947        self
948    }
949
950    /// Ignore a specific index.
951    pub fn ignore_index(mut self, index_name: impl Into<String>) -> Self {
952        self.indexes.push(IndexHint {
953            table: None,
954            index_name: index_name.into(),
955            hint_type: IndexHintType::Ignore,
956        });
957        self
958    }
959
960    /// Set parallelism level.
961    pub fn parallel(mut self, workers: u32) -> Self {
962        self.parallel_workers = Some(workers);
963        self
964    }
965
966    /// Disable parallel execution.
967    pub fn no_parallel(mut self) -> Self {
968        self.parallel_workers = Some(0);
969        self
970    }
971
972    /// Prevent sequential scans.
973    pub fn no_seq_scan(mut self) -> Self {
974        self.no_seq_scan = true;
975        self
976    }
977
978    /// Prevent index scans.
979    pub fn no_index_scan(mut self) -> Self {
980        self.no_index_scan = true;
981        self
982    }
983
984    /// Set CTE materialization preference.
985    pub fn cte_materialized(mut self, materialized: bool) -> Self {
986        self.cte_materialized = Some(materialized);
987        self
988    }
989
990    /// Force nested loop join.
991    pub fn nested_loop_join(mut self, tables: Vec<String>) -> Self {
992        self.join_hints.push(JoinHint {
993            tables,
994            method: JoinMethod::NestedLoop,
995        });
996        self
997    }
998
999    /// Force hash join.
1000    pub fn hash_join(mut self, tables: Vec<String>) -> Self {
1001        self.join_hints.push(JoinHint {
1002            tables,
1003            method: JoinMethod::Hash,
1004        });
1005        self
1006    }
1007
1008    /// Force merge join.
1009    pub fn merge_join(mut self, tables: Vec<String>) -> Self {
1010        self.join_hints.push(JoinHint {
1011            tables,
1012            method: JoinMethod::Merge,
1013        });
1014        self
1015    }
1016
1017    /// Set query timeout.
1018    pub fn timeout(mut self, ms: u64) -> Self {
1019        self.timeout_ms = Some(ms);
1020        self
1021    }
1022
1023    /// Add a custom database-specific hint.
1024    pub fn custom_hint(mut self, hint: impl Into<String>) -> Self {
1025        self.custom.push(hint.into());
1026        self
1027    }
1028
1029    /// Generate hints as SQL prefix for the given database.
1030    pub fn to_sql_prefix(&self, db_type: DatabaseType) -> String {
1031        match db_type {
1032            DatabaseType::PostgreSQL => self.to_postgres_prefix(),
1033            DatabaseType::MySQL => self.to_mysql_prefix(),
1034            DatabaseType::SQLite => self.to_sqlite_prefix(),
1035            DatabaseType::MSSQL => self.to_mssql_prefix(),
1036        }
1037    }
1038
1039    /// Generate hints as SQL suffix (for query options).
1040    pub fn to_sql_suffix(&self, db_type: DatabaseType) -> String {
1041        match db_type {
1042            DatabaseType::MySQL => self.to_mysql_suffix(),
1043            DatabaseType::MSSQL => self.to_mssql_suffix(),
1044            _ => String::new(),
1045        }
1046    }
1047
1048    /// Apply hints to a query.
1049    pub fn apply_to_query(&self, query: &str, db_type: DatabaseType) -> String {
1050        let prefix = self.to_sql_prefix(db_type);
1051        let suffix = self.to_sql_suffix(db_type);
1052
1053        if prefix.is_empty() && suffix.is_empty() {
1054            return query.to_string();
1055        }
1056
1057        let mut result = String::with_capacity(prefix.len() + query.len() + suffix.len() + 2);
1058        if !prefix.is_empty() {
1059            result.push_str(&prefix);
1060            result.push('\n');
1061        }
1062        result.push_str(query);
1063        if !suffix.is_empty() {
1064            result.push(' ');
1065            result.push_str(&suffix);
1066        }
1067        result
1068    }
1069
1070    fn to_postgres_prefix(&self) -> String {
1071        let mut settings: Vec<String> = Vec::new();
1072
1073        if self.no_seq_scan {
1074            settings.push("SET LOCAL enable_seqscan = off;".to_string());
1075        }
1076        if self.no_index_scan {
1077            settings.push("SET LOCAL enable_indexscan = off;".to_string());
1078        }
1079        if let Some(workers) = self.parallel_workers {
1080            settings.push(format!(
1081                "SET LOCAL max_parallel_workers_per_gather = {};",
1082                workers
1083            ));
1084        }
1085        if let Some(ms) = self.timeout_ms {
1086            settings.push(format!("SET LOCAL statement_timeout = {};", ms));
1087        }
1088
1089        // Join hints
1090        for hint in &self.join_hints {
1091            match hint.method {
1092                JoinMethod::NestedLoop => {
1093                    settings.push("SET LOCAL enable_hashjoin = off;".to_string());
1094                    settings.push("SET LOCAL enable_mergejoin = off;".to_string());
1095                }
1096                JoinMethod::Hash => {
1097                    settings.push("SET LOCAL enable_nestloop = off;".to_string());
1098                    settings.push("SET LOCAL enable_mergejoin = off;".to_string());
1099                }
1100                JoinMethod::Merge => {
1101                    settings.push("SET LOCAL enable_nestloop = off;".to_string());
1102                    settings.push("SET LOCAL enable_hashjoin = off;".to_string());
1103                }
1104            }
1105        }
1106
1107        // Custom hints
1108        for hint in &self.custom {
1109            settings.push(hint.clone());
1110        }
1111
1112        settings.join("\n")
1113    }
1114
1115    fn to_mysql_prefix(&self) -> String {
1116        // MySQL uses inline hints, not SET statements
1117        String::new()
1118    }
1119
1120    fn to_mysql_suffix(&self) -> String {
1121        let mut hints: Vec<String> = Vec::new();
1122
1123        // Index hints (applied after table name in actual query, but we return as hint comment)
1124        for hint in &self.indexes {
1125            let hint_type = match hint.hint_type {
1126                IndexHintType::Use => "USE INDEX",
1127                IndexHintType::Ignore => "IGNORE INDEX",
1128                IndexHintType::Prefer => "FORCE INDEX",
1129            };
1130            if let Some(ref table) = hint.table {
1131                hints.push(format!(
1132                    "/* {} FOR {} ({}) */",
1133                    hint_type, table, hint.index_name
1134                ));
1135            } else {
1136                hints.push(format!("/* {} ({}) */", hint_type, hint.index_name));
1137            }
1138        }
1139
1140        // Join hints
1141        for hint in &self.join_hints {
1142            let method = match hint.method {
1143                JoinMethod::NestedLoop => "BNL",
1144                JoinMethod::Hash => "HASH_JOIN",
1145                JoinMethod::Merge => "MERGE",
1146            };
1147            hints.push(format!("/* {}({}) */", method, hint.tables.join(", ")));
1148        }
1149
1150        hints.join(" ")
1151    }
1152
1153    fn to_sqlite_prefix(&self) -> String {
1154        // SQLite has limited hint support
1155        String::new()
1156    }
1157
1158    fn to_mssql_prefix(&self) -> String {
1159        // MSSQL uses inline OPTION hints
1160        String::new()
1161    }
1162
1163    fn to_mssql_suffix(&self) -> String {
1164        let mut options: Vec<String> = Vec::new();
1165
1166        // Index hints
1167        for hint in &self.indexes {
1168            match hint.hint_type {
1169                IndexHintType::Use => {
1170                    if let Some(ref table) = hint.table {
1171                        options.push(format!("TABLE HINT({}, INDEX({}))", table, hint.index_name));
1172                    }
1173                }
1174                IndexHintType::Ignore => {
1175                    // MSSQL doesn't have ignore index, skip
1176                }
1177                IndexHintType::Prefer => {
1178                    if let Some(ref table) = hint.table {
1179                        options.push(format!(
1180                            "TABLE HINT({}, FORCESEEK({}))",
1181                            table, hint.index_name
1182                        ));
1183                    }
1184                }
1185            }
1186        }
1187
1188        // Parallelism
1189        if let Some(workers) = self.parallel_workers {
1190            if workers == 0 {
1191                options.push("MAXDOP 1".to_string());
1192            } else {
1193                options.push(format!("MAXDOP {}", workers));
1194            }
1195        }
1196
1197        // Join hints
1198        for hint in &self.join_hints {
1199            let method = match hint.method {
1200                JoinMethod::NestedLoop => "LOOP JOIN",
1201                JoinMethod::Hash => "HASH JOIN",
1202                JoinMethod::Merge => "MERGE JOIN",
1203            };
1204            options.push(method.to_string());
1205        }
1206
1207        if options.is_empty() {
1208            String::new()
1209        } else {
1210            format!("OPTION ({})", options.join(", "))
1211        }
1212    }
1213
1214    /// Check if any hints are configured.
1215    pub fn has_hints(&self) -> bool {
1216        !self.indexes.is_empty()
1217            || self.parallel_workers.is_some()
1218            || !self.join_hints.is_empty()
1219            || self.no_seq_scan
1220            || self.no_index_scan
1221            || self.cte_materialized.is_some()
1222            || self.timeout_ms.is_some()
1223            || !self.custom.is_empty()
1224    }
1225}
1226
1227// ==============================================================================
1228// Tests
1229// ==============================================================================
1230
1231#[cfg(test)]
1232mod tests {
1233    use super::*;
1234
1235    #[test]
1236    fn test_prepared_statement_cache() {
1237        let cache = PreparedStatementCache::new(10);
1238
1239        // First access - miss
1240        let stmt1 = cache.get_or_create("test", || "SELECT * FROM users".to_string());
1241        assert_eq!(stmt1.sql, "SELECT * FROM users");
1242
1243        let stats = cache.stats();
1244        assert_eq!(stats.misses, 1);
1245        assert_eq!(stats.hits, 0);
1246
1247        // Second access - hit
1248        let stmt2 = cache.get_or_create("test", || panic!("Should not be called"));
1249        assert_eq!(stmt2.sql, "SELECT * FROM users");
1250
1251        let stats = cache.stats();
1252        assert_eq!(stats.misses, 1);
1253        assert_eq!(stats.hits, 1);
1254        assert!(stats.hit_rate() > 0.0);
1255    }
1256
1257    #[test]
1258    fn test_batch_config_auto_tune() {
1259        // Small dataset
1260        let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 100, 50);
1261        assert_eq!(config.batch_size, 50); // No batching needed
1262
1263        // Medium dataset
1264        let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 500, 5000);
1265        assert!(config.batch_size >= 100);
1266        assert!(config.batch_size <= 5000);
1267
1268        // Large dataset
1269        let config = BatchConfig::auto_tune(DatabaseType::PostgreSQL, 200, 100_000);
1270        assert!(config.use_copy); // Should use COPY for large PG imports
1271        assert!(config.batch_size >= 100);
1272    }
1273
1274    #[test]
1275    fn test_batch_ranges() {
1276        let config = BatchConfig {
1277            batch_size: 100,
1278            ..Default::default()
1279        };
1280
1281        let ranges: Vec<_> = config.batch_ranges(250).collect();
1282        assert_eq!(ranges.len(), 3);
1283        assert_eq!(ranges[0], (0, 100));
1284        assert_eq!(ranges[1], (100, 200));
1285        assert_eq!(ranges[2], (200, 250));
1286    }
1287
1288    #[test]
1289    fn test_mongo_pipeline_builder() {
1290        let pipeline = MongoPipelineBuilder::new()
1291            .match_stage(r#"{ "status": "active" }"#)
1292            .lookup("orders", "user_id", "_id", "user_orders")
1293            .unwind("$user_orders")
1294            .group(r#""$user_id""#, r#""total": { "$sum": "$amount" }"#)
1295            .sort(r#"{ "total": -1 }"#)
1296            .limit(10)
1297            .build();
1298
1299        assert!(pipeline.contains("$match"));
1300        assert!(pipeline.contains("$lookup"));
1301        assert!(pipeline.contains("$unwind"));
1302        assert!(pipeline.contains("$group"));
1303        assert!(pipeline.contains("$sort"));
1304        assert!(pipeline.contains("$limit"));
1305    }
1306
1307    #[test]
1308    fn test_query_hints_postgres() {
1309        let hints = QueryHints::new().no_seq_scan().parallel(4).timeout(5000);
1310
1311        let prefix = hints.to_sql_prefix(DatabaseType::PostgreSQL);
1312        assert!(prefix.contains("enable_seqscan = off"));
1313        assert!(prefix.contains("max_parallel_workers_per_gather = 4"));
1314        assert!(prefix.contains("statement_timeout = 5000"));
1315    }
1316
1317    #[test]
1318    fn test_query_hints_mssql() {
1319        let hints = QueryHints::new()
1320            .parallel(2)
1321            .hash_join(vec!["users".to_string(), "orders".to_string()]);
1322
1323        let suffix = hints.to_sql_suffix(DatabaseType::MSSQL);
1324        assert!(suffix.contains("MAXDOP 2"));
1325        assert!(suffix.contains("HASH JOIN"));
1326    }
1327
1328    #[test]
1329    fn test_query_hints_apply() {
1330        let hints = QueryHints::new().no_seq_scan();
1331
1332        let query = "SELECT * FROM users WHERE id = $1";
1333        let result = hints.apply_to_query(query, DatabaseType::PostgreSQL);
1334
1335        assert!(result.contains("enable_seqscan = off"));
1336        assert!(result.contains("SELECT * FROM users"));
1337    }
1338}