cdbc_mssql/connection/
prepare.rs1
2use cdbc::decode::Decode;
3use crate::protocol::done::Status;
4use crate::protocol::message::Message;
5use crate::protocol::packet::PacketType;
6use crate::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
7use crate::statement::MssqlStatementMetadata;
8use crate::{Mssql, MssqlArguments, MssqlConnection, MssqlTypeInfo, MssqlValueRef};
9use either::Either;
10use regex::Regex;
11use std::sync::Arc;
12use once_cell::sync::Lazy;
13use cdbc::Error;
14
15pub fn prepare(
16 conn: &mut MssqlConnection,
17 sql: &str,
18) -> Result<Arc<MssqlStatementMetadata>, Error> {
19 if let Some(metadata) = conn.cache_statement.get_mut(sql) {
20 return Ok(metadata.clone());
21 }
22
23 static PARAMS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap());
27
28 let mut params = String::new();
29
30 for m in PARAMS_RE.captures_iter(sql) {
31 if !params.is_empty() {
32 params.push_str(",");
33 }
34
35 params.push_str(&m[0]);
36
37 params.push_str(" nvarchar(1)");
41 }
42
43 let params = if params.is_empty() {
44 None
45 } else {
46 Some(&*params)
47 };
48
49 let mut args = MssqlArguments::default();
50
51 args.declare("", 0_i32);
52 args.add_unnamed(params);
53 args.add_unnamed(sql);
54 args.add_unnamed(0x0001_i32); conn.stream.write_packet(
57 PacketType::Rpc,
58 RpcRequest {
59 transaction_descriptor: conn.stream.transaction_descriptor,
60 arguments: &args,
61 procedure: Either::Right(Procedure::Prepare),
66 options: OptionFlags::empty(),
67 },
68 );
69
70 conn.stream.flush()?;
71 conn.stream.wait_until_ready()?;
72 conn.stream.pending_done_count += 1;
73
74 let mut id: Option<i32> = None;
75
76 loop {
77 let message = conn.stream.recv_message()?;
78
79 match message {
80 Message::DoneProc(done) | Message::Done(done) => {
81 if !done.status.contains(Status::DONE_MORE) {
82 conn.stream.handle_done(&done);
84 break;
85 }
86 }
87
88 Message::ReturnValue(rv) => {
89 id = <i32 as Decode<Mssql>>::decode(MssqlValueRef {
90 data: rv.value.as_ref(),
91 type_info: MssqlTypeInfo(rv.type_info),
92 })
93 .ok();
94 }
95
96 _ => {}
97 }
98 }
99
100 if let Some(id) = id {
101 let mut args = MssqlArguments::default();
102 args.add_unnamed(id);
103
104 conn.stream.write_packet(
105 PacketType::Rpc,
106 RpcRequest {
107 transaction_descriptor: conn.stream.transaction_descriptor,
108 arguments: &args,
109 procedure: Either::Right(Procedure::Unprepare),
110 options: OptionFlags::empty(),
111 },
112 );
113
114 conn.stream.flush()?;
115 conn.stream.wait_until_ready()?;
116 conn.stream.pending_done_count += 1;
117
118 loop {
119 let message = conn.stream.recv_message()?;
120
121 match message {
122 Message::DoneProc(done) | Message::Done(done) => {
123 if !done.status.contains(Status::DONE_MORE) {
124 conn.stream.handle_done(&done);
126 break;
127 }
128 }
129
130 _ => {}
131 }
132 }
133 }
134
135 let metadata = Arc::new(MssqlStatementMetadata {
136 columns: conn.stream.columns.as_ref().clone(),
137 column_names: conn.stream.column_names.as_ref().clone(),
138 });
139
140 conn.cache_statement.insert(sql, metadata.clone());
141
142 Ok(metadata)
143}