1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct PreparedStatement {
11 pub name: String,
13 pub query: String,
15 pub param_types: Vec<u32>,
17 pub prepared_at: chrono::DateTime<chrono::Utc>,
19 pub execution_count: u64,
21}
22
23#[derive(Debug, Default)]
28pub struct PreparedStatementTracker {
29 statements: HashMap<String, PreparedStatement>,
31 max_statements: usize,
33 total_prepared: u64,
35 total_deallocated: u64,
37}
38
39impl PreparedStatementTracker {
40 pub fn new() -> Self {
42 Self::with_capacity(1000)
43 }
44
45 pub fn with_capacity(max_statements: usize) -> Self {
47 Self {
48 statements: HashMap::with_capacity(max_statements.min(100)),
49 max_statements,
50 total_prepared: 0,
51 total_deallocated: 0,
52 }
53 }
54
55 pub fn register(&mut self, name: String, query: String, param_types: Vec<u32>) {
62 if name.is_empty() {
64 return;
65 }
66
67 if self.statements.len() >= self.max_statements {
69 if let Some(oldest) = self
71 .statements
72 .iter()
73 .min_by_key(|(_, s)| s.prepared_at)
74 .map(|(k, _)| k.clone())
75 {
76 self.statements.remove(&oldest);
77 self.total_deallocated += 1;
78 }
79 }
80
81 self.statements.insert(
82 name.clone(),
83 PreparedStatement {
84 name,
85 query,
86 param_types,
87 prepared_at: chrono::Utc::now(),
88 execution_count: 0,
89 },
90 );
91
92 self.total_prepared += 1;
93 }
94
95 pub fn unregister(&mut self, name: &str) -> Option<PreparedStatement> {
97 let stmt = self.statements.remove(name);
98 if stmt.is_some() {
99 self.total_deallocated += 1;
100 }
101 stmt
102 }
103
104 pub fn clear(&mut self) {
106 self.total_deallocated += self.statements.len() as u64;
107 self.statements.clear();
108 }
109
110 pub fn get(&self, name: &str) -> Option<&PreparedStatement> {
112 self.statements.get(name)
113 }
114
115 pub fn record_execution(&mut self, name: &str) {
117 if let Some(stmt) = self.statements.get_mut(name) {
118 stmt.execution_count += 1;
119 }
120 }
121
122 pub fn contains(&self, name: &str) -> bool {
124 self.statements.contains_key(name)
125 }
126
127 pub fn all_statements(&self) -> impl Iterator<Item = &PreparedStatement> {
129 self.statements.values()
130 }
131
132 pub fn len(&self) -> usize {
134 self.statements.len()
135 }
136
137 pub fn is_empty(&self) -> bool {
139 self.statements.is_empty()
140 }
141
142 pub fn generate_prepare_sql(&self) -> Vec<String> {
146 self.statements
147 .values()
148 .map(|stmt| {
149 if stmt.param_types.is_empty() {
150 format!("PREPARE {} AS {}", stmt.name, stmt.query)
151 } else {
152 let types: Vec<String> = stmt
153 .param_types
154 .iter()
155 .map(|t| oid_to_type_name(*t))
156 .collect();
157 format!(
158 "PREPARE {} ({}) AS {}",
159 stmt.name,
160 types.join(", "),
161 stmt.query
162 )
163 }
164 })
165 .collect()
166 }
167
168 pub fn stats(&self) -> TrackerStats {
170 TrackerStats {
171 active_statements: self.statements.len(),
172 total_prepared: self.total_prepared,
173 total_deallocated: self.total_deallocated,
174 max_capacity: self.max_statements,
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct TrackerStats {
182 pub active_statements: usize,
184 pub total_prepared: u64,
186 pub total_deallocated: u64,
188 pub max_capacity: usize,
190}
191
192fn oid_to_type_name(oid: u32) -> String {
196 match oid {
197 16 => "boolean".to_string(),
198 17 => "bytea".to_string(),
199 18 => "char".to_string(),
200 19 => "name".to_string(),
201 20 => "bigint".to_string(),
202 21 => "smallint".to_string(),
203 23 => "integer".to_string(),
204 25 => "text".to_string(),
205 26 => "oid".to_string(),
206 700 => "real".to_string(),
207 701 => "double precision".to_string(),
208 790 => "money".to_string(),
209 1042 => "char".to_string(),
210 1043 => "varchar".to_string(),
211 1082 => "date".to_string(),
212 1083 => "time".to_string(),
213 1114 => "timestamp".to_string(),
214 1184 => "timestamptz".to_string(),
215 1186 => "interval".to_string(),
216 1700 => "numeric".to_string(),
217 2950 => "uuid".to_string(),
218 3802 => "jsonb".to_string(),
219 _ => format!("unknown({})", oid),
220 }
221}
222
223pub fn parse_prepare_statement(sql: &str) -> Option<(String, Vec<String>, String)> {
227 let sql = sql.trim();
228 let upper = sql.to_uppercase();
229
230 if !upper.starts_with("PREPARE ") {
231 return None;
232 }
233
234 let rest = &sql[8..].trim_start(); let name_end = rest
239 .find(|c: char| c.is_whitespace() || c == '(')
240 .unwrap_or(rest.len());
241 let name = rest[..name_end].to_string();
242 let rest = rest[name_end..].trim_start();
243
244 let (param_types, rest) = if rest.starts_with('(') {
246 if let Some(close) = rest.find(')') {
248 let types_str = &rest[1..close];
249 let types: Vec<String> = types_str
250 .split(',')
251 .map(|s| s.trim().to_string())
252 .filter(|s| !s.is_empty())
253 .collect();
254 (types, rest[close + 1..].trim_start())
255 } else {
256 (Vec::new(), rest)
257 }
258 } else {
259 (Vec::new(), rest)
260 };
261
262 let upper_rest = rest.to_uppercase();
264 if !upper_rest.starts_with("AS ") {
265 return None;
266 }
267
268 let query = rest[3..].trim_start().to_string();
269
270 Some((name, param_types, query))
271}
272
273pub fn parse_deallocate_statement(sql: &str) -> Option<Option<String>> {
277 let sql = sql.trim();
278 let upper = sql.to_uppercase();
279
280 if !upper.starts_with("DEALLOCATE ") {
281 return None;
282 }
283
284 let rest = sql[11..].trim();
285 let upper_rest = rest.to_uppercase();
286
287 if upper_rest == "ALL" || upper_rest.starts_with("ALL ") || upper_rest.starts_with("ALL;") {
288 Some(None) } else {
290 let name = if upper_rest.starts_with("PREPARE ") {
292 rest[8..].trim()
293 } else {
294 rest
295 };
296 let name = name.trim_end_matches(';').trim();
298 Some(Some(name.to_string()))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_register_and_get() {
308 let mut tracker = PreparedStatementTracker::new();
309
310 tracker.register(
311 "stmt1".to_string(),
312 "SELECT * FROM users WHERE id = $1".to_string(),
313 vec![23],
314 );
315
316 assert!(tracker.contains("stmt1"));
317 let stmt = tracker.get("stmt1").unwrap();
318 assert_eq!(stmt.query, "SELECT * FROM users WHERE id = $1");
319 assert_eq!(stmt.param_types, vec![23]);
320 }
321
322 #[test]
323 fn test_unregister() {
324 let mut tracker = PreparedStatementTracker::new();
325
326 tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
327
328 assert!(tracker.contains("stmt1"));
329 tracker.unregister("stmt1");
330 assert!(!tracker.contains("stmt1"));
331 }
332
333 #[test]
334 fn test_clear() {
335 let mut tracker = PreparedStatementTracker::new();
336
337 tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
338 tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
339
340 assert_eq!(tracker.len(), 2);
341 tracker.clear();
342 assert!(tracker.is_empty());
343 }
344
345 #[test]
346 fn test_capacity_limit() {
347 let mut tracker = PreparedStatementTracker::with_capacity(3);
348
349 tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
350 tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
351 tracker.register("stmt3".to_string(), "SELECT 3".to_string(), vec![]);
352
353 tracker.register("stmt4".to_string(), "SELECT 4".to_string(), vec![]);
355
356 assert_eq!(tracker.len(), 3);
357 assert!(tracker.contains("stmt4"));
358 }
359
360 #[test]
361 fn test_generate_prepare_sql() {
362 let mut tracker = PreparedStatementTracker::new();
363
364 tracker.register(
365 "get_user".to_string(),
366 "SELECT * FROM users WHERE id = $1".to_string(),
367 vec![23],
368 );
369
370 let sqls = tracker.generate_prepare_sql();
371 assert_eq!(sqls.len(), 1);
372 assert!(sqls[0].contains("PREPARE get_user"));
373 assert!(sqls[0].contains("integer"));
374 }
375
376 #[test]
377 fn test_parse_prepare_statement() {
378 let result = parse_prepare_statement("PREPARE stmt1 AS SELECT 1");
379 assert!(result.is_some());
380 let (name, params, query) = result.unwrap();
381 assert_eq!(name, "stmt1");
382 assert!(params.is_empty());
383 assert_eq!(query, "SELECT 1");
384
385 let result = parse_prepare_statement("PREPARE stmt2 (integer, text) AS SELECT * FROM t WHERE id = $1 AND name = $2");
386 assert!(result.is_some());
387 let (name, params, query) = result.unwrap();
388 assert_eq!(name, "stmt2");
389 assert_eq!(params, vec!["integer", "text"]);
390 assert!(query.starts_with("SELECT"));
391 }
392
393 #[test]
394 fn test_parse_deallocate_statement() {
395 assert_eq!(
396 parse_deallocate_statement("DEALLOCATE ALL"),
397 Some(None)
398 );
399 assert_eq!(
400 parse_deallocate_statement("DEALLOCATE stmt1"),
401 Some(Some("stmt1".to_string()))
402 );
403 assert_eq!(
404 parse_deallocate_statement("DEALLOCATE PREPARE stmt2"),
405 Some(Some("stmt2".to_string()))
406 );
407 assert_eq!(parse_deallocate_statement("SELECT 1"), None);
408 }
409
410 #[test]
411 fn test_execution_tracking() {
412 let mut tracker = PreparedStatementTracker::new();
413
414 tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
415
416 tracker.record_execution("stmt1");
417 tracker.record_execution("stmt1");
418
419 let stmt = tracker.get("stmt1").unwrap();
420 assert_eq!(stmt.execution_count, 2);
421 }
422
423 #[test]
424 fn test_unnamed_statements_ignored() {
425 let mut tracker = PreparedStatementTracker::new();
426
427 tracker.register("".to_string(), "SELECT 1".to_string(), vec![]);
428
429 assert!(tracker.is_empty());
430 }
431}