1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37#[non_exhaustive]
38pub enum ProcId {
39 Cursor = 0x0001,
41 CursorOpen = 0x0002,
43 CursorPrepare = 0x0003,
45 CursorExecute = 0x0004,
47 CursorPrepExec = 0x0005,
49 CursorUnprepare = 0x0006,
51 CursorFetch = 0x0007,
53 CursorOption = 0x0008,
55 CursorClose = 0x0009,
57 ExecuteSql = 0x000A,
59 Prepare = 0x000B,
61 Execute = 0x000C,
63 PrepExec = 0x000D,
65 PrepExecRpc = 0x000E,
67 Unprepare = 0x000F,
69}
70
71#[derive(Debug, Clone, Copy, Default)]
73pub struct RpcOptionFlags {
74 pub with_recompile: bool,
76 pub no_metadata: bool,
78 pub reuse_metadata: bool,
80}
81
82impl RpcOptionFlags {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 #[must_use]
90 pub fn with_recompile(mut self, value: bool) -> Self {
91 self.with_recompile = value;
92 self
93 }
94
95 pub fn encode(&self) -> u16 {
97 let mut flags = 0u16;
98 if self.with_recompile {
99 flags |= 0x0001;
100 }
101 if self.no_metadata {
102 flags |= 0x0002;
103 }
104 if self.reuse_metadata {
105 flags |= 0x0004;
106 }
107 flags
108 }
109}
110
111#[derive(Debug, Clone, Copy, Default)]
113pub struct ParamFlags {
114 pub by_ref: bool,
116 pub default: bool,
118 pub encrypted: bool,
120}
121
122impl ParamFlags {
123 pub fn new() -> Self {
125 Self::default()
126 }
127
128 #[must_use]
130 pub fn output(mut self) -> Self {
131 self.by_ref = true;
132 self
133 }
134
135 pub fn encode(&self) -> u8 {
137 let mut flags = 0u8;
138 if self.by_ref {
139 flags |= 0x01;
140 }
141 if self.default {
142 flags |= 0x02;
143 }
144 if self.encrypted {
145 flags |= 0x08;
146 }
147 flags
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct TypeInfo {
154 pub type_id: u8,
156 pub max_length: Option<u16>,
158 pub precision: Option<u8>,
160 pub scale: Option<u8>,
162 pub collation: Option<[u8; 5]>,
164 pub tvp_type_name: Option<String>,
166}
167
168impl TypeInfo {
169 pub fn int() -> Self {
171 Self {
172 type_id: 0x26, max_length: Some(4),
174 precision: None,
175 scale: None,
176 collation: None,
177 tvp_type_name: None,
178 }
179 }
180
181 pub fn bigint() -> Self {
183 Self {
184 type_id: 0x26, max_length: Some(8),
186 precision: None,
187 scale: None,
188 collation: None,
189 tvp_type_name: None,
190 }
191 }
192
193 pub fn smallint() -> Self {
195 Self {
196 type_id: 0x26, max_length: Some(2),
198 precision: None,
199 scale: None,
200 collation: None,
201 tvp_type_name: None,
202 }
203 }
204
205 pub fn tinyint() -> Self {
207 Self {
208 type_id: 0x26, max_length: Some(1),
210 precision: None,
211 scale: None,
212 collation: None,
213 tvp_type_name: None,
214 }
215 }
216
217 pub fn bit() -> Self {
219 Self {
220 type_id: 0x68, max_length: Some(1),
222 precision: None,
223 scale: None,
224 collation: None,
225 tvp_type_name: None,
226 }
227 }
228
229 pub fn float() -> Self {
231 Self {
232 type_id: 0x6D, max_length: Some(8),
234 precision: None,
235 scale: None,
236 collation: None,
237 tvp_type_name: None,
238 }
239 }
240
241 pub fn real() -> Self {
243 Self {
244 type_id: 0x6D, max_length: Some(4),
246 precision: None,
247 scale: None,
248 collation: None,
249 tvp_type_name: None,
250 }
251 }
252
253 pub fn nvarchar(max_len: u16) -> Self {
255 Self {
256 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
259 scale: None,
260 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262 tvp_type_name: None,
263 }
264 }
265
266 pub fn nvarchar_max() -> Self {
268 Self {
269 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
272 scale: None,
273 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
274 tvp_type_name: None,
275 }
276 }
277
278 pub fn varbinary(max_len: u16) -> Self {
280 Self {
281 type_id: 0xA5, max_length: Some(max_len),
283 precision: None,
284 scale: None,
285 collation: None,
286 tvp_type_name: None,
287 }
288 }
289
290 pub fn uniqueidentifier() -> Self {
292 Self {
293 type_id: 0x24, max_length: Some(16),
295 precision: None,
296 scale: None,
297 collation: None,
298 tvp_type_name: None,
299 }
300 }
301
302 pub fn date() -> Self {
304 Self {
305 type_id: 0x28, max_length: None,
307 precision: None,
308 scale: None,
309 collation: None,
310 tvp_type_name: None,
311 }
312 }
313
314 pub fn datetime2(scale: u8) -> Self {
316 Self {
317 type_id: 0x2A, max_length: None,
319 precision: None,
320 scale: Some(scale),
321 collation: None,
322 tvp_type_name: None,
323 }
324 }
325
326 pub fn decimal(precision: u8, scale: u8) -> Self {
328 Self {
329 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
332 scale: Some(scale),
333 collation: None,
334 tvp_type_name: None,
335 }
336 }
337
338 pub fn tvp(type_name: impl Into<String>) -> Self {
343 Self {
344 type_id: 0xF3, max_length: None,
346 precision: None,
347 scale: None,
348 collation: None,
349 tvp_type_name: Some(type_name.into()),
350 }
351 }
352
353 pub fn encode(&self, buf: &mut BytesMut) {
355 if self.type_id != 0xF3 {
358 buf.put_u8(self.type_id);
359 }
360
361 match self.type_id {
363 0x26 | 0x68 | 0x6D => {
364 if let Some(len) = self.max_length {
366 buf.put_u8(len as u8);
367 }
368 }
369 0xE7 | 0xA5 | 0xEF => {
370 if let Some(len) = self.max_length {
372 buf.put_u16_le(len);
373 }
374 if let Some(collation) = self.collation {
376 buf.put_slice(&collation);
377 }
378 }
379 0x24 => {
380 if let Some(len) = self.max_length {
382 buf.put_u8(len as u8);
383 }
384 }
385 0x29..=0x2B => {
386 if let Some(scale) = self.scale {
388 buf.put_u8(scale);
389 }
390 }
391 0x6C | 0x6A => {
392 if let Some(len) = self.max_length {
394 buf.put_u8(len as u8);
395 }
396 if let Some(precision) = self.precision {
397 buf.put_u8(precision);
398 }
399 if let Some(scale) = self.scale {
400 buf.put_u8(scale);
401 }
402 }
403 _ => {}
404 }
405 }
406}
407
408#[derive(Debug, Clone)]
410pub struct RpcParam {
411 pub name: String,
413 pub flags: ParamFlags,
415 pub type_info: TypeInfo,
417 pub value: Option<Bytes>,
419}
420
421impl RpcParam {
422 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
424 Self {
425 name: name.into(),
426 flags: ParamFlags::default(),
427 type_info,
428 value: Some(value),
429 }
430 }
431
432 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
434 Self {
435 name: name.into(),
436 flags: ParamFlags::default(),
437 type_info,
438 value: None,
439 }
440 }
441
442 pub fn int(name: impl Into<String>, value: i32) -> Self {
444 let mut buf = BytesMut::with_capacity(4);
445 buf.put_i32_le(value);
446 Self::new(name, TypeInfo::int(), buf.freeze())
447 }
448
449 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
451 let mut buf = BytesMut::with_capacity(8);
452 buf.put_i64_le(value);
453 Self::new(name, TypeInfo::bigint(), buf.freeze())
454 }
455
456 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
458 let mut buf = BytesMut::new();
459 for code_unit in value.encode_utf16() {
461 buf.put_u16_le(code_unit);
462 }
463 let char_len = value.chars().count();
464 let type_info = if char_len > 4000 {
465 TypeInfo::nvarchar_max()
466 } else {
467 TypeInfo::nvarchar(char_len.max(1) as u16)
468 };
469 Self::new(name, type_info, buf.freeze())
470 }
471
472 #[must_use]
474 pub fn as_output(mut self) -> Self {
475 self.flags = self.flags.output();
476 self
477 }
478
479 pub fn encode(&self, buf: &mut BytesMut) {
481 let name_len = self.name.encode_utf16().count() as u8;
483 buf.put_u8(name_len);
484 if name_len > 0 {
485 for code_unit in self.name.encode_utf16() {
486 buf.put_u16_le(code_unit);
487 }
488 }
489
490 buf.put_u8(self.flags.encode());
492
493 self.type_info.encode(buf);
495
496 if let Some(ref value) = self.value {
498 match self.type_info.type_id {
500 0x26 => {
501 buf.put_u8(value.len() as u8);
503 buf.put_slice(value);
504 }
505 0x68 | 0x6D => {
506 buf.put_u8(value.len() as u8);
508 buf.put_slice(value);
509 }
510 0xE7 | 0xA5 => {
511 if self.type_info.max_length == Some(0xFFFF) {
513 let total_len = value.len() as u64;
516 buf.put_u64_le(total_len);
517 buf.put_u32_le(value.len() as u32);
518 buf.put_slice(value);
519 buf.put_u32_le(0); } else {
521 buf.put_u16_le(value.len() as u16);
522 buf.put_slice(value);
523 }
524 }
525 0x24 => {
526 buf.put_u8(value.len() as u8);
528 buf.put_slice(value);
529 }
530 0x28 => {
531 buf.put_slice(value);
533 }
534 0x2A => {
535 buf.put_u8(value.len() as u8);
537 buf.put_slice(value);
538 }
539 0x6C => {
540 buf.put_u8(value.len() as u8);
542 buf.put_slice(value);
543 }
544 0xF3 => {
545 buf.put_slice(value);
549 }
550 _ => {
551 buf.put_u8(value.len() as u8);
553 buf.put_slice(value);
554 }
555 }
556 } else {
557 match self.type_info.type_id {
559 0xE7 | 0xA5 => {
560 if self.type_info.max_length == Some(0xFFFF) {
562 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
564 buf.put_u16_le(0xFFFF);
565 }
566 }
567 _ => {
568 buf.put_u8(0); }
570 }
571 }
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct RpcRequest {
578 proc_name: Option<String>,
580 proc_id: Option<ProcId>,
582 options: RpcOptionFlags,
584 params: Vec<RpcParam>,
586}
587
588impl RpcRequest {
589 pub fn named(proc_name: impl Into<String>) -> Self {
591 Self {
592 proc_name: Some(proc_name.into()),
593 proc_id: None,
594 options: RpcOptionFlags::default(),
595 params: Vec::new(),
596 }
597 }
598
599 pub fn by_id(proc_id: ProcId) -> Self {
601 Self {
602 proc_name: None,
603 proc_id: Some(proc_id),
604 options: RpcOptionFlags::default(),
605 params: Vec::new(),
606 }
607 }
608
609 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
627 let mut request = Self::by_id(ProcId::ExecuteSql);
628
629 request.params.push(RpcParam::nvarchar("", sql));
631
632 if !params.is_empty() {
634 let declarations = Self::build_param_declarations(¶ms);
635 request.params.push(RpcParam::nvarchar("", &declarations));
636 }
637
638 request.params.extend(params);
640
641 request
642 }
643
644 fn build_param_declarations(params: &[RpcParam]) -> String {
646 params
647 .iter()
648 .map(|p| {
649 let name = if p.name.starts_with('@') {
650 p.name.clone()
651 } else if p.name.is_empty() {
652 format!(
654 "@p{}",
655 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
656 )
657 } else {
658 format!("@{}", p.name)
659 };
660
661 let type_name: String = match p.type_info.type_id {
662 0x26 => match p.type_info.max_length {
663 Some(1) => "tinyint".to_string(),
664 Some(2) => "smallint".to_string(),
665 Some(4) => "int".to_string(),
666 Some(8) => "bigint".to_string(),
667 _ => "int".to_string(),
668 },
669 0x68 => "bit".to_string(),
670 0x6D => match p.type_info.max_length {
671 Some(4) => "real".to_string(),
672 _ => "float".to_string(),
673 },
674 0xE7 => {
675 if p.type_info.max_length == Some(0xFFFF) {
676 "nvarchar(max)".to_string()
677 } else {
678 let len = p.type_info.max_length.unwrap_or(4000) / 2;
679 format!("nvarchar({len})")
680 }
681 }
682 0xA5 => {
683 if p.type_info.max_length == Some(0xFFFF) {
684 "varbinary(max)".to_string()
685 } else {
686 let len = p.type_info.max_length.unwrap_or(8000);
687 format!("varbinary({len})")
688 }
689 }
690 0x24 => "uniqueidentifier".to_string(),
691 0x28 => "date".to_string(),
692 0x2A => {
693 let scale = p.type_info.scale.unwrap_or(7);
694 format!("datetime2({scale})")
695 }
696 0x6C => {
697 let precision = p.type_info.precision.unwrap_or(18);
698 let scale = p.type_info.scale.unwrap_or(0);
699 format!("decimal({precision}, {scale})")
700 }
701 0xF3 => {
702 if let Some(ref tvp_name) = p.type_info.tvp_type_name {
705 format!("{tvp_name} READONLY")
706 } else {
707 "sql_variant".to_string()
709 }
710 }
711 _ => "sql_variant".to_string(),
712 };
713
714 format!("{name} {type_name}")
715 })
716 .collect::<Vec<_>>()
717 .join(", ")
718 }
719
720 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
722 let mut request = Self::by_id(ProcId::Prepare);
723
724 request
726 .params
727 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
728
729 let declarations = Self::build_param_declarations(params);
731 request
732 .params
733 .push(RpcParam::nvarchar("@params", &declarations));
734
735 request.params.push(RpcParam::nvarchar("@stmt", sql));
737
738 request.params.push(RpcParam::int("@options", 1));
740
741 request
742 }
743
744 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
746 let mut request = Self::by_id(ProcId::Execute);
747
748 request.params.push(RpcParam::int("@handle", handle));
750
751 request.params.extend(params);
753
754 request
755 }
756
757 pub fn unprepare(handle: i32) -> Self {
759 let mut request = Self::by_id(ProcId::Unprepare);
760 request.params.push(RpcParam::int("@handle", handle));
761 request
762 }
763
764 #[must_use]
766 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
767 self.options = options;
768 self
769 }
770
771 #[must_use]
773 pub fn param(mut self, param: RpcParam) -> Self {
774 self.params.push(param);
775 self
776 }
777
778 #[must_use]
782 pub fn encode(&self) -> Bytes {
783 self.encode_with_transaction(0)
784 }
785
786 #[must_use]
798 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
799 let mut buf = BytesMut::with_capacity(256);
800
801 let all_headers_start = buf.len();
804 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
815 let len_bytes = (all_headers_len as u32).to_le_bytes();
816 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
817
818 if let Some(proc_id) = self.proc_id {
820 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
823 } else if let Some(ref proc_name) = self.proc_name {
824 let name_len = proc_name.encode_utf16().count() as u16;
826 buf.put_u16_le(name_len);
827 write_utf16_string(&mut buf, proc_name);
828 }
829
830 buf.put_u16_le(self.options.encode());
832
833 for param in &self.params {
835 param.encode(&mut buf);
836 }
837
838 buf.freeze()
839 }
840}
841
842#[cfg(test)]
843#[allow(clippy::unwrap_used)]
844mod tests {
845 use super::*;
846
847 #[test]
848 fn test_proc_id_values() {
849 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
850 assert_eq!(ProcId::Prepare as u16, 0x000B);
851 assert_eq!(ProcId::Execute as u16, 0x000C);
852 assert_eq!(ProcId::Unprepare as u16, 0x000F);
853 }
854
855 #[test]
856 fn test_option_flags_encode() {
857 let flags = RpcOptionFlags::new().with_recompile(true);
858 assert_eq!(flags.encode(), 0x0001);
859 }
860
861 #[test]
862 fn test_param_flags_encode() {
863 let flags = ParamFlags::new().output();
864 assert_eq!(flags.encode(), 0x01);
865 }
866
867 #[test]
868 fn test_int_param() {
869 let param = RpcParam::int("@p1", 42);
870 assert_eq!(param.name, "@p1");
871 assert_eq!(param.type_info.type_id, 0x26);
872 assert!(param.value.is_some());
873 }
874
875 #[test]
876 fn test_nvarchar_param() {
877 let param = RpcParam::nvarchar("@name", "Alice");
878 assert_eq!(param.name, "@name");
879 assert_eq!(param.type_info.type_id, 0xE7);
880 assert_eq!(param.value.as_ref().unwrap().len(), 10);
882 }
883
884 #[test]
885 fn test_execute_sql_request() {
886 let rpc = RpcRequest::execute_sql(
887 "SELECT * FROM users WHERE id = @p1",
888 vec![RpcParam::int("@p1", 42)],
889 );
890
891 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
892 assert_eq!(rpc.params.len(), 3);
894 }
895
896 #[test]
897 fn test_param_declarations() {
898 let params = vec![
899 RpcParam::int("@p1", 42),
900 RpcParam::nvarchar("@name", "Alice"),
901 ];
902
903 let decls = RpcRequest::build_param_declarations(¶ms);
904 assert!(decls.contains("@p1 int"));
905 assert!(decls.contains("@name nvarchar"));
906 }
907
908 #[test]
909 fn test_rpc_encode_not_empty() {
910 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
911 let encoded = rpc.encode();
912 assert!(!encoded.is_empty());
913 }
914
915 #[test]
916 fn test_prepare_request() {
917 let rpc = RpcRequest::prepare(
918 "SELECT * FROM users WHERE id = @p1",
919 &[RpcParam::int("@p1", 0)],
920 );
921
922 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
923 assert_eq!(rpc.params.len(), 4);
925 assert!(rpc.params[0].flags.by_ref); }
927
928 #[test]
929 fn test_execute_request() {
930 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
931
932 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
933 assert_eq!(rpc.params.len(), 2); }
935
936 #[test]
937 fn test_unprepare_request() {
938 let rpc = RpcRequest::unprepare(123);
939
940 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
941 assert_eq!(rpc.params.len(), 1); }
943}