heliosdb_proxy/pool/
reset.rs1use crate::{ProxyError, Result};
6
7pub struct ConnectionResetExecutor {
11 reset_query: String,
13 use_discard_all: bool,
15 custom_commands: Vec<String>,
17}
18
19impl Default for ConnectionResetExecutor {
20 fn default() -> Self {
21 Self::new("DISCARD ALL")
22 }
23}
24
25impl ConnectionResetExecutor {
26 pub fn new(reset_query: impl Into<String>) -> Self {
28 let query = reset_query.into();
29 let use_discard_all = query.to_uppercase().contains("DISCARD ALL");
30
31 Self {
32 reset_query: query,
33 use_discard_all,
34 custom_commands: Vec::new(),
35 }
36 }
37
38 pub fn with_commands(commands: Vec<String>) -> Self {
40 Self {
41 reset_query: String::new(),
42 use_discard_all: false,
43 custom_commands: commands,
44 }
45 }
46
47 pub fn add_command(&mut self, command: impl Into<String>) {
49 self.custom_commands.push(command.into());
50 }
51
52 pub fn reset_queries(&self) -> Vec<&str> {
54 if !self.custom_commands.is_empty() {
55 self.custom_commands.iter().map(|s| s.as_str()).collect()
56 } else {
57 vec![&self.reset_query]
58 }
59 }
60
61 pub fn uses_discard_all(&self) -> bool {
63 self.use_discard_all
64 }
65
66 pub fn build_reset_sql(&self) -> String {
68 if !self.custom_commands.is_empty() {
69 self.custom_commands.join("; ")
70 } else {
71 self.reset_query.clone()
72 }
73 }
74
75 pub fn validate(&self) -> Result<()> {
79 let queries = self.reset_queries();
80
81 for query in queries {
82 let upper = query.to_uppercase();
83
84 if upper.contains("INSERT")
86 || upper.contains("UPDATE")
87 || upper.contains("DELETE")
88 || upper.contains("DROP")
89 || upper.contains("CREATE")
90 || upper.contains("ALTER")
91 || upper.contains("TRUNCATE")
92 {
93 return Err(ProxyError::Configuration(format!(
94 "Reset query cannot contain data modification: {}",
95 query
96 )));
97 }
98
99 if upper.contains("BEGIN") || upper.contains("COMMIT") || upper.contains("ROLLBACK") {
101 return Err(ProxyError::Configuration(format!(
102 "Reset query cannot contain transaction control: {}",
103 query
104 )));
105 }
106 }
107
108 Ok(())
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum ResetLevel {
133 Full,
135 PreparedStatements,
137 SessionVariables,
139 Minimal,
141 None,
143}
144
145impl ResetLevel {
146 pub fn sql(&self) -> Option<&'static str> {
148 match self {
149 ResetLevel::Full => Some("DISCARD ALL"),
150 ResetLevel::PreparedStatements => Some("DEALLOCATE ALL"),
151 ResetLevel::SessionVariables => Some("RESET ALL"),
152 ResetLevel::Minimal => Some("SELECT pg_advisory_unlock_all()"),
153 ResetLevel::None => None,
154 }
155 }
156
157 pub fn executor(&self) -> ConnectionResetExecutor {
159 match self.sql() {
160 Some(sql) => ConnectionResetExecutor::new(sql),
161 None => ConnectionResetExecutor {
162 reset_query: String::new(),
163 use_discard_all: false,
164 custom_commands: Vec::new(),
165 },
166 }
167 }
168}
169
170pub struct ResetBuilder {
172 commands: Vec<String>,
173}
174
175impl Default for ResetBuilder {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181impl ResetBuilder {
182 pub fn new() -> Self {
184 Self {
185 commands: Vec::new(),
186 }
187 }
188
189 pub fn deallocate_all(mut self) -> Self {
191 self.commands.push("DEALLOCATE ALL".to_string());
192 self
193 }
194
195 pub fn close_cursors(mut self) -> Self {
197 self.commands.push("CLOSE ALL".to_string());
198 self
199 }
200
201 pub fn unlisten_all(mut self) -> Self {
203 self.commands.push("UNLISTEN *".to_string());
204 self
205 }
206
207 pub fn reset_all(mut self) -> Self {
209 self.commands.push("RESET ALL".to_string());
210 self
211 }
212
213 pub fn release_advisory_locks(mut self) -> Self {
215 self.commands
216 .push("SELECT pg_advisory_unlock_all()".to_string());
217 self
218 }
219
220 pub fn discard_plans(mut self) -> Self {
222 self.commands.push("DISCARD PLANS".to_string());
223 self
224 }
225
226 pub fn discard_temp(mut self) -> Self {
228 self.commands.push("DISCARD TEMP".to_string());
229 self
230 }
231
232 pub fn custom(mut self, command: impl Into<String>) -> Self {
234 self.commands.push(command.into());
235 self
236 }
237
238 pub fn build(self) -> ConnectionResetExecutor {
240 if self.commands.is_empty() {
241 ConnectionResetExecutor::default()
242 } else {
243 ConnectionResetExecutor::with_commands(self.commands)
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_default_reset() {
254 let executor = ConnectionResetExecutor::default();
255 assert!(executor.uses_discard_all());
256 assert_eq!(executor.reset_queries(), vec!["DISCARD ALL"]);
257 }
258
259 #[test]
260 fn test_custom_reset() {
261 let executor = ConnectionResetExecutor::new("RESET ALL");
262 assert!(!executor.uses_discard_all());
263 assert_eq!(executor.reset_queries(), vec!["RESET ALL"]);
264 }
265
266 #[test]
267 fn test_multiple_commands() {
268 let executor = ConnectionResetExecutor::with_commands(vec![
269 "DEALLOCATE ALL".to_string(),
270 "RESET ALL".to_string(),
271 ]);
272 assert_eq!(executor.reset_queries(), vec!["DEALLOCATE ALL", "RESET ALL"]);
273 assert_eq!(executor.build_reset_sql(), "DEALLOCATE ALL; RESET ALL");
274 }
275
276 #[test]
277 fn test_validation_success() {
278 let executor = ConnectionResetExecutor::default();
279 assert!(executor.validate().is_ok());
280
281 let executor = ConnectionResetExecutor::new("RESET ALL");
282 assert!(executor.validate().is_ok());
283 }
284
285 #[test]
286 fn test_validation_failure() {
287 let executor = ConnectionResetExecutor::new("DROP TABLE users");
288 assert!(executor.validate().is_err());
289
290 let executor = ConnectionResetExecutor::new("INSERT INTO log VALUES (1)");
291 assert!(executor.validate().is_err());
292
293 let executor = ConnectionResetExecutor::new("BEGIN; RESET ALL; COMMIT");
294 assert!(executor.validate().is_err());
295 }
296
297 #[test]
298 fn test_reset_level() {
299 assert_eq!(ResetLevel::Full.sql(), Some("DISCARD ALL"));
300 assert_eq!(ResetLevel::PreparedStatements.sql(), Some("DEALLOCATE ALL"));
301 assert_eq!(ResetLevel::None.sql(), None);
302 }
303
304 #[test]
305 fn test_reset_builder() {
306 let executor = ResetBuilder::new()
307 .deallocate_all()
308 .close_cursors()
309 .reset_all()
310 .build();
311
312 let queries = executor.reset_queries();
313 assert_eq!(queries.len(), 3);
314 assert!(queries.contains(&"DEALLOCATE ALL"));
315 assert!(queries.contains(&"CLOSE ALL"));
316 assert!(queries.contains(&"RESET ALL"));
317 }
318}