use super::{CANCEL_REQUEST_CODE, PgConnection, PgResult};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
fn encode_cancel_request(process_id: i32, secret_key: &[u8]) -> PgResult<Vec<u8>> {
if !(4..=256).contains(&secret_key.len()) {
return Err(crate::driver::PgError::Protocol(format!(
"Invalid cancel key length: {} (expected 4..=256)",
secret_key.len()
)));
}
let total_len = 12usize.checked_add(secret_key.len()).ok_or_else(|| {
crate::driver::PgError::Protocol("CancelRequest length overflow".to_string())
})?;
let total_len = i32::try_from(total_len).map_err(|_| {
crate::driver::PgError::Protocol("CancelRequest length exceeds i32".to_string())
})?;
let mut buf = Vec::with_capacity(total_len as usize);
buf.extend_from_slice(&total_len.to_be_bytes());
buf.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
buf.extend_from_slice(&process_id.to_be_bytes());
buf.extend_from_slice(secret_key);
Ok(buf)
}
#[derive(Debug, Clone)]
pub struct CancelToken {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) process_id: i32,
pub(crate) secret_key_bytes: Vec<u8>,
}
impl CancelToken {
pub async fn cancel_query(&self) -> PgResult<()> {
PgConnection::cancel_query_bytes(
&self.host,
self.port,
self.process_id,
&self.secret_key_bytes,
)
.await
}
pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
(self.process_id, &self.secret_key_bytes)
}
}
impl PgConnection {
pub fn get_cancel_key_bytes(&self) -> (i32, &[u8]) {
(self.process_id, &self.cancel_key_bytes)
}
pub fn get_cancel_key(&self) -> (i32, i32) {
if self.cancel_key_bytes.len() == 4 {
(
self.process_id,
i32::from_be_bytes([
self.cancel_key_bytes[0],
self.cancel_key_bytes[1],
self.cancel_key_bytes[2],
self.cancel_key_bytes[3],
]),
)
} else {
(self.process_id, 0)
}
}
pub async fn cancel_query_bytes(
host: &str,
port: u16,
process_id: i32,
secret_key: &[u8],
) -> PgResult<()> {
let addr = format!("{}:{}", host, port);
let mut stream = TcpStream::connect(&addr).await?;
let buf = encode_cancel_request(process_id, secret_key)?;
stream.write_all(&buf).await?;
Ok(())
}
pub async fn cancel_query(
host: &str,
port: u16,
process_id: i32,
secret_key: i32,
) -> PgResult<()> {
Self::cancel_query_bytes(host, port, process_id, &secret_key.to_be_bytes()).await
}
}
#[cfg(test)]
mod tests {
use super::{CANCEL_REQUEST_CODE, encode_cancel_request};
#[test]
fn encode_cancel_request_with_4_byte_key() {
let buf = encode_cancel_request(42, &99i32.to_be_bytes()).expect("encode");
assert_eq!(buf.len(), 16);
assert_eq!(&buf[0..4], &16i32.to_be_bytes());
assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
assert_eq!(&buf[8..12], &42i32.to_be_bytes());
assert_eq!(&buf[12..16], &99i32.to_be_bytes());
}
#[test]
fn encode_cancel_request_with_extended_key() {
let key = [1u8, 2, 3, 4, 5, 6, 7, 8];
let buf = encode_cancel_request(7, &key).expect("encode");
assert_eq!(&buf[0..4], &20i32.to_be_bytes());
assert_eq!(&buf[4..8], &CANCEL_REQUEST_CODE.to_be_bytes());
assert_eq!(&buf[8..12], &7i32.to_be_bytes());
assert_eq!(&buf[12..], &key);
}
#[test]
fn encode_cancel_request_rejects_invalid_key_lengths() {
let short = encode_cancel_request(1, &[1, 2, 3]).expect_err("short");
assert!(short.to_string().contains("Invalid cancel key length"));
let long_key = vec![0u8; 257];
let long = encode_cancel_request(1, &long_key).expect_err("long");
assert!(long.to_string().contains("Invalid cancel key length"));
}
}