Skip to main content

heliosdb_proxy/pool/
prepared.rs

1//! Prepared Statement Tracker
2//!
3//! Tracks prepared statements for recreation when switching connections
4//! in transaction or statement pooling modes.
5
6use std::collections::HashMap;
7
8/// Prepared statement information
9#[derive(Debug, Clone)]
10pub struct PreparedStatement {
11    /// Statement name
12    pub name: String,
13    /// SQL query
14    pub query: String,
15    /// Parameter types (OIDs)
16    pub param_types: Vec<u32>,
17    /// When the statement was prepared
18    pub prepared_at: chrono::DateTime<chrono::Utc>,
19    /// Number of times executed
20    pub execution_count: u64,
21}
22
23/// Tracker for prepared statements
24///
25/// Maintains a registry of prepared statements so they can be
26/// recreated on new backend connections.
27#[derive(Debug, Default)]
28pub struct PreparedStatementTracker {
29    /// Statements by name
30    statements: HashMap<String, PreparedStatement>,
31    /// Maximum statements to track (to prevent memory bloat)
32    max_statements: usize,
33    /// Total statements prepared
34    total_prepared: u64,
35    /// Total statements deallocated
36    total_deallocated: u64,
37}
38
39impl PreparedStatementTracker {
40    /// Create a new tracker with default capacity
41    pub fn new() -> Self {
42        Self::with_capacity(1000)
43    }
44
45    /// Create a new tracker with specified capacity
46    pub fn with_capacity(max_statements: usize) -> Self {
47        Self {
48            statements: HashMap::with_capacity(max_statements.min(100)),
49            max_statements,
50            total_prepared: 0,
51            total_deallocated: 0,
52        }
53    }
54
55    /// Register a prepared statement
56    ///
57    /// # Arguments
58    /// * `name` - Statement name (empty for unnamed)
59    /// * `query` - The SQL query
60    /// * `param_types` - Parameter type OIDs
61    pub fn register(&mut self, name: String, query: String, param_types: Vec<u32>) {
62        // Don't track unnamed statements
63        if name.is_empty() {
64            return;
65        }
66
67        // Check capacity
68        if self.statements.len() >= self.max_statements {
69            // Remove least recently used (oldest)
70            if let Some(oldest) = self
71                .statements
72                .iter()
73                .min_by_key(|(_, s)| s.prepared_at)
74                .map(|(k, _)| k.clone())
75            {
76                self.statements.remove(&oldest);
77                self.total_deallocated += 1;
78            }
79        }
80
81        self.statements.insert(
82            name.clone(),
83            PreparedStatement {
84                name,
85                query,
86                param_types,
87                prepared_at: chrono::Utc::now(),
88                execution_count: 0,
89            },
90        );
91
92        self.total_prepared += 1;
93    }
94
95    /// Remove a prepared statement
96    pub fn unregister(&mut self, name: &str) -> Option<PreparedStatement> {
97        let stmt = self.statements.remove(name);
98        if stmt.is_some() {
99            self.total_deallocated += 1;
100        }
101        stmt
102    }
103
104    /// Clear all statements (DEALLOCATE ALL)
105    pub fn clear(&mut self) {
106        self.total_deallocated += self.statements.len() as u64;
107        self.statements.clear();
108    }
109
110    /// Get a prepared statement by name
111    pub fn get(&self, name: &str) -> Option<&PreparedStatement> {
112        self.statements.get(name)
113    }
114
115    /// Record an execution of a statement
116    pub fn record_execution(&mut self, name: &str) {
117        if let Some(stmt) = self.statements.get_mut(name) {
118            stmt.execution_count += 1;
119        }
120    }
121
122    /// Check if a statement exists
123    pub fn contains(&self, name: &str) -> bool {
124        self.statements.contains_key(name)
125    }
126
127    /// Get all statements (for recreation on new connection)
128    pub fn all_statements(&self) -> impl Iterator<Item = &PreparedStatement> {
129        self.statements.values()
130    }
131
132    /// Get statement count
133    pub fn len(&self) -> usize {
134        self.statements.len()
135    }
136
137    /// Check if empty
138    pub fn is_empty(&self) -> bool {
139        self.statements.is_empty()
140    }
141
142    /// Generate PREPARE statements for all tracked statements
143    ///
144    /// Returns SQL to recreate all statements on a new connection.
145    pub fn generate_prepare_sql(&self) -> Vec<String> {
146        self.statements
147            .values()
148            .map(|stmt| {
149                if stmt.param_types.is_empty() {
150                    format!("PREPARE {} AS {}", stmt.name, stmt.query)
151                } else {
152                    let types: Vec<String> = stmt
153                        .param_types
154                        .iter()
155                        .map(|t| oid_to_type_name(*t))
156                        .collect();
157                    format!(
158                        "PREPARE {} ({}) AS {}",
159                        stmt.name,
160                        types.join(", "),
161                        stmt.query
162                    )
163                }
164            })
165            .collect()
166    }
167
168    /// Get statistics
169    pub fn stats(&self) -> TrackerStats {
170        TrackerStats {
171            active_statements: self.statements.len(),
172            total_prepared: self.total_prepared,
173            total_deallocated: self.total_deallocated,
174            max_capacity: self.max_statements,
175        }
176    }
177}
178
179/// Tracker statistics
180#[derive(Debug, Clone)]
181pub struct TrackerStats {
182    /// Currently tracked statements
183    pub active_statements: usize,
184    /// Total statements ever prepared
185    pub total_prepared: u64,
186    /// Total statements deallocated
187    pub total_deallocated: u64,
188    /// Maximum capacity
189    pub max_capacity: usize,
190}
191
192/// Convert PostgreSQL OID to type name
193///
194/// This is a simplified mapping for common types.
195fn oid_to_type_name(oid: u32) -> String {
196    match oid {
197        16 => "boolean".to_string(),
198        17 => "bytea".to_string(),
199        18 => "char".to_string(),
200        19 => "name".to_string(),
201        20 => "bigint".to_string(),
202        21 => "smallint".to_string(),
203        23 => "integer".to_string(),
204        25 => "text".to_string(),
205        26 => "oid".to_string(),
206        700 => "real".to_string(),
207        701 => "double precision".to_string(),
208        790 => "money".to_string(),
209        1042 => "char".to_string(),
210        1043 => "varchar".to_string(),
211        1082 => "date".to_string(),
212        1083 => "time".to_string(),
213        1114 => "timestamp".to_string(),
214        1184 => "timestamptz".to_string(),
215        1186 => "interval".to_string(),
216        1700 => "numeric".to_string(),
217        2950 => "uuid".to_string(),
218        3802 => "jsonb".to_string(),
219        _ => format!("unknown({})", oid),
220    }
221}
222
223/// Parse PREPARE statement to extract components
224///
225/// Returns (name, param_types, query) if successful.
226pub fn parse_prepare_statement(sql: &str) -> Option<(String, Vec<String>, String)> {
227    let sql = sql.trim();
228    let upper = sql.to_uppercase();
229
230    if !upper.starts_with("PREPARE ") {
231        return None;
232    }
233
234    // PREPARE name [(type, ...)] AS query
235    let rest = &sql[8..].trim_start(); // After "PREPARE "
236
237    // Find name (until space or open paren)
238    let name_end = rest
239        .find(|c: char| c.is_whitespace() || c == '(')
240        .unwrap_or(rest.len());
241    let name = rest[..name_end].to_string();
242    let rest = rest[name_end..].trim_start();
243
244    // Check for parameter types
245    let (param_types, rest) = if rest.starts_with('(') {
246        // Find matching close paren
247        if let Some(close) = rest.find(')') {
248            let types_str = &rest[1..close];
249            let types: Vec<String> = types_str
250                .split(',')
251                .map(|s| s.trim().to_string())
252                .filter(|s| !s.is_empty())
253                .collect();
254            (types, rest[close + 1..].trim_start())
255        } else {
256            (Vec::new(), rest)
257        }
258    } else {
259        (Vec::new(), rest)
260    };
261
262    // Check for AS
263    let upper_rest = rest.to_uppercase();
264    if !upper_rest.starts_with("AS ") {
265        return None;
266    }
267
268    let query = rest[3..].trim_start().to_string();
269
270    Some((name, param_types, query))
271}
272
273/// Parse DEALLOCATE statement
274///
275/// Returns the statement name or None for DEALLOCATE ALL.
276pub fn parse_deallocate_statement(sql: &str) -> Option<Option<String>> {
277    let sql = sql.trim();
278    let upper = sql.to_uppercase();
279
280    if !upper.starts_with("DEALLOCATE ") {
281        return None;
282    }
283
284    let rest = sql[11..].trim();
285    let upper_rest = rest.to_uppercase();
286
287    if upper_rest == "ALL" || upper_rest.starts_with("ALL ") || upper_rest.starts_with("ALL;") {
288        Some(None) // DEALLOCATE ALL
289    } else {
290        // Remove optional PREPARE keyword
291        let name = if upper_rest.starts_with("PREPARE ") {
292            rest[8..].trim()
293        } else {
294            rest
295        };
296        // Remove trailing semicolon if present
297        let name = name.trim_end_matches(';').trim();
298        Some(Some(name.to_string()))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_register_and_get() {
308        let mut tracker = PreparedStatementTracker::new();
309
310        tracker.register(
311            "stmt1".to_string(),
312            "SELECT * FROM users WHERE id = $1".to_string(),
313            vec![23],
314        );
315
316        assert!(tracker.contains("stmt1"));
317        let stmt = tracker.get("stmt1").unwrap();
318        assert_eq!(stmt.query, "SELECT * FROM users WHERE id = $1");
319        assert_eq!(stmt.param_types, vec![23]);
320    }
321
322    #[test]
323    fn test_unregister() {
324        let mut tracker = PreparedStatementTracker::new();
325
326        tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
327
328        assert!(tracker.contains("stmt1"));
329        tracker.unregister("stmt1");
330        assert!(!tracker.contains("stmt1"));
331    }
332
333    #[test]
334    fn test_clear() {
335        let mut tracker = PreparedStatementTracker::new();
336
337        tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
338        tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
339
340        assert_eq!(tracker.len(), 2);
341        tracker.clear();
342        assert!(tracker.is_empty());
343    }
344
345    #[test]
346    fn test_capacity_limit() {
347        let mut tracker = PreparedStatementTracker::with_capacity(3);
348
349        tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
350        tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
351        tracker.register("stmt3".to_string(), "SELECT 3".to_string(), vec![]);
352
353        // Adding a 4th should evict the oldest
354        tracker.register("stmt4".to_string(), "SELECT 4".to_string(), vec![]);
355
356        assert_eq!(tracker.len(), 3);
357        assert!(tracker.contains("stmt4"));
358    }
359
360    #[test]
361    fn test_generate_prepare_sql() {
362        let mut tracker = PreparedStatementTracker::new();
363
364        tracker.register(
365            "get_user".to_string(),
366            "SELECT * FROM users WHERE id = $1".to_string(),
367            vec![23],
368        );
369
370        let sqls = tracker.generate_prepare_sql();
371        assert_eq!(sqls.len(), 1);
372        assert!(sqls[0].contains("PREPARE get_user"));
373        assert!(sqls[0].contains("integer"));
374    }
375
376    #[test]
377    fn test_parse_prepare_statement() {
378        let result = parse_prepare_statement("PREPARE stmt1 AS SELECT 1");
379        assert!(result.is_some());
380        let (name, params, query) = result.unwrap();
381        assert_eq!(name, "stmt1");
382        assert!(params.is_empty());
383        assert_eq!(query, "SELECT 1");
384
385        let result = parse_prepare_statement("PREPARE stmt2 (integer, text) AS SELECT * FROM t WHERE id = $1 AND name = $2");
386        assert!(result.is_some());
387        let (name, params, query) = result.unwrap();
388        assert_eq!(name, "stmt2");
389        assert_eq!(params, vec!["integer", "text"]);
390        assert!(query.starts_with("SELECT"));
391    }
392
393    #[test]
394    fn test_parse_deallocate_statement() {
395        assert_eq!(
396            parse_deallocate_statement("DEALLOCATE ALL"),
397            Some(None)
398        );
399        assert_eq!(
400            parse_deallocate_statement("DEALLOCATE stmt1"),
401            Some(Some("stmt1".to_string()))
402        );
403        assert_eq!(
404            parse_deallocate_statement("DEALLOCATE PREPARE stmt2"),
405            Some(Some("stmt2".to_string()))
406        );
407        assert_eq!(parse_deallocate_statement("SELECT 1"), None);
408    }
409
410    #[test]
411    fn test_execution_tracking() {
412        let mut tracker = PreparedStatementTracker::new();
413
414        tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
415
416        tracker.record_execution("stmt1");
417        tracker.record_execution("stmt1");
418
419        let stmt = tracker.get("stmt1").unwrap();
420        assert_eq!(stmt.execution_count, 2);
421    }
422
423    #[test]
424    fn test_unnamed_statements_ignored() {
425        let mut tracker = PreparedStatementTracker::new();
426
427        tracker.register("".to_string(), "SELECT 1".to_string(), vec![]);
428
429        assert!(tracker.is_empty());
430    }
431}