use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
use crate::client::Client;
use crate::error::Result;
use crate::state::ConnectionState;
use crate::stream::ProcedureResult;
pub struct ProcedureBuilder<'a, S: ConnectionState> {
client: &'a mut Client<S>,
proc_name: String,
params: Vec<RpcParam>,
}
impl<'a, S: ConnectionState> ProcedureBuilder<'a, S> {
pub(crate) fn new(client: &'a mut Client<S>, proc_name: &str) -> Self {
Self {
client,
proc_name: proc_name.to_string(),
params: Vec::new(),
}
}
pub fn input(&mut self, name: &str, value: &(dyn crate::ToSql + Sync)) -> &mut Self {
match Client::<S>::convert_single_param(
name,
value,
self.client.send_unicode(),
self.client.server_collation(),
) {
Ok(param) => self.params.push(param),
Err(e) => {
tracing::warn!(name = name, error = %e, "failed to convert input parameter");
self.params
.push(RpcParam::null(name, RpcTypeInfo::nvarchar(1)));
}
}
self
}
pub fn output_int(&mut self, name: &str) -> &mut Self {
self.params
.push(RpcParam::null(name, RpcTypeInfo::int()).as_output());
self
}
pub fn output_bigint(&mut self, name: &str) -> &mut Self {
self.params
.push(RpcParam::null(name, RpcTypeInfo::bigint()).as_output());
self
}
pub fn output_nvarchar(&mut self, name: &str, max_len: u16) -> &mut Self {
let type_info = if max_len == 0 {
RpcTypeInfo::nvarchar_max()
} else {
RpcTypeInfo::nvarchar(max_len)
};
self.params
.push(RpcParam::null(name, type_info).as_output());
self
}
pub fn output_bit(&mut self, name: &str) -> &mut Self {
self.params
.push(RpcParam::null(name, RpcTypeInfo::bit()).as_output());
self
}
pub fn output_float(&mut self, name: &str) -> &mut Self {
self.params
.push(RpcParam::null(name, RpcTypeInfo::float()).as_output());
self
}
pub fn output_decimal(&mut self, name: &str, precision: u8, scale: u8) -> &mut Self {
self.params
.push(RpcParam::null(name, RpcTypeInfo::decimal(precision, scale)).as_output());
self
}
pub fn output_raw(&mut self, name: &str, type_info: RpcTypeInfo) -> &mut Self {
self.params
.push(RpcParam::null(name, type_info).as_output());
self
}
pub async fn execute(&mut self) -> Result<ProcedureResult> {
let mut rpc = RpcRequest::named(&self.proc_name);
for param in self.params.drain(..) {
rpc = rpc.param(param);
}
self.client.send_rpc(&rpc).await?;
self.client.read_procedure_result().await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use tds_protocol::rpc::TypeInfo as RpcTypeInfo;
#[test]
fn test_output_int_has_by_ref_flag() {
use tds_protocol::rpc::RpcParam;
let param = RpcParam::null("@result", RpcTypeInfo::int()).as_output();
assert!(param.flags.by_ref);
assert!(param.value.is_none());
assert_eq!(param.name, "@result");
}
#[test]
fn test_output_nvarchar_max() {
use tds_protocol::rpc::RpcParam;
let param = RpcParam::null("@msg", RpcTypeInfo::nvarchar_max()).as_output();
assert!(param.flags.by_ref);
assert_eq!(param.type_info.max_length, Some(0xFFFF));
}
#[test]
fn test_output_decimal_precision_scale() {
use tds_protocol::rpc::RpcParam;
let param = RpcParam::null("@total", RpcTypeInfo::decimal(18, 2)).as_output();
assert!(param.flags.by_ref);
assert_eq!(param.type_info.precision, Some(18));
assert_eq!(param.type_info.scale, Some(2));
}
}