mssql_types/
encode.rs

1//! TDS binary encoding for SQL values.
2//!
3//! This module provides encoding of Rust values into TDS wire format
4//! for transmission to SQL Server.
5
6// Allow expect() for chrono date construction with known-valid constant dates
7#![allow(clippy::expect_used)]
8
9use bytes::{BufMut, BytesMut};
10
11use crate::error::TypeError;
12use crate::value::SqlValue;
13
14/// Trait for encoding values to TDS binary format.
15pub trait TdsEncode {
16    /// Encode this value into the buffer in TDS format.
17    fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError>;
18
19    /// Get the TDS type ID for this value.
20    fn type_id(&self) -> u8;
21}
22
23impl TdsEncode for SqlValue {
24    fn encode(&self, buf: &mut BytesMut) -> Result<(), TypeError> {
25        match self {
26            SqlValue::Null => {
27                // NULL is represented by length indicator in most contexts
28                // For INTNTYPE, length 0 means NULL
29                Ok(())
30            }
31            SqlValue::Bool(v) => {
32                buf.put_u8(if *v { 1 } else { 0 });
33                Ok(())
34            }
35            SqlValue::TinyInt(v) => {
36                buf.put_u8(*v);
37                Ok(())
38            }
39            SqlValue::SmallInt(v) => {
40                buf.put_i16_le(*v);
41                Ok(())
42            }
43            SqlValue::Int(v) => {
44                buf.put_i32_le(*v);
45                Ok(())
46            }
47            SqlValue::BigInt(v) => {
48                buf.put_i64_le(*v);
49                Ok(())
50            }
51            SqlValue::Float(v) => {
52                buf.put_f32_le(*v);
53                Ok(())
54            }
55            SqlValue::Double(v) => {
56                buf.put_f64_le(*v);
57                Ok(())
58            }
59            SqlValue::String(s) => {
60                // Encode as UTF-16LE for NVARCHAR
61                encode_utf16_string(s, buf);
62                Ok(())
63            }
64            SqlValue::Binary(b) => {
65                // Length-prefixed binary data
66                if b.len() > u16::MAX as usize {
67                    return Err(TypeError::BufferTooSmall {
68                        needed: b.len(),
69                        available: u16::MAX as usize,
70                    });
71                }
72                buf.put_u16_le(b.len() as u16);
73                buf.put_slice(b);
74                Ok(())
75            }
76            #[cfg(feature = "decimal")]
77            SqlValue::Decimal(d) => {
78                encode_decimal(*d, buf);
79                Ok(())
80            }
81            #[cfg(feature = "uuid")]
82            SqlValue::Uuid(u) => {
83                encode_uuid(*u, buf);
84                Ok(())
85            }
86            #[cfg(feature = "chrono")]
87            SqlValue::Date(d) => {
88                encode_date(*d, buf);
89                Ok(())
90            }
91            #[cfg(feature = "chrono")]
92            SqlValue::Time(t) => {
93                encode_time(*t, buf);
94                Ok(())
95            }
96            #[cfg(feature = "chrono")]
97            SqlValue::DateTime(dt) => {
98                encode_datetime2(*dt, buf);
99                Ok(())
100            }
101            #[cfg(feature = "chrono")]
102            SqlValue::DateTimeOffset(dto) => {
103                encode_datetimeoffset(*dto, buf);
104                Ok(())
105            }
106            #[cfg(feature = "json")]
107            SqlValue::Json(j) => {
108                // JSON is sent as NVARCHAR string
109                let s = j.to_string();
110                encode_utf16_string(&s, buf);
111                Ok(())
112            }
113            SqlValue::Xml(x) => {
114                // XML is sent as UTF-16LE string
115                encode_utf16_string(x, buf);
116                Ok(())
117            }
118            SqlValue::Tvp(_) => {
119                // TVP encoding is handled at the RPC parameter level, not here.
120                // This method is for encoding the value data portion; TVPs have
121                // their own complex encoding structure that includes metadata.
122                // See tds-protocol crate for full TVP encoding.
123                Err(TypeError::UnsupportedConversion {
124                    from: "TvpData".to_string(),
125                    to: "raw bytes (use RPC parameter encoding)",
126                })
127            }
128        }
129    }
130
131    fn type_id(&self) -> u8 {
132        match self {
133            SqlValue::Null => 0x1F,        // NULLTYPE
134            SqlValue::Bool(_) => 0x32,     // BITTYPE
135            SqlValue::TinyInt(_) => 0x30,  // INT1TYPE
136            SqlValue::SmallInt(_) => 0x34, // INT2TYPE
137            SqlValue::Int(_) => 0x38,      // INT4TYPE
138            SqlValue::BigInt(_) => 0x7F,   // INT8TYPE
139            SqlValue::Float(_) => 0x3B,    // FLT4TYPE
140            SqlValue::Double(_) => 0x3E,   // FLT8TYPE
141            SqlValue::String(_) => 0xE7,   // NVARCHARTYPE
142            SqlValue::Binary(_) => 0xA5,   // BIGVARBINTYPE
143            #[cfg(feature = "decimal")]
144            SqlValue::Decimal(_) => 0x6C, // DECIMALTYPE
145            #[cfg(feature = "uuid")]
146            SqlValue::Uuid(_) => 0x24, // GUIDTYPE
147            #[cfg(feature = "chrono")]
148            SqlValue::Date(_) => 0x28, // DATETYPE
149            #[cfg(feature = "chrono")]
150            SqlValue::Time(_) => 0x29, // TIMETYPE
151            #[cfg(feature = "chrono")]
152            SqlValue::DateTime(_) => 0x2A, // DATETIME2TYPE
153            #[cfg(feature = "chrono")]
154            SqlValue::DateTimeOffset(_) => 0x2B, // DATETIMEOFFSETTYPE
155            #[cfg(feature = "json")]
156            SqlValue::Json(_) => 0xE7, // NVARCHARTYPE (JSON as string)
157            SqlValue::Xml(_) => 0xF1,      // XMLTYPE
158            SqlValue::Tvp(_) => 0xF3,      // TVPTYPE
159        }
160    }
161}
162
163/// Encode a string as UTF-16LE with length prefix.
164pub fn encode_utf16_string(s: &str, buf: &mut BytesMut) {
165    let utf16: Vec<u16> = s.encode_utf16().collect();
166    let byte_len = utf16.len() * 2;
167
168    // Write byte length (not char length)
169    buf.put_u16_le(byte_len as u16);
170
171    // Write UTF-16LE bytes
172    for code_unit in utf16 {
173        buf.put_u16_le(code_unit);
174    }
175}
176
177/// Encode a string as UTF-16LE without length prefix (for fixed-length fields).
178pub fn encode_utf16_string_no_len(s: &str, buf: &mut BytesMut) {
179    for code_unit in s.encode_utf16() {
180        buf.put_u16_le(code_unit);
181    }
182}
183
184/// Encode a UUID in SQL Server's mixed-endian format.
185///
186/// SQL Server stores UUIDs in a unique byte order:
187/// - First 4 bytes: little-endian
188/// - Next 2 bytes: little-endian
189/// - Next 2 bytes: little-endian
190/// - Last 8 bytes: big-endian (as-is)
191#[cfg(feature = "uuid")]
192pub fn encode_uuid(uuid: uuid::Uuid, buf: &mut BytesMut) {
193    let bytes = uuid.as_bytes();
194
195    // First group (4 bytes) - reverse for little-endian
196    buf.put_u8(bytes[3]);
197    buf.put_u8(bytes[2]);
198    buf.put_u8(bytes[1]);
199    buf.put_u8(bytes[0]);
200
201    // Second group (2 bytes) - reverse for little-endian
202    buf.put_u8(bytes[5]);
203    buf.put_u8(bytes[4]);
204
205    // Third group (2 bytes) - reverse for little-endian
206    buf.put_u8(bytes[7]);
207    buf.put_u8(bytes[6]);
208
209    // Last 8 bytes - big-endian (keep as-is)
210    buf.put_slice(&bytes[8..16]);
211}
212
213/// Encode a decimal value.
214///
215/// TDS DECIMAL format:
216/// - 1 byte: sign (0 = negative, 1 = positive)
217/// - Remaining bytes: absolute value in little-endian
218#[cfg(feature = "decimal")]
219pub fn encode_decimal(decimal: rust_decimal::Decimal, buf: &mut BytesMut) {
220    let sign = if decimal.is_sign_negative() { 0u8 } else { 1u8 };
221    buf.put_u8(sign);
222
223    // Get the mantissa and encode as 128-bit integer
224    let mantissa = decimal.mantissa().unsigned_abs();
225    buf.put_u128_le(mantissa);
226}
227
228/// Encode a DATE value.
229///
230/// TDS DATE is the number of days since 0001-01-01.
231#[cfg(feature = "chrono")]
232pub fn encode_date(date: chrono::NaiveDate, buf: &mut BytesMut) {
233    // Calculate days since 0001-01-01
234    let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
235    let days = date.signed_duration_since(base).num_days() as u32;
236
237    // DATE is encoded as 3 bytes (little-endian)
238    buf.put_u8((days & 0xFF) as u8);
239    buf.put_u8(((days >> 8) & 0xFF) as u8);
240    buf.put_u8(((days >> 16) & 0xFF) as u8);
241}
242
243/// Encode a TIME value.
244///
245/// TDS TIME is encoded as 100-nanosecond intervals since midnight.
246#[cfg(feature = "chrono")]
247pub fn encode_time(time: chrono::NaiveTime, buf: &mut BytesMut) {
248    use chrono::Timelike;
249
250    // Calculate 100-ns intervals since midnight
251    // Scale = 7 (100-nanosecond precision)
252    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
253    let intervals = nanos / 100;
254
255    // TIME with scale 7 uses 5 bytes
256    buf.put_u8((intervals & 0xFF) as u8);
257    buf.put_u8(((intervals >> 8) & 0xFF) as u8);
258    buf.put_u8(((intervals >> 16) & 0xFF) as u8);
259    buf.put_u8(((intervals >> 24) & 0xFF) as u8);
260    buf.put_u8(((intervals >> 32) & 0xFF) as u8);
261}
262
263/// Encode a DATETIME2 value.
264///
265/// DATETIME2 is encoded as TIME followed by DATE.
266#[cfg(feature = "chrono")]
267pub fn encode_datetime2(datetime: chrono::NaiveDateTime, buf: &mut BytesMut) {
268    encode_time(datetime.time(), buf);
269    encode_date(datetime.date(), buf);
270}
271
272/// Encode a DATETIMEOFFSET value.
273///
274/// DATETIMEOFFSET is encoded as TIME + DATE + offset (in minutes).
275#[cfg(feature = "chrono")]
276pub fn encode_datetimeoffset(datetime: chrono::DateTime<chrono::FixedOffset>, buf: &mut BytesMut) {
277    use chrono::Offset;
278
279    // Encode time and date components
280    encode_time(datetime.time(), buf);
281    encode_date(datetime.date_naive(), buf);
282
283    // Encode timezone offset in minutes (signed 16-bit)
284    let offset_seconds = datetime.offset().fix().local_minus_utc();
285    let offset_minutes = (offset_seconds / 60) as i16;
286    buf.put_i16_le(offset_minutes);
287}
288
289#[cfg(test)]
290#[allow(clippy::unwrap_used)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_encode_int() {
296        let mut buf = BytesMut::new();
297        SqlValue::Int(42).encode(&mut buf).unwrap();
298        assert_eq!(&buf[..], &[42, 0, 0, 0]);
299    }
300
301    #[test]
302    fn test_encode_bigint() {
303        let mut buf = BytesMut::new();
304        SqlValue::BigInt(0x0102030405060708)
305            .encode(&mut buf)
306            .unwrap();
307        assert_eq!(&buf[..], &[0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]);
308    }
309
310    #[test]
311    fn test_encode_utf16_string() {
312        let mut buf = BytesMut::new();
313        encode_utf16_string("AB", &mut buf);
314        // Length (4 bytes for 2 UTF-16 code units) + "AB" in UTF-16LE
315        assert_eq!(&buf[..], &[4, 0, 0x41, 0, 0x42, 0]);
316    }
317
318    #[cfg(feature = "uuid")]
319    #[test]
320    fn test_encode_uuid() {
321        let mut buf = BytesMut::new();
322        let uuid = uuid::Uuid::parse_str("12345678-1234-5678-1234-567812345678").unwrap();
323        encode_uuid(uuid, &mut buf);
324        // SQL Server mixed-endian format
325        assert_eq!(
326            &buf[..],
327            &[
328                0x78, 0x56, 0x34, 0x12, // First group reversed
329                0x34, 0x12, // Second group reversed
330                0x78, 0x56, // Third group reversed
331                0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78 // Last 8 bytes as-is
332            ]
333        );
334    }
335
336    #[cfg(feature = "chrono")]
337    #[test]
338    fn test_encode_date() {
339        let mut buf = BytesMut::new();
340        let date = chrono::NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
341        encode_date(date, &mut buf);
342        // Should be 3 bytes representing days since 0001-01-01
343        assert_eq!(buf.len(), 3);
344    }
345}