cdbc_mssql/connection/
prepare.rs

1
2use cdbc::decode::Decode;
3use crate::protocol::done::Status;
4use crate::protocol::message::Message;
5use crate::protocol::packet::PacketType;
6use crate::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
7use crate::statement::MssqlStatementMetadata;
8use crate::{Mssql, MssqlArguments, MssqlConnection, MssqlTypeInfo, MssqlValueRef};
9use either::Either;
10use regex::Regex;
11use std::sync::Arc;
12use once_cell::sync::Lazy;
13use cdbc::Error;
14
15pub fn prepare(
16    conn: &mut MssqlConnection,
17    sql: &str,
18) -> Result<Arc<MssqlStatementMetadata>, Error> {
19    if let Some(metadata) = conn.cache_statement.get_mut(sql) {
20        return Ok(metadata.clone());
21    }
22
23    // NOTE: this does not support unicode identifiers; as we don't even support
24    //       named parameters (yet) this is probably fine, for now
25
26    static PARAMS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap());
27
28    let mut params = String::new();
29
30    for m in PARAMS_RE.captures_iter(sql) {
31        if !params.is_empty() {
32            params.push_str(",");
33        }
34
35        params.push_str(&m[0]);
36
37        // NOTE: this means that a query! of `SELECT @p1` will have the macros believe
38        //       it will return nvarchar(1); this is a greater issue with `query!` that we
39        //       we need to circle back to. This doesn't happen much in practice however.
40        params.push_str(" nvarchar(1)");
41    }
42
43    let params = if params.is_empty() {
44        None
45    } else {
46        Some(&*params)
47    };
48
49    let mut args = MssqlArguments::default();
50
51    args.declare("", 0_i32);
52    args.add_unnamed(params);
53    args.add_unnamed(sql);
54    args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA
55
56    conn.stream.write_packet(
57        PacketType::Rpc,
58        RpcRequest {
59            transaction_descriptor: conn.stream.transaction_descriptor,
60            arguments: &args,
61            // [sp_prepare] will emit the column meta data
62            // small issue is that we need to declare all the used placeholders with a "fallback" type
63            // we currently use regex to collect them; false positives are *okay* but false
64            // negatives would break the query
65            procedure: Either::Right(Procedure::Prepare),
66            options: OptionFlags::empty(),
67        },
68    );
69
70    conn.stream.flush()?;
71    conn.stream.wait_until_ready()?;
72    conn.stream.pending_done_count += 1;
73
74    let mut id: Option<i32> = None;
75
76    loop {
77        let message = conn.stream.recv_message()?;
78
79        match message {
80            Message::DoneProc(done) | Message::Done(done) => {
81                if !done.status.contains(Status::DONE_MORE) {
82                    // done with prepare
83                    conn.stream.handle_done(&done);
84                    break;
85                }
86            }
87
88            Message::ReturnValue(rv) => {
89                id = <i32 as Decode<Mssql>>::decode(MssqlValueRef {
90                    data: rv.value.as_ref(),
91                    type_info: MssqlTypeInfo(rv.type_info),
92                })
93                .ok();
94            }
95
96            _ => {}
97        }
98    }
99
100    if let Some(id) = id {
101        let mut args = MssqlArguments::default();
102        args.add_unnamed(id);
103
104        conn.stream.write_packet(
105            PacketType::Rpc,
106            RpcRequest {
107                transaction_descriptor: conn.stream.transaction_descriptor,
108                arguments: &args,
109                procedure: Either::Right(Procedure::Unprepare),
110                options: OptionFlags::empty(),
111            },
112        );
113
114        conn.stream.flush()?;
115        conn.stream.wait_until_ready()?;
116        conn.stream.pending_done_count += 1;
117
118        loop {
119            let message = conn.stream.recv_message()?;
120
121            match message {
122                Message::DoneProc(done) | Message::Done(done) => {
123                    if !done.status.contains(Status::DONE_MORE) {
124                        // done with unprepare
125                        conn.stream.handle_done(&done);
126                        break;
127                    }
128                }
129
130                _ => {}
131            }
132        }
133    }
134
135    let metadata = Arc::new(MssqlStatementMetadata {
136        columns: conn.stream.columns.as_ref().clone(),
137        column_names: conn.stream.column_names.as_ref().clone(),
138    });
139
140    conn.cache_statement.insert(sql, metadata.clone());
141
142    Ok(metadata)
143}