#![allow(clippy::expect_used)]
use bytes::{BufMut, BytesMut};
use crate::error::TypeError;
use crate::value::SqlValue;
pub trait TdsEncode {
fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError>;
fn type_id(&self) -> u8;
}
impl TdsEncode for SqlValue {
fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError> {
match self {
SqlValue::Null => {
Ok(())
}
SqlValue::Bool(v) => {
buf.put_u8(if *v { 1 } else { 0 });
Ok(())
}
SqlValue::TinyInt(v) => {
buf.put_u8(*v);
Ok(())
}
SqlValue::SmallInt(v) => {
buf.put_i16_le(*v);
Ok(())
}
SqlValue::Int(v) => {
buf.put_i32_le(*v);
Ok(())
}
SqlValue::BigInt(v) => {
buf.put_i64_le(*v);
Ok(())
}
SqlValue::Float(v) => {
buf.put_f32_le(*v);
Ok(())
}
SqlValue::Double(v) => {
buf.put_f64_le(*v);
Ok(())
}
SqlValue::String(s) => {
encode_utf16_string(s, buf);
Ok(())
}
SqlValue::Binary(b) => {
if b.len() > u16::MAX as usize {
return Err(TypeError::BufferTooSmall {
needed: b.len(),
available: u16::MAX as usize,
});
}
buf.put_u16_le(b.len() as u16);
buf.put_slice(b);
Ok(())
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
encode_decimal(*d, buf);
Ok(())
}
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
encode_uuid(*u, buf);
Ok(())
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
encode_date(*d, buf);
Ok(())
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
encode_time(*t, buf);
Ok(())
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) => {
encode_datetime2(*dt, buf);
Ok(())
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
encode_datetimeoffset(*dto, buf);
Ok(())
}
#[cfg(feature = "json")]
SqlValue::Json(j) => {
let s = j.to_string();
encode_utf16_string(&s, buf);
Ok(())
}
SqlValue::Xml(x) => {
encode_utf16_string(x, buf);
Ok(())
}
SqlValue::Tvp(_) => {
Err(TypeError::UnsupportedConversion {
from: "TvpData".to_string(),
to: "raw bytes (use RPC parameter encoding)",
})
}
}
}
fn type_id(&self) -> u8 {
match self {
SqlValue::Null => 0x1F, SqlValue::Bool(_) => 0x32, SqlValue::TinyInt(_) => 0x30, SqlValue::SmallInt(_) => 0x34, SqlValue::Int(_) => 0x38, SqlValue::BigInt(_) => 0x7F, SqlValue::Float(_) => 0x3B, SqlValue::Double(_) => 0x3E, SqlValue::String(_) => 0xE7, SqlValue::Binary(_) => 0xA5, #[cfg(feature = "decimal")]
SqlValue::Decimal(_) => 0x6C, #[cfg(feature = "uuid")]
SqlValue::Uuid(_) => 0x24, #[cfg(feature = "chrono")]
SqlValue::Date(_) => 0x28, #[cfg(feature = "chrono")]
SqlValue::Time(_) => 0x29, #[cfg(feature = "chrono")]
SqlValue::DateTime(_) => 0x2A, #[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(_) => 0x2B, #[cfg(feature = "json")]
SqlValue::Json(_) => 0xE7, SqlValue::Xml(_) => 0xF1, SqlValue::Tvp(_) => 0xF3, }
}
}
pub fn encode_utf16_string(s: &str, buf: &mut BytesMut) {
let utf16: Vec<u16> = s.encode_utf16().collect();
let byte_len = utf16.len() * 2;
buf.put_u16_le(byte_len as u16);
for code_unit in utf16 {
buf.put_u16_le(code_unit);
}
}
pub fn encode_utf16_string_no_len(s: &str, buf: &mut BytesMut) {
for code_unit in s.encode_utf16() {
buf.put_u16_le(code_unit);
}
}
#[cfg(feature = "uuid")]
pub fn encode_uuid(uuid: uuid::Uuid, buf: &mut BytesMut) {
let bytes = uuid.as_bytes();
buf.put_u8(bytes[3]);
buf.put_u8(bytes[2]);
buf.put_u8(bytes[1]);
buf.put_u8(bytes[0]);
buf.put_u8(bytes[5]);
buf.put_u8(bytes[4]);
buf.put_u8(bytes[7]);
buf.put_u8(bytes[6]);
buf.put_slice(&bytes[8..16]);
}
#[cfg(feature = "decimal")]
pub fn encode_decimal(decimal: rust_decimal::Decimal, buf: &mut BytesMut) {
let sign = if decimal.is_sign_negative() { 0u8 } else { 1u8 };
buf.put_u8(sign);
let mantissa = decimal.mantissa().unsigned_abs();
buf.put_u128_le(mantissa);
}
#[cfg(feature = "chrono")]
pub fn encode_date(date: chrono::NaiveDate, buf: &mut BytesMut) {
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
let days = date.signed_duration_since(base).num_days() as u32;
buf.put_u8((days & 0xFF) as u8);
buf.put_u8(((days >> 8) & 0xFF) as u8);
buf.put_u8(((days >> 16) & 0xFF) as u8);
}
#[cfg(feature = "chrono")]
pub fn encode_time(time: chrono::NaiveTime, buf: &mut BytesMut) {
use chrono::Timelike;
let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
let intervals = nanos / 100;
buf.put_u8((intervals & 0xFF) as u8);
buf.put_u8(((intervals >> 8) & 0xFF) as u8);
buf.put_u8(((intervals >> 16) & 0xFF) as u8);
buf.put_u8(((intervals >> 24) & 0xFF) as u8);
buf.put_u8(((intervals >> 32) & 0xFF) as u8);
}
#[cfg(feature = "chrono")]
pub fn encode_datetime2(datetime: chrono::NaiveDateTime, buf: &mut BytesMut) {
encode_time(datetime.time(), buf);
encode_date(datetime.date(), buf);
}
#[cfg(feature = "chrono")]
pub fn encode_datetimeoffset(datetime: chrono::DateTime<chrono::FixedOffset>, buf: &mut BytesMut) {
use chrono::Offset;
encode_time(datetime.time(), buf);
encode_date(datetime.date_naive(), buf);
let offset_seconds = datetime.offset().fix().local_minus_utc();
let offset_minutes = (offset_seconds / 60) as i16;
buf.put_i16_le(offset_minutes);
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_encode_int() {
let mut buf = BytesMut::new();
SqlValue::Int(42).encode(&mut buf).unwrap();
assert_eq!(&buf[..], &[42, 0, 0, 0]);
}
#[test]
fn test_encode_bigint() {
let mut buf = BytesMut::new();
SqlValue::BigInt(0x0102030405060708)
.encode(&mut buf)
.unwrap();
assert_eq!(&buf[..], &[0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]);
}
#[test]
fn test_encode_utf16_string() {
let mut buf = BytesMut::new();
encode_utf16_string("AB", &mut buf);
assert_eq!(&buf[..], &[4, 0, 0x41, 0, 0x42, 0]);
}
#[cfg(feature = "uuid")]
#[test]
fn test_encode_uuid() {
let mut buf = BytesMut::new();
let uuid = uuid::Uuid::parse_str("12345678-1234-5678-1234-567812345678").unwrap();
encode_uuid(uuid, &mut buf);
assert_eq!(
&buf[..],
&[
0x78, 0x56, 0x34, 0x12, 0x34, 0x12, 0x78, 0x56, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78 ]
);
}
#[cfg(feature = "chrono")]
#[test]
fn test_encode_date() {
let mut buf = BytesMut::new();
let date = chrono::NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
encode_date(date, &mut buf);
assert_eq!(buf.len(), 3);
}
}