Skip to main content

qail_pg/driver/
cancel.rs

1//! Query cancellation methods for PostgreSQL connection.
2
3use super::{CANCEL_REQUEST_CODE, PgConnection, PgResult};
4use tokio::io::AsyncWriteExt;
5use tokio::net::TcpStream;
6
7fn encode_cancel_request(process_id: i32, secret_key: &[u8]) -> PgResult<Vec<u8>> {
8    if !(4..=256).contains(&secret_key.len()) {
9        return Err(crate::driver::PgError::Protocol(format!(
10            "Invalid cancel key length: {} (expected 4..=256)",
11            secret_key.len()
12        )));
13    }
14
15    let total_len = 12usize.checked_add(secret_key.len()).ok_or_else(|| {
16        crate::driver::PgError::Protocol("CancelRequest length overflow".to_string())
17    })?;
18    let total_len = i32::try_from(total_len).map_err(|_| {
19        crate::driver::PgError::Protocol("CancelRequest length exceeds i32".to_string())
20    })?;
21
22    let mut buf = Vec::with_capacity(total_len as usize);
23    buf.extend_from_slice(&total_len.to_be_bytes());
24    buf.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
25    buf.extend_from_slice(&process_id.to_be_bytes());
26    buf.extend_from_slice(secret_key);
27    Ok(buf)
28}
29
30/// A token that can be used to cancel a running query.
31/// This token is safe to send across threads and does not borrow the connection.
32#[derive(Debug, Clone)]
33pub struct CancelToken {
34    pub(crate) host: String,
35    pub(crate) port: u16,
36    pub(crate) process_id: i32,
37    /// Full cancel secret key bytes (`4..=256`).
38    pub(crate) secret_key_bytes: Vec<u8>,
39}
40
41impl CancelToken {
42    /// Attempt to cancel the ongoing query.
43    /// This opens a new TCP connection and sends a CancelRequest message.
44    pub async fn cancel_query(&self) -> PgResult<()> {
45        PgConnection::cancel_query_bytes(
46            &self.host,
47            self.port,
48            self.process_id,
49            &self.secret_key_bytes,
50        )
51        .await
52    }
53
54    /// Get the full cancel key bytes (`process_id`, `secret_key_bytes`).
55    pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
56        (self.process_id, &self.secret_key_bytes)
57    }
58}
59
60impl PgConnection {
61    /// Get the full cancel key bytes for this connection.
62    pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
63        (self.process_id, &self.cancel_key_bytes)
64    }
65
66    /// Legacy cancel key accessor (`process_id`, `secret_key_i32`).
67    ///
68    /// Compatibility-only: valid for protocol 3.0 4-byte cancel keys.
69    /// For protocol 3.2 extended keys, this returns `(process_id, 0)`.
70    pub fn get_cancel_key(&self) -> (i32, i32) {
71        if self.cancel_key_bytes.len() == 4 {
72            (
73                self.process_id,
74                i32::from_be_bytes([
75                    self.cancel_key_bytes[0],
76                    self.cancel_key_bytes[1],
77                    self.cancel_key_bytes[2],
78                    self.cancel_key_bytes[3],
79                ]),
80            )
81        } else {
82            (self.process_id, 0)
83        }
84    }
85
86    /// Cancel a running query using bytes-native cancel key.
87    pub async fn cancel_query_bytes(
88        host: &str,
89        port: u16,
90        process_id: i32,
91        secret_key: &[u8],
92    ) -> PgResult<()> {
93        // Open new connection just for cancel
94        let addr = format!("{}:{}", host, port);
95        let mut stream = TcpStream::connect(&addr).await?;
96
97        // Send CancelRequest message:
98        // Length + CancelRequest code + process_id + secret_key bytes
99        let buf = encode_cancel_request(process_id, secret_key)?;
100
101        stream.write_all(&buf).await?;
102
103        // Server will close connection after receiving cancel request
104        Ok(())
105    }
106
107    /// Legacy i32 cancel API wrapper (protocol 3.0-style 4-byte key).
108    pub async fn cancel_query(
109        host: &str,
110        port: u16,
111        process_id: i32,
112        secret_key: i32,
113    ) -> PgResult<()> {
114        Self::cancel_query_bytes(host, port, process_id, &secret_key.to_be_bytes()).await
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::{CANCEL_REQUEST_CODE, encode_cancel_request};
121
122    #[test]
123    fn encode_cancel_request_with_4_byte_key() {
124        let buf = encode_cancel_request(42, &99i32.to_be_bytes()).expect("encode");
125        assert_eq!(buf.len(), 16);
126        assert_eq!(&buf[0..4], &16i32.to_be_bytes());
127        assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
128        assert_eq!(&buf[8..12], &42i32.to_be_bytes());
129        assert_eq!(&buf[12..16], &99i32.to_be_bytes());
130    }
131
132    #[test]
133    fn encode_cancel_request_with_extended_key() {
134        let key = [1u8, 2, 3, 4, 5, 6, 7, 8];
135        let buf = encode_cancel_request(7, &key).expect("encode");
136        assert_eq!(&buf[0..4], &20i32.to_be_bytes());
137        assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
138        assert_eq!(&buf[8..12], &7i32.to_be_bytes());
139        assert_eq!(&buf[12..], &key);
140    }
141
142    #[test]
143    fn encode_cancel_request_rejects_invalid_key_lengths() {
144        let short = encode_cancel_request(1, &[1, 2, 3]).expect_err("short");
145        assert!(short.to_string().contains("Invalid cancel key length"));
146
147        let long_key = vec![0u8; 257];
148        let long = encode_cancel_request(1, &long_key).expect_err("long");
149        assert!(long.to_string().contains("Invalid cancel key length"));
150    }
151}