mod auth;
mod config;
mod connection;
mod tls;
#[cfg(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
))]
mod tls_stream;
pub use auth::*;
pub use config::*;
pub(crate) use connection::*;
use crate::tds::stream::ReceivedToken;
use crate::{
result::ExecuteResult,
tds::{
codec::{self, IteratorJoin},
stream::{QueryStream, TokenStream},
},
BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
};
use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
use enumflags2::BitFlags;
use futures_util::io::{AsyncRead, AsyncWrite};
use futures_util::stream::TryStreamExt;
use std::{borrow::Cow, fmt::Debug};
#[derive(Debug)]
pub struct Client<S: AsyncRead + AsyncWrite + Unpin + Send> {
pub(crate) connection: Connection<S>,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
Ok(Client {
connection: Connection::connect(config, tcp_stream).await?,
})
}
pub async fn execute<'a>(
&mut self,
query: impl Into<Cow<'a, str>>,
params: &[&dyn ToSql],
) -> crate::Result<ExecuteResult> {
self.connection.flush_stream().await?;
let rpc_params = Self::rpc_params(query);
let params = params.iter().map(|s| s.to_sql());
self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
.await?;
ExecuteResult::new(&mut self.connection).await
}
pub async fn query<'a, 'b>(
&'a mut self,
query: impl Into<Cow<'b, str>>,
params: &'b [&'b dyn ToSql],
) -> crate::Result<QueryStream<'a>>
where
'a: 'b,
{
self.connection.flush_stream().await?;
let rpc_params = Self::rpc_params(query);
let params = params.iter().map(|p| p.to_sql());
self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
.await?;
let ts = TokenStream::new(&mut self.connection);
let mut result = QueryStream::new(ts.try_unfold());
result.forward_to_metadata().await?;
Ok(result)
}
pub async fn simple_query<'a, 'b>(
&'a mut self,
query: impl Into<Cow<'b, str>>,
) -> crate::Result<QueryStream<'a>>
where
'a: 'b,
{
self.connection.flush_stream().await?;
let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();
self.connection.send(PacketHeader::batch(id), req).await?;
let ts = TokenStream::new(&mut self.connection);
let mut result = QueryStream::new(ts.try_unfold());
result.forward_to_metadata().await?;
Ok(result)
}
pub async fn bulk_insert<'a>(
&'a mut self,
table: &'a str,
) -> crate::Result<BulkLoadRequest<'a, S>> {
self.connection.flush_stream().await?;
let query = format!("SELECT TOP 0 * FROM {}", table);
let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();
self.connection.send(PacketHeader::batch(id), req).await?;
let token_stream = TokenStream::new(&mut self.connection).try_unfold();
let columns = token_stream
.try_fold(None, |mut columns, token| async move {
if let ReceivedToken::NewResultset(metadata) = token {
columns = Some(metadata.columns.clone());
};
Ok(columns)
})
.await?;
let columns: Vec<_> = columns
.ok_or_else(|| {
crate::Error::Protocol("expecting column metadata from query but not found".into())
})?
.into_iter()
.filter(|column| column.base.flags.contains(ColumnFlag::Updateable))
.collect();
self.connection.flush_stream().await?;
let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
let query = format!("INSERT BULK {} ({})", table, col_data);
let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();
self.connection.send(PacketHeader::batch(id), req).await?;
let ts = TokenStream::new(&mut self.connection);
ts.flush_done().await?;
BulkLoadRequest::new(&mut self.connection, columns)
}
pub async fn close(self) -> crate::Result<()> {
self.connection.close().await
}
pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
vec![
RpcParam {
name: Cow::Borrowed("stmt"),
flags: BitFlags::empty(),
value: ColumnData::String(Some(query.into())),
},
RpcParam {
name: Cow::Borrowed("params"),
flags: BitFlags::empty(),
value: ColumnData::I32(Some(0)),
},
]
}
pub(crate) async fn rpc_perform_query<'a, 'b>(
&'a mut self,
proc_id: RpcProcId,
mut rpc_params: Vec<RpcParam<'b>>,
params: impl Iterator<Item = ColumnData<'b>>,
) -> crate::Result<()>
where
'a: 'b,
{
let mut param_str = String::new();
for (i, param) in params.enumerate() {
if i > 0 {
param_str.push(',')
}
param_str.push_str(&format!("@P{} ", i + 1));
param_str.push_str(¶m.type_name());
rpc_params.push(RpcParam {
name: Cow::Owned(format!("@P{}", i + 1)),
flags: BitFlags::empty(),
value: param,
});
}
if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") {
params.value = ColumnData::String(Some(param_str.into()));
}
let req = TokenRpcRequest::new(
proc_id,
rpc_params,
self.connection.context().transaction_descriptor(),
);
let id = self.connection.context_mut().next_packet_id();
self.connection.send(PacketHeader::rpc(id), req).await?;
Ok(())
}
}