1use super::{CANCEL_REQUEST_CODE, PgConnection, PgResult};
4use tokio::io::AsyncWriteExt;
5use tokio::net::TcpStream;
6
7fn socket_addr(host: &str, port: u16) -> String {
8 if host.contains(':') && !host.starts_with('[') {
9 format!("[{}]:{}", host, port)
10 } else {
11 format!("{}:{}", host, port)
12 }
13}
14
15fn encode_cancel_request(process_id: i32, secret_key: &[u8]) -> PgResult<Vec<u8>> {
16 if !(4..=256).contains(&secret_key.len()) {
17 return Err(crate::driver::PgError::Protocol(format!(
18 "Invalid cancel key length: {} (expected 4..=256)",
19 secret_key.len()
20 )));
21 }
22
23 let total_len = 12usize.checked_add(secret_key.len()).ok_or_else(|| {
24 crate::driver::PgError::Protocol("CancelRequest length overflow".to_string())
25 })?;
26 let total_len = i32::try_from(total_len).map_err(|_| {
27 crate::driver::PgError::Protocol("CancelRequest length exceeds i32".to_string())
28 })?;
29
30 let mut buf = Vec::with_capacity(total_len as usize);
31 buf.extend_from_slice(&total_len.to_be_bytes());
32 buf.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
33 buf.extend_from_slice(&process_id.to_be_bytes());
34 buf.extend_from_slice(secret_key);
35 Ok(buf)
36}
37
38#[derive(Debug, Clone)]
41pub struct CancelToken {
42 pub(crate) host: String,
43 pub(crate) port: u16,
44 pub(crate) process_id: i32,
45 pub(crate) secret_key_bytes: Vec<u8>,
47}
48
49impl CancelToken {
50 pub async fn cancel_query(&self) -> PgResult<()> {
53 PgConnection::cancel_query_bytes(
54 &self.host,
55 self.port,
56 self.process_id,
57 &self.secret_key_bytes,
58 )
59 .await
60 }
61
62 pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
64 (self.process_id, &self.secret_key_bytes)
65 }
66}
67
68impl PgConnection {
69 pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
71 (self.process_id, &self.cancel_key_bytes)
72 }
73
74 pub async fn cancel_query_bytes(
76 host: &str,
77 port: u16,
78 process_id: i32,
79 secret_key: &[u8],
80 ) -> PgResult<()> {
81 let addr = socket_addr(host, port);
83 let mut stream = TcpStream::connect(&addr).await?;
84
85 let buf = encode_cancel_request(process_id, secret_key)?;
88
89 stream.write_all(&buf).await?;
90
91 Ok(())
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::{CANCEL_REQUEST_CODE, encode_cancel_request, socket_addr};
99
100 #[test]
101 fn cancel_socket_addr_brackets_ipv6_hosts() {
102 assert_eq!(socket_addr("127.0.0.1", 5432), "127.0.0.1:5432");
103 assert_eq!(socket_addr("::1", 5432), "[::1]:5432");
104 assert_eq!(socket_addr("[::1]", 5432), "[::1]:5432");
105 }
106
107 #[test]
108 fn encode_cancel_request_with_4_byte_key() {
109 let buf = encode_cancel_request(42, &99i32.to_be_bytes()).expect("encode");
110 assert_eq!(buf.len(), 16);
111 assert_eq!(&buf[0..4], &16i32.to_be_bytes());
112 assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
113 assert_eq!(&buf[8..12], &42i32.to_be_bytes());
114 assert_eq!(&buf[12..16], &99i32.to_be_bytes());
115 }
116
117 #[test]
118 fn encode_cancel_request_with_extended_key() {
119 let key = [1u8, 2, 3, 4, 5, 6, 7, 8];
120 let buf = encode_cancel_request(7, &key).expect("encode");
121 assert_eq!(&buf[0..4], &20i32.to_be_bytes());
122 assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
123 assert_eq!(&buf[8..12], &7i32.to_be_bytes());
124 assert_eq!(&buf[12..], &key);
125 }
126
127 #[test]
128 fn encode_cancel_request_rejects_invalid_key_lengths() {
129 let short = encode_cancel_request(1, &[1, 2, 3]).expect_err("short");
130 assert!(short.to_string().contains("Invalid cancel key length"));
131
132 let long_key = vec![0u8; 257];
133 let long = encode_cancel_request(1, &long_key).expect_err("long");
134 assert!(long.to_string().contains("Invalid cancel key length"));
135 }
136}