1#![allow(clippy::expect_used)]
8
9use bytes::{Buf, Bytes};
10
11use crate::error::TypeError;
12use crate::value::SqlValue;
13
14pub trait TdsDecode: Sized {
16 fn decode(buf: &mut Bytes, type_info: &TypeInfo) -> Result<Self, TypeError>;
18}
19
20#[derive(Debug, Clone)]
22pub struct TypeInfo {
23 pub type_id: u8,
25 pub length: Option<u32>,
27 pub scale: Option<u8>,
29 pub precision: Option<u8>,
31 pub collation: Option<Collation>,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub struct Collation {
38 pub lcid: u32,
40 pub flags: u8,
42}
43
44impl Collation {
45 #[must_use]
50 pub fn is_utf8(&self) -> bool {
51 (self.lcid & 0x0400_0000) != 0
52 }
53
54 #[cfg(feature = "encoding")]
59 #[must_use]
60 pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
61 encoding_for_lcid(self.lcid)
62 }
63}
64
65#[cfg(feature = "encoding")]
68const UTF8_COLLATION_FLAG: u32 = 0x0400_0000;
69
70#[cfg(feature = "encoding")]
72fn encoding_for_lcid(lcid: u32) -> Option<&'static encoding_rs::Encoding> {
73 if (lcid & UTF8_COLLATION_FLAG) != 0 {
75 return Some(encoding_rs::UTF_8);
76 }
77
78 let code_page = code_page_for_lcid(lcid)?;
80
81 match code_page {
83 874 => Some(encoding_rs::WINDOWS_874),
84 932 => Some(encoding_rs::SHIFT_JIS),
85 936 => Some(encoding_rs::GB18030),
86 949 => Some(encoding_rs::EUC_KR),
87 950 => Some(encoding_rs::BIG5),
88 1250 => Some(encoding_rs::WINDOWS_1250),
89 1251 => Some(encoding_rs::WINDOWS_1251),
90 1252 => Some(encoding_rs::WINDOWS_1252),
91 1253 => Some(encoding_rs::WINDOWS_1253),
92 1254 => Some(encoding_rs::WINDOWS_1254),
93 1255 => Some(encoding_rs::WINDOWS_1255),
94 1256 => Some(encoding_rs::WINDOWS_1256),
95 1257 => Some(encoding_rs::WINDOWS_1257),
96 1258 => Some(encoding_rs::WINDOWS_1258),
97 _ => None,
98 }
99}
100
101#[cfg(feature = "encoding")]
103fn code_page_for_lcid(lcid: u32) -> Option<u16> {
104 const PRIMARY_LANGUAGE_MASK: u32 = 0x3FF;
106 let primary_lang = lcid & PRIMARY_LANGUAGE_MASK;
107
108 match primary_lang {
109 0x0411 => Some(932), 0x0804 | 0x1004 => Some(936), 0x0404 | 0x0C04 | 0x1404 => Some(950), 0x0412 => Some(949), 0x041E => Some(874), 0x042A => Some(1258), 0x0405 | 0x0415 | 0x040E | 0x041A | 0x081A | 0x141A | 0x101A | 0x041B | 0x0424 | 0x0418
118 | 0x041C => Some(1250),
119
120 0x0419 | 0x0422 | 0x0423 | 0x0402 | 0x042F | 0x0C1A | 0x201A | 0x0440 | 0x0843 | 0x0444
122 | 0x0450 | 0x0485 => Some(1251),
123
124 0x0408 => Some(1253), 0x041F | 0x042C => Some(1254), 0x040D => Some(1255), 0x0401 | 0x0801 | 0x0C01 | 0x1001 | 0x1401 | 0x1801 | 0x1C01 | 0x2001 | 0x2401 | 0x2801
130 | 0x2C01 | 0x3001 | 0x3401 | 0x3801 | 0x3C01 | 0x4001 | 0x0429 | 0x0420 | 0x048C
131 | 0x0463 => Some(1256),
132
133 0x0425..=0x0427 => Some(1257),
135
136 0x0409 | 0x0809 | 0x0C09 | 0x1009 | 0x1409 | 0x1809 | 0x1C09 | 0x2009 | 0x2409 | 0x2809
138 | 0x2C09 | 0x3009 | 0x3409 | 0x0407 | 0x0807 | 0x0C07 | 0x1007 | 0x1407 | 0x040C
139 | 0x080C | 0x0C0C | 0x100C | 0x140C | 0x180C | 0x0410 | 0x0810 | 0x0413 | 0x0813
140 | 0x0416 | 0x0816 | 0x040A | 0x080A | 0x0C0A | 0x100A | 0x140A | 0x180A | 0x1C0A
141 | 0x200A | 0x240A | 0x280A | 0x2C0A | 0x300A | 0x340A | 0x380A | 0x3C0A | 0x400A
142 | 0x440A | 0x480A | 0x4C0A | 0x500A => Some(1252),
143
144 _ => Some(1252), }
146}
147
148impl TypeInfo {
149 #[must_use]
151 pub fn int(type_id: u8) -> Self {
152 Self {
153 type_id,
154 length: None,
155 scale: None,
156 precision: None,
157 collation: None,
158 }
159 }
160
161 #[must_use]
163 pub fn varchar(length: u32) -> Self {
164 Self {
165 type_id: 0xE7, length: Some(length),
167 scale: None,
168 precision: None,
169 collation: None,
170 }
171 }
172
173 #[must_use]
175 pub fn decimal(precision: u8, scale: u8) -> Self {
176 Self {
177 type_id: 0x6C,
178 length: None,
179 scale: Some(scale),
180 precision: Some(precision),
181 collation: None,
182 }
183 }
184
185 #[must_use]
187 pub fn datetime_with_scale(type_id: u8, scale: u8) -> Self {
188 Self {
189 type_id,
190 length: None,
191 scale: Some(scale),
192 precision: None,
193 collation: None,
194 }
195 }
196}
197
198pub fn decode_value(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
200 match type_info.type_id {
201 0x1F => Ok(SqlValue::Null), 0x32 => decode_bit(buf), 0x30 => decode_tinyint(buf), 0x34 => decode_smallint(buf), 0x38 => decode_int(buf), 0x7F => decode_bigint(buf), 0x3B => decode_float(buf), 0x3E => decode_double(buf), 0x26 => decode_intn(buf, type_info),
213
214 0xE7 => decode_nvarchar(buf, type_info), 0xAF => decode_varchar(buf, type_info), 0xA7 => decode_varchar(buf, type_info), 0xA5 => decode_varbinary(buf, type_info), 0xAD => decode_varbinary(buf, type_info), 0x24 => decode_guid(buf),
225
226 0x6C | 0x6A => decode_decimal(buf, type_info),
228
229 0x28 => decode_date(buf), 0x29 => decode_time(buf, type_info), 0x2A => decode_datetime2(buf, type_info), 0x2B => decode_datetimeoffset(buf, type_info), 0x3D => decode_datetime(buf), 0x3F => decode_smalldatetime(buf), 0xF1 => decode_xml(buf),
239
240 _ => Err(TypeError::UnsupportedConversion {
241 from: format!("TDS type 0x{:02X}", type_info.type_id),
242 to: "SqlValue",
243 }),
244 }
245}
246
247fn decode_bit(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
248 if buf.remaining() < 1 {
249 return Err(TypeError::BufferTooSmall {
250 needed: 1,
251 available: buf.remaining(),
252 });
253 }
254 Ok(SqlValue::Bool(buf.get_u8() != 0))
255}
256
257fn decode_tinyint(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
258 if buf.remaining() < 1 {
259 return Err(TypeError::BufferTooSmall {
260 needed: 1,
261 available: buf.remaining(),
262 });
263 }
264 Ok(SqlValue::TinyInt(buf.get_u8()))
265}
266
267fn decode_smallint(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
268 if buf.remaining() < 2 {
269 return Err(TypeError::BufferTooSmall {
270 needed: 2,
271 available: buf.remaining(),
272 });
273 }
274 Ok(SqlValue::SmallInt(buf.get_i16_le()))
275}
276
277fn decode_int(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
278 if buf.remaining() < 4 {
279 return Err(TypeError::BufferTooSmall {
280 needed: 4,
281 available: buf.remaining(),
282 });
283 }
284 Ok(SqlValue::Int(buf.get_i32_le()))
285}
286
287fn decode_bigint(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
288 if buf.remaining() < 8 {
289 return Err(TypeError::BufferTooSmall {
290 needed: 8,
291 available: buf.remaining(),
292 });
293 }
294 Ok(SqlValue::BigInt(buf.get_i64_le()))
295}
296
297fn decode_float(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
298 if buf.remaining() < 4 {
299 return Err(TypeError::BufferTooSmall {
300 needed: 4,
301 available: buf.remaining(),
302 });
303 }
304 Ok(SqlValue::Float(buf.get_f32_le()))
305}
306
307fn decode_double(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
308 if buf.remaining() < 8 {
309 return Err(TypeError::BufferTooSmall {
310 needed: 8,
311 available: buf.remaining(),
312 });
313 }
314 Ok(SqlValue::Double(buf.get_f64_le()))
315}
316
317fn decode_intn(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
318 if buf.remaining() < 1 {
319 return Err(TypeError::BufferTooSmall {
320 needed: 1,
321 available: buf.remaining(),
322 });
323 }
324
325 let actual_len = buf.get_u8() as usize;
326 if actual_len == 0 {
327 return Ok(SqlValue::Null);
328 }
329
330 if buf.remaining() < actual_len {
331 return Err(TypeError::BufferTooSmall {
332 needed: actual_len,
333 available: buf.remaining(),
334 });
335 }
336
337 match actual_len {
338 1 => Ok(SqlValue::TinyInt(buf.get_u8())),
339 2 => Ok(SqlValue::SmallInt(buf.get_i16_le())),
340 4 => Ok(SqlValue::Int(buf.get_i32_le())),
341 8 => Ok(SqlValue::BigInt(buf.get_i64_le())),
342 _ => Err(TypeError::InvalidBinary(format!(
343 "invalid INTN length: {actual_len}"
344 ))),
345 }
346}
347
348fn decode_nvarchar(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
349 if buf.remaining() < 2 {
350 return Err(TypeError::BufferTooSmall {
351 needed: 2,
352 available: buf.remaining(),
353 });
354 }
355
356 let byte_len = buf.get_u16_le() as usize;
357
358 if byte_len == 0xFFFF {
360 return Ok(SqlValue::Null);
361 }
362
363 if buf.remaining() < byte_len {
364 return Err(TypeError::BufferTooSmall {
365 needed: byte_len,
366 available: buf.remaining(),
367 });
368 }
369
370 let utf16_data = buf.copy_to_bytes(byte_len);
371 let s = decode_utf16_string(&utf16_data)?;
372 Ok(SqlValue::String(s))
373}
374
375fn decode_varchar(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
376 if buf.remaining() < 2 {
377 return Err(TypeError::BufferTooSmall {
378 needed: 2,
379 available: buf.remaining(),
380 });
381 }
382
383 let byte_len = buf.get_u16_le() as usize;
384
385 if byte_len == 0xFFFF {
387 return Ok(SqlValue::Null);
388 }
389
390 if buf.remaining() < byte_len {
391 return Err(TypeError::BufferTooSmall {
392 needed: byte_len,
393 available: buf.remaining(),
394 });
395 }
396
397 let data = buf.copy_to_bytes(byte_len);
398
399 if let Ok(s) = String::from_utf8(data.to_vec()) {
401 return Ok(SqlValue::String(s));
402 }
403
404 #[cfg(feature = "encoding")]
406 if let Some(ref collation) = type_info.collation {
407 if let Some(encoding) = collation.encoding() {
408 let (decoded, _, had_errors) = encoding.decode(&data);
409 if !had_errors {
410 return Ok(SqlValue::String(decoded.into_owned()));
411 }
412 }
413 }
414
415 #[cfg(not(feature = "encoding"))]
417 let _ = type_info;
418
419 Ok(SqlValue::String(
421 String::from_utf8_lossy(&data).into_owned(),
422 ))
423}
424
425fn decode_varbinary(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
426 if buf.remaining() < 2 {
427 return Err(TypeError::BufferTooSmall {
428 needed: 2,
429 available: buf.remaining(),
430 });
431 }
432
433 let byte_len = buf.get_u16_le() as usize;
434
435 if byte_len == 0xFFFF {
437 return Ok(SqlValue::Null);
438 }
439
440 if buf.remaining() < byte_len {
441 return Err(TypeError::BufferTooSmall {
442 needed: byte_len,
443 available: buf.remaining(),
444 });
445 }
446
447 let data = buf.copy_to_bytes(byte_len);
448 Ok(SqlValue::Binary(data))
449}
450
451#[cfg(feature = "uuid")]
452fn decode_guid(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
453 if buf.remaining() < 1 {
454 return Err(TypeError::BufferTooSmall {
455 needed: 1,
456 available: buf.remaining(),
457 });
458 }
459
460 let len = buf.get_u8() as usize;
461 if len == 0 {
462 return Ok(SqlValue::Null);
463 }
464
465 if len != 16 {
466 return Err(TypeError::InvalidBinary(format!(
467 "invalid GUID length: {len}"
468 )));
469 }
470
471 if buf.remaining() < 16 {
472 return Err(TypeError::BufferTooSmall {
473 needed: 16,
474 available: buf.remaining(),
475 });
476 }
477
478 let mut bytes = [0u8; 16];
480
481 bytes[3] = buf.get_u8();
483 bytes[2] = buf.get_u8();
484 bytes[1] = buf.get_u8();
485 bytes[0] = buf.get_u8();
486
487 bytes[5] = buf.get_u8();
489 bytes[4] = buf.get_u8();
490
491 bytes[7] = buf.get_u8();
493 bytes[6] = buf.get_u8();
494
495 for byte in &mut bytes[8..16] {
497 *byte = buf.get_u8();
498 }
499
500 Ok(SqlValue::Uuid(uuid::Uuid::from_bytes(bytes)))
501}
502
503#[cfg(not(feature = "uuid"))]
504fn decode_guid(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
505 if buf.remaining() < 1 {
507 return Err(TypeError::BufferTooSmall {
508 needed: 1,
509 available: buf.remaining(),
510 });
511 }
512
513 let len = buf.get_u8() as usize;
514 if len == 0 {
515 return Ok(SqlValue::Null);
516 }
517
518 if buf.remaining() < len {
519 return Err(TypeError::BufferTooSmall {
520 needed: len,
521 available: buf.remaining(),
522 });
523 }
524
525 let data = buf.copy_to_bytes(len);
526 Ok(SqlValue::Binary(data))
527}
528
529#[cfg(feature = "decimal")]
530fn decode_decimal(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
531 use rust_decimal::Decimal;
532
533 if buf.remaining() < 1 {
534 return Err(TypeError::BufferTooSmall {
535 needed: 1,
536 available: buf.remaining(),
537 });
538 }
539
540 let len = buf.get_u8() as usize;
541 if len == 0 {
542 return Ok(SqlValue::Null);
543 }
544
545 if buf.remaining() < len {
546 return Err(TypeError::BufferTooSmall {
547 needed: len,
548 available: buf.remaining(),
549 });
550 }
551
552 let sign = buf.get_u8();
554 let remaining = len - 1;
555
556 let mut mantissa_bytes = [0u8; 16];
558 for byte in mantissa_bytes.iter_mut().take(remaining.min(16)) {
559 *byte = buf.get_u8();
560 }
561 for _ in 16..remaining {
564 buf.get_u8();
565 }
566
567 let mantissa = u128::from_le_bytes(mantissa_bytes);
568 let scale = type_info.scale.unwrap_or(0) as u32;
569
570 let decimal = i128::try_from(mantissa)
577 .ok()
578 .and_then(|m| Decimal::try_from_i128_with_scale(m, scale).ok());
579 match decimal {
580 Some(mut decimal) => {
581 if sign == 0 {
582 decimal.set_sign_negative(true);
583 }
584 Ok(SqlValue::Decimal(decimal))
585 }
586 None => Err(TypeError::InvalidDecimal(format!(
587 "NUMERIC value (mantissa {mantissa}, scale {scale}) exceeds \
588 rust_decimal's 96-bit/scale-28 range; CAST the column to a \
589 narrower NUMERIC, FLOAT, or VARCHAR in the query"
590 ))),
591 }
592}
593
594#[cfg(not(feature = "decimal"))]
595fn decode_decimal(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
596 if buf.remaining() < 1 {
598 return Err(TypeError::BufferTooSmall {
599 needed: 1,
600 available: buf.remaining(),
601 });
602 }
603
604 let len = buf.get_u8() as usize;
605 if len == 0 {
606 return Ok(SqlValue::Null);
607 }
608
609 if buf.remaining() < len {
610 return Err(TypeError::BufferTooSmall {
611 needed: len,
612 available: buf.remaining(),
613 });
614 }
615
616 buf.advance(len);
617 Ok(SqlValue::String("DECIMAL (feature disabled)".to_string()))
618}
619
620#[cfg(feature = "chrono")]
621fn decode_date(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
622 if buf.remaining() < 1 {
623 return Err(TypeError::BufferTooSmall {
624 needed: 1,
625 available: buf.remaining(),
626 });
627 }
628
629 let len = buf.get_u8() as usize;
630 if len == 0 {
631 return Ok(SqlValue::Null);
632 }
633
634 if len != 3 {
635 return Err(TypeError::InvalidDateTime(format!(
636 "invalid DATE length: {len}"
637 )));
638 }
639
640 if buf.remaining() < 3 {
641 return Err(TypeError::BufferTooSmall {
642 needed: 3,
643 available: buf.remaining(),
644 });
645 }
646
647 let days = buf.get_u8() as u32 | ((buf.get_u8() as u32) << 8) | ((buf.get_u8() as u32) << 16);
649
650 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
651 let date = base + chrono::Duration::days(days as i64);
652
653 Ok(SqlValue::Date(date))
654}
655
656#[cfg(not(feature = "chrono"))]
657fn decode_date(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
658 if buf.remaining() < 1 {
659 return Err(TypeError::BufferTooSmall {
660 needed: 1,
661 available: buf.remaining(),
662 });
663 }
664
665 let len = buf.get_u8() as usize;
666 if len == 0 {
667 return Ok(SqlValue::Null);
668 }
669
670 if buf.remaining() < len {
671 return Err(TypeError::BufferTooSmall {
672 needed: len,
673 available: buf.remaining(),
674 });
675 }
676
677 buf.advance(len);
678 Ok(SqlValue::String("DATE (feature disabled)".to_string()))
679}
680
681#[cfg(feature = "chrono")]
682fn decode_time(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
683 let scale = type_info.scale.unwrap_or(7);
684 let time_len = time_bytes_for_scale(scale);
685
686 if buf.remaining() < 1 {
687 return Err(TypeError::BufferTooSmall {
688 needed: 1,
689 available: buf.remaining(),
690 });
691 }
692
693 let len = buf.get_u8() as usize;
694 if len == 0 {
695 return Ok(SqlValue::Null);
696 }
697
698 if buf.remaining() < len {
699 return Err(TypeError::BufferTooSmall {
700 needed: len,
701 available: buf.remaining(),
702 });
703 }
704 if len < time_len {
707 return Err(TypeError::InvalidDateTime(format!(
708 "TIME length {len} too short for scale {scale}"
709 )));
710 }
711
712 let mut time_bytes = [0u8; 8];
714 for byte in time_bytes.iter_mut().take(time_len) {
715 *byte = buf.get_u8();
716 }
717
718 let intervals = u64::from_le_bytes(time_bytes);
719 let time = intervals_to_time(intervals, scale);
720
721 Ok(SqlValue::Time(time))
722}
723
724#[cfg(not(feature = "chrono"))]
725fn decode_time(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
726 if buf.remaining() < 1 {
727 return Err(TypeError::BufferTooSmall {
728 needed: 1,
729 available: buf.remaining(),
730 });
731 }
732
733 let len = buf.get_u8() as usize;
734 if len == 0 {
735 return Ok(SqlValue::Null);
736 }
737
738 if buf.remaining() < len {
739 return Err(TypeError::BufferTooSmall {
740 needed: len,
741 available: buf.remaining(),
742 });
743 }
744
745 buf.advance(len);
746 Ok(SqlValue::String("TIME (feature disabled)".to_string()))
747}
748
749#[cfg(feature = "chrono")]
750fn decode_datetime2(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
751 let scale = type_info.scale.unwrap_or(7);
752 let time_len = time_bytes_for_scale(scale);
753
754 if buf.remaining() < 1 {
755 return Err(TypeError::BufferTooSmall {
756 needed: 1,
757 available: buf.remaining(),
758 });
759 }
760
761 let len = buf.get_u8() as usize;
762 if len == 0 {
763 return Ok(SqlValue::Null);
764 }
765
766 if buf.remaining() < len {
767 return Err(TypeError::BufferTooSmall {
768 needed: len,
769 available: buf.remaining(),
770 });
771 }
772 if len < time_len + 3 {
775 return Err(TypeError::InvalidDateTime(format!(
776 "DATETIME2 length {len} too short for scale {scale}"
777 )));
778 }
779
780 let mut time_bytes = [0u8; 8];
782 for byte in time_bytes.iter_mut().take(time_len) {
783 *byte = buf.get_u8();
784 }
785 let intervals = u64::from_le_bytes(time_bytes);
786 let time = intervals_to_time(intervals, scale);
787
788 let days = buf.get_u8() as u32 | ((buf.get_u8() as u32) << 8) | ((buf.get_u8() as u32) << 16);
790 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
791 let date = base + chrono::Duration::days(days as i64);
792
793 Ok(SqlValue::DateTime(date.and_time(time)))
794}
795
796#[cfg(not(feature = "chrono"))]
797fn decode_datetime2(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
798 if buf.remaining() < 1 {
799 return Err(TypeError::BufferTooSmall {
800 needed: 1,
801 available: buf.remaining(),
802 });
803 }
804
805 let len = buf.get_u8() as usize;
806 if len == 0 {
807 return Ok(SqlValue::Null);
808 }
809
810 if buf.remaining() < len {
811 return Err(TypeError::BufferTooSmall {
812 needed: len,
813 available: buf.remaining(),
814 });
815 }
816
817 buf.advance(len);
818 Ok(SqlValue::String("DATETIME2 (feature disabled)".to_string()))
819}
820
821#[cfg(feature = "chrono")]
822fn decode_datetimeoffset(buf: &mut Bytes, type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
823 use chrono::TimeZone;
824
825 let scale = type_info.scale.unwrap_or(7);
826 let time_len = time_bytes_for_scale(scale);
827
828 if buf.remaining() < 1 {
829 return Err(TypeError::BufferTooSmall {
830 needed: 1,
831 available: buf.remaining(),
832 });
833 }
834
835 let len = buf.get_u8() as usize;
836 if len == 0 {
837 return Ok(SqlValue::Null);
838 }
839
840 if buf.remaining() < len {
841 return Err(TypeError::BufferTooSmall {
842 needed: len,
843 available: buf.remaining(),
844 });
845 }
846 if len < time_len + 5 {
849 return Err(TypeError::InvalidDateTime(format!(
850 "DATETIMEOFFSET length {len} too short for scale {scale}"
851 )));
852 }
853
854 let mut time_bytes = [0u8; 8];
856 for byte in time_bytes.iter_mut().take(time_len) {
857 *byte = buf.get_u8();
858 }
859 let intervals = u64::from_le_bytes(time_bytes);
860 let time = intervals_to_time(intervals, scale);
861
862 let days = buf.get_u8() as u32 | ((buf.get_u8() as u32) << 8) | ((buf.get_u8() as u32) << 16);
864 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("valid date");
865 let date = base + chrono::Duration::days(days as i64);
866
867 let offset_minutes = buf.get_i16_le();
869 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
870 .ok_or_else(|| TypeError::InvalidDateTime(format!("invalid offset: {offset_minutes}")))?;
871
872 let datetime = offset.from_utc_datetime(&date.and_time(time));
875
876 Ok(SqlValue::DateTimeOffset(datetime))
877}
878
879#[cfg(not(feature = "chrono"))]
880fn decode_datetimeoffset(buf: &mut Bytes, _type_info: &TypeInfo) -> Result<SqlValue, TypeError> {
881 if buf.remaining() < 1 {
882 return Err(TypeError::BufferTooSmall {
883 needed: 1,
884 available: buf.remaining(),
885 });
886 }
887
888 let len = buf.get_u8() as usize;
889 if len == 0 {
890 return Ok(SqlValue::Null);
891 }
892
893 if buf.remaining() < len {
894 return Err(TypeError::BufferTooSmall {
895 needed: len,
896 available: buf.remaining(),
897 });
898 }
899
900 buf.advance(len);
901 Ok(SqlValue::String(
902 "DATETIMEOFFSET (feature disabled)".to_string(),
903 ))
904}
905
906#[cfg(feature = "chrono")]
907fn decode_datetime(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
908 if buf.remaining() < 8 {
910 return Err(TypeError::BufferTooSmall {
911 needed: 8,
912 available: buf.remaining(),
913 });
914 }
915
916 let days = buf.get_i32_le();
917 let time_300ths = buf.get_u32_le();
918
919 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("valid date");
920 let date = base
922 .checked_add_signed(chrono::Duration::days(days as i64))
923 .ok_or_else(|| TypeError::InvalidDateTime(format!("DATETIME days out of range: {days}")))?;
924
925 let total_ms = (time_300ths as u64 * 1000) / 300;
927 let secs = (total_ms / 1000) as u32;
928 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
929
930 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos)
931 .ok_or_else(|| TypeError::InvalidDateTime("invalid DATETIME time".to_string()))?;
932
933 Ok(SqlValue::DateTime(date.and_time(time)))
934}
935
936#[cfg(not(feature = "chrono"))]
937fn decode_datetime(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
938 if buf.remaining() < 8 {
939 return Err(TypeError::BufferTooSmall {
940 needed: 8,
941 available: buf.remaining(),
942 });
943 }
944
945 buf.advance(8);
946 Ok(SqlValue::String("DATETIME (feature disabled)".to_string()))
947}
948
949#[cfg(feature = "chrono")]
950fn decode_smalldatetime(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
951 if buf.remaining() < 4 {
953 return Err(TypeError::BufferTooSmall {
954 needed: 4,
955 available: buf.remaining(),
956 });
957 }
958
959 let days = buf.get_u16_le();
960 let minutes = buf.get_u16_le();
961
962 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("valid date");
963 let date = base + chrono::Duration::days(days as i64);
964
965 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt((minutes as u32) * 60, 0)
966 .ok_or_else(|| TypeError::InvalidDateTime("invalid SMALLDATETIME time".to_string()))?;
967
968 Ok(SqlValue::DateTime(date.and_time(time)))
969}
970
971#[cfg(not(feature = "chrono"))]
972fn decode_smalldatetime(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
973 if buf.remaining() < 4 {
974 return Err(TypeError::BufferTooSmall {
975 needed: 4,
976 available: buf.remaining(),
977 });
978 }
979
980 buf.advance(4);
981 Ok(SqlValue::String(
982 "SMALLDATETIME (feature disabled)".to_string(),
983 ))
984}
985
986fn decode_xml(buf: &mut Bytes) -> Result<SqlValue, TypeError> {
987 if buf.remaining() < 2 {
989 return Err(TypeError::BufferTooSmall {
990 needed: 2,
991 available: buf.remaining(),
992 });
993 }
994
995 let byte_len = buf.get_u16_le() as usize;
996
997 if byte_len == 0xFFFF {
998 return Ok(SqlValue::Null);
999 }
1000
1001 if buf.remaining() < byte_len {
1002 return Err(TypeError::BufferTooSmall {
1003 needed: byte_len,
1004 available: buf.remaining(),
1005 });
1006 }
1007
1008 let utf16_data = buf.copy_to_bytes(byte_len);
1009 let s = decode_utf16_string(&utf16_data)?;
1010 Ok(SqlValue::Xml(s))
1011}
1012
1013pub fn decode_utf16_string(data: &[u8]) -> Result<String, TypeError> {
1015 if data.len() % 2 != 0 {
1016 return Err(TypeError::InvalidEncoding(
1017 "UTF-16 data must have even length".to_string(),
1018 ));
1019 }
1020
1021 let utf16: Vec<u16> = data
1022 .chunks_exact(2)
1023 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
1024 .collect();
1025
1026 String::from_utf16(&utf16).map_err(|e| TypeError::InvalidEncoding(e.to_string()))
1027}
1028
1029#[cfg(feature = "chrono")]
1031fn time_bytes_for_scale(scale: u8) -> usize {
1032 match scale {
1033 0..=2 => 3,
1034 3..=4 => 4,
1035 5..=7 => 5,
1036 _ => 5, }
1038}
1039
1040#[cfg(feature = "chrono")]
1042fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
1043 let nanos = match scale {
1057 0 => intervals.saturating_mul(1_000_000_000),
1058 1 => intervals.saturating_mul(100_000_000),
1059 2 => intervals.saturating_mul(10_000_000),
1060 3 => intervals.saturating_mul(1_000_000),
1061 4 => intervals.saturating_mul(100_000),
1062 5 => intervals.saturating_mul(10_000),
1063 6 => intervals.saturating_mul(1_000),
1064 7 => intervals.saturating_mul(100),
1065 _ => intervals.saturating_mul(100),
1066 };
1067
1068 let secs = (nanos / 1_000_000_000) as u32;
1069 let nano_part = (nanos % 1_000_000_000) as u32;
1070
1071 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
1072 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).expect("valid time"))
1073}
1074
1075#[cfg(test)]
1076#[allow(clippy::unwrap_used, clippy::panic)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_decode_int() {
1082 let mut buf = Bytes::from_static(&[42, 0, 0, 0]);
1083 let type_info = TypeInfo::int(0x38);
1084 let result = decode_value(&mut buf, &type_info).unwrap();
1085 assert_eq!(result, SqlValue::Int(42));
1086 }
1087
1088 #[cfg(feature = "uuid")]
1093 #[test]
1094 fn test_guid_encode_decode_roundtrip() {
1095 use crate::encode::encode_uuid;
1096 use bytes::BufMut;
1097
1098 let original = uuid::Uuid::parse_str("00112233-4455-6677-8899-aabbccddeeff").unwrap();
1099 let mut encoded = bytes::BytesMut::new();
1100 encode_uuid(original, &mut encoded);
1101 assert_eq!(encoded.len(), 16);
1102
1103 let mut framed = bytes::BytesMut::new();
1105 framed.put_u8(16);
1106 framed.put_slice(&encoded);
1107 let decoded = decode_guid(&mut framed.freeze()).unwrap();
1108 assert_eq!(decoded, SqlValue::Uuid(original));
1109 }
1110
1111 #[cfg(feature = "chrono")]
1112 #[test]
1113 fn hostile_datetime_days_overflow_is_error_not_panic() {
1114 let mut data = Vec::new();
1117 data.extend_from_slice(&i32::MAX.to_le_bytes());
1118 data.extend_from_slice(&0u32.to_le_bytes());
1119 let mut buf = Bytes::from(data);
1120 assert!(decode_datetime(&mut buf).is_err());
1121 }
1122
1123 #[cfg(feature = "chrono")]
1130 #[test]
1131 fn test_datetimeoffset_decodes_wire_as_utc() {
1132 use chrono::TimeZone;
1133
1134 let mut data = Vec::new();
1135 data.push(10u8); let intervals: u64 = 10 * 3600 * 10_000_000; for i in 0..5 {
1138 data.push(((intervals >> (8 * i)) & 0xFF) as u8);
1139 }
1140 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1141 let days = (chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap() - base).num_days() as u32;
1142 data.push((days & 0xFF) as u8);
1143 data.push(((days >> 8) & 0xFF) as u8);
1144 data.push(((days >> 16) & 0xFF) as u8);
1145 data.extend_from_slice(&120i16.to_le_bytes()); let type_info = TypeInfo {
1148 type_id: 0x2B,
1149 length: None,
1150 scale: Some(7),
1151 precision: None,
1152 collation: None,
1153 };
1154 let mut buf = Bytes::from(data);
1155 let value = decode_datetimeoffset(&mut buf, &type_info).unwrap();
1156
1157 let offset = chrono::FixedOffset::east_opt(2 * 3600).unwrap();
1158 let expected = offset.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap();
1159 match value {
1160 SqlValue::DateTimeOffset(dt) => {
1161 assert_eq!(dt, expected);
1162 assert_eq!(dt.offset().local_minus_utc(), 7200);
1163 assert_eq!(
1164 dt.naive_utc(),
1165 chrono::NaiveDate::from_ymd_opt(2024, 3, 15)
1166 .unwrap()
1167 .and_hms_opt(10, 0, 0)
1168 .unwrap()
1169 );
1170 }
1171 other => panic!("expected DateTimeOffset, got {other:?}"),
1172 }
1173 }
1174
1175 #[test]
1176 fn test_decode_utf16_string() {
1177 let data = [0x41, 0x00, 0x42, 0x00];
1179 let result = decode_utf16_string(&data).unwrap();
1180 assert_eq!(result, "AB");
1181 }
1182
1183 #[test]
1184 fn test_decode_nvarchar() {
1185 let mut buf = Bytes::from_static(&[4, 0, 0x41, 0x00, 0x42, 0x00]);
1187 let type_info = TypeInfo::varchar(100);
1188 let type_info = TypeInfo {
1189 type_id: 0xE7,
1190 ..type_info
1191 };
1192 let result = decode_value(&mut buf, &type_info).unwrap();
1193 assert_eq!(result, SqlValue::String("AB".to_string()));
1194 }
1195
1196 #[test]
1197 fn test_decode_null_nvarchar() {
1198 let mut buf = Bytes::from_static(&[0xFF, 0xFF]);
1200 let type_info = TypeInfo {
1201 type_id: 0xE7,
1202 length: Some(100),
1203 scale: None,
1204 precision: None,
1205 collation: None,
1206 };
1207 let result = decode_value(&mut buf, &type_info).unwrap();
1208 assert_eq!(result, SqlValue::Null);
1209 }
1210
1211 #[cfg(feature = "decimal")]
1216 mod decimal_roundtrip {
1217 use super::*;
1218 use bytes::{BufMut, BytesMut};
1219 use rust_decimal::Decimal;
1220
1221 fn roundtrip_decimal(value: Decimal, precision: u8, scale: u8) -> Decimal {
1223 let mut encode_buf = BytesMut::new();
1225 crate::encode::encode_decimal(value, &mut encode_buf);
1226 let encoded_len = encode_buf.len() as u8; let mut decode_buf = BytesMut::with_capacity(1 + encoded_len as usize);
1230 decode_buf.put_u8(encoded_len);
1231 decode_buf.extend_from_slice(&encode_buf);
1232
1233 let mut bytes = decode_buf.freeze();
1234 let type_info = TypeInfo::decimal(precision, scale);
1235 match decode_value(&mut bytes, &type_info).unwrap() {
1236 SqlValue::Decimal(d) => d,
1237 other => panic!("expected Decimal, got {other:?}"),
1238 }
1239 }
1240
1241 #[test]
1242 fn test_negative_decimal_17_80() {
1243 let d = Decimal::new(-1780, 2); let result = roundtrip_decimal(d, 18, 2);
1245 assert_eq!(result, d, "round-trip of -17.80 must be exact");
1246 }
1247
1248 #[test]
1249 fn test_negative_decimal_0_01() {
1250 let d = Decimal::new(-1, 2); let result = roundtrip_decimal(d, 18, 2);
1252 assert_eq!(result, d, "round-trip of -0.01 must be exact");
1253 }
1254
1255 #[test]
1256 fn test_negative_decimal_large() {
1257 let d = Decimal::new(-9999999999, 2); let result = roundtrip_decimal(d, 18, 2);
1259 assert_eq!(result, d, "round-trip of -99999999.99 must be exact");
1260 }
1261
1262 #[test]
1263 fn test_positive_decimal() {
1264 let d = Decimal::new(1780, 2); let result = roundtrip_decimal(d, 18, 2);
1266 assert_eq!(result, d, "round-trip of 17.80 must be exact");
1267 }
1268
1269 #[test]
1270 fn test_decimal_zero() {
1271 let d = Decimal::ZERO;
1272 let result = roundtrip_decimal(d, 18, 0);
1273 assert_eq!(result, d, "round-trip of 0 must be exact");
1274 }
1275
1276 #[test]
1277 fn test_decimal_max_precision() {
1278 let d = Decimal::new(i64::MAX, 0);
1280 let result = roundtrip_decimal(d, 38, 0);
1281 assert_eq!(result, d, "round-trip of large positive must be exact");
1282 }
1283
1284 #[test]
1285 fn test_decimal_min_precision() {
1286 let d = Decimal::new(i64::MIN + 1, 0);
1287 let result = roundtrip_decimal(d, 38, 0);
1288 assert_eq!(result, d, "round-trip of large negative must be exact");
1289 }
1290
1291 #[test]
1295 fn test_decimal_out_of_range_errors_instead_of_f64() {
1296 let mut buf = BytesMut::new();
1298 buf.put_u8(17); buf.put_u8(1); buf.extend_from_slice(&[0xFF; 16]);
1301
1302 let mut bytes = buf.freeze();
1303 let type_info = TypeInfo::decimal(38, 0);
1304 match decode_value(&mut bytes, &type_info) {
1305 Err(TypeError::InvalidDecimal(_)) => {}
1306 other => panic!("expected InvalidDecimal error, got {other:?}"),
1307 }
1308 }
1309
1310 #[test]
1313 fn test_decimal_oversized_mantissa_keeps_frame_aligned() {
1314 let mut buf = BytesMut::new();
1315 buf.put_u8(18); buf.put_u8(1); buf.put_u8(42); buf.extend_from_slice(&[0u8; 16]); buf.put_u8(0xAB); let mut bytes = buf.freeze();
1322 let type_info = TypeInfo::decimal(38, 0);
1323 let value = decode_value(&mut bytes, &type_info).unwrap();
1324 assert_eq!(value, SqlValue::Decimal(Decimal::new(42, 0)));
1325 assert_eq!(
1326 bytes.remaining(),
1327 1,
1328 "excess mantissa bytes must be consumed, leaving the sentinel"
1329 );
1330 assert_eq!(bytes.get_u8(), 0xAB);
1331 }
1332 }
1333
1334 #[cfg(feature = "chrono")]
1339 mod date_tests {
1340 use bytes::{BufMut, BytesMut};
1341 use chrono::NaiveDate;
1342
1343 #[test]
1344 fn test_encode_date_pre_1900() {
1345 let mut buf = BytesMut::new();
1347 let date = NaiveDate::from_ymd_opt(1753, 1, 1).unwrap();
1348 crate::encode::encode_date(date, &mut buf).unwrap();
1349 assert_eq!(buf.len(), 3, "DATE encoding is always 3 bytes");
1350 }
1351
1352 #[test]
1353 fn test_encode_date_epoch() {
1354 let mut buf = BytesMut::new();
1356 let date = NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1357 crate::encode::encode_date(date, &mut buf).unwrap();
1358 assert_eq!(&buf[..], &[0, 0, 0]);
1360 }
1361
1362 #[test]
1363 fn test_encode_date_max() {
1364 let mut buf = BytesMut::new();
1366 let date = NaiveDate::from_ymd_opt(9999, 12, 31).unwrap();
1367 crate::encode::encode_date(date, &mut buf).unwrap();
1368 assert_eq!(buf.len(), 3, "DATE encoding is always 3 bytes");
1369 let days = buf[0] as u32 | ((buf[1] as u32) << 8) | ((buf[2] as u32) << 16);
1371 assert_eq!(days, 3_652_058);
1372 }
1373
1374 #[test]
1375 fn test_decode_datetime_pre_1900() {
1376 use super::*;
1379
1380 let base = NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1381 let target = NaiveDate::from_ymd_opt(1753, 1, 1).unwrap();
1382 let days = target.signed_duration_since(base).num_days() as i32;
1383
1384 let mut raw = BytesMut::new();
1386 raw.put_i32_le(days);
1387 raw.put_u32_le(0); let mut buf = raw.freeze();
1390 let result = decode_datetime(&mut buf).unwrap();
1391
1392 match result {
1393 SqlValue::DateTime(dt) => {
1394 assert_eq!(dt.date(), target);
1395 }
1396 other => panic!("expected DateTime, got {other:?}"),
1397 }
1398 }
1399
1400 #[test]
1401 fn test_decode_smalldatetime_1900() {
1402 use super::*;
1404
1405 let mut raw = BytesMut::new();
1407 raw.put_u16_le(0);
1408 raw.put_u16_le(0);
1409
1410 let mut buf = raw.freeze();
1411 let result = decode_smalldatetime(&mut buf).unwrap();
1412
1413 match result {
1414 SqlValue::DateTime(dt) => {
1415 assert_eq!(
1416 dt,
1417 NaiveDate::from_ymd_opt(1900, 1, 1)
1418 .unwrap()
1419 .and_hms_opt(0, 0, 0)
1420 .unwrap()
1421 );
1422 }
1423 other => panic!("expected DateTime, got {other:?}"),
1424 }
1425 }
1426 }
1427
1428 #[cfg(feature = "decimal")]
1433 mod proptest_decimal {
1434 use super::*;
1435 use bytes::{BufMut, BytesMut};
1436 use proptest::prelude::*;
1437 use rust_decimal::Decimal;
1438
1439 fn roundtrip_decimal(value: Decimal, scale: u8) -> Decimal {
1441 let mut encode_buf = BytesMut::new();
1442 crate::encode::encode_decimal(value, &mut encode_buf);
1443 let encoded_len = encode_buf.len() as u8;
1444
1445 let mut decode_buf = BytesMut::with_capacity(1 + encoded_len as usize);
1446 decode_buf.put_u8(encoded_len);
1447 decode_buf.extend_from_slice(&encode_buf);
1448
1449 let mut bytes = decode_buf.freeze();
1450 let type_info = TypeInfo::decimal(38, scale);
1451 match decode_value(&mut bytes, &type_info).unwrap() {
1452 SqlValue::Decimal(d) => d,
1453 other => panic!("expected Decimal, got {other:?}"),
1454 }
1455 }
1456
1457 proptest! {
1458 #[test]
1459 fn decimal_roundtrip_scale0(mantissa in -999_999_999_999i64..=999_999_999_999i64) {
1460 let d = Decimal::new(mantissa, 0);
1461 let result = roundtrip_decimal(d, 0);
1462 prop_assert_eq!(result, d);
1463 }
1464
1465 #[test]
1466 fn decimal_roundtrip_scale2(mantissa in -999_999_999_999i64..=999_999_999_999i64) {
1467 let d = Decimal::new(mantissa, 2);
1468 let result = roundtrip_decimal(d, 2);
1469 prop_assert_eq!(result, d);
1470 }
1471
1472 #[test]
1473 fn decimal_roundtrip_various_scales(
1474 mantissa in -999_999_999i64..=999_999_999i64,
1475 scale in 0u8..=10u8,
1476 ) {
1477 let d = Decimal::new(mantissa, scale as u32);
1478 let result = roundtrip_decimal(d, scale);
1479 prop_assert_eq!(result, d);
1480 }
1481 }
1482 }
1483
1484 #[cfg(feature = "chrono")]
1485 mod proptest_date {
1486 use bytes::BytesMut;
1487 use chrono::NaiveDate;
1488 use proptest::prelude::*;
1489
1490 proptest! {
1491 #[test]
1492 fn date_encode_never_panics(
1493 year in 1i32..=9999i32,
1494 month in 1u32..=12u32,
1495 day in 1u32..=28u32, ) {
1497 let date = NaiveDate::from_ymd_opt(year, month, day).unwrap();
1498 let mut buf = BytesMut::new();
1499 crate::encode::encode_date(date, &mut buf).unwrap();
1500 prop_assert_eq!(buf.len(), 3);
1501 }
1502 }
1503 }
1504}