1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u16)]
36pub enum ProcId {
37 Cursor = 0x0001,
39 CursorOpen = 0x0002,
41 CursorPrepare = 0x0003,
43 CursorExecute = 0x0004,
45 CursorPrepExec = 0x0005,
47 CursorUnprepare = 0x0006,
49 CursorFetch = 0x0007,
51 CursorOption = 0x0008,
53 CursorClose = 0x0009,
55 ExecuteSql = 0x000A,
57 Prepare = 0x000B,
59 Execute = 0x000C,
61 PrepExec = 0x000D,
63 PrepExecRpc = 0x000E,
65 Unprepare = 0x000F,
67}
68
69#[derive(Debug, Clone, Copy, Default)]
71pub struct RpcOptionFlags {
72 pub with_recompile: bool,
74 pub no_metadata: bool,
76 pub reuse_metadata: bool,
78}
79
80impl RpcOptionFlags {
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 #[must_use]
88 pub fn with_recompile(mut self, value: bool) -> Self {
89 self.with_recompile = value;
90 self
91 }
92
93 pub fn encode(&self) -> u16 {
95 let mut flags = 0u16;
96 if self.with_recompile {
97 flags |= 0x0001;
98 }
99 if self.no_metadata {
100 flags |= 0x0002;
101 }
102 if self.reuse_metadata {
103 flags |= 0x0004;
104 }
105 flags
106 }
107}
108
109#[derive(Debug, Clone, Copy, Default)]
111pub struct ParamFlags {
112 pub by_ref: bool,
114 pub default: bool,
116 pub encrypted: bool,
118}
119
120impl ParamFlags {
121 pub fn new() -> Self {
123 Self::default()
124 }
125
126 #[must_use]
128 pub fn output(mut self) -> Self {
129 self.by_ref = true;
130 self
131 }
132
133 pub fn encode(&self) -> u8 {
135 let mut flags = 0u8;
136 if self.by_ref {
137 flags |= 0x01;
138 }
139 if self.default {
140 flags |= 0x02;
141 }
142 if self.encrypted {
143 flags |= 0x08;
144 }
145 flags
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct TypeInfo {
152 pub type_id: u8,
154 pub max_length: Option<u16>,
156 pub precision: Option<u8>,
158 pub scale: Option<u8>,
160 pub collation: Option<[u8; 5]>,
162}
163
164impl TypeInfo {
165 pub fn int() -> Self {
167 Self {
168 type_id: 0x26, max_length: Some(4),
170 precision: None,
171 scale: None,
172 collation: None,
173 }
174 }
175
176 pub fn bigint() -> Self {
178 Self {
179 type_id: 0x26, max_length: Some(8),
181 precision: None,
182 scale: None,
183 collation: None,
184 }
185 }
186
187 pub fn smallint() -> Self {
189 Self {
190 type_id: 0x26, max_length: Some(2),
192 precision: None,
193 scale: None,
194 collation: None,
195 }
196 }
197
198 pub fn tinyint() -> Self {
200 Self {
201 type_id: 0x26, max_length: Some(1),
203 precision: None,
204 scale: None,
205 collation: None,
206 }
207 }
208
209 pub fn bit() -> Self {
211 Self {
212 type_id: 0x68, max_length: Some(1),
214 precision: None,
215 scale: None,
216 collation: None,
217 }
218 }
219
220 pub fn float() -> Self {
222 Self {
223 type_id: 0x6D, max_length: Some(8),
225 precision: None,
226 scale: None,
227 collation: None,
228 }
229 }
230
231 pub fn real() -> Self {
233 Self {
234 type_id: 0x6D, max_length: Some(4),
236 precision: None,
237 scale: None,
238 collation: None,
239 }
240 }
241
242 pub fn nvarchar(max_len: u16) -> Self {
244 Self {
245 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
248 scale: None,
249 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
251 }
252 }
253
254 pub fn nvarchar_max() -> Self {
256 Self {
257 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
260 scale: None,
261 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262 }
263 }
264
265 pub fn varbinary(max_len: u16) -> Self {
267 Self {
268 type_id: 0xA5, max_length: Some(max_len),
270 precision: None,
271 scale: None,
272 collation: None,
273 }
274 }
275
276 pub fn uniqueidentifier() -> Self {
278 Self {
279 type_id: 0x24, max_length: Some(16),
281 precision: None,
282 scale: None,
283 collation: None,
284 }
285 }
286
287 pub fn date() -> Self {
289 Self {
290 type_id: 0x28, max_length: None,
292 precision: None,
293 scale: None,
294 collation: None,
295 }
296 }
297
298 pub fn datetime2(scale: u8) -> Self {
300 Self {
301 type_id: 0x2A, max_length: None,
303 precision: None,
304 scale: Some(scale),
305 collation: None,
306 }
307 }
308
309 pub fn decimal(precision: u8, scale: u8) -> Self {
311 Self {
312 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
315 scale: Some(scale),
316 collation: None,
317 }
318 }
319
320 pub fn encode(&self, buf: &mut BytesMut) {
322 if self.type_id != 0xF3 {
325 buf.put_u8(self.type_id);
326 }
327
328 match self.type_id {
330 0x26 | 0x68 | 0x6D => {
331 if let Some(len) = self.max_length {
333 buf.put_u8(len as u8);
334 }
335 }
336 0xE7 | 0xA5 | 0xEF => {
337 if let Some(len) = self.max_length {
339 buf.put_u16_le(len);
340 }
341 if let Some(collation) = self.collation {
343 buf.put_slice(&collation);
344 }
345 }
346 0x24 => {
347 if let Some(len) = self.max_length {
349 buf.put_u8(len as u8);
350 }
351 }
352 0x29..=0x2B => {
353 if let Some(scale) = self.scale {
355 buf.put_u8(scale);
356 }
357 }
358 0x6C | 0x6A => {
359 if let Some(len) = self.max_length {
361 buf.put_u8(len as u8);
362 }
363 if let Some(precision) = self.precision {
364 buf.put_u8(precision);
365 }
366 if let Some(scale) = self.scale {
367 buf.put_u8(scale);
368 }
369 }
370 _ => {}
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct RpcParam {
378 pub name: String,
380 pub flags: ParamFlags,
382 pub type_info: TypeInfo,
384 pub value: Option<Bytes>,
386}
387
388impl RpcParam {
389 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
391 Self {
392 name: name.into(),
393 flags: ParamFlags::default(),
394 type_info,
395 value: Some(value),
396 }
397 }
398
399 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
401 Self {
402 name: name.into(),
403 flags: ParamFlags::default(),
404 type_info,
405 value: None,
406 }
407 }
408
409 pub fn int(name: impl Into<String>, value: i32) -> Self {
411 let mut buf = BytesMut::with_capacity(4);
412 buf.put_i32_le(value);
413 Self::new(name, TypeInfo::int(), buf.freeze())
414 }
415
416 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
418 let mut buf = BytesMut::with_capacity(8);
419 buf.put_i64_le(value);
420 Self::new(name, TypeInfo::bigint(), buf.freeze())
421 }
422
423 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
425 let mut buf = BytesMut::new();
426 for code_unit in value.encode_utf16() {
428 buf.put_u16_le(code_unit);
429 }
430 let char_len = value.chars().count();
431 let type_info = if char_len > 4000 {
432 TypeInfo::nvarchar_max()
433 } else {
434 TypeInfo::nvarchar(char_len.max(1) as u16)
435 };
436 Self::new(name, type_info, buf.freeze())
437 }
438
439 #[must_use]
441 pub fn as_output(mut self) -> Self {
442 self.flags = self.flags.output();
443 self
444 }
445
446 pub fn encode(&self, buf: &mut BytesMut) {
448 let name_len = self.name.encode_utf16().count() as u8;
450 buf.put_u8(name_len);
451 if name_len > 0 {
452 for code_unit in self.name.encode_utf16() {
453 buf.put_u16_le(code_unit);
454 }
455 }
456
457 buf.put_u8(self.flags.encode());
459
460 self.type_info.encode(buf);
462
463 if let Some(ref value) = self.value {
465 match self.type_info.type_id {
467 0x26 => {
468 buf.put_u8(value.len() as u8);
470 buf.put_slice(value);
471 }
472 0x68 | 0x6D => {
473 buf.put_u8(value.len() as u8);
475 buf.put_slice(value);
476 }
477 0xE7 | 0xA5 => {
478 if self.type_info.max_length == Some(0xFFFF) {
480 let total_len = value.len() as u64;
483 buf.put_u64_le(total_len);
484 buf.put_u32_le(value.len() as u32);
485 buf.put_slice(value);
486 buf.put_u32_le(0); } else {
488 buf.put_u16_le(value.len() as u16);
489 buf.put_slice(value);
490 }
491 }
492 0x24 => {
493 buf.put_u8(value.len() as u8);
495 buf.put_slice(value);
496 }
497 0x28 => {
498 buf.put_slice(value);
500 }
501 0x2A => {
502 buf.put_u8(value.len() as u8);
504 buf.put_slice(value);
505 }
506 0x6C => {
507 buf.put_u8(value.len() as u8);
509 buf.put_slice(value);
510 }
511 0xF3 => {
512 buf.put_slice(value);
516 }
517 _ => {
518 buf.put_u8(value.len() as u8);
520 buf.put_slice(value);
521 }
522 }
523 } else {
524 match self.type_info.type_id {
526 0xE7 | 0xA5 => {
527 if self.type_info.max_length == Some(0xFFFF) {
529 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
531 buf.put_u16_le(0xFFFF);
532 }
533 }
534 _ => {
535 buf.put_u8(0); }
537 }
538 }
539 }
540}
541
542#[derive(Debug, Clone)]
544pub struct RpcRequest {
545 proc_name: Option<String>,
547 proc_id: Option<ProcId>,
549 options: RpcOptionFlags,
551 params: Vec<RpcParam>,
553}
554
555impl RpcRequest {
556 pub fn named(proc_name: impl Into<String>) -> Self {
558 Self {
559 proc_name: Some(proc_name.into()),
560 proc_id: None,
561 options: RpcOptionFlags::default(),
562 params: Vec::new(),
563 }
564 }
565
566 pub fn by_id(proc_id: ProcId) -> Self {
568 Self {
569 proc_name: None,
570 proc_id: Some(proc_id),
571 options: RpcOptionFlags::default(),
572 params: Vec::new(),
573 }
574 }
575
576 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
594 let mut request = Self::by_id(ProcId::ExecuteSql);
595
596 request.params.push(RpcParam::nvarchar("", sql));
598
599 if !params.is_empty() {
601 let declarations = Self::build_param_declarations(¶ms);
602 request.params.push(RpcParam::nvarchar("", &declarations));
603 }
604
605 request.params.extend(params);
607
608 request
609 }
610
611 fn build_param_declarations(params: &[RpcParam]) -> String {
613 params
614 .iter()
615 .map(|p| {
616 let name = if p.name.starts_with('@') {
617 p.name.clone()
618 } else if p.name.is_empty() {
619 format!(
621 "@p{}",
622 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
623 )
624 } else {
625 format!("@{}", p.name)
626 };
627
628 let type_name: String = match p.type_info.type_id {
629 0x26 => match p.type_info.max_length {
630 Some(1) => "tinyint".to_string(),
631 Some(2) => "smallint".to_string(),
632 Some(4) => "int".to_string(),
633 Some(8) => "bigint".to_string(),
634 _ => "int".to_string(),
635 },
636 0x68 => "bit".to_string(),
637 0x6D => match p.type_info.max_length {
638 Some(4) => "real".to_string(),
639 _ => "float".to_string(),
640 },
641 0xE7 => {
642 if p.type_info.max_length == Some(0xFFFF) {
643 "nvarchar(max)".to_string()
644 } else {
645 let len = p.type_info.max_length.unwrap_or(4000) / 2;
646 format!("nvarchar({})", len)
647 }
648 }
649 0xA5 => {
650 if p.type_info.max_length == Some(0xFFFF) {
651 "varbinary(max)".to_string()
652 } else {
653 let len = p.type_info.max_length.unwrap_or(8000);
654 format!("varbinary({})", len)
655 }
656 }
657 0x24 => "uniqueidentifier".to_string(),
658 0x28 => "date".to_string(),
659 0x2A => {
660 let scale = p.type_info.scale.unwrap_or(7);
661 format!("datetime2({})", scale)
662 }
663 0x6C => {
664 let precision = p.type_info.precision.unwrap_or(18);
665 let scale = p.type_info.scale.unwrap_or(0);
666 format!("decimal({}, {})", precision, scale)
667 }
668 _ => "sql_variant".to_string(),
669 };
670
671 format!("{} {}", name, type_name)
672 })
673 .collect::<Vec<_>>()
674 .join(", ")
675 }
676
677 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
679 let mut request = Self::by_id(ProcId::Prepare);
680
681 request
683 .params
684 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
685
686 let declarations = Self::build_param_declarations(params);
688 request
689 .params
690 .push(RpcParam::nvarchar("@params", &declarations));
691
692 request.params.push(RpcParam::nvarchar("@stmt", sql));
694
695 request.params.push(RpcParam::int("@options", 1));
697
698 request
699 }
700
701 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
703 let mut request = Self::by_id(ProcId::Execute);
704
705 request.params.push(RpcParam::int("@handle", handle));
707
708 request.params.extend(params);
710
711 request
712 }
713
714 pub fn unprepare(handle: i32) -> Self {
716 let mut request = Self::by_id(ProcId::Unprepare);
717 request.params.push(RpcParam::int("@handle", handle));
718 request
719 }
720
721 #[must_use]
723 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
724 self.options = options;
725 self
726 }
727
728 #[must_use]
730 pub fn param(mut self, param: RpcParam) -> Self {
731 self.params.push(param);
732 self
733 }
734
735 #[must_use]
739 pub fn encode(&self) -> Bytes {
740 self.encode_with_transaction(0)
741 }
742
743 #[must_use]
755 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
756 let mut buf = BytesMut::with_capacity(256);
757
758 let all_headers_start = buf.len();
761 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
772 let len_bytes = (all_headers_len as u32).to_le_bytes();
773 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
774
775 if let Some(proc_id) = self.proc_id {
777 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
780 } else if let Some(ref proc_name) = self.proc_name {
781 let name_len = proc_name.encode_utf16().count() as u16;
783 buf.put_u16_le(name_len);
784 write_utf16_string(&mut buf, proc_name);
785 }
786
787 buf.put_u16_le(self.options.encode());
789
790 for param in &self.params {
792 param.encode(&mut buf);
793 }
794
795 buf.freeze()
796 }
797}
798
799#[cfg(test)]
800#[allow(clippy::unwrap_used)]
801mod tests {
802 use super::*;
803
804 #[test]
805 fn test_proc_id_values() {
806 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
807 assert_eq!(ProcId::Prepare as u16, 0x000B);
808 assert_eq!(ProcId::Execute as u16, 0x000C);
809 assert_eq!(ProcId::Unprepare as u16, 0x000F);
810 }
811
812 #[test]
813 fn test_option_flags_encode() {
814 let flags = RpcOptionFlags::new().with_recompile(true);
815 assert_eq!(flags.encode(), 0x0001);
816 }
817
818 #[test]
819 fn test_param_flags_encode() {
820 let flags = ParamFlags::new().output();
821 assert_eq!(flags.encode(), 0x01);
822 }
823
824 #[test]
825 fn test_int_param() {
826 let param = RpcParam::int("@p1", 42);
827 assert_eq!(param.name, "@p1");
828 assert_eq!(param.type_info.type_id, 0x26);
829 assert!(param.value.is_some());
830 }
831
832 #[test]
833 fn test_nvarchar_param() {
834 let param = RpcParam::nvarchar("@name", "Alice");
835 assert_eq!(param.name, "@name");
836 assert_eq!(param.type_info.type_id, 0xE7);
837 assert_eq!(param.value.as_ref().unwrap().len(), 10);
839 }
840
841 #[test]
842 fn test_execute_sql_request() {
843 let rpc = RpcRequest::execute_sql(
844 "SELECT * FROM users WHERE id = @p1",
845 vec![RpcParam::int("@p1", 42)],
846 );
847
848 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
849 assert_eq!(rpc.params.len(), 3);
851 }
852
853 #[test]
854 fn test_param_declarations() {
855 let params = vec![
856 RpcParam::int("@p1", 42),
857 RpcParam::nvarchar("@name", "Alice"),
858 ];
859
860 let decls = RpcRequest::build_param_declarations(¶ms);
861 assert!(decls.contains("@p1 int"));
862 assert!(decls.contains("@name nvarchar"));
863 }
864
865 #[test]
866 fn test_rpc_encode_not_empty() {
867 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
868 let encoded = rpc.encode();
869 assert!(!encoded.is_empty());
870 }
871
872 #[test]
873 fn test_prepare_request() {
874 let rpc = RpcRequest::prepare(
875 "SELECT * FROM users WHERE id = @p1",
876 &[RpcParam::int("@p1", 0)],
877 );
878
879 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
880 assert_eq!(rpc.params.len(), 4);
882 assert!(rpc.params[0].flags.by_ref); }
884
885 #[test]
886 fn test_execute_request() {
887 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
888
889 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
890 assert_eq!(rpc.params.len(), 2); }
892
893 #[test]
894 fn test_unprepare_request() {
895 let rpc = RpcRequest::unprepare(123);
896
897 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
898 assert_eq!(rpc.params.len(), 1); }
900}