mssql_types/
encode.rs

1//! TDS binary encoding for SQL values.
2//!
3//! This module provides encoding of Rust values into TDS wire format
4//! for transmission to SQL Server.
5
6// Allow expect() for chrono date construction with known-valid constant dates
7#![allow(clippy::expect_used)]
8
9use bytes::{BufMut, BytesMut};
10
11use crate::error::TypeError;
12use crate::value::SqlValue;
13
14/// Trait for encoding values to TDS binary format.
15pub trait TdsEncode {
16    /// Encode this value into the buffer in TDS format.
17    fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError>;
18
19    /// Get the TDS type ID for this value.
20    fn type_id(&self) -> u8;
21}
22
23impl TdsEncode for SqlValue {
24    fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError> {
25        match self {
26            SqlValue::Null => {
27                // NULL is represented by length indicator in most contexts
28                // For INTNTYPE, length 0 means NULL
29                Ok(())
30            }
31            SqlValue::Bool(v) => {
32                buf.put_u8(if *v { 1 } else { 0 });
33                Ok(())
34            }
35            SqlValue::TinyInt(v) => {
36                buf.put_u8(*v);
37                Ok(())
38            }
39            SqlValue::SmallInt(v) => {
40                buf.put_i16_le(*v);
41                Ok(())
42            }
43            SqlValue::Int(v) => {
44                buf.put_i32_le(*v);
45                Ok(())
46            }
47            SqlValue::BigInt(v) => {
48                buf.put_i64_le(*v);
49                Ok(())
50            }
51            SqlValue::Float(v) => {
52                buf.put_f32_le(*v);
53                Ok(())
54            }
55            SqlValue::Double(v) => {
56                buf.put_f64_le(*v);
57                Ok(())
58            }
59            SqlValue::String(s) => {
60                // Encode as UTF-16LE for NVARCHAR
61                encode_utf16_string(s, buf);
62                Ok(())
63            }
64            SqlValue::Binary(b) => {
65                // Length-prefixed binary data
66                if b.len() > u16::MAX as usize {
67                    return Err(TypeError::BufferTooSmall {
68                        needed: b.len(),
69                        available: u16::MAX as usize,
70                    });
71                }
72                buf.put_u16_le(b.len() as u16);
73                buf.put_slice(b);
74                Ok(())
75            }
76            #[cfg(feature = "decimal")]
77            SqlValue::Decimal(d) => {
78                encode_decimal(*d, buf);
79                Ok(())
80            }
81            #[cfg(feature = "uuid")]
82            SqlValue::Uuid(u) => {
83                encode_uuid(*u, buf);
84                Ok(())
85            }
86            #[cfg(feature = "chrono")]
87            SqlValue::Date(d) => {
88                encode_date(*d, buf);
89                Ok(())
90            }
91            #[cfg(feature = "chrono")]
92            SqlValue::Time(t) => {
93                encode_time(*t, buf);
94                Ok(())
95            }
96            #[cfg(feature = "chrono")]
97            SqlValue::DateTime(dt) => {
98                encode_datetime2(*dt, buf);
99                Ok(())
100            }
101            #[cfg(feature = "chrono")]
102            SqlValue::DateTimeOffset(dto) => {
103                encode_datetimeoffset(*dto, buf);
104                Ok(())
105            }
106            #[cfg(feature = "json")]
107            SqlValue::Json(j) => {
108                // JSON is sent as NVARCHAR string
109                let s = j.to_string();
110                encode_utf16_string(&s, buf);
111                Ok(())
112            }
113            SqlValue::Xml(x) => {
114                // XML is sent as UTF-16LE string
115                encode_utf16_string(x, buf);
116                Ok(())
117            }
118        }
119    }
120
121    fn type_id(&self) -> u8 {
122        match self {
123            SqlValue::Null => 0x1F,        // NULLTYPE
124            SqlValue::Bool(_) => 0x32,     // BITTYPE
125            SqlValue::TinyInt(_) => 0x30,  // INT1TYPE
126            SqlValue::SmallInt(_) => 0x34, // INT2TYPE
127            SqlValue::Int(_) => 0x38,      // INT4TYPE
128            SqlValue::BigInt(_) => 0x7F,   // INT8TYPE
129            SqlValue::Float(_) => 0x3B,    // FLT4TYPE
130            SqlValue::Double(_) => 0x3E,   // FLT8TYPE
131            SqlValue::String(_) => 0xE7,   // NVARCHARTYPE
132            SqlValue::Binary(_) => 0xA5,   // BIGVARBINTYPE
133            #[cfg(feature = "decimal")]
134            SqlValue::Decimal(_) => 0x6C, // DECIMALTYPE
135            #[cfg(feature = "uuid")]
136            SqlValue::Uuid(_) => 0x24, // GUIDTYPE
137            #[cfg(feature = "chrono")]
138            SqlValue::Date(_) => 0x28, // DATETYPE
139            #[cfg(feature = "chrono")]
140            SqlValue::Time(_) => 0x29, // TIMETYPE
141            #[cfg(feature = "chrono")]
142            SqlValue::DateTime(_) => 0x2A, // DATETIME2TYPE
143            #[cfg(feature = "chrono")]
144            SqlValue::DateTimeOffset(_) => 0x2B, // DATETIMEOFFSETTYPE
145            #[cfg(feature = "json")]
146            SqlValue::Json(_) => 0xE7, // NVARCHARTYPE (JSON as string)
147            SqlValue::Xml(_) => 0xF1,      // XMLTYPE
148        }
149    }
150}
151
152/// Encode a string as UTF-16LE with length prefix.
153pub fn encode_utf16_string(s: &str, buf: &mut BytesMut) {
154    let utf16: Vec<u16> = s.encode_utf16().collect();
155    let byte_len = utf16.len() * 2;
156
157    // Write byte length (not char length)
158    buf.put_u16_le(byte_len as u16);
159
160    // Write UTF-16LE bytes
161    for code_unit in utf16 {
162        buf.put_u16_le(code_unit);
163    }
164}
165
166/// Encode a string as UTF-16LE without length prefix (for fixed-length fields).
167pub fn encode_utf16_string_no_len(s: &str, buf: &mut BytesMut) {
168    for code_unit in s.encode_utf16() {
169        buf.put_u16_le(code_unit);
170    }
171}
172
173/// Encode a UUID in SQL Server's mixed-endian format.
174///
175/// SQL Server stores UUIDs in a unique byte order:
176/// - First 4 bytes: little-endian
177/// - Next 2 bytes: little-endian
178/// - Next 2 bytes: little-endian
179/// - Last 8 bytes: big-endian (as-is)
180#[cfg(feature = "uuid")]
181pub fn encode_uuid(uuid: uuid::Uuid, buf: &mut BytesMut) {
182    let bytes = uuid.as_bytes();
183
184    // First group (4 bytes) - reverse for little-endian
185    buf.put_u8(bytes[3]);
186    buf.put_u8(bytes[2]);
187    buf.put_u8(bytes[1]);
188    buf.put_u8(bytes[0]);
189
190    // Second group (2 bytes) - reverse for little-endian
191    buf.put_u8(bytes[5]);
192    buf.put_u8(bytes[4]);
193
194    // Third group (2 bytes) - reverse for little-endian
195    buf.put_u8(bytes[7]);
196    buf.put_u8(bytes[6]);
197
198    // Last 8 bytes - big-endian (keep as-is)
199    buf.put_slice(&bytes[8..16]);
200}
201
202/// Encode a decimal value.
203///
204/// TDS DECIMAL format:
205/// - 1 byte: sign (0 = negative, 1 = positive)
206/// - Remaining bytes: absolute value in little-endian
207#[cfg(feature = "decimal")]
208pub fn encode_decimal(decimal: rust_decimal::Decimal, buf: &mut BytesMut) {
209    let sign = if decimal.is_sign_negative() { 0u8 } else { 1u8 };
210    buf.put_u8(sign);
211
212    // Get the mantissa and encode as 128-bit integer
213    let mantissa = decimal.mantissa().unsigned_abs();
214    buf.put_u128_le(mantissa);
215}
216
217/// Encode a DATE value.
218///
219/// TDS DATE is the number of days since 0001-01-01.
220#[cfg(feature = "chrono")]
221pub fn encode_date(date: chrono::NaiveDate, buf: &mut BytesMut) {
222    // Calculate days since 0001-01-01
223    let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
224    let days = date.signed_duration_since(base).num_days() as u32;
225
226    // DATE is encoded as 3 bytes (little-endian)
227    buf.put_u8((days & 0xFF) as u8);
228    buf.put_u8(((days >> 8) & 0xFF) as u8);
229    buf.put_u8(((days >> 16) & 0xFF) as u8);
230}
231
232/// Encode a TIME value.
233///
234/// TDS TIME is encoded as 100-nanosecond intervals since midnight.
235#[cfg(feature = "chrono")]
236pub fn encode_time(time: chrono::NaiveTime, buf: &mut BytesMut) {
237    use chrono::Timelike;
238
239    // Calculate 100-ns intervals since midnight
240    // Scale = 7 (100-nanosecond precision)
241    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
242    let intervals = nanos / 100;
243
244    // TIME with scale 7 uses 5 bytes
245    buf.put_u8((intervals & 0xFF) as u8);
246    buf.put_u8(((intervals >> 8) & 0xFF) as u8);
247    buf.put_u8(((intervals >> 16) & 0xFF) as u8);
248    buf.put_u8(((intervals >> 24) & 0xFF) as u8);
249    buf.put_u8(((intervals >> 32) & 0xFF) as u8);
250}
251
252/// Encode a DATETIME2 value.
253///
254/// DATETIME2 is encoded as TIME followed by DATE.
255#[cfg(feature = "chrono")]
256pub fn encode_datetime2(datetime: chrono::NaiveDateTime, buf: &mut BytesMut) {
257    encode_time(datetime.time(), buf);
258    encode_date(datetime.date(), buf);
259}
260
261/// Encode a DATETIMEOFFSET value.
262///
263/// DATETIMEOFFSET is encoded as TIME + DATE + offset (in minutes).
264#[cfg(feature = "chrono")]
265pub fn encode_datetimeoffset(datetime: chrono::DateTime<chrono::FixedOffset>, buf: &mut BytesMut) {
266    use chrono::Offset;
267
268    // Encode time and date components
269    encode_time(datetime.time(), buf);
270    encode_date(datetime.date_naive(), buf);
271
272    // Encode timezone offset in minutes (signed 16-bit)
273    let offset_seconds = datetime.offset().fix().local_minus_utc();
274    let offset_minutes = (offset_seconds / 60) as i16;
275    buf.put_i16_le(offset_minutes);
276}
277
278#[cfg(test)]
279#[allow(clippy::unwrap_used)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_encode_int() {
285        let mut buf = BytesMut::new();
286        SqlValue::Int(42).encode(&mut buf).unwrap();
287        assert_eq!(&buf[..], &[42, 0, 0, 0]);
288    }
289
290    #[test]
291    fn test_encode_bigint() {
292        let mut buf = BytesMut::new();
293        SqlValue::BigInt(0x0102030405060708)
294            .encode(&mut buf)
295            .unwrap();
296        assert_eq!(&buf[..], &[0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]);
297    }
298
299    #[test]
300    fn test_encode_utf16_string() {
301        let mut buf = BytesMut::new();
302        encode_utf16_string("AB", &mut buf);
303        // Length (4 bytes for 2 UTF-16 code units) + "AB" in UTF-16LE
304        assert_eq!(&buf[..], &[4, 0, 0x41, 0, 0x42, 0]);
305    }
306
307    #[cfg(feature = "uuid")]
308    #[test]
309    fn test_encode_uuid() {
310        let mut buf = BytesMut::new();
311        let uuid = uuid::Uuid::parse_str("12345678-1234-5678-1234-567812345678").unwrap();
312        encode_uuid(uuid, &mut buf);
313        // SQL Server mixed-endian format
314        assert_eq!(
315            &buf[..],
316            &[
317                0x78, 0x56, 0x34, 0x12, // First group reversed
318                0x34, 0x12, // Second group reversed
319                0x78, 0x56, // Third group reversed
320                0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78 // Last 8 bytes as-is
321            ]
322        );
323    }
324
325    #[cfg(feature = "chrono")]
326    #[test]
327    fn test_encode_date() {
328        let mut buf = BytesMut::new();
329        let date = chrono::NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
330        encode_date(date, &mut buf);
331        // Should be 3 bytes representing days since 0001-01-01
332        assert_eq!(buf.len(), 3);
333    }
334}