1use crate::error::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sqlx::PgPool;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13pub enum LockType {
14 AccessShare,
16 RowShare,
18 RowExclusive,
20 ShareUpdateExclusive,
22 Share,
24 ShareRowExclusive,
26 Exclusive,
28 AccessExclusive,
30 Unknown(String),
32}
33
34impl From<String> for LockType {
35 fn from(s: String) -> Self {
36 match s.as_str() {
37 "AccessShareLock" => LockType::AccessShare,
38 "RowShareLock" => LockType::RowShare,
39 "RowExclusiveLock" => LockType::RowExclusive,
40 "ShareUpdateExclusiveLock" => LockType::ShareUpdateExclusive,
41 "ShareLock" => LockType::Share,
42 "ShareRowExclusiveLock" => LockType::ShareRowExclusive,
43 "ExclusiveLock" => LockType::Exclusive,
44 "AccessExclusiveLock" => LockType::AccessExclusive,
45 _ => LockType::Unknown(s),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct LockInfo {
53 pub pid: i32,
55 pub lock_type: LockType,
57 pub database: String,
59 pub relation: Option<String>,
61 pub granted: bool,
63 pub lock_acquired: Option<DateTime<Utc>>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct BlockingQuery {
70 pub blocking_pid: i32,
72 pub blocked_pid: i32,
74 pub blocking_query: String,
76 pub blocked_query: String,
78 pub lock_type: LockType,
80 pub table_name: Option<String>,
82 pub blocked_duration_ms: Option<i64>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct DeadlockInfo {
89 pub involved_pids: Vec<i32>,
91 pub cycle_description: String,
93 pub detected_at: DateTime<Utc>,
95}
96
97pub async fn get_current_locks(pool: &PgPool) -> Result<Vec<LockInfo>> {
99 let locks = sqlx::query_as::<_, (i32, String, String, Option<String>, bool)>(
100 r#"
101 SELECT
102 l.pid,
103 l.mode as lock_type,
104 d.datname as database,
105 c.relname as relation,
106 l.granted
107 FROM pg_locks l
108 LEFT JOIN pg_database d ON l.database = d.oid
109 LEFT JOIN pg_class c ON l.relation = c.oid
110 WHERE l.pid != pg_backend_pid()
111 ORDER BY l.pid
112 "#,
113 )
114 .fetch_all(pool)
115 .await?;
116
117 Ok(locks
118 .into_iter()
119 .map(|l| LockInfo {
120 pid: l.0,
121 lock_type: LockType::from(l.1),
122 database: l.2,
123 relation: l.3,
124 granted: l.4,
125 lock_acquired: None,
126 })
127 .collect())
128}
129
130pub async fn get_blocking_queries(pool: &PgPool) -> Result<Vec<BlockingQuery>> {
134 let blocking = sqlx::query_as::<
135 _,
136 (
137 i32,
138 i32,
139 String,
140 String,
141 String,
142 Option<String>,
143 Option<DateTime<Utc>>,
144 ),
145 >(
146 r#"
147 SELECT
148 blocking.pid AS blocking_pid,
149 blocked.pid AS blocked_pid,
150 blocking_activity.query AS blocking_query,
151 blocked_activity.query AS blocked_query,
152 blocked_lock.mode AS lock_type,
153 c.relname AS table_name,
154 blocked_activity.query_start
155 FROM pg_locks blocked_lock
156 JOIN pg_stat_activity blocked_activity ON blocked_activity.pid = blocked_lock.pid
157 JOIN pg_locks blocking_lock ON blocking_lock.locktype = blocked_lock.locktype
158 AND blocking_lock.database IS NOT DISTINCT FROM blocked_lock.database
159 AND blocking_lock.relation IS NOT DISTINCT FROM blocked_lock.relation
160 AND blocking_lock.page IS NOT DISTINCT FROM blocked_lock.page
161 AND blocking_lock.tuple IS NOT DISTINCT FROM blocked_lock.tuple
162 AND blocking_lock.virtualxid IS NOT DISTINCT FROM blocked_lock.virtualxid
163 AND blocking_lock.transactionid IS NOT DISTINCT FROM blocked_lock.transactionid
164 AND blocking_lock.classid IS NOT DISTINCT FROM blocked_lock.classid
165 AND blocking_lock.objid IS NOT DISTINCT FROM blocked_lock.objid
166 AND blocking_lock.objsubid IS NOT DISTINCT FROM blocked_lock.objsubid
167 AND blocking_lock.pid != blocked_lock.pid
168 JOIN pg_stat_activity blocking_activity ON blocking_activity.pid = blocking_lock.pid
169 LEFT JOIN pg_class c ON c.oid = blocked_lock.relation
170 WHERE NOT blocked_lock.granted
171 AND blocking_lock.granted
172 "#,
173 )
174 .fetch_all(pool)
175 .await?;
176
177 Ok(blocking
178 .into_iter()
179 .map(|b| {
180 let blocked_duration_ms = b.6.map(|start| {
181 let now = Utc::now();
182 now.signed_duration_since(start).num_milliseconds()
183 });
184
185 BlockingQuery {
186 blocking_pid: b.0,
187 blocked_pid: b.1,
188 blocking_query: b.2,
189 blocked_query: b.3,
190 lock_type: LockType::from(b.4),
191 table_name: b.5,
192 blocked_duration_ms,
193 }
194 })
195 .collect())
196}
197
198pub async fn detect_deadlocks(pool: &PgPool) -> Result<Vec<DeadlockInfo>> {
202 let blocking_queries = get_blocking_queries(pool).await?;
203
204 let mut deadlocks = Vec::new();
205 let mut checked_pids = std::collections::HashSet::new();
206
207 for query in &blocking_queries {
208 if checked_pids.contains(&query.blocked_pid) {
209 continue;
210 }
211
212 let mut chain = vec![query.blocked_pid];
214 let mut current_pid = query.blocking_pid;
215 let mut cycle_found = false;
216
217 while let Some(blocker) = blocking_queries
218 .iter()
219 .find(|q| q.blocked_pid == current_pid)
220 {
221 if chain.contains(&blocker.blocking_pid) {
222 cycle_found = true;
224 chain.push(blocker.blocking_pid);
225 break;
226 }
227
228 chain.push(blocker.blocking_pid);
229 current_pid = blocker.blocking_pid;
230
231 if chain.len() > 100 {
232 break;
234 }
235 }
236
237 if cycle_found {
238 for pid in &chain {
239 checked_pids.insert(*pid);
240 }
241
242 deadlocks.push(DeadlockInfo {
243 involved_pids: chain.clone(),
244 cycle_description: format!(
245 "Deadlock cycle detected: {}",
246 chain
247 .iter()
248 .map(|p| p.to_string())
249 .collect::<Vec<_>>()
250 .join(" -> ")
251 ),
252 detected_at: Utc::now(),
253 });
254 }
255 }
256
257 Ok(deadlocks)
258}
259
260pub async fn kill_blocking_query(pool: &PgPool, pid: i32) -> Result<bool> {
264 let result = sqlx::query_scalar::<_, bool>("SELECT pg_terminate_backend($1)")
265 .bind(pid)
266 .fetch_one(pool)
267 .await?;
268
269 Ok(result)
270}
271
272pub async fn get_lock_wait_stats(pool: &PgPool) -> Result<LockWaitStats> {
276 let blocking_queries = get_blocking_queries(pool).await?;
277
278 let total_blocked = blocking_queries.len();
279 let max_wait_time = blocking_queries
280 .iter()
281 .filter_map(|q| q.blocked_duration_ms)
282 .max()
283 .unwrap_or(0);
284
285 let avg_wait_time = if total_blocked > 0 {
286 blocking_queries
287 .iter()
288 .filter_map(|q| q.blocked_duration_ms)
289 .sum::<i64>()
290 / total_blocked as i64
291 } else {
292 0
293 };
294
295 let mut tables = std::collections::HashSet::new();
297 for query in &blocking_queries {
298 if let Some(ref table) = query.table_name {
299 tables.insert(table.clone());
300 }
301 }
302
303 Ok(LockWaitStats {
304 total_blocked_queries: total_blocked,
305 max_wait_time_ms: max_wait_time,
306 avg_wait_time_ms: avg_wait_time,
307 affected_tables: tables.into_iter().collect(),
308 })
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct LockWaitStats {
314 pub total_blocked_queries: usize,
316 pub max_wait_time_ms: i64,
318 pub avg_wait_time_ms: i64,
320 pub affected_tables: Vec<String>,
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_lock_type_from_string() {
330 assert_eq!(
331 LockType::from("AccessShareLock".to_string()),
332 LockType::AccessShare
333 );
334 assert_eq!(
335 LockType::from("RowExclusiveLock".to_string()),
336 LockType::RowExclusive
337 );
338 assert_eq!(
339 LockType::from("ExclusiveLock".to_string()),
340 LockType::Exclusive
341 );
342 }
343
344 #[test]
345 fn test_lock_type_unknown() {
346 match LockType::from("CustomLock".to_string()) {
347 LockType::Unknown(s) => assert_eq!(s, "CustomLock"),
348 _ => panic!("Expected Unknown variant"),
349 }
350 }
351
352 #[test]
353 fn test_lock_info_structure() {
354 let info = LockInfo {
355 pid: 12345,
356 lock_type: LockType::RowExclusive,
357 database: "mydb".to_string(),
358 relation: Some("users".to_string()),
359 granted: true,
360 lock_acquired: None,
361 };
362
363 assert_eq!(info.pid, 12345);
364 assert!(info.granted);
365 }
366
367 #[test]
368 fn test_blocking_query_structure() {
369 let query = BlockingQuery {
370 blocking_pid: 100,
371 blocked_pid: 200,
372 blocking_query: "UPDATE users SET ...".to_string(),
373 blocked_query: "SELECT * FROM users".to_string(),
374 lock_type: LockType::RowExclusive,
375 table_name: Some("users".to_string()),
376 blocked_duration_ms: Some(5000),
377 };
378
379 assert_eq!(query.blocking_pid, 100);
380 assert_eq!(query.blocked_duration_ms, Some(5000));
381 }
382
383 #[test]
384 fn test_deadlock_info_structure() {
385 let info = DeadlockInfo {
386 involved_pids: vec![100, 200, 300],
387 cycle_description: "100 -> 200 -> 300 -> 100".to_string(),
388 detected_at: Utc::now(),
389 };
390
391 assert_eq!(info.involved_pids.len(), 3);
392 }
393
394 #[test]
395 fn test_lock_wait_stats_structure() {
396 let stats = LockWaitStats {
397 total_blocked_queries: 5,
398 max_wait_time_ms: 10000,
399 avg_wait_time_ms: 5000,
400 affected_tables: vec!["users".to_string(), "orders".to_string()],
401 };
402
403 assert_eq!(stats.total_blocked_queries, 5);
404 assert_eq!(stats.affected_tables.len(), 2);
405 }
406
407 #[test]
408 fn test_lock_type_serialization() {
409 let lock_type = LockType::Exclusive;
410 let json = serde_json::to_string(&lock_type).unwrap();
411 let deserialized: LockType = serde_json::from_str(&json).unwrap();
412
413 assert_eq!(deserialized, lock_type);
414 }
415
416 #[test]
417 fn test_lock_info_serialization() {
418 let info = LockInfo {
419 pid: 999,
420 lock_type: LockType::Share,
421 database: "testdb".to_string(),
422 relation: None,
423 granted: false,
424 lock_acquired: None,
425 };
426
427 let json = serde_json::to_string(&info).unwrap();
428 let deserialized: LockInfo = serde_json::from_str(&json).unwrap();
429
430 assert_eq!(deserialized.pid, info.pid);
431 assert_eq!(deserialized.granted, info.granted);
432 }
433}