1use futures::TryStreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json::Value;
4use sqlx::PgPool;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum SamplingStrategy {
9 Full,
12
13 Random { limit: usize },
16
17 ReservoirPK { sample_size: usize, pk: String },
20
21 TableSample { percentage: f32, limit: usize },
24}
25
26impl SamplingStrategy {
27 pub async fn auto_select(
34 pool: &PgPool,
35 schema: &str,
36 table: &str,
37 estimated_rows: Option<i64>,
38 sample_size: usize,
39 ) -> Result<Self, sqlx::Error> {
40 let row_count = match estimated_rows {
41 Some(count) if count > 0 => count,
42 _ => crate::discovery::get_row_count(pool, schema, table).await?,
43 };
44
45 if sample_size >= row_count as usize {
47 return Ok(Self::Full);
48 }
49
50 Ok(match row_count {
51 n if n < 100_000 => Self::Random { limit: sample_size },
52 n if n < 10_000_000 => {
53 match find_primary_key(pool, schema, table).await {
55 Ok(pk) => Self::ReservoirPK { sample_size, pk },
56 Err(_) => {
57 Self::Random { limit: sample_size }
59 }
60 }
61 }
62 _ => {
63 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
66 Self::TableSample {
67 percentage: pct,
68 limit: sample_size,
69 }
70 }
71 })
72 }
73
74 pub fn max_samples(&self) -> usize {
76 match self {
77 Self::Full => usize::MAX, Self::Random { limit } => *limit,
79 Self::ReservoirPK { sample_size, .. } => *sample_size,
80 Self::TableSample { limit, .. } => *limit,
81 }
82 }
83
84 fn build_query(&self, schema: &str, table: &str, column: &str) -> String {
85 let schema_quoted = quote_identifier(schema);
86 let table_quoted = quote_identifier(table);
87 let column_quoted = quote_identifier(column);
88
89 match self {
90 Self::Full => {
91 format!(
93 "SELECT {} FROM {}.{} WHERE {} IS NOT NULL",
94 column_quoted, schema_quoted, table_quoted, column_quoted
95 )
96 }
97 Self::Random { limit } => {
98 format!(
99 "SELECT {} FROM {}.{} WHERE {} IS NOT NULL ORDER BY random() LIMIT {}",
100 column_quoted, schema_quoted, table_quoted, column_quoted, limit
101 )
102 }
103 Self::ReservoirPK { sample_size, pk } => {
104 let pk_quoted = quote_identifier(pk);
105 format!(
108 "WITH random_ids AS (
109 SELECT floor(random() * (SELECT MAX({}) FROM {}.{}))::bigint AS rand_id
110 FROM generate_series(1, {} * 2)
111 )
112 SELECT t.{}
113 FROM {}.{} t
114 INNER JOIN random_ids r ON t.{} = r.rand_id
115 WHERE t.{} IS NOT NULL
116 LIMIT {}",
117 pk_quoted,
118 schema_quoted,
119 table_quoted, sample_size, column_quoted, schema_quoted,
123 table_quoted, pk_quoted, column_quoted, sample_size )
128 }
129 Self::TableSample { percentage, limit } => {
130 format!(
131 "SELECT {} FROM {}.{} TABLESAMPLE BERNOULLI({}) WHERE {} IS NOT NULL LIMIT {}",
132 column_quoted, schema_quoted, table_quoted, percentage, column_quoted, limit
133 )
134 }
135 }
136 }
137}
138
139pub struct Sampler {
140 strategy: SamplingStrategy,
141 show_progress: bool,
142}
143
144impl Sampler {
145 pub async fn new(
147 pool: &PgPool,
148 schema: &str,
149 table: &str,
150 estimated_rows: Option<i64>,
151 sample_size: usize,
152 ) -> Result<Self, sqlx::Error> {
153 let strategy =
154 SamplingStrategy::auto_select(pool, schema, table, estimated_rows, sample_size).await?;
155 Ok(Self {
156 strategy,
157 show_progress: true,
158 })
159 }
160
161 pub fn with_strategy(strategy: SamplingStrategy) -> Self {
163 Self {
164 strategy,
165 show_progress: true,
166 }
167 }
168
169 pub fn show_progress(mut self, enabled: bool) -> Self {
171 self.show_progress = enabled;
172 self
173 }
174
175 pub async fn sample(
183 &self,
184 pool: &PgPool,
185 schema: &str,
186 table: &str,
187 column: &str,
188 ) -> Result<Vec<Value>, sqlx::Error> {
189 let query = self.strategy.build_query(schema, table, column);
190 let max_samples = self.strategy.max_samples();
191
192 let progress = if self.show_progress {
194 let pb = ProgressBar::new(max_samples as u64);
195 pb.set_style(
196 ProgressStyle::default_bar()
197 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} samples")
198 .expect("Invalid progress bar template")
199 .progress_chars("█▓▒░"),
200 );
201 Some(pb)
202 } else {
203 None
204 };
205
206 let mut samples = Vec::new();
208 let mut rows = sqlx::query_scalar::<_, Value>(&query).fetch(pool);
209
210 while let Some(value) = rows.try_next().await? {
212 samples.push(value);
213
214 if let Some(ref pb) = progress {
215 pb.set_position(samples.len() as u64);
216 }
217 }
218
219 if let Some(pb) = progress {
220 pb.finish_with_message(format!("Collected {} samples", samples.len()));
221 }
222
223 Ok(samples)
224 }
225 pub fn strategy_info(&self) -> String {
227 match &self.strategy {
228 SamplingStrategy::Full => "Full table scan (all non-NULL rows)".to_string(),
229 SamplingStrategy::Random { limit } => {
230 format!("Random sampling (up to {} rows)", limit)
231 }
232 SamplingStrategy::ReservoirPK { sample_size, pk } => {
233 format!(
234 "Reservoir sampling using PK '{}' (up to {} rows)",
235 pk, sample_size
236 )
237 }
238 SamplingStrategy::TableSample { percentage, limit } => {
239 format!("TABLESAMPLE {:.2}% (up to {} rows)", percentage, limit)
240 }
241 }
242 }
243}
244
245async fn find_primary_key(pool: &PgPool, schema: &str, table: &str) -> Result<String, sqlx::Error> {
246 let pk: Option<String> = sqlx::query_scalar(
247 r#"
248 SELECT a.attname
249 FROM pg_index i
250 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
251 JOIN pg_type t ON t.oid = a.atttypid
252 WHERE i.indrelid = ($1 || '.' || $2)::regclass
253 AND i.indisprimary
254 AND t.typcategory = 'N'
255 LIMIT 1
256 "#,
257 )
258 .bind(schema)
259 .bind(table)
260 .fetch_optional(pool)
261 .await?;
262
263 pk.ok_or_else(|| sqlx::Error::RowNotFound)
264}
265
266fn quote_identifier(identifier: &str) -> String {
267 format!("\"{}\"", identifier.replace("\"", "\"\""))
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_strategy_max_samples() {
276 let random = SamplingStrategy::Random { limit: 5000 };
277 assert_eq!(random.max_samples(), 5000);
278
279 let reservoir = SamplingStrategy::ReservoirPK {
280 sample_size: 10000,
281 pk: "id".to_string(),
282 };
283 assert_eq!(reservoir.max_samples(), 10000);
284
285 let tablesample = SamplingStrategy::TableSample {
286 percentage: 1.0,
287 limit: 15000,
288 };
289 assert_eq!(tablesample.max_samples(), 15000);
290 }
291
292 #[test]
293 fn test_build_query_random() {
294 let strategy = SamplingStrategy::Random { limit: 1000 };
295 let query = strategy.build_query("public", "users", "metadata");
296
297 assert!(query.contains("ORDER BY random()"));
298 assert!(query.contains("LIMIT 1000"));
299 assert!(query.contains("IS NOT NULL"));
300 assert!(query.contains("\"public\""));
301 assert!(query.contains("\"users\""));
302 assert!(query.contains("\"metadata\""));
303 }
304
305 #[test]
306 fn test_build_query_reservoir() {
307 let strategy = SamplingStrategy::ReservoirPK {
308 sample_size: 5000,
309 pk: "id".to_string(),
310 };
311 let query = strategy.build_query("public", "users", "metadata");
312
313 assert!(query.contains("WITH random_ids"));
314 assert!(query.contains("generate_series"));
315 assert!(query.contains("INNER JOIN"));
316 assert!(query.contains("LIMIT 5000"));
317 assert!(query.contains("IS NOT NULL"));
318 }
319
320 #[test]
321 fn test_build_query_tablesample() {
322 let strategy = SamplingStrategy::TableSample {
323 percentage: 0.5,
324 limit: 10000,
325 };
326 let query = strategy.build_query("public", "users", "metadata");
327
328 assert!(query.contains("TABLESAMPLE BERNOULLI(0.5)"));
329 assert!(query.contains("LIMIT 10000"));
330 assert!(query.contains("IS NOT NULL"));
331 }
332
333 #[test]
334 fn test_quote_identifier() {
335 assert_eq!(quote_identifier("simple"), "\"simple\"");
336 assert_eq!(quote_identifier("with\"quote"), "\"with\"\"quote\"");
337 assert_eq!(quote_identifier("schema.table"), "\"schema.table\"");
338 }
339
340 #[test]
341 fn test_quote_identifier_sql_injection() {
342 assert_eq!(
344 quote_identifier("table\"; DROP TABLE users; --"),
345 "\"table\"\"; DROP TABLE users; --\""
346 );
347 }
348
349 #[test]
350 fn test_sampler_builder() {
351 let strategy = SamplingStrategy::Random { limit: 1000 };
352 let sampler = Sampler::with_strategy(strategy.clone()).show_progress(false);
353
354 assert_eq!(sampler.strategy, strategy);
355 assert!(!sampler.show_progress);
356 }
357
358 #[test]
359 fn test_sampler_default_settings() {
360 let strategy = SamplingStrategy::Random { limit: 5000 };
361 let sampler = Sampler::with_strategy(strategy);
362
363 assert!(sampler.show_progress);
364 }
365
366 #[test]
367 fn test_strategy_info_random() {
368 let sampler = Sampler::with_strategy(SamplingStrategy::Random { limit: 5000 });
369 assert_eq!(sampler.strategy_info(), "Random sampling (up to 5000 rows)");
370 }
371
372 #[test]
373 fn test_strategy_info_reservoir() {
374 let sampler = Sampler::with_strategy(SamplingStrategy::ReservoirPK {
375 sample_size: 10000,
376 pk: "user_id".to_string(),
377 });
378 assert_eq!(
379 sampler.strategy_info(),
380 "Reservoir sampling using PK 'user_id' (up to 10000 rows)"
381 );
382 }
383
384 #[test]
385 fn test_strategy_info_tablesample() {
386 let sampler = Sampler::with_strategy(SamplingStrategy::TableSample {
387 percentage: 2.5,
388 limit: 20000,
389 });
390 assert_eq!(
391 sampler.strategy_info(),
392 "TABLESAMPLE 2.50% (up to 20000 rows)"
393 );
394 }
395
396 #[test]
397 fn test_strategy_equality() {
398 let strat1 = SamplingStrategy::Random { limit: 1000 };
399 let strat2 = SamplingStrategy::Random { limit: 1000 };
400 let strat3 = SamplingStrategy::Random { limit: 2000 };
401
402 assert_eq!(strat1, strat2);
403 assert_ne!(strat1, strat3);
404 }
405
406 #[test]
407 fn test_tablesample_percentage_capped_at_100() {
408 let row_count = 10_000_000_i64;
410 let sample_size = 10_000_011_usize;
411
412 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
413
414 assert!(
415 pct <= 100.0,
416 "Percentage must not exceed 100.0, got {}",
417 pct
418 );
419 assert_eq!(
420 pct, 100.0,
421 "When sample_size > row_count, percentage should be capped at 100.0"
422 );
423 }
424
425 #[test]
426 fn test_tablesample_percentage_minimum() {
427 let row_count = 1_000_000_000_i64;
429 let sample_size = 100_usize;
430
431 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
432
433 assert_eq!(pct, 0.1, "Minimum percentage should be 0.1");
434 }
435}