1use std::cell::RefCell;
2use std::collections::HashMap;
3
4#[derive(Debug, Clone)]
5pub struct StatsCache {
6 table_stats: HashMap<String, TableStats>,
7 column_stats: HashMap<(String, String), ColumnStats>,
8 table_hits: RefCell<HashMap<String, usize>>,
9 column_hits: RefCell<HashMap<(String, String), usize>>,
10 max_table_entries: usize,
11 max_column_entries: usize,
12}
13
14impl StatsCache {
15 pub fn new() -> Self {
16 Self {
17 table_stats: HashMap::new(),
18 column_stats: HashMap::new(),
19 table_hits: RefCell::new(HashMap::new()),
20 column_hits: RefCell::new(HashMap::new()),
21 max_table_entries: 128,
22 max_column_entries: 1024,
23 }
24 }
25
26 pub fn new_with_capacity(max_table_entries: usize, max_column_entries: usize) -> Self {
27 Self {
28 table_stats: HashMap::new(),
29 column_stats: HashMap::new(),
30 table_hits: RefCell::new(HashMap::new()),
31 column_hits: RefCell::new(HashMap::new()),
32 max_table_entries,
33 max_column_entries,
34 }
35 }
36
37 pub fn set_capacity(&mut self, max_table_entries: usize, max_column_entries: usize) {
38 self.max_table_entries = max_table_entries;
39 self.max_column_entries = max_column_entries;
40 self.evict_tables_if_needed();
41 self.evict_columns_if_needed();
42 }
43
44 pub fn is_empty(&self) -> bool {
45 self.table_stats.is_empty() && self.column_stats.is_empty()
46 }
47
48 pub fn insert_table_stats(&mut self, table: impl Into<String>, stats: TableStats) {
49 let table = table.into();
50 self.table_stats.insert(table.clone(), stats);
51 self.table_hits.borrow_mut().entry(table).or_insert(0);
52 self.evict_tables_if_needed();
53 }
54
55 pub fn table_stats(&self, table: &str) -> Option<&TableStats> {
56 if self.table_stats.contains_key(table) {
57 let mut hits = self.table_hits.borrow_mut();
58 let entry = hits.entry(table.to_string()).or_insert(0);
59 *entry += 1;
60 }
61 self.table_stats.get(table)
62 }
63
64 pub fn insert_column_stats(
65 &mut self,
66 table: impl Into<String>,
67 column: impl Into<String>,
68 stats: ColumnStats,
69 ) {
70 let key = (table.into(), column.into());
71 self.column_stats.insert(key.clone(), stats);
72 self.column_hits.borrow_mut().entry(key).or_insert(0);
73 self.evict_columns_if_needed();
74 }
75
76 pub fn column_stats(&self, table: &str, column: &str) -> Option<&ColumnStats> {
77 let key = (table.to_string(), column.to_string());
78 if self.column_stats.contains_key(&key) {
79 let mut hits = self.column_hits.borrow_mut();
80 let entry = hits.entry(key.clone()).or_insert(0);
81 *entry += 1;
82 }
83 self.column_stats.get(&key)
84 }
85
86 fn evict_tables_if_needed(&mut self) {
87 while self.table_stats.len() > self.max_table_entries {
88 let key = {
89 let hits = self.table_hits.borrow();
90 hits.iter()
91 .min_by_key(|(_, hits)| *hits)
92 .map(|(key, _)| key.clone())
93 };
94 if let Some(key) = key {
95 self.table_stats.remove(&key);
96 self.table_hits.borrow_mut().remove(&key);
97 } else {
98 break;
99 }
100 }
101 }
102
103 fn evict_columns_if_needed(&mut self) {
104 while self.column_stats.len() > self.max_column_entries {
105 let key = {
106 let hits = self.column_hits.borrow();
107 hits.iter()
108 .min_by_key(|(_, hits)| *hits)
109 .map(|(key, _)| key.clone())
110 };
111 if let Some(key) = key {
112 self.column_stats.remove(&key);
113 self.column_hits.borrow_mut().remove(&key);
114 } else {
115 break;
116 }
117 }
118 }
119}
120
121impl Default for StatsCache {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
128pub struct TableStats {
129 pub row_count: f64,
130}
131
132#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
133pub struct ColumnStats {
134 pub distinct_count: f64,
135 pub null_fraction: f64,
136}
137
138#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
139pub struct StatsSnapshot {
140 pub tables: Vec<(String, TableStats)>,
141 pub columns: Vec<(String, String, ColumnStats)>,
142}
143
144impl StatsSnapshot {
145 pub fn to_cache(&self) -> StatsCache {
146 let mut cache = StatsCache::new_with_capacity(self.tables.len(), self.columns.len());
147 for (table, stats) in &self.tables {
148 cache.insert_table_stats(table.clone(), stats.clone());
149 }
150 for (table, column, stats) in &self.columns {
151 cache.insert_column_stats(table.clone(), column.clone(), stats.clone());
152 }
153 cache
154 }
155
156 pub fn from_cache(cache: &StatsCache) -> Self {
157 let mut tables = cache
158 .table_stats
159 .iter()
160 .map(|(name, stats)| (name.clone(), stats.clone()))
161 .collect::<Vec<_>>();
162 let mut columns = cache
163 .column_stats
164 .iter()
165 .map(|((table, column), stats)| (table.clone(), column.clone(), stats.clone()))
166 .collect::<Vec<_>>();
167 tables.sort_by(|(a, _), (b, _)| a.cmp(b));
168 columns.sort_by(|(ta, ca, _), (tb, cb, _)| (ta, ca).cmp(&(tb, cb)));
169 Self { tables, columns }
170 }
171
172 pub fn load_json(path: impl AsRef<std::path::Path>) -> chryso_core::error::ChrysoResult<Self> {
173 let content = std::fs::read_to_string(path.as_ref()).map_err(|err| {
174 chryso_core::error::ChrysoError::new(format!("read stats snapshot failed: {err}"))
175 })?;
176 let snapshot = serde_json::from_str(&content).map_err(|err| {
177 chryso_core::error::ChrysoError::new(format!("parse stats snapshot failed: {err}"))
178 })?;
179 Ok(snapshot)
180 }
181
182 pub fn write_json(
183 &self,
184 path: impl AsRef<std::path::Path>,
185 ) -> chryso_core::error::ChrysoResult<()> {
186 let content = serde_json::to_string_pretty(self).map_err(|err| {
187 chryso_core::error::ChrysoError::new(format!("serialize stats snapshot failed: {err}"))
188 })?;
189 std::fs::write(path.as_ref(), format!("{content}\n")).map_err(|err| {
190 chryso_core::error::ChrysoError::new(format!("write stats snapshot failed: {err}"))
191 })?;
192 Ok(())
193 }
194}
195
196pub trait StatsProvider {
197 fn load_stats(
198 &self,
199 tables: &[String],
200 columns: &[(String, String)],
201 cache: &mut StatsCache,
202 ) -> chryso_core::ChrysoResult<()>;
203}
204
205pub mod analyze;
206pub mod catalog;
207pub mod functions;
208pub mod type_coercion;
209pub mod type_inference;
210pub mod types;
211
212#[cfg(test)]
213mod catalog_tests;
214
215#[cfg(test)]
216mod tests {
217 use super::{ColumnStats, StatsCache, StatsSnapshot, TableStats};
218
219 #[test]
220 fn stats_cache_roundtrip() {
221 let mut cache = StatsCache::new();
222 cache.insert_table_stats("users", TableStats { row_count: 42.0 });
223 cache.insert_column_stats(
224 "users",
225 "id",
226 ColumnStats {
227 distinct_count: 40.0,
228 null_fraction: 0.0,
229 },
230 );
231 assert_eq!(cache.table_stats("users").unwrap().row_count, 42.0);
232 assert_eq!(
233 cache.column_stats("users", "id").unwrap().distinct_count,
234 40.0
235 );
236 }
237
238 #[test]
239 fn stats_cache_lfu_eviction() {
240 let mut cache = StatsCache::new_with_capacity(1, 1);
241 cache.insert_table_stats("t1", TableStats { row_count: 1.0 });
242 cache.insert_table_stats("t2", TableStats { row_count: 2.0 });
243 assert!(cache.table_stats("t1").is_none() || cache.table_stats("t2").is_none());
244
245 cache.insert_column_stats(
246 "t1",
247 "c1",
248 ColumnStats {
249 distinct_count: 1.0,
250 null_fraction: 0.0,
251 },
252 );
253 cache.insert_column_stats(
254 "t1",
255 "c2",
256 ColumnStats {
257 distinct_count: 2.0,
258 null_fraction: 0.0,
259 },
260 );
261 let c1 = cache.column_stats("t1", "c1").is_some();
262 let c2 = cache.column_stats("t1", "c2").is_some();
263 assert!(c1 ^ c2);
264 }
265
266 #[test]
267 fn stats_snapshot_is_sorted() {
268 let mut cache = StatsCache::new();
269 cache.insert_table_stats("b", TableStats { row_count: 2.0 });
270 cache.insert_table_stats("a", TableStats { row_count: 1.0 });
271 cache.insert_column_stats(
272 "b",
273 "y",
274 ColumnStats {
275 distinct_count: 2.0,
276 null_fraction: 0.0,
277 },
278 );
279 cache.insert_column_stats(
280 "a",
281 "z",
282 ColumnStats {
283 distinct_count: 3.0,
284 null_fraction: 0.0,
285 },
286 );
287 cache.insert_column_stats(
288 "a",
289 "b",
290 ColumnStats {
291 distinct_count: 4.0,
292 null_fraction: 0.0,
293 },
294 );
295
296 let snapshot = StatsSnapshot::from_cache(&cache);
297 let tables = snapshot
298 .tables
299 .iter()
300 .map(|(name, _)| name.as_str())
301 .collect::<Vec<_>>();
302 let columns = snapshot
303 .columns
304 .iter()
305 .map(|(table, column, _)| (table.as_str(), column.as_str()))
306 .collect::<Vec<_>>();
307
308 assert_eq!(tables, vec!["a", "b"]);
309 assert_eq!(columns, vec![("a", "b"), ("a", "z"), ("b", "y")]);
310 }
311}