sql_splitter/analyzer/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::splitter::Compression;
3use ahash::AHashMap;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone)]
9pub struct TableStats {
10    pub table_name: String,
11    pub insert_count: u64,
12    pub create_count: u64,
13    pub total_bytes: u64,
14    pub statement_count: u64,
15}
16
17impl TableStats {
18    fn new(table_name: String) -> Self {
19        Self {
20            table_name,
21            insert_count: 0,
22            create_count: 0,
23            total_bytes: 0,
24            statement_count: 0,
25        }
26    }
27}
28
29pub struct Analyzer {
30    input_file: PathBuf,
31    dialect: SqlDialect,
32    stats: AHashMap<String, TableStats>,
33}
34
35impl Analyzer {
36    pub fn new(input_file: PathBuf) -> Self {
37        Self {
38            input_file,
39            dialect: SqlDialect::default(),
40            stats: AHashMap::new(),
41        }
42    }
43
44    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
45        self.dialect = dialect;
46        self
47    }
48
49    pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
50        let file = File::open(&self.input_file)?;
51        let file_size = file.metadata()?.len();
52        let buffer_size = determine_buffer_size(file_size);
53        let dialect = self.dialect;
54
55        // Detect and apply decompression
56        let compression = Compression::from_path(&self.input_file);
57        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(file));
58
59        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
60
61        while let Some(stmt) = parser.read_statement()? {
62            let (stmt_type, table_name) =
63                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
64
65            if stmt_type == StatementType::Unknown || table_name.is_empty() {
66                continue;
67            }
68
69            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
70        }
71
72        Ok(self.get_sorted_stats())
73    }
74
75    pub fn analyze_with_progress<F: Fn(u64) + 'static>(
76        mut self,
77        progress_fn: F,
78    ) -> anyhow::Result<Vec<TableStats>> {
79        let file = File::open(&self.input_file)?;
80        let file_size = file.metadata()?.len();
81        let buffer_size = determine_buffer_size(file_size);
82        let dialect = self.dialect;
83
84        // Detect and apply decompression
85        let compression = Compression::from_path(&self.input_file);
86        let progress_reader = ProgressReader::new(file, progress_fn);
87        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(progress_reader));
88
89        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
90
91        while let Some(stmt) = parser.read_statement()? {
92            let (stmt_type, table_name) =
93                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
94
95            if stmt_type == StatementType::Unknown || table_name.is_empty() {
96                continue;
97            }
98
99            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
100        }
101
102        Ok(self.get_sorted_stats())
103    }
104
105    fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
106        let stats = self
107            .stats
108            .entry(table_name.to_string())
109            .or_insert_with(|| TableStats::new(table_name.to_string()));
110
111        stats.statement_count += 1;
112        stats.total_bytes += bytes;
113
114        match stmt_type {
115            StatementType::CreateTable => stats.create_count += 1,
116            StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
117            _ => {}
118        }
119    }
120
121    fn get_sorted_stats(&self) -> Vec<TableStats> {
122        let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
123        result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
124        result
125    }
126}
127
128struct ProgressReader<R: Read, F: Fn(u64)> {
129    reader: R,
130    callback: F,
131    bytes_read: u64,
132}
133
134impl<R: Read, F: Fn(u64)> ProgressReader<R, F> {
135    fn new(reader: R, callback: F) -> Self {
136        Self {
137            reader,
138            callback,
139            bytes_read: 0,
140        }
141    }
142}
143
144impl<R: Read, F: Fn(u64)> Read for ProgressReader<R, F> {
145    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
146        let n = self.reader.read(buf)?;
147        self.bytes_read += n as u64;
148        (self.callback)(self.bytes_read);
149        Ok(n)
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use tempfile::TempDir;
157
158    #[test]
159    fn test_analyzer_basic() {
160        let temp_dir = TempDir::new().unwrap();
161        let input_file = temp_dir.path().join("input.sql");
162
163        std::fs::write(
164            &input_file,
165            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);\nCREATE TABLE posts (id INT);\nINSERT INTO posts VALUES (1);",
166        )
167        .unwrap();
168
169        let analyzer = Analyzer::new(input_file);
170        let stats = analyzer.analyze().unwrap();
171
172        assert_eq!(stats.len(), 2);
173
174        let users_stats = stats.iter().find(|s| s.table_name == "users").unwrap();
175        assert_eq!(users_stats.insert_count, 2);
176        assert_eq!(users_stats.create_count, 1);
177        assert_eq!(users_stats.statement_count, 3);
178
179        let posts_stats = stats.iter().find(|s| s.table_name == "posts").unwrap();
180        assert_eq!(posts_stats.insert_count, 1);
181        assert_eq!(posts_stats.create_count, 1);
182        assert_eq!(posts_stats.statement_count, 2);
183    }
184
185    #[test]
186    fn test_analyzer_sorted_by_insert_count() {
187        let temp_dir = TempDir::new().unwrap();
188        let input_file = temp_dir.path().join("input.sql");
189
190        std::fs::write(
191            &input_file,
192            b"CREATE TABLE a (id INT);\nINSERT INTO a VALUES (1);\nCREATE TABLE b (id INT);\nINSERT INTO b VALUES (1);\nINSERT INTO b VALUES (2);\nINSERT INTO b VALUES (3);",
193        )
194        .unwrap();
195
196        let analyzer = Analyzer::new(input_file);
197        let stats = analyzer.analyze().unwrap();
198
199        assert_eq!(stats[0].table_name, "b");
200        assert_eq!(stats[0].insert_count, 3);
201        assert_eq!(stats[1].table_name, "a");
202        assert_eq!(stats[1].insert_count, 1);
203    }
204}