Skip to main content

rivet/state/
shape.rs

1//! Data shape drift tracking (Epic 8).
2//!
3//! Tracks the maximum observed byte length per string/binary column across runs.
4//! On each run the current-run maxima are compared against the stored maxima;
5//! columns that grew beyond `warn_factor × stored_max` are returned as warnings.
6//! Stored maxima are always updated to `max(stored, current)` — shape drift
7//! tracking is advisory and never blocks a run.
8
9use std::collections::HashMap;
10
11use crate::error::Result;
12
13use super::{StateConn, StateStore, pg_sql};
14
15/// One column whose observed max byte length grew beyond the configured threshold.
16pub struct ShapeWarning {
17    pub column: String,
18    pub stored_max_bytes: u64,
19    pub current_max_bytes: u64,
20    /// `current_max_bytes / stored_max_bytes` — always > `warn_factor`.
21    pub growth_factor: f64,
22}
23
24impl StateStore {
25    /// Return the stored per-column max byte lengths for `export_name`.
26    pub fn get_shape_stats(&self, export_name: &str) -> Result<HashMap<String, u64>> {
27        let sql = "SELECT column_name, max_byte_len FROM export_shape WHERE export_name = ?1";
28        match &self.conn {
29            StateConn::Sqlite(c) => {
30                let mut stmt = c.prepare(sql)?;
31                let rows = stmt.query_map([export_name], |row| {
32                    Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)? as u64))
33                })?;
34                let mut map = HashMap::new();
35                for r in rows {
36                    let (k, v) = r?;
37                    map.insert(k, v);
38                }
39                Ok(map)
40            }
41            StateConn::Postgres(client) => {
42                let mut c = client.borrow_mut();
43                let rows = c.query(&pg_sql(sql), &[&export_name])?;
44                let mut map = HashMap::new();
45                for row in rows {
46                    let k: String = row.get(0);
47                    let v: i64 = row.get(1);
48                    map.insert(k, v as u64);
49                }
50                Ok(map)
51            }
52        }
53    }
54
55    /// Upsert per-column max byte lengths, keeping the running maximum.
56    pub fn store_shape_stats(&self, export_name: &str, stats: &HashMap<String, u64>) -> Result<()> {
57        let now = chrono::Utc::now().to_rfc3339();
58        let sql = "INSERT INTO export_shape (export_name, column_name, max_byte_len, updated_at)
59                 VALUES (?1, ?2, ?3, ?4)
60                 ON CONFLICT(export_name, column_name) DO UPDATE SET
61                     max_byte_len = MAX(max_byte_len, excluded.max_byte_len),
62                     updated_at   = excluded.updated_at";
63        match &self.conn {
64            StateConn::Sqlite(c) => {
65                for (col, &max_bytes) in stats {
66                    c.execute(
67                        sql,
68                        rusqlite::params![export_name, col, max_bytes as i64, now],
69                    )?;
70                }
71            }
72            StateConn::Postgres(client) => {
73                let mut c = client.borrow_mut();
74                for (col, &max_bytes) in stats {
75                    c.execute(
76                        &pg_sql(sql),
77                        &[&export_name, col, &(max_bytes as i64), &now],
78                    )?;
79                }
80            }
81        }
82        Ok(())
83    }
84
85    /// Compare `current` run's per-column maxima against stored history.
86    ///
87    /// Returns a warning for every column whose `current_max > stored_max * warn_factor`.
88    /// The stored maxima are updated to `max(stored, current)` unconditionally so that
89    /// the running high-water mark is always current.
90    ///
91    /// First-run columns (no stored record) are silently accepted.
92    pub fn detect_shape_drift(
93        &self,
94        export_name: &str,
95        current: &HashMap<String, u64>,
96        warn_factor: f64,
97    ) -> Result<Vec<ShapeWarning>> {
98        let stored = self.get_shape_stats(export_name)?;
99        let mut warnings = Vec::new();
100
101        for (col, &current_max) in current {
102            if let Some(&stored_max) = stored.get(col)
103                && stored_max > 0
104                && (current_max as f64) > stored_max as f64 * warn_factor
105            {
106                warnings.push(ShapeWarning {
107                    column: col.clone(),
108                    stored_max_bytes: stored_max,
109                    current_max_bytes: current_max,
110                    growth_factor: current_max as f64 / stored_max as f64,
111                });
112            }
113        }
114
115        self.store_shape_stats(export_name, current)?;
116        Ok(warnings)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn store() -> StateStore {
125        StateStore::open_in_memory().expect("in-memory store")
126    }
127
128    #[test]
129    fn first_run_no_warnings() {
130        let s = store();
131        let stats: HashMap<String, u64> =
132            [("notes".into(), 512u64), ("description".into(), 1024u64)].into();
133        let warnings = s.detect_shape_drift("orders", &stats, 2.0).unwrap();
134        assert!(warnings.is_empty(), "first run must not warn");
135    }
136
137    #[test]
138    fn growth_below_threshold_no_warning() {
139        let s = store();
140        let v1: HashMap<String, u64> = [("body".into(), 1000u64)].into();
141        s.detect_shape_drift("t", &v1, 2.0).unwrap();
142
143        let v2: HashMap<String, u64> = [("body".into(), 1800u64)].into();
144        let warnings = s.detect_shape_drift("t", &v2, 2.0).unwrap();
145        assert!(warnings.is_empty());
146    }
147
148    #[test]
149    fn growth_above_threshold_warns() {
150        let s = store();
151        let v1: HashMap<String, u64> = [("body".into(), 1000u64)].into();
152        s.detect_shape_drift("t", &v1, 2.0).unwrap();
153
154        let v2: HashMap<String, u64> = [("body".into(), 2500u64)].into();
155        let warnings = s.detect_shape_drift("t", &v2, 2.0).unwrap();
156        assert_eq!(warnings.len(), 1);
157        assert_eq!(warnings[0].column, "body");
158        assert_eq!(warnings[0].stored_max_bytes, 1000);
159        assert_eq!(warnings[0].current_max_bytes, 2500);
160        assert!((warnings[0].growth_factor - 2.5).abs() < 0.01);
161    }
162
163    #[test]
164    fn high_water_mark_advances_after_warning() {
165        let s = store();
166        let v1: HashMap<String, u64> = [("text".into(), 100u64)].into();
167        s.detect_shape_drift("t", &v1, 2.0).unwrap();
168
169        let v2: HashMap<String, u64> = [("text".into(), 300u64)].into();
170        s.detect_shape_drift("t", &v2, 2.0).unwrap();
171
172        let v3: HashMap<String, u64> = [("text".into(), 450u64)].into();
173        let warnings = s.detect_shape_drift("t", &v3, 2.0).unwrap();
174        assert!(
175            warnings.is_empty(),
176            "must not re-warn after high-water mark advanced"
177        );
178    }
179
180    #[test]
181    fn new_column_in_later_run_no_warning() {
182        let s = store();
183        let v1: HashMap<String, u64> = [("id_str".into(), 36u64)].into();
184        s.detect_shape_drift("t", &v1, 2.0).unwrap();
185
186        let v2: HashMap<String, u64> =
187            [("id_str".into(), 36u64), ("new_col".into(), 9999u64)].into();
188        let warnings = s.detect_shape_drift("t", &v2, 2.0).unwrap();
189        assert!(
190            warnings.is_empty(),
191            "new columns with no history must not warn"
192        );
193    }
194}