1use futures::TryStreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json::Value;
4use sqlx::PgPool;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum SamplingStrategy {
9 Full,
12
13 Random { limit: usize },
16
17 ReservoirPK { sample_size: usize, pk: String },
20
21 TableSample { percentage: f32, limit: usize },
24}
25
26impl SamplingStrategy {
27 pub async fn auto_select(
34 pool: &PgPool,
35 schema: &str,
36 table: &str,
37 estimated_rows: Option<i64>,
38 sample_size: usize,
39 ) -> Result<Self, sqlx::Error> {
40 let row_count = match estimated_rows {
41 Some(count) if count > 0 => count,
42 _ => crate::discovery::get_row_count(pool, schema, table).await?,
43 };
44
45 if sample_size >= row_count as usize {
47 return Ok(Self::Full);
48 }
49
50 Ok(match row_count {
51 n if n < 100_000 => Self::Random { limit: sample_size },
52 n if n < 10_000_000 => {
53 match find_primary_key(pool, schema, table).await {
55 Ok(pk) => Self::ReservoirPK { sample_size, pk },
56 Err(_) => {
57 Self::Random { limit: sample_size }
59 }
60 }
61 }
62 _ => {
63 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
66 Self::TableSample {
67 percentage: pct,
68 limit: sample_size,
69 }
70 }
71 })
72 }
73
74 pub fn max_samples(&self) -> usize {
76 match self {
77 Self::Full => usize::MAX, Self::Random { limit } => *limit,
79 Self::ReservoirPK { sample_size, .. } => *sample_size,
80 Self::TableSample { limit, .. } => *limit,
81 }
82 }
83
84 fn build_query(&self, schema: &str, table: &str, column: &str) -> String {
85 let schema_quoted = quote_identifier(schema);
86 let table_quoted = quote_identifier(table);
87 let column_quoted = quote_identifier(column);
88
89 match self {
90 Self::Full => {
91 format!(
93 "SELECT {} FROM {}.{} WHERE {} IS NOT NULL",
94 column_quoted, schema_quoted, table_quoted, column_quoted
95 )
96 }
97 Self::Random { limit } => {
98 format!(
99 "SELECT {} FROM {}.{} WHERE {} IS NOT NULL ORDER BY random() LIMIT {}",
100 column_quoted, schema_quoted, table_quoted, column_quoted, limit
101 )
102 }
103 Self::ReservoirPK { sample_size, pk } => {
104 let pk_quoted = quote_identifier(pk);
105 format!(
108 "WITH random_ids AS (
109 SELECT floor(random() * (SELECT MAX({}) FROM {}.{}))::bigint AS rand_id
110 FROM generate_series(1, {} * 2)
111 )
112 SELECT t.{}
113 FROM {}.{} t
114 INNER JOIN random_ids r ON t.{} = r.rand_id
115 WHERE t.{} IS NOT NULL
116 LIMIT {}",
117 pk_quoted,
118 schema_quoted,
119 table_quoted, sample_size, column_quoted, schema_quoted,
123 table_quoted, pk_quoted, column_quoted, sample_size )
128 }
129 Self::TableSample { percentage, limit } => {
130 format!(
131 "SELECT {} FROM {}.{} TABLESAMPLE BERNOULLI({}) WHERE {} IS NOT NULL LIMIT {}",
132 column_quoted, schema_quoted, table_quoted, percentage, column_quoted, limit
133 )
134 }
135 }
136 }
137}
138
139pub struct Sampler {
140 strategy: SamplingStrategy,
141 show_progress: bool,
142}
143
144impl Sampler {
145 pub async fn new(
147 pool: &PgPool,
148 schema: &str,
149 table: &str,
150 estimated_rows: Option<i64>,
151 sample_size: usize,
152 ) -> Result<Self, sqlx::Error> {
153 let strategy =
154 SamplingStrategy::auto_select(pool, schema, table, estimated_rows, sample_size).await?;
155 Ok(Self {
156 strategy,
157 show_progress: true,
158 })
159 }
160
161 pub fn with_strategy(strategy: SamplingStrategy) -> Self {
163 Self {
164 strategy,
165 show_progress: true,
166 }
167 }
168
169 pub fn show_progress(mut self, enabled: bool) -> Self {
171 self.show_progress = enabled;
172 self
173 }
174
175 pub async fn sample(
183 &self,
184 pool: &PgPool,
185 schema: &str,
186 table: &str,
187 column: &str,
188 ) -> Result<Vec<Value>, sqlx::Error> {
189 let query = self.strategy.build_query(schema, table, column);
190 let max_samples = self.strategy.max_samples();
191
192 let progress = if self.show_progress {
194 let pb = ProgressBar::new(max_samples as u64);
195 pb.set_style(
196 ProgressStyle::default_bar()
197 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} samples")
198 .expect("Invalid progress bar template")
199 .progress_chars("█▓▒░"),
200 );
201 Some(pb)
202 } else {
203 None
204 };
205
206 let mut samples = Vec::new();
208 let mut rows = sqlx::query_scalar::<_, Value>(&query).fetch(pool);
209
210 while let Some(value) = rows.try_next().await? {
212 samples.push(value);
213
214 if let Some(ref pb) = progress {
215 pb.set_position(samples.len() as u64);
216 }
217 }
218
219 if let Some(pb) = progress {
220 pb.finish_with_message(format!("Collected {} samples", samples.len()));
221 }
222
223 Ok(samples)
224 }
225 pub fn strategy_info(&self) -> String {
227 match &self.strategy {
228 SamplingStrategy::Full => "Full table scan (all non-NULL rows)".to_string(),
229 SamplingStrategy::Random { limit } => {
230 format!("Random sampling (up to {} rows)", limit)
231 }
232 SamplingStrategy::ReservoirPK { sample_size, pk } => {
233 format!(
234 "Reservoir sampling using PK '{}' (up to {} rows)",
235 pk, sample_size
236 )
237 }
238 SamplingStrategy::TableSample { percentage, limit } => {
239 format!("TABLESAMPLE {:.2}% (up to {} rows)", percentage, limit)
240 }
241 }
242 }
243}
244
245async fn find_primary_key(pool: &PgPool, schema: &str, table: &str) -> Result<String, sqlx::Error> {
246 let pk: Option<String> = sqlx::query_scalar(
247 r#"
248 SELECT a.attname
249 FROM pg_index i
250 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
251 WHERE i.indrelid = ($1 || '.' || $2)::regclass
252 AND i.indisprimary
253 LIMIT 1
254 "#,
255 )
256 .bind(schema)
257 .bind(table)
258 .fetch_optional(pool)
259 .await?;
260
261 pk.ok_or_else(|| sqlx::Error::RowNotFound)
262}
263
264fn quote_identifier(identifier: &str) -> String {
265 format!("\"{}\"", identifier.replace("\"", "\"\""))
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_strategy_max_samples() {
274 let random = SamplingStrategy::Random { limit: 5000 };
275 assert_eq!(random.max_samples(), 5000);
276
277 let reservoir = SamplingStrategy::ReservoirPK {
278 sample_size: 10000,
279 pk: "id".to_string(),
280 };
281 assert_eq!(reservoir.max_samples(), 10000);
282
283 let tablesample = SamplingStrategy::TableSample {
284 percentage: 1.0,
285 limit: 15000,
286 };
287 assert_eq!(tablesample.max_samples(), 15000);
288 }
289
290 #[test]
291 fn test_build_query_random() {
292 let strategy = SamplingStrategy::Random { limit: 1000 };
293 let query = strategy.build_query("public", "users", "metadata");
294
295 assert!(query.contains("ORDER BY random()"));
296 assert!(query.contains("LIMIT 1000"));
297 assert!(query.contains("IS NOT NULL"));
298 assert!(query.contains("\"public\""));
299 assert!(query.contains("\"users\""));
300 assert!(query.contains("\"metadata\""));
301 }
302
303 #[test]
304 fn test_build_query_reservoir() {
305 let strategy = SamplingStrategy::ReservoirPK {
306 sample_size: 5000,
307 pk: "id".to_string(),
308 };
309 let query = strategy.build_query("public", "users", "metadata");
310
311 assert!(query.contains("WITH random_ids"));
312 assert!(query.contains("generate_series"));
313 assert!(query.contains("INNER JOIN"));
314 assert!(query.contains("LIMIT 5000"));
315 assert!(query.contains("IS NOT NULL"));
316 }
317
318 #[test]
319 fn test_build_query_tablesample() {
320 let strategy = SamplingStrategy::TableSample {
321 percentage: 0.5,
322 limit: 10000,
323 };
324 let query = strategy.build_query("public", "users", "metadata");
325
326 assert!(query.contains("TABLESAMPLE BERNOULLI(0.5)"));
327 assert!(query.contains("LIMIT 10000"));
328 assert!(query.contains("IS NOT NULL"));
329 }
330
331 #[test]
332 fn test_quote_identifier() {
333 assert_eq!(quote_identifier("simple"), "\"simple\"");
334 assert_eq!(quote_identifier("with\"quote"), "\"with\"\"quote\"");
335 assert_eq!(quote_identifier("schema.table"), "\"schema.table\"");
336 }
337
338 #[test]
339 fn test_quote_identifier_sql_injection() {
340 assert_eq!(
342 quote_identifier("table\"; DROP TABLE users; --"),
343 "\"table\"\"; DROP TABLE users; --\""
344 );
345 }
346
347 #[test]
348 fn test_sampler_builder() {
349 let strategy = SamplingStrategy::Random { limit: 1000 };
350 let sampler = Sampler::with_strategy(strategy.clone()).show_progress(false);
351
352 assert_eq!(sampler.strategy, strategy);
353 assert!(!sampler.show_progress);
354 }
355
356 #[test]
357 fn test_sampler_default_settings() {
358 let strategy = SamplingStrategy::Random { limit: 5000 };
359 let sampler = Sampler::with_strategy(strategy);
360
361 assert!(sampler.show_progress);
362 }
363
364 #[test]
365 fn test_strategy_info_random() {
366 let sampler = Sampler::with_strategy(SamplingStrategy::Random { limit: 5000 });
367 assert_eq!(sampler.strategy_info(), "Random sampling (up to 5000 rows)");
368 }
369
370 #[test]
371 fn test_strategy_info_reservoir() {
372 let sampler = Sampler::with_strategy(SamplingStrategy::ReservoirPK {
373 sample_size: 10000,
374 pk: "user_id".to_string(),
375 });
376 assert_eq!(
377 sampler.strategy_info(),
378 "Reservoir sampling using PK 'user_id' (up to 10000 rows)"
379 );
380 }
381
382 #[test]
383 fn test_strategy_info_tablesample() {
384 let sampler = Sampler::with_strategy(SamplingStrategy::TableSample {
385 percentage: 2.5,
386 limit: 20000,
387 });
388 assert_eq!(
389 sampler.strategy_info(),
390 "TABLESAMPLE 2.50% (up to 20000 rows)"
391 );
392 }
393
394 #[test]
395 fn test_strategy_equality() {
396 let strat1 = SamplingStrategy::Random { limit: 1000 };
397 let strat2 = SamplingStrategy::Random { limit: 1000 };
398 let strat3 = SamplingStrategy::Random { limit: 2000 };
399
400 assert_eq!(strat1, strat2);
401 assert_ne!(strat1, strat3);
402 }
403
404 #[test]
405 fn test_tablesample_percentage_capped_at_100() {
406 let row_count = 10_000_000_i64;
408 let sample_size = 10_000_011_usize;
409
410 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
411
412 assert!(
413 pct <= 100.0,
414 "Percentage must not exceed 100.0, got {}",
415 pct
416 );
417 assert_eq!(
418 pct, 100.0,
419 "When sample_size > row_count, percentage should be capped at 100.0"
420 );
421 }
422
423 #[test]
424 fn test_tablesample_percentage_minimum() {
425 let row_count = 1_000_000_000_i64;
427 let sample_size = 100_usize;
428
429 let pct = (sample_size as f32 / row_count as f32 * 100.0).clamp(0.1, 100.0);
430
431 assert_eq!(pct, 0.1, "Minimum percentage should be 0.1");
432 }
433}