use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Buf, Bytes, BytesMut};
use futures::{Stream, StreamExt};
use crate::{Error, Result};
const FRAME_DATA: u32 = 0x0080_0001;
const FRAME_CONTINUOUS: u32 = 0x0080_0004;
const FRAME_END: u32 = 0x0080_0005;
const FRAME_META_END_CSV: u32 = 0x0080_0006;
const FRAME_META_END_JSON: u32 = 0x0080_0007;
const MIN_PAYLOAD_LEN: usize = 8;
#[derive(Debug, Clone)]
pub enum SelectFrame {
Data {
offset: u64,
data: Bytes,
},
Continuous { offset: u64 },
End {
offset: u64,
total_scanned: u64,
status: u32,
error_message: String,
},
MetaEndCsv {
offset: u64,
total_scanned: u64,
status: u32,
splits_count: u32,
rows_count: u64,
cols_count: u32,
error_message: String,
},
MetaEndJson {
offset: u64,
total_scanned: u64,
status: u32,
splits_count: u32,
rows_count: u64,
error_message: String,
},
}
impl SelectFrame {
pub fn decode(buf: &[u8], verify_payload_crc: bool) -> Result<Option<(Self, usize)>> {
if buf.len() < 20 {
return Ok(None);
}
let type_bytes = [0, buf[1], buf[2], buf[3]];
let frame_type = u32::from_be_bytes(type_bytes);
let payload_len = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
if payload_len < MIN_PAYLOAD_LEN {
return Err(Error::Other(format!(
"select frame payload too short: {payload_len} < {MIN_PAYLOAD_LEN}"
)));
}
let total_len = 12 + payload_len + 4; if buf.len() < total_len {
return Ok(None);
}
let payload = &buf[12..12 + payload_len];
let server_crc = u32::from_be_bytes([
buf[12 + payload_len],
buf[12 + payload_len + 1],
buf[12 + payload_len + 2],
buf[12 + payload_len + 3],
]);
if verify_payload_crc && server_crc != 0 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(payload);
let client_crc = hasher.finalize();
if client_crc != server_crc {
return Err(Error::Other(format!(
"select payload CRC mismatch: server={server_crc:#010x}, \
client={client_crc:#010x}"
)));
}
}
let frame = decode_payload(frame_type, payload)?;
Ok(Some((frame, total_len)))
}
pub fn terminal_status(&self) -> Option<u32> {
match self {
SelectFrame::Data { .. } | SelectFrame::Continuous { .. } => None,
SelectFrame::End { status, .. }
| SelectFrame::MetaEndCsv { status, .. }
| SelectFrame::MetaEndJson { status, .. } => Some(*status),
}
}
}
fn decode_payload(frame_type: u32, payload: &[u8]) -> Result<SelectFrame> {
let offset = u64::from_be_bytes(payload[0..8].try_into().unwrap());
let rest = &payload[8..];
match frame_type {
FRAME_DATA => Ok(SelectFrame::Data {
offset,
data: Bytes::copy_from_slice(rest),
}),
FRAME_CONTINUOUS => {
if !rest.is_empty() {
return Err(Error::Other("select continuous frame has trailing bytes".into()));
}
Ok(SelectFrame::Continuous { offset })
},
FRAME_END => {
if rest.len() < 12 {
return Err(Error::Other("select end frame truncated".into()));
}
let total_scanned = u64::from_be_bytes(rest[0..8].try_into().unwrap());
let status = u32::from_be_bytes(rest[8..12].try_into().unwrap());
let error_message = String::from_utf8_lossy(&rest[12..]).into_owned();
Ok(SelectFrame::End {
offset,
total_scanned,
status,
error_message,
})
},
FRAME_META_END_CSV => {
if rest.len() < 28 {
return Err(Error::Other("select meta-end CSV frame truncated".into()));
}
let total_scanned = u64::from_be_bytes(rest[0..8].try_into().unwrap());
let status = u32::from_be_bytes(rest[8..12].try_into().unwrap());
let splits_count = u32::from_be_bytes(rest[12..16].try_into().unwrap());
let rows_count = u64::from_be_bytes(rest[16..24].try_into().unwrap());
let cols_count = u32::from_be_bytes(rest[24..28].try_into().unwrap());
let error_message = String::from_utf8_lossy(&rest[28..]).into_owned();
Ok(SelectFrame::MetaEndCsv {
offset,
total_scanned,
status,
splits_count,
rows_count,
cols_count,
error_message,
})
},
FRAME_META_END_JSON => {
if rest.len() < 24 {
return Err(Error::Other("select meta-end JSON frame truncated".into()));
}
let total_scanned = u64::from_be_bytes(rest[0..8].try_into().unwrap());
let status = u32::from_be_bytes(rest[8..12].try_into().unwrap());
let splits_count = u32::from_be_bytes(rest[12..16].try_into().unwrap());
let rows_count = u64::from_be_bytes(rest[16..24].try_into().unwrap());
let error_message = String::from_utf8_lossy(&rest[24..]).into_owned();
Ok(SelectFrame::MetaEndJson {
offset,
total_scanned,
status,
splits_count,
rows_count,
error_message,
})
},
other => Err(Error::Other(format!("unknown select frame type: {other:#010x}"))),
}
}
pub struct SelectFrameStream<S> {
inner: S,
buffer: BytesMut,
exhausted: bool,
verify_payload_crc: bool,
}
impl<S> SelectFrameStream<S> {
pub fn new(inner: S, verify_payload_crc: bool) -> Self {
Self {
inner,
buffer: BytesMut::new(),
exhausted: false,
verify_payload_crc,
}
}
}
impl<S> Stream for SelectFrameStream<S>
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Unpin,
{
type Item = Result<SelectFrame>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match SelectFrame::decode(&self.buffer, self.verify_payload_crc) {
Ok(Some((frame, consumed))) => {
self.buffer.advance(consumed);
return Poll::Ready(Some(Ok(frame)));
},
Ok(None) => {
if self.exhausted {
return Poll::Ready(None);
}
},
Err(e) => return Poll::Ready(Some(Err(e))),
}
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
self.buffer.extend_from_slice(&chunk);
},
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e.into())));
},
Poll::Ready(None) => {
self.exhausted = true;
if self.buffer.is_empty() {
return Poll::Ready(None);
}
},
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn encode_frame(frame_type: u32, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(12 + payload.len() + 4);
let type_with_version = frame_type | (1u32 << 24);
out.extend_from_slice(&type_with_version.to_be_bytes());
out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
let mut header_hasher = crc32fast::Hasher::new();
header_hasher.update(&out[0..8]);
out.extend_from_slice(&header_hasher.finalize().to_be_bytes());
out.extend_from_slice(payload);
let mut payload_hasher = crc32fast::Hasher::new();
payload_hasher.update(payload);
out.extend_from_slice(&payload_hasher.finalize().to_be_bytes());
out
}
#[test]
fn decode_data_frame() {
let offset: u64 = 1234;
let mut payload = Vec::new();
payload.extend_from_slice(&offset.to_be_bytes());
payload.extend_from_slice(b"hello, world");
let bytes = encode_frame(FRAME_DATA, &payload);
let (frame, consumed) = SelectFrame::decode(&bytes, true).unwrap().unwrap();
assert_eq!(consumed, bytes.len());
match frame {
SelectFrame::Data { offset: o, data } => {
assert_eq!(o, 1234);
assert_eq!(&data[..], b"hello, world");
},
_ => panic!("expected Data frame"),
}
}
#[test]
fn decode_continuous_frame() {
let payload = 42u64.to_be_bytes();
let bytes = encode_frame(FRAME_CONTINUOUS, &payload);
let (frame, _) = SelectFrame::decode(&bytes, false).unwrap().unwrap();
matches!(frame, SelectFrame::Continuous { offset: 42 });
}
#[test]
fn decode_end_frame_success() {
let mut payload = Vec::new();
payload.extend_from_slice(&100u64.to_be_bytes()); payload.extend_from_slice(&200u64.to_be_bytes()); payload.extend_from_slice(&200u32.to_be_bytes()); let bytes = encode_frame(FRAME_END, &payload);
let (frame, _) = SelectFrame::decode(&bytes, false).unwrap().unwrap();
match frame {
SelectFrame::End {
offset,
total_scanned,
status,
error_message,
} => {
assert_eq!(offset, 100);
assert_eq!(total_scanned, 200);
assert_eq!(status, 200);
assert!(error_message.is_empty());
},
_ => panic!("expected End frame"),
}
}
#[test]
fn decode_meta_end_csv() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes()); payload.extend_from_slice(&1024u64.to_be_bytes()); payload.extend_from_slice(&200u32.to_be_bytes()); payload.extend_from_slice(&4u32.to_be_bytes()); payload.extend_from_slice(&10000u64.to_be_bytes()); payload.extend_from_slice(&5u32.to_be_bytes()); let bytes = encode_frame(FRAME_META_END_CSV, &payload);
let (frame, _) = SelectFrame::decode(&bytes, false).unwrap().unwrap();
match frame {
SelectFrame::MetaEndCsv {
splits_count,
rows_count,
cols_count,
..
} => {
assert_eq!(splits_count, 4);
assert_eq!(rows_count, 10000);
assert_eq!(cols_count, 5);
},
_ => panic!("expected MetaEndCsv"),
}
}
#[test]
fn decode_meta_end_json() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes()); payload.extend_from_slice(&2048u64.to_be_bytes()); payload.extend_from_slice(&200u32.to_be_bytes()); payload.extend_from_slice(&8u32.to_be_bytes()); payload.extend_from_slice(&500u64.to_be_bytes()); let bytes = encode_frame(FRAME_META_END_JSON, &payload);
let (frame, _) = SelectFrame::decode(&bytes, true).unwrap().unwrap();
match frame {
SelectFrame::MetaEndJson {
splits_count,
rows_count,
..
} => {
assert_eq!(splits_count, 8);
assert_eq!(rows_count, 500);
},
_ => panic!("expected MetaEndJson"),
}
}
#[test]
fn decode_returns_none_when_incomplete() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes());
payload.extend_from_slice(b"data");
let bytes = encode_frame(FRAME_DATA, &payload);
assert!(SelectFrame::decode(&bytes[..10], false).unwrap().is_none());
assert!(SelectFrame::decode(&bytes[..20], false).unwrap().is_none());
assert!(SelectFrame::decode(&bytes, false).unwrap().is_some());
}
#[test]
fn decode_rejects_unknown_frame_type() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes());
let bytes = encode_frame(0x00AA_BBCC, &payload);
let err = SelectFrame::decode(&bytes, false).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("unknown select frame type"), "msg={msg}");
}
#[test]
fn payload_crc_mismatch_is_detected() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes());
payload.extend_from_slice(b"data");
let mut bytes = encode_frame(FRAME_DATA, &payload);
let len = bytes.len();
bytes[len - 1] ^= 0xff;
let err = SelectFrame::decode(&bytes, true).unwrap_err();
assert!(format!("{err}").contains("CRC mismatch"));
}
#[test]
fn payload_crc_zero_is_accepted() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u64.to_be_bytes());
payload.extend_from_slice(b"data");
let mut bytes = encode_frame(FRAME_DATA, &payload);
let len = bytes.len();
bytes[len - 4..].copy_from_slice(&[0, 0, 0, 0]);
let result = SelectFrame::decode(&bytes, true).unwrap();
assert!(result.is_some());
}
}