Skip to main content

pgdrift_db/
sampler.rs

1use futures::TryStreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json::Value;
4use sqlx::PgPool;
5
6/// Sampling strategy selection based on table size
7#[derive(Debug, Clone, PartialEq)]
8pub enum SamplingStrategy {
9    /// Full table scan - when sample_size >= row_count
10    /// No randomization, deterministic results
11    Full,
12
13    /// Random sampling for smaller tables (< 100k rows)
14    /// Simple ORDER BY Random() LIMIT N
15    Random { limit: usize },
16
17    /// Reservoir sampling for medium tables (100K - 10M rows)
18    /// uses primary key based random sampling for better performance
19    ReservoirPK { sample_size: usize, pk: String },
20
21    /// TABLESAMPLE for larger tables (> 10M rows)
22    /// Postgresql's built in sampling  - fast and no table locks
23    TableSample { percentage: f32, limit: usize },
24}
25
26impl SamplingStrategy {
27    /// Auto select the best sampling strat based on table size
28    ///
29    /// # Strat selection
30    /// - < 100k rows: Random sampling
31    /// - 100k - 10M rows: Resevoir sampling with PK
32    /// - 10M rows: TABLESAMPLE
33    pub async fn auto_select(
34        pool: &PgPool,
35        schema: &str,
36        table: &str,
37        estimated_rows: Option<i64>,
38        sample_size: usize,
39    ) -> Result<Self, sqlx::Error> {
40        let row_count = match estimated_rows {
41            Some(count) if count > 0 => count,
42            _ => crate::discovery::get_row_count(pool, schema, table).await?,
43        };
44
45        // If requesting all or more rows than exist, do a full deterministic scan
46        if sample_size >= row_count as usize {
47            return Ok(Self::Full);
48        }
49
50        Ok(match row_count {
51            n if n < 100_000 => Self::Random { limit: sample_size },
52            n if n < 10_000_000 => {
53                // try to find pk for Reservoir sampling
54                match find_primary_key(pool, schema, table).await {
55                    Ok(pk) => Self::ReservoirPK { sample_size, pk },
56                    Err(_) => {
57                        // Fallback to random pk
58                        Self::Random { limit: sample_size }
59                    }
60                }
61            }
62            _ => {
63                // for very large tables
64                // Cap percentage at 100.0 (PostgreSQL limit) and minimum 0.1
65                let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
66                Self::TableSample {
67                    percentage: pct,
68                    limit: sample_size,
69                }
70            }
71        })
72    }
73
74    /// Get the max number of samples that this strat should return
75    pub fn max_samples(&self) -> usize {
76        match self {
77            Self::Full => usize::MAX, // Full scan - unknown size
78            Self::Random { limit } => *limit,
79            Self::ReservoirPK { sample_size, .. } => *sample_size,
80            Self::TableSample { limit, .. } => *limit,
81        }
82    }
83
84    fn build_query(&self, schema: &str, table: &str, column: &str) -> String {
85        let schema_quoted = quote_identifier(schema);
86        let table_quoted = quote_identifier(table);
87        let column_quoted = quote_identifier(column);
88
89        match self {
90            Self::Full => {
91                // Full table scan - deterministic, no randomization
92                format!(
93                    "SELECT {} FROM {}.{} WHERE {} IS NOT NULL",
94                    column_quoted, schema_quoted, table_quoted, column_quoted
95                )
96            }
97            Self::Random { limit } => {
98                format!(
99                    "SELECT {} FROM {}.{} WHERE {} IS NOT NULL ORDER BY random() LIMIT {}",
100                    column_quoted, schema_quoted, table_quoted, column_quoted, limit
101                )
102            }
103            Self::ReservoirPK { sample_size, pk } => {
104                let pk_quoted = quote_identifier(pk);
105                // True reservoir sampling: generate random IDs and fetch via index
106                // This is MUCH faster than ORDER BY random() because it uses the PK index
107                format!(
108                    "WITH random_ids AS (
109                        SELECT floor(random() * (SELECT MAX({}) FROM {}.{}))::bigint AS rand_id
110                        FROM generate_series(1, {} * 2)
111                    )
112                    SELECT t.{}
113                    FROM {}.{} t
114                    INNER JOIN random_ids r ON t.{} = r.rand_id
115                    WHERE t.{} IS NOT NULL
116                    LIMIT {}",
117                    pk_quoted,
118                    schema_quoted,
119                    table_quoted,  // MAX(pk)
120                    sample_size,   // Generate 2x samples to account for PK gaps
121                    column_quoted, // SELECT column
122                    schema_quoted,
123                    table_quoted,  // FROM table
124                    pk_quoted,     // JOIN ON pk
125                    column_quoted, // WHERE column IS NOT NULL
126                    sample_size    // LIMIT
127                )
128            }
129            Self::TableSample { percentage, limit } => {
130                format!(
131                    "SELECT {} FROM {}.{} TABLESAMPLE BERNOULLI({}) WHERE {} IS NOT NULL LIMIT {}",
132                    column_quoted, schema_quoted, table_quoted, percentage, column_quoted, limit
133                )
134            }
135        }
136    }
137}
138
139pub struct Sampler {
140    strategy: SamplingStrategy,
141    show_progress: bool,
142}
143
144impl Sampler {
145    /// Create a new sampler with auto select strat
146    pub async fn new(
147        pool: &PgPool,
148        schema: &str,
149        table: &str,
150        estimated_rows: Option<i64>,
151        sample_size: usize,
152    ) -> Result<Self, sqlx::Error> {
153        let strategy =
154            SamplingStrategy::auto_select(pool, schema, table, estimated_rows, sample_size).await?;
155        Ok(Self {
156            strategy,
157            show_progress: true,
158        })
159    }
160
161    /// Create a sampler with a specific strat
162    pub fn with_strategy(strategy: SamplingStrategy) -> Self {
163        Self {
164            strategy,
165            show_progress: true,
166        }
167    }
168
169    /// Enable or disable prog bar
170    pub fn show_progress(mut self, enabled: bool) -> Self {
171        self.show_progress = enabled;
172        self
173    }
174
175    //// Execute the sampling strat and return jsonb valuies
176    ///
177    /// # Production safety
178    /// In prod mode:
179    /// - Max 1% sampling for large tables
180    /// - Requires explicit confirmation (future work)
181    /// - shows estimated query
182    pub async fn sample(
183        &self,
184        pool: &PgPool,
185        schema: &str,
186        table: &str,
187        column: &str,
188    ) -> Result<Vec<Value>, sqlx::Error> {
189        let query = self.strategy.build_query(schema, table, column);
190        let max_samples = self.strategy.max_samples();
191
192        // Create progress bar if enabled
193        let progress = if self.show_progress {
194            let pb = ProgressBar::new(max_samples as u64);
195            pb.set_style(
196                ProgressStyle::default_bar()
197                    .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} samples")
198                    .expect("Invalid progress bar template")
199                    .progress_chars("█▓▒░"),
200            );
201            Some(pb)
202        } else {
203            None
204        };
205
206        // Execute query and collect results
207        let mut samples = Vec::new();
208        let mut rows = sqlx::query_scalar::<_, Value>(&query).fetch(pool);
209
210        // Use sqlx's streaming to handle large result sets
211        while let Some(value) = rows.try_next().await? {
212            samples.push(value);
213
214            if let Some(ref pb) = progress {
215                pb.set_position(samples.len() as u64);
216            }
217        }
218
219        if let Some(pb) = progress {
220            pb.finish_with_message(format!("Collected {} samples", samples.len()));
221        }
222
223        Ok(samples)
224    }
225    /// Get information about the sampling strategy
226    pub fn strategy_info(&self) -> String {
227        match &self.strategy {
228            SamplingStrategy::Full => "Full table scan (all non-NULL rows)".to_string(),
229            SamplingStrategy::Random { limit } => {
230                format!("Random sampling (up to {} rows)", limit)
231            }
232            SamplingStrategy::ReservoirPK { sample_size, pk } => {
233                format!(
234                    "Reservoir sampling using PK '{}' (up to {} rows)",
235                    pk, sample_size
236                )
237            }
238            SamplingStrategy::TableSample { percentage, limit } => {
239                format!("TABLESAMPLE {:.2}% (up to {} rows)", percentage, limit)
240            }
241        }
242    }
243}
244
245async fn find_primary_key(pool: &PgPool, schema: &str, table: &str) -> Result<String, sqlx::Error> {
246    let pk: Option<String> = sqlx::query_scalar(
247        r#"
248          SELECT a.attname
249          FROM pg_index i
250          JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
251          JOIN pg_type t ON t.oid = a.atttypid
252          WHERE i.indrelid = ($1 || '.' || $2)::regclass
253            AND i.indisprimary
254            AND t.typcategory = 'N'
255          LIMIT 1
256          "#,
257    )
258    .bind(schema)
259    .bind(table)
260    .fetch_optional(pool)
261    .await?;
262
263    pk.ok_or_else(|| sqlx::Error::RowNotFound)
264}
265
266fn quote_identifier(identifier: &str) -> String {
267    format!("\"{}\"", identifier.replace("\"", "\"\""))
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_strategy_max_samples() {
276        let random = SamplingStrategy::Random { limit: 5000 };
277        assert_eq!(random.max_samples(), 5000);
278
279        let reservoir = SamplingStrategy::ReservoirPK {
280            sample_size: 10000,
281            pk: "id".to_string(),
282        };
283        assert_eq!(reservoir.max_samples(), 10000);
284
285        let tablesample = SamplingStrategy::TableSample {
286            percentage: 1.0,
287            limit: 15000,
288        };
289        assert_eq!(tablesample.max_samples(), 15000);
290    }
291
292    #[test]
293    fn test_build_query_random() {
294        let strategy = SamplingStrategy::Random { limit: 1000 };
295        let query = strategy.build_query("public", "users", "metadata");
296
297        assert!(query.contains("ORDER BY random()"));
298        assert!(query.contains("LIMIT 1000"));
299        assert!(query.contains("IS NOT NULL"));
300        assert!(query.contains("\"public\""));
301        assert!(query.contains("\"users\""));
302        assert!(query.contains("\"metadata\""));
303    }
304
305    #[test]
306    fn test_build_query_reservoir() {
307        let strategy = SamplingStrategy::ReservoirPK {
308            sample_size: 5000,
309            pk: "id".to_string(),
310        };
311        let query = strategy.build_query("public", "users", "metadata");
312
313        assert!(query.contains("WITH random_ids"));
314        assert!(query.contains("generate_series"));
315        assert!(query.contains("INNER JOIN"));
316        assert!(query.contains("LIMIT 5000"));
317        assert!(query.contains("IS NOT NULL"));
318    }
319
320    #[test]
321    fn test_build_query_tablesample() {
322        let strategy = SamplingStrategy::TableSample {
323            percentage: 0.5,
324            limit: 10000,
325        };
326        let query = strategy.build_query("public", "users", "metadata");
327
328        assert!(query.contains("TABLESAMPLE BERNOULLI(0.5)"));
329        assert!(query.contains("LIMIT 10000"));
330        assert!(query.contains("IS NOT NULL"));
331    }
332
333    #[test]
334    fn test_quote_identifier() {
335        assert_eq!(quote_identifier("simple"), "\"simple\"");
336        assert_eq!(quote_identifier("with\"quote"), "\"with\"\"quote\"");
337        assert_eq!(quote_identifier("schema.table"), "\"schema.table\"");
338    }
339
340    #[test]
341    fn test_quote_identifier_sql_injection() {
342        // Ensure SQL injection attempts are properly escaped
343        assert_eq!(
344            quote_identifier("table\"; DROP TABLE users; --"),
345            "\"table\"\"; DROP TABLE users; --\""
346        );
347    }
348
349    #[test]
350    fn test_sampler_builder() {
351        let strategy = SamplingStrategy::Random { limit: 1000 };
352        let sampler = Sampler::with_strategy(strategy.clone()).show_progress(false);
353
354        assert_eq!(sampler.strategy, strategy);
355        assert!(!sampler.show_progress);
356    }
357
358    #[test]
359    fn test_sampler_default_settings() {
360        let strategy = SamplingStrategy::Random { limit: 5000 };
361        let sampler = Sampler::with_strategy(strategy);
362
363        assert!(sampler.show_progress);
364    }
365
366    #[test]
367    fn test_strategy_info_random() {
368        let sampler = Sampler::with_strategy(SamplingStrategy::Random { limit: 5000 });
369        assert_eq!(sampler.strategy_info(), "Random sampling (up to 5000 rows)");
370    }
371
372    #[test]
373    fn test_strategy_info_reservoir() {
374        let sampler = Sampler::with_strategy(SamplingStrategy::ReservoirPK {
375            sample_size: 10000,
376            pk: "user_id".to_string(),
377        });
378        assert_eq!(
379            sampler.strategy_info(),
380            "Reservoir sampling using PK 'user_id' (up to 10000 rows)"
381        );
382    }
383
384    #[test]
385    fn test_strategy_info_tablesample() {
386        let sampler = Sampler::with_strategy(SamplingStrategy::TableSample {
387            percentage: 2.5,
388            limit: 20000,
389        });
390        assert_eq!(
391            sampler.strategy_info(),
392            "TABLESAMPLE 2.50% (up to 20000 rows)"
393        );
394    }
395
396    #[test]
397    fn test_strategy_equality() {
398        let strat1 = SamplingStrategy::Random { limit: 1000 };
399        let strat2 = SamplingStrategy::Random { limit: 1000 };
400        let strat3 = SamplingStrategy::Random { limit: 2000 };
401
402        assert_eq!(strat1, strat2);
403        assert_ne!(strat1, strat3);
404    }
405
406    #[test]
407    fn test_tablesample_percentage_capped_at_100() {
408        // Simulate requesting more samples than rows in table
409        let row_count = 10_000_000_i64;
410        let sample_size = 10_000_011_usize;
411
412        let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
413
414        assert!(
415            pct <= 100.0,
416            "Percentage must not exceed 100.0, got {}",
417            pct
418        );
419        assert_eq!(
420            pct, 100.0,
421            "When sample_size > row_count, percentage should be capped at 100.0"
422        );
423    }
424
425    #[test]
426    fn test_tablesample_percentage_minimum() {
427        // Very large table with small sample size should respect minimum
428        let row_count = 1_000_000_000_i64;
429        let sample_size = 100_usize;
430
431        let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
432
433        assert_eq!(pct, 0.1, "Minimum percentage should be 0.1");
434    }
435}