1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37pub enum ProcId {
38 Cursor = 0x0001,
40 CursorOpen = 0x0002,
42 CursorPrepare = 0x0003,
44 CursorExecute = 0x0004,
46 CursorPrepExec = 0x0005,
48 CursorUnprepare = 0x0006,
50 CursorFetch = 0x0007,
52 CursorOption = 0x0008,
54 CursorClose = 0x0009,
56 ExecuteSql = 0x000A,
58 Prepare = 0x000B,
60 Execute = 0x000C,
62 PrepExec = 0x000D,
64 PrepExecRpc = 0x000E,
66 Unprepare = 0x000F,
68}
69
70#[derive(Debug, Clone, Copy, Default)]
72pub struct RpcOptionFlags {
73 pub with_recompile: bool,
75 pub no_metadata: bool,
77 pub reuse_metadata: bool,
79}
80
81impl RpcOptionFlags {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 #[must_use]
89 pub fn with_recompile(mut self, value: bool) -> Self {
90 self.with_recompile = value;
91 self
92 }
93
94 pub fn encode(&self) -> u16 {
96 let mut flags = 0u16;
97 if self.with_recompile {
98 flags |= 0x0001;
99 }
100 if self.no_metadata {
101 flags |= 0x0002;
102 }
103 if self.reuse_metadata {
104 flags |= 0x0004;
105 }
106 flags
107 }
108}
109
110#[derive(Debug, Clone, Copy, Default)]
112pub struct ParamFlags {
113 pub by_ref: bool,
115 pub default: bool,
117 pub encrypted: bool,
119}
120
121impl ParamFlags {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 #[must_use]
129 pub fn output(mut self) -> Self {
130 self.by_ref = true;
131 self
132 }
133
134 pub fn encode(&self) -> u8 {
136 let mut flags = 0u8;
137 if self.by_ref {
138 flags |= 0x01;
139 }
140 if self.default {
141 flags |= 0x02;
142 }
143 if self.encrypted {
144 flags |= 0x08;
145 }
146 flags
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct TypeInfo {
153 pub type_id: u8,
155 pub max_length: Option<u16>,
157 pub precision: Option<u8>,
159 pub scale: Option<u8>,
161 pub collation: Option<[u8; 5]>,
163}
164
165impl TypeInfo {
166 pub fn int() -> Self {
168 Self {
169 type_id: 0x26, max_length: Some(4),
171 precision: None,
172 scale: None,
173 collation: None,
174 }
175 }
176
177 pub fn bigint() -> Self {
179 Self {
180 type_id: 0x26, max_length: Some(8),
182 precision: None,
183 scale: None,
184 collation: None,
185 }
186 }
187
188 pub fn smallint() -> Self {
190 Self {
191 type_id: 0x26, max_length: Some(2),
193 precision: None,
194 scale: None,
195 collation: None,
196 }
197 }
198
199 pub fn tinyint() -> Self {
201 Self {
202 type_id: 0x26, max_length: Some(1),
204 precision: None,
205 scale: None,
206 collation: None,
207 }
208 }
209
210 pub fn bit() -> Self {
212 Self {
213 type_id: 0x68, max_length: Some(1),
215 precision: None,
216 scale: None,
217 collation: None,
218 }
219 }
220
221 pub fn float() -> Self {
223 Self {
224 type_id: 0x6D, max_length: Some(8),
226 precision: None,
227 scale: None,
228 collation: None,
229 }
230 }
231
232 pub fn real() -> Self {
234 Self {
235 type_id: 0x6D, max_length: Some(4),
237 precision: None,
238 scale: None,
239 collation: None,
240 }
241 }
242
243 pub fn nvarchar(max_len: u16) -> Self {
245 Self {
246 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
249 scale: None,
250 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
252 }
253 }
254
255 pub fn nvarchar_max() -> Self {
257 Self {
258 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
261 scale: None,
262 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
263 }
264 }
265
266 pub fn varbinary(max_len: u16) -> Self {
268 Self {
269 type_id: 0xA5, max_length: Some(max_len),
271 precision: None,
272 scale: None,
273 collation: None,
274 }
275 }
276
277 pub fn uniqueidentifier() -> Self {
279 Self {
280 type_id: 0x24, max_length: Some(16),
282 precision: None,
283 scale: None,
284 collation: None,
285 }
286 }
287
288 pub fn date() -> Self {
290 Self {
291 type_id: 0x28, max_length: None,
293 precision: None,
294 scale: None,
295 collation: None,
296 }
297 }
298
299 pub fn datetime2(scale: u8) -> Self {
301 Self {
302 type_id: 0x2A, max_length: None,
304 precision: None,
305 scale: Some(scale),
306 collation: None,
307 }
308 }
309
310 pub fn decimal(precision: u8, scale: u8) -> Self {
312 Self {
313 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
316 scale: Some(scale),
317 collation: None,
318 }
319 }
320
321 pub fn encode(&self, buf: &mut BytesMut) {
323 if self.type_id != 0xF3 {
326 buf.put_u8(self.type_id);
327 }
328
329 match self.type_id {
331 0x26 | 0x68 | 0x6D => {
332 if let Some(len) = self.max_length {
334 buf.put_u8(len as u8);
335 }
336 }
337 0xE7 | 0xA5 | 0xEF => {
338 if let Some(len) = self.max_length {
340 buf.put_u16_le(len);
341 }
342 if let Some(collation) = self.collation {
344 buf.put_slice(&collation);
345 }
346 }
347 0x24 => {
348 if let Some(len) = self.max_length {
350 buf.put_u8(len as u8);
351 }
352 }
353 0x29..=0x2B => {
354 if let Some(scale) = self.scale {
356 buf.put_u8(scale);
357 }
358 }
359 0x6C | 0x6A => {
360 if let Some(len) = self.max_length {
362 buf.put_u8(len as u8);
363 }
364 if let Some(precision) = self.precision {
365 buf.put_u8(precision);
366 }
367 if let Some(scale) = self.scale {
368 buf.put_u8(scale);
369 }
370 }
371 _ => {}
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct RpcParam {
379 pub name: String,
381 pub flags: ParamFlags,
383 pub type_info: TypeInfo,
385 pub value: Option<Bytes>,
387}
388
389impl RpcParam {
390 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
392 Self {
393 name: name.into(),
394 flags: ParamFlags::default(),
395 type_info,
396 value: Some(value),
397 }
398 }
399
400 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
402 Self {
403 name: name.into(),
404 flags: ParamFlags::default(),
405 type_info,
406 value: None,
407 }
408 }
409
410 pub fn int(name: impl Into<String>, value: i32) -> Self {
412 let mut buf = BytesMut::with_capacity(4);
413 buf.put_i32_le(value);
414 Self::new(name, TypeInfo::int(), buf.freeze())
415 }
416
417 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
419 let mut buf = BytesMut::with_capacity(8);
420 buf.put_i64_le(value);
421 Self::new(name, TypeInfo::bigint(), buf.freeze())
422 }
423
424 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
426 let mut buf = BytesMut::new();
427 for code_unit in value.encode_utf16() {
429 buf.put_u16_le(code_unit);
430 }
431 let char_len = value.chars().count();
432 let type_info = if char_len > 4000 {
433 TypeInfo::nvarchar_max()
434 } else {
435 TypeInfo::nvarchar(char_len.max(1) as u16)
436 };
437 Self::new(name, type_info, buf.freeze())
438 }
439
440 #[must_use]
442 pub fn as_output(mut self) -> Self {
443 self.flags = self.flags.output();
444 self
445 }
446
447 pub fn encode(&self, buf: &mut BytesMut) {
449 let name_len = self.name.encode_utf16().count() as u8;
451 buf.put_u8(name_len);
452 if name_len > 0 {
453 for code_unit in self.name.encode_utf16() {
454 buf.put_u16_le(code_unit);
455 }
456 }
457
458 buf.put_u8(self.flags.encode());
460
461 self.type_info.encode(buf);
463
464 if let Some(ref value) = self.value {
466 match self.type_info.type_id {
468 0x26 => {
469 buf.put_u8(value.len() as u8);
471 buf.put_slice(value);
472 }
473 0x68 | 0x6D => {
474 buf.put_u8(value.len() as u8);
476 buf.put_slice(value);
477 }
478 0xE7 | 0xA5 => {
479 if self.type_info.max_length == Some(0xFFFF) {
481 let total_len = value.len() as u64;
484 buf.put_u64_le(total_len);
485 buf.put_u32_le(value.len() as u32);
486 buf.put_slice(value);
487 buf.put_u32_le(0); } else {
489 buf.put_u16_le(value.len() as u16);
490 buf.put_slice(value);
491 }
492 }
493 0x24 => {
494 buf.put_u8(value.len() as u8);
496 buf.put_slice(value);
497 }
498 0x28 => {
499 buf.put_slice(value);
501 }
502 0x2A => {
503 buf.put_u8(value.len() as u8);
505 buf.put_slice(value);
506 }
507 0x6C => {
508 buf.put_u8(value.len() as u8);
510 buf.put_slice(value);
511 }
512 0xF3 => {
513 buf.put_slice(value);
517 }
518 _ => {
519 buf.put_u8(value.len() as u8);
521 buf.put_slice(value);
522 }
523 }
524 } else {
525 match self.type_info.type_id {
527 0xE7 | 0xA5 => {
528 if self.type_info.max_length == Some(0xFFFF) {
530 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
532 buf.put_u16_le(0xFFFF);
533 }
534 }
535 _ => {
536 buf.put_u8(0); }
538 }
539 }
540 }
541}
542
543#[derive(Debug, Clone)]
545pub struct RpcRequest {
546 proc_name: Option<String>,
548 proc_id: Option<ProcId>,
550 options: RpcOptionFlags,
552 params: Vec<RpcParam>,
554}
555
556impl RpcRequest {
557 pub fn named(proc_name: impl Into<String>) -> Self {
559 Self {
560 proc_name: Some(proc_name.into()),
561 proc_id: None,
562 options: RpcOptionFlags::default(),
563 params: Vec::new(),
564 }
565 }
566
567 pub fn by_id(proc_id: ProcId) -> Self {
569 Self {
570 proc_name: None,
571 proc_id: Some(proc_id),
572 options: RpcOptionFlags::default(),
573 params: Vec::new(),
574 }
575 }
576
577 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
595 let mut request = Self::by_id(ProcId::ExecuteSql);
596
597 request.params.push(RpcParam::nvarchar("", sql));
599
600 if !params.is_empty() {
602 let declarations = Self::build_param_declarations(¶ms);
603 request.params.push(RpcParam::nvarchar("", &declarations));
604 }
605
606 request.params.extend(params);
608
609 request
610 }
611
612 fn build_param_declarations(params: &[RpcParam]) -> String {
614 params
615 .iter()
616 .map(|p| {
617 let name = if p.name.starts_with('@') {
618 p.name.clone()
619 } else if p.name.is_empty() {
620 format!(
622 "@p{}",
623 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
624 )
625 } else {
626 format!("@{}", p.name)
627 };
628
629 let type_name: String = match p.type_info.type_id {
630 0x26 => match p.type_info.max_length {
631 Some(1) => "tinyint".to_string(),
632 Some(2) => "smallint".to_string(),
633 Some(4) => "int".to_string(),
634 Some(8) => "bigint".to_string(),
635 _ => "int".to_string(),
636 },
637 0x68 => "bit".to_string(),
638 0x6D => match p.type_info.max_length {
639 Some(4) => "real".to_string(),
640 _ => "float".to_string(),
641 },
642 0xE7 => {
643 if p.type_info.max_length == Some(0xFFFF) {
644 "nvarchar(max)".to_string()
645 } else {
646 let len = p.type_info.max_length.unwrap_or(4000) / 2;
647 format!("nvarchar({})", len)
648 }
649 }
650 0xA5 => {
651 if p.type_info.max_length == Some(0xFFFF) {
652 "varbinary(max)".to_string()
653 } else {
654 let len = p.type_info.max_length.unwrap_or(8000);
655 format!("varbinary({})", len)
656 }
657 }
658 0x24 => "uniqueidentifier".to_string(),
659 0x28 => "date".to_string(),
660 0x2A => {
661 let scale = p.type_info.scale.unwrap_or(7);
662 format!("datetime2({})", scale)
663 }
664 0x6C => {
665 let precision = p.type_info.precision.unwrap_or(18);
666 let scale = p.type_info.scale.unwrap_or(0);
667 format!("decimal({}, {})", precision, scale)
668 }
669 _ => "sql_variant".to_string(),
670 };
671
672 format!("{} {}", name, type_name)
673 })
674 .collect::<Vec<_>>()
675 .join(", ")
676 }
677
678 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
680 let mut request = Self::by_id(ProcId::Prepare);
681
682 request
684 .params
685 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
686
687 let declarations = Self::build_param_declarations(params);
689 request
690 .params
691 .push(RpcParam::nvarchar("@params", &declarations));
692
693 request.params.push(RpcParam::nvarchar("@stmt", sql));
695
696 request.params.push(RpcParam::int("@options", 1));
698
699 request
700 }
701
702 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
704 let mut request = Self::by_id(ProcId::Execute);
705
706 request.params.push(RpcParam::int("@handle", handle));
708
709 request.params.extend(params);
711
712 request
713 }
714
715 pub fn unprepare(handle: i32) -> Self {
717 let mut request = Self::by_id(ProcId::Unprepare);
718 request.params.push(RpcParam::int("@handle", handle));
719 request
720 }
721
722 #[must_use]
724 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
725 self.options = options;
726 self
727 }
728
729 #[must_use]
731 pub fn param(mut self, param: RpcParam) -> Self {
732 self.params.push(param);
733 self
734 }
735
736 #[must_use]
740 pub fn encode(&self) -> Bytes {
741 self.encode_with_transaction(0)
742 }
743
744 #[must_use]
756 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
757 let mut buf = BytesMut::with_capacity(256);
758
759 let all_headers_start = buf.len();
762 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
773 let len_bytes = (all_headers_len as u32).to_le_bytes();
774 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
775
776 if let Some(proc_id) = self.proc_id {
778 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
781 } else if let Some(ref proc_name) = self.proc_name {
782 let name_len = proc_name.encode_utf16().count() as u16;
784 buf.put_u16_le(name_len);
785 write_utf16_string(&mut buf, proc_name);
786 }
787
788 buf.put_u16_le(self.options.encode());
790
791 for param in &self.params {
793 param.encode(&mut buf);
794 }
795
796 buf.freeze()
797 }
798}
799
800#[cfg(test)]
801#[allow(clippy::unwrap_used)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_proc_id_values() {
807 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
808 assert_eq!(ProcId::Prepare as u16, 0x000B);
809 assert_eq!(ProcId::Execute as u16, 0x000C);
810 assert_eq!(ProcId::Unprepare as u16, 0x000F);
811 }
812
813 #[test]
814 fn test_option_flags_encode() {
815 let flags = RpcOptionFlags::new().with_recompile(true);
816 assert_eq!(flags.encode(), 0x0001);
817 }
818
819 #[test]
820 fn test_param_flags_encode() {
821 let flags = ParamFlags::new().output();
822 assert_eq!(flags.encode(), 0x01);
823 }
824
825 #[test]
826 fn test_int_param() {
827 let param = RpcParam::int("@p1", 42);
828 assert_eq!(param.name, "@p1");
829 assert_eq!(param.type_info.type_id, 0x26);
830 assert!(param.value.is_some());
831 }
832
833 #[test]
834 fn test_nvarchar_param() {
835 let param = RpcParam::nvarchar("@name", "Alice");
836 assert_eq!(param.name, "@name");
837 assert_eq!(param.type_info.type_id, 0xE7);
838 assert_eq!(param.value.as_ref().unwrap().len(), 10);
840 }
841
842 #[test]
843 fn test_execute_sql_request() {
844 let rpc = RpcRequest::execute_sql(
845 "SELECT * FROM users WHERE id = @p1",
846 vec![RpcParam::int("@p1", 42)],
847 );
848
849 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
850 assert_eq!(rpc.params.len(), 3);
852 }
853
854 #[test]
855 fn test_param_declarations() {
856 let params = vec![
857 RpcParam::int("@p1", 42),
858 RpcParam::nvarchar("@name", "Alice"),
859 ];
860
861 let decls = RpcRequest::build_param_declarations(¶ms);
862 assert!(decls.contains("@p1 int"));
863 assert!(decls.contains("@name nvarchar"));
864 }
865
866 #[test]
867 fn test_rpc_encode_not_empty() {
868 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
869 let encoded = rpc.encode();
870 assert!(!encoded.is_empty());
871 }
872
873 #[test]
874 fn test_prepare_request() {
875 let rpc = RpcRequest::prepare(
876 "SELECT * FROM users WHERE id = @p1",
877 &[RpcParam::int("@p1", 0)],
878 );
879
880 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
881 assert_eq!(rpc.params.len(), 4);
883 assert!(rpc.params[0].flags.by_ref); }
885
886 #[test]
887 fn test_execute_request() {
888 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
889
890 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
891 assert_eq!(rpc.params.len(), 2); }
893
894 #[test]
895 fn test_unprepare_request() {
896 let rpc = RpcRequest::unprepare(123);
897
898 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
899 assert_eq!(rpc.params.len(), 1); }
901}