sql_splitter/analyzer/
mod.rs1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::splitter::Compression;
3use ahash::AHashMap;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone)]
9pub struct TableStats {
10 pub table_name: String,
11 pub insert_count: u64,
12 pub create_count: u64,
13 pub total_bytes: u64,
14 pub statement_count: u64,
15}
16
17impl TableStats {
18 fn new(table_name: String) -> Self {
19 Self {
20 table_name,
21 insert_count: 0,
22 create_count: 0,
23 total_bytes: 0,
24 statement_count: 0,
25 }
26 }
27}
28
29pub struct Analyzer {
30 input_file: PathBuf,
31 dialect: SqlDialect,
32 stats: AHashMap<String, TableStats>,
33}
34
35impl Analyzer {
36 pub fn new(input_file: PathBuf) -> Self {
37 Self {
38 input_file,
39 dialect: SqlDialect::default(),
40 stats: AHashMap::new(),
41 }
42 }
43
44 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
45 self.dialect = dialect;
46 self
47 }
48
49 pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
50 let file = File::open(&self.input_file)?;
51 let file_size = file.metadata()?.len();
52 let buffer_size = determine_buffer_size(file_size);
53 let dialect = self.dialect;
54
55 let compression = Compression::from_path(&self.input_file);
57 let reader: Box<dyn Read> = compression.wrap_reader(Box::new(file));
58
59 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
60
61 while let Some(stmt) = parser.read_statement()? {
62 let (stmt_type, table_name) =
63 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
64
65 if stmt_type == StatementType::Unknown || table_name.is_empty() {
66 continue;
67 }
68
69 self.update_stats(&table_name, stmt_type, stmt.len() as u64);
70 }
71
72 Ok(self.get_sorted_stats())
73 }
74
75 pub fn analyze_with_progress<F: Fn(u64) + 'static>(
76 mut self,
77 progress_fn: F,
78 ) -> anyhow::Result<Vec<TableStats>> {
79 let file = File::open(&self.input_file)?;
80 let file_size = file.metadata()?.len();
81 let buffer_size = determine_buffer_size(file_size);
82 let dialect = self.dialect;
83
84 let compression = Compression::from_path(&self.input_file);
86 let progress_reader = ProgressReader::new(file, progress_fn);
87 let reader: Box<dyn Read> = compression.wrap_reader(Box::new(progress_reader));
88
89 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
90
91 while let Some(stmt) = parser.read_statement()? {
92 let (stmt_type, table_name) =
93 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
94
95 if stmt_type == StatementType::Unknown || table_name.is_empty() {
96 continue;
97 }
98
99 self.update_stats(&table_name, stmt_type, stmt.len() as u64);
100 }
101
102 Ok(self.get_sorted_stats())
103 }
104
105 fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
106 let stats = self
107 .stats
108 .entry(table_name.to_string())
109 .or_insert_with(|| TableStats::new(table_name.to_string()));
110
111 stats.statement_count += 1;
112 stats.total_bytes += bytes;
113
114 match stmt_type {
115 StatementType::CreateTable => stats.create_count += 1,
116 StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
117 _ => {}
118 }
119 }
120
121 fn get_sorted_stats(&self) -> Vec<TableStats> {
122 let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
123 result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
124 result
125 }
126}
127
128struct ProgressReader<R: Read, F: Fn(u64)> {
129 reader: R,
130 callback: F,
131 bytes_read: u64,
132}
133
134impl<R: Read, F: Fn(u64)> ProgressReader<R, F> {
135 fn new(reader: R, callback: F) -> Self {
136 Self {
137 reader,
138 callback,
139 bytes_read: 0,
140 }
141 }
142}
143
144impl<R: Read, F: Fn(u64)> Read for ProgressReader<R, F> {
145 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
146 let n = self.reader.read(buf)?;
147 self.bytes_read += n as u64;
148 (self.callback)(self.bytes_read);
149 Ok(n)
150 }
151}