1use super::{CANCEL_REQUEST_CODE, PgConnection, PgResult};
4use tokio::io::AsyncWriteExt;
5use tokio::net::TcpStream;
6
7fn encode_cancel_request(process_id: i32, secret_key: &[u8]) -> PgResult<Vec<u8>> {
8 if !(4..=256).contains(&secret_key.len()) {
9 return Err(crate::driver::PgError::Protocol(format!(
10 "Invalid cancel key length: {} (expected 4..=256)",
11 secret_key.len()
12 )));
13 }
14
15 let total_len = 12usize.checked_add(secret_key.len()).ok_or_else(|| {
16 crate::driver::PgError::Protocol("CancelRequest length overflow".to_string())
17 })?;
18 let total_len = i32::try_from(total_len).map_err(|_| {
19 crate::driver::PgError::Protocol("CancelRequest length exceeds i32".to_string())
20 })?;
21
22 let mut buf = Vec::with_capacity(total_len as usize);
23 buf.extend_from_slice(&total_len.to_be_bytes());
24 buf.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
25 buf.extend_from_slice(&process_id.to_be_bytes());
26 buf.extend_from_slice(secret_key);
27 Ok(buf)
28}
29
30#[derive(Debug, Clone)]
33pub struct CancelToken {
34 pub(crate) host: String,
35 pub(crate) port: u16,
36 pub(crate) process_id: i32,
37 pub(crate) secret_key_bytes: Vec<u8>,
39}
40
41impl CancelToken {
42 pub async fn cancel_query(&self) -> PgResult<()> {
45 PgConnection::cancel_query_bytes(
46 &self.host,
47 self.port,
48 self.process_id,
49 &self.secret_key_bytes,
50 )
51 .await
52 }
53
54 pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
56 (self.process_id, &self.secret_key_bytes)
57 }
58}
59
60impl PgConnection {
61 pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
63 (self.process_id, &self.cancel_key_bytes)
64 }
65
66 pub fn get_cancel_key(&self) -> (i32, i32) {
71 if self.cancel_key_bytes.len() == 4 {
72 (
73 self.process_id,
74 i32::from_be_bytes([
75 self.cancel_key_bytes[0],
76 self.cancel_key_bytes[1],
77 self.cancel_key_bytes[2],
78 self.cancel_key_bytes[3],
79 ]),
80 )
81 } else {
82 (self.process_id, 0)
83 }
84 }
85
86 pub async fn cancel_query_bytes(
88 host: &str,
89 port: u16,
90 process_id: i32,
91 secret_key: &[u8],
92 ) -> PgResult<()> {
93 let addr = format!("{}:{}", host, port);
95 let mut stream = TcpStream::connect(&addr).await?;
96
97 let buf = encode_cancel_request(process_id, secret_key)?;
100
101 stream.write_all(&buf).await?;
102
103 Ok(())
105 }
106
107 pub async fn cancel_query(
109 host: &str,
110 port: u16,
111 process_id: i32,
112 secret_key: i32,
113 ) -> PgResult<()> {
114 Self::cancel_query_bytes(host, port, process_id, &secret_key.to_be_bytes()).await
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::{CANCEL_REQUEST_CODE, encode_cancel_request};
121
122 #[test]
123 fn encode_cancel_request_with_4_byte_key() {
124 let buf = encode_cancel_request(42, &99i32.to_be_bytes()).expect("encode");
125 assert_eq!(buf.len(), 16);
126 assert_eq!(&buf[0..4], &16i32.to_be_bytes());
127 assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
128 assert_eq!(&buf[8..12], &42i32.to_be_bytes());
129 assert_eq!(&buf[12..16], &99i32.to_be_bytes());
130 }
131
132 #[test]
133 fn encode_cancel_request_with_extended_key() {
134 let key = [1u8, 2, 3, 4, 5, 6, 7, 8];
135 let buf = encode_cancel_request(7, &key).expect("encode");
136 assert_eq!(&buf[0..4], &20i32.to_be_bytes());
137 assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
138 assert_eq!(&buf[8..12], &7i32.to_be_bytes());
139 assert_eq!(&buf[12..], &key);
140 }
141
142 #[test]
143 fn encode_cancel_request_rejects_invalid_key_lengths() {
144 let short = encode_cancel_request(1, &[1, 2, 3]).expect_err("short");
145 assert!(short.to_string().contains("Invalid cancel key length"));
146
147 let long_key = vec![0u8; 257];
148 let long = encode_cancel_request(1, &long_key).expect_err("long");
149 assert!(long.to_string().contains("Invalid cancel key length"));
150 }
151}