use bytes::{Bytes, BytesMut};
use tds_protocol::ProtocolError;
use tds_protocol::token::{ColMetaData, Token, TokenParser};
#[derive(Debug)]
pub(crate) enum Pull {
Token(Token),
NeedMore,
End,
}
pub(crate) struct RowSource {
buf: Bytes,
metadata: Option<ColMetaData>,
encryption_enabled: bool,
eom: bool,
}
impl RowSource {
pub(crate) fn new(encryption_enabled: bool) -> Self {
Self {
buf: Bytes::new(),
metadata: None,
encryption_enabled,
eom: false,
}
}
pub(crate) fn push_packet(&mut self, payload: Bytes, is_eom: bool) {
if self.buf.is_empty() {
self.buf = payload;
} else {
let mut joined = BytesMut::with_capacity(self.buf.len() + payload.len());
joined.extend_from_slice(&self.buf);
joined.extend_from_slice(&payload);
self.buf = joined.freeze();
}
self.eom |= is_eom;
}
#[cfg(test)]
pub(crate) fn is_eom(&self) -> bool {
self.eom
}
pub(crate) fn into_parts(self) -> (Bytes, bool) {
(self.buf, self.eom)
}
pub(crate) fn pull(&mut self) -> Result<Pull, ProtocolError> {
if self.buf.is_empty() {
return Ok(if self.eom { Pull::End } else { Pull::NeedMore });
}
let mut parser =
TokenParser::new(self.buf.clone()).with_encryption(self.encryption_enabled);
match parser.next_token_with_metadata(self.metadata.as_ref()) {
Ok(Some(token)) => {
let consumed = self.buf.len() - parser.remaining();
self.buf = self.buf.slice(consumed..);
if let Token::ColMetaData(meta) = &token {
self.metadata = Some(meta.clone());
}
Ok(Pull::Token(token))
}
Ok(None) => Ok(if self.eom { Pull::End } else { Pull::NeedMore }),
Err(ProtocolError::UnexpectedEof | ProtocolError::IncompletePacket { .. })
if !self.eom =>
{
Ok(Pull::NeedMore)
}
Err(e) => Err(e),
}
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
fn colmetadata() -> Vec<u8> {
let mut v = vec![0x81]; v.extend_from_slice(&[0x02, 0x00]); v.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); v.extend_from_slice(&[0x01, 0x00]); v.push(0x38); v.push(0x02); v.extend_from_slice(&[b'i', 0x00, b'd', 0x00]);
v.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); v.extend_from_slice(&[0x01, 0x00]); v.push(0xE7); v.extend_from_slice(&[0x64, 0x00]); v.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); v.push(0x04); v.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
v
}
fn row(id: i32, name: Option<&str>) -> Vec<u8> {
let mut v = vec![0xD1];
v.extend_from_slice(&id.to_le_bytes());
match name {
None => v.extend_from_slice(&[0xFF, 0xFF]), Some(s) => {
let utf16: Vec<u8> = s.encode_utf16().flat_map(u16::to_le_bytes).collect();
v.extend_from_slice(&(utf16.len() as u16).to_le_bytes());
v.extend_from_slice(&utf16);
}
}
v
}
fn done(row_count: u64) -> Vec<u8> {
let mut v = vec![0xFD];
v.extend_from_slice(&[0x10, 0x00]); v.extend_from_slice(&[0xC1, 0x00]); v.extend_from_slice(&row_count.to_le_bytes());
v
}
fn sample_response() -> Vec<u8> {
let mut v = colmetadata();
v.extend_from_slice(&row(1, Some("alpha")));
v.extend_from_slice(&row(2, None));
v.extend_from_slice(&row(3, Some("gamma")));
v.extend_from_slice(&done(3));
v
}
type CollectedRow = (bool, Bytes);
fn drain(src: &mut RowSource, rows: &mut Vec<CollectedRow>) -> bool {
loop {
match src.pull().expect("pull must not error on valid input") {
Pull::Token(Token::Row(r)) => rows.push((false, r.data)),
Pull::Token(Token::NbcRow(r)) => rows.push((true, r.data)),
Pull::Token(_) => {} Pull::NeedMore => return false,
Pull::End => return true,
}
}
}
fn eager_rows(bytes: &[u8]) -> Vec<CollectedRow> {
let mut parser = TokenParser::new(Bytes::copy_from_slice(bytes));
let mut meta: Option<ColMetaData> = None;
let mut rows = Vec::new();
while let Some(token) = parser
.next_token_with_metadata(meta.as_ref())
.expect("eager parse")
{
match token {
Token::ColMetaData(m) => meta = Some(m),
Token::Row(r) => rows.push((false, r.data)),
Token::NbcRow(r) => rows.push((true, r.data)),
_ => {}
}
}
rows
}
#[test]
fn whole_response_in_one_packet() {
let full = sample_response();
let mut src = RowSource::new(false);
src.push_packet(Bytes::copy_from_slice(&full), true);
let mut rows = Vec::new();
assert!(drain(&mut src, &mut rows));
assert_eq!(rows, eager_rows(&full));
assert_eq!(rows.len(), 3);
}
#[test]
fn every_two_packet_split_matches_eager() {
let full = sample_response();
let reference = eager_rows(&full);
for split in 0..=full.len() {
let mut src = RowSource::new(false);
let mut rows = Vec::new();
src.push_packet(Bytes::copy_from_slice(&full[..split]), false);
let ended_early = drain(&mut src, &mut rows);
assert!(!ended_early, "must not end before eom (split {split})");
src.push_packet(Bytes::copy_from_slice(&full[split..]), true);
let ended = drain(&mut src, &mut rows);
assert!(ended, "stream must end after eom (split {split})");
assert_eq!(rows, reference, "rows differ at split {split}");
}
}
#[test]
fn byte_by_byte_feed_matches_eager() {
let full = sample_response();
let reference = eager_rows(&full);
let mut src = RowSource::new(false);
let mut rows = Vec::new();
for (i, b) in full.iter().enumerate() {
let is_last = i == full.len() - 1;
src.push_packet(Bytes::copy_from_slice(&[*b]), is_last);
drain(&mut src, &mut rows);
}
assert_eq!(rows, reference);
assert_eq!(rows.len(), 3);
}
#[test]
fn empty_buffer_reports_need_more_then_end() {
let mut src = RowSource::new(false);
assert!(matches!(src.pull().unwrap(), Pull::NeedMore));
assert!(!src.is_eom());
src.push_packet(Bytes::new(), true);
assert!(src.is_eom());
assert!(matches!(src.pull().unwrap(), Pull::End));
}
#[test]
fn truncated_token_after_eom_errors() {
let mut src = RowSource::new(false);
src.push_packet(Bytes::copy_from_slice(&colmetadata()), false);
let mut rows = Vec::new();
drain(&mut src, &mut rows);
src.push_packet(Bytes::copy_from_slice(&[0xD1, 0x2A, 0x00]), true);
let err = src.pull();
assert!(
err.is_err(),
"truncated token after eom must error, got {err:?}"
);
}
#[test]
fn metadata_persists_across_packets() {
let mut src = RowSource::new(false);
src.push_packet(Bytes::copy_from_slice(&colmetadata()), false);
let mut rows = Vec::new();
assert!(!drain(&mut src, &mut rows)); assert!(rows.is_empty());
let mut tail = row(7, Some("z"));
tail.extend_from_slice(&done(1));
src.push_packet(Bytes::copy_from_slice(&tail), true);
assert!(drain(&mut src, &mut rows));
assert_eq!(rows.len(), 1);
assert_eq!(&rows[0].1[..4], &7i32.to_le_bytes());
}
}