Skip to main content

qail_pg/driver/
cancel.rs

1//! Query cancellation methods for PostgreSQL connection.
2
3use super::{CANCEL_REQUEST_CODE, PgConnection, PgResult};
4use tokio::io::AsyncWriteExt;
5use tokio::net::TcpStream;
6
7fn socket_addr(host: &str, port: u16) -> String {
8    if host.contains(':') && !host.starts_with('[') {
9        format!("[{}]:{}", host, port)
10    } else {
11        format!("{}:{}", host, port)
12    }
13}
14
15fn encode_cancel_request(process_id: i32, secret_key: &[u8]) -> PgResult<Vec<u8>> {
16    if !(4..=256).contains(&secret_key.len()) {
17        return Err(crate::driver::PgError::Protocol(format!(
18            "Invalid cancel key length: {} (expected 4..=256)",
19            secret_key.len()
20        )));
21    }
22
23    let total_len = 12usize.checked_add(secret_key.len()).ok_or_else(|| {
24        crate::driver::PgError::Protocol("CancelRequest length overflow".to_string())
25    })?;
26    let total_len = i32::try_from(total_len).map_err(|_| {
27        crate::driver::PgError::Protocol("CancelRequest length exceeds i32".to_string())
28    })?;
29
30    let mut buf = Vec::with_capacity(total_len as usize);
31    buf.extend_from_slice(&total_len.to_be_bytes());
32    buf.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
33    buf.extend_from_slice(&process_id.to_be_bytes());
34    buf.extend_from_slice(secret_key);
35    Ok(buf)
36}
37
38/// A token that can be used to cancel a running query.
39/// This token is safe to send across threads and does not borrow the connection.
40#[derive(Debug, Clone)]
41pub struct CancelToken {
42    pub(crate) host: String,
43    pub(crate) port: u16,
44    pub(crate) process_id: i32,
45    /// Full cancel secret key bytes (`4..=256`).
46    pub(crate) secret_key_bytes: Vec<u8>,
47}
48
49impl CancelToken {
50    /// Attempt to cancel the ongoing query.
51    /// This opens a new TCP connection and sends a CancelRequest message.
52    pub async fn cancel_query(&self) -> PgResult<()> {
53        PgConnection::cancel_query_bytes(
54            &self.host,
55            self.port,
56            self.process_id,
57            &self.secret_key_bytes,
58        )
59        .await
60    }
61
62    /// Get the full cancel key bytes (`process_id`, `secret_key_bytes`).
63    pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
64        (self.process_id, &self.secret_key_bytes)
65    }
66}
67
68impl PgConnection {
69    /// Get the full cancel key bytes for this connection.
70    pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
71        (self.process_id, &self.cancel_key_bytes)
72    }
73
74    /// Cancel a running query using bytes-native cancel key.
75    pub async fn cancel_query_bytes(
76        host: &str,
77        port: u16,
78        process_id: i32,
79        secret_key: &[u8],
80    ) -> PgResult<()> {
81        // Open new connection just for cancel
82        let addr = socket_addr(host, port);
83        let mut stream = TcpStream::connect(&addr).await?;
84
85        // Send CancelRequest message:
86        // Length + CancelRequest code + process_id + secret_key bytes
87        let buf = encode_cancel_request(process_id, secret_key)?;
88
89        stream.write_all(&buf).await?;
90
91        // Server will close connection after receiving cancel request
92        Ok(())
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::{CANCEL_REQUEST_CODE, encode_cancel_request, socket_addr};
99
100    #[test]
101    fn cancel_socket_addr_brackets_ipv6_hosts() {
102        assert_eq!(socket_addr("127.0.0.1", 5432), "127.0.0.1:5432");
103        assert_eq!(socket_addr("::1", 5432), "[::1]:5432");
104        assert_eq!(socket_addr("[::1]", 5432), "[::1]:5432");
105    }
106
107    #[test]
108    fn encode_cancel_request_with_4_byte_key() {
109        let buf = encode_cancel_request(42, &99i32.to_be_bytes()).expect("encode");
110        assert_eq!(buf.len(), 16);
111        assert_eq!(&buf[0..4], &16i32.to_be_bytes());
112        assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
113        assert_eq!(&buf[8..12], &42i32.to_be_bytes());
114        assert_eq!(&buf[12..16], &99i32.to_be_bytes());
115    }
116
117    #[test]
118    fn encode_cancel_request_with_extended_key() {
119        let key = [1u8, 2, 3, 4, 5, 6, 7, 8];
120        let buf = encode_cancel_request(7, &key).expect("encode");
121        assert_eq!(&buf[0..4], &20i32.to_be_bytes());
122        assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
123        assert_eq!(&buf[8..12], &7i32.to_be_bytes());
124        assert_eq!(&buf[12..], &key);
125    }
126
127    #[test]
128    fn encode_cancel_request_rejects_invalid_key_lengths() {
129        let short = encode_cancel_request(1, &[1, 2, 3]).expect_err("short");
130        assert!(short.to_string().contains("Invalid cancel key length"));
131
132        let long_key = vec![0u8; 257];
133        let long = encode_cancel_request(1, &long_key).expect_err("long");
134        assert!(long.to_string().contains("Invalid cancel key length"));
135    }
136}