tiberius/client.rs
1mod auth;
2mod config;
3mod connection;
4
5#[cfg(all(windows, feature = "winauth"))]
6mod sspi;
7mod tls;
8#[cfg(any(
9 feature = "rustls",
10 feature = "native-tls",
11 feature = "vendored-openssl"
12))]
13mod tls_stream;
14
15pub use auth::*;
16pub use config::*;
17pub(crate) use connection::*;
18
19use crate::tds::stream::ReceivedToken;
20use crate::{
21 result::ExecuteResult,
22 tds::{
23 codec::{self, IteratorJoin},
24 stream::{QueryStream, TokenStream},
25 },
26 BulkLoadColumns, BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
27};
28use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
29use enumflags2::BitFlags;
30use futures_util::io::{AsyncRead, AsyncWrite};
31use futures_util::stream::TryStreamExt;
32use std::{borrow::Cow, fmt::Debug};
33
34/// `Client` is the main entry point to the SQL Server, providing query
35/// execution capabilities.
36///
37/// A `Client` is created using the [`Config`], defining the needed
38/// connection options and capabilities.
39///
40/// # Example
41///
42/// ```no_run
43/// # use tiberius::{Config, AuthMethod};
44/// use tokio_util::compat::TokioAsyncWriteCompatExt;
45///
46/// # #[tokio::main]
47/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
48/// let mut config = Config::new();
49///
50/// config.host("0.0.0.0");
51/// config.port(1433);
52/// config.authentication(AuthMethod::sql_server("SA", "<Mys3cureP4ssW0rD>"));
53///
54/// let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
55/// tcp.set_nodelay(true)?;
56/// // Client is ready to use.
57/// let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
58/// # Ok(())
59/// # }
60/// ```
61///
62/// [`Config`]: struct.Config.html
63#[derive(Debug)]
64pub struct Client<S: AsyncRead + AsyncWrite + Unpin + Send> {
65 pub(crate) connection: Connection<S>,
66}
67
68impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
69 /// Uses an instance of [`Config`] to specify the connection
70 /// options required to connect to the database using an established
71 /// tcp connection
72 ///
73 /// [`Config`]: struct.Config.html
74 pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
75 Ok(Client {
76 connection: Connection::connect(config, tcp_stream).await?,
77 })
78 }
79
80 /// Executes SQL statements in the SQL Server, returning the number rows
81 /// affected. Useful for `INSERT`, `UPDATE` and `DELETE` statements. The
82 /// `query` can define the parameter placement by annotating them with
83 /// `@PN`, where N is the index of the parameter, starting from `1`. If
84 /// executing multiple queries at a time, delimit them with `;` and refer to
85 /// [`ExecuteResult`] how to get results for the separate queries.
86 ///
87 /// For mapping of Rust types when writing, see the documentation for
88 /// [`ToSql`]. For reading data from the database, see the documentation for
89 /// [`FromSql`].
90 ///
91 /// This API is not quite suitable for dynamic query parameters. In these
92 /// cases using a [`Query`] object might be easier.
93 ///
94 /// # Example
95 ///
96 /// ```no_run
97 /// # use tiberius::Config;
98 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
99 /// # use std::env;
100 /// # #[tokio::main]
101 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
102 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
103 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
104 /// # );
105 /// # let config = Config::from_ado_string(&c_str)?;
106 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
107 /// # tcp.set_nodelay(true)?;
108 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
109 /// let results = client
110 /// .execute(
111 /// "INSERT INTO ##Test (id) VALUES (@P1), (@P2), (@P3)",
112 /// &[&1i32, &2i32, &3i32],
113 /// )
114 /// .await?;
115 /// # Ok(())
116 /// # }
117 /// ```
118 ///
119 /// [`ExecuteResult`]: struct.ExecuteResult.html
120 /// [`ToSql`]: trait.ToSql.html
121 /// [`FromSql`]: trait.FromSql.html
122 /// [`Query`]: struct.Query.html
123 pub async fn execute<'a>(
124 &mut self,
125 query: impl Into<Cow<'a, str>>,
126 params: &[&dyn ToSql],
127 ) -> crate::Result<ExecuteResult> {
128 self.connection.flush_stream().await?;
129 let rpc_params = Self::rpc_params(query);
130
131 let params = params.iter().map(|s| s.to_sql());
132 self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
133 .await?;
134
135 ExecuteResult::new(&mut self.connection).await
136 }
137
138 /// Executes SQL statements in the SQL Server, returning resulting rows.
139 /// Useful for `SELECT` statements. The `query` can define the parameter
140 /// placement by annotating them with `@PN`, where N is the index of the
141 /// parameter, starting from `1`. If executing multiple queries at a time,
142 /// delimit them with `;` and refer to [`QueryStream`] on proper stream
143 /// handling.
144 ///
145 /// For mapping of Rust types when writing, see the documentation for
146 /// [`ToSql`]. For reading data from the database, see the documentation for
147 /// [`FromSql`].
148 ///
149 /// This API can be cumbersome for dynamic query parameters. In these cases,
150 /// if fighting too much with the compiler, using a [`Query`] object might be
151 /// easier.
152 ///
153 /// # Example
154 ///
155 /// ```
156 /// # use tiberius::Config;
157 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
158 /// # use std::env;
159 /// # #[tokio::main]
160 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
161 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
162 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
163 /// # );
164 /// # let config = Config::from_ado_string(&c_str)?;
165 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
166 /// # tcp.set_nodelay(true)?;
167 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
168 /// let stream = client
169 /// .query(
170 /// "SELECT @P1, @P2, @P3",
171 /// &[&1i32, &2i32, &3i32],
172 /// )
173 /// .await?;
174 /// # Ok(())
175 /// # }
176 /// ```
177 ///
178 /// [`QueryStream`]: struct.QueryStream.html
179 /// [`Query`]: struct.Query.html
180 /// [`ToSql`]: trait.ToSql.html
181 /// [`FromSql`]: trait.FromSql.html
182 pub async fn query<'a, 'b>(
183 &'a mut self,
184 query: impl Into<Cow<'b, str>>,
185 params: &'b [&'b dyn ToSql],
186 ) -> crate::Result<QueryStream<'a>>
187 where
188 'a: 'b,
189 {
190 self.connection.flush_stream().await?;
191 let rpc_params = Self::rpc_params(query);
192
193 let params = params.iter().map(|p| p.to_sql());
194 self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
195 .await?;
196
197 let ts = TokenStream::new(&mut self.connection);
198 let mut result = QueryStream::new(ts.try_unfold());
199 result.forward_to_metadata().await?;
200
201 Ok(result)
202 }
203
204 /// Execute multiple queries, delimited with `;` and return multiple result
205 /// sets; one for each query.
206 ///
207 /// # Example
208 ///
209 /// ```
210 /// # use tiberius::Config;
211 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
212 /// # use std::env;
213 /// # #[tokio::main]
214 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
215 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
216 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
217 /// # );
218 /// # let config = Config::from_ado_string(&c_str)?;
219 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
220 /// # tcp.set_nodelay(true)?;
221 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
222 /// let row = client.simple_query("SELECT 1 AS col").await?.into_row().await?.unwrap();
223 /// assert_eq!(Some(1i32), row.get("col"));
224 /// # Ok(())
225 /// # }
226 /// ```
227 ///
228 /// # Warning
229 ///
230 /// Do not use this with any user specified input. Please resort to prepared
231 /// statements using the [`query`] method.
232 ///
233 /// [`query`]: #method.query
234 pub async fn simple_query<'a, 'b>(
235 &'a mut self,
236 query: impl Into<Cow<'b, str>>,
237 ) -> crate::Result<QueryStream<'a>>
238 where
239 'a: 'b,
240 {
241 self.connection.flush_stream().await?;
242
243 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
244
245 let id = self.connection.context_mut().next_packet_id();
246 self.connection.send(PacketHeader::batch(id), req).await?;
247
248 let ts = TokenStream::new(&mut self.connection);
249
250 let mut result = QueryStream::new(ts.try_unfold());
251 result.forward_to_metadata().await?;
252
253 Ok(result)
254 }
255
256 /// Execute a `BULK INSERT` statement, efficiantly storing a large number of
257 /// rows to a specified table. Note: make sure the input row follows the same
258 /// schema as the table, otherwise calling `send()` will return an error.
259 ///
260 /// # Example
261 ///
262 /// ```
263 /// # use tiberius::{Config, IntoRow};
264 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
265 /// # use std::env;
266 /// # #[tokio::main]
267 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
268 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
269 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
270 /// # );
271 /// # let config = Config::from_ado_string(&c_str)?;
272 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
273 /// # tcp.set_nodelay(true)?;
274 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
275 /// let create_table = r#"
276 /// CREATE TABLE ##bulk_test (
277 /// id INT IDENTITY PRIMARY KEY,
278 /// val INT NOT NULL
279 /// )
280 /// "#;
281 ///
282 /// client.simple_query(create_table).await?;
283 ///
284 /// // Start the bulk insert with the client.
285 /// let mut req = client.bulk_insert("##bulk_test").await?;
286 ///
287 /// for i in [0i32, 1i32, 2i32] {
288 /// let row = (i).into_row();
289 ///
290 /// // The request will handle flushing to the wire in an optimal way,
291 /// // balancing between memory usage and IO performance.
292 /// req.send(row).await?;
293 /// }
294 ///
295 /// // The request must be finalized.
296 /// let res = req.finalize().await?;
297 /// assert_eq!(3, res.total());
298 /// # Ok(())
299 /// # }
300 /// ```
301 pub async fn bulk_insert<'a>(
302 &'a mut self,
303 table: &str,
304 ) -> crate::Result<BulkLoadRequest<'a, S>> {
305 let columns = self.bulk_insert_columns(table).await?;
306 self.bulk_insert_with_columns(table, columns).await
307 }
308
309 /// Returns updateable target column metadata for a future bulk insert.
310 ///
311 /// This method only sends a metadata query. It does not start the
312 /// `INSERT BULK` protocol flow, so callers can validate the target table
313 /// and fail without needing to finalize an empty bulk-load request.
314 pub async fn bulk_insert_columns(
315 &mut self,
316 table: &str,
317 ) -> crate::Result<BulkLoadColumns<'static>> {
318 self.connection.flush_stream().await?;
319
320 let query = format!("SELECT TOP 0 * FROM {}", table);
321
322 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
323
324 let id = self.connection.context_mut().next_packet_id();
325 self.connection.send(PacketHeader::batch(id), req).await?;
326
327 let token_stream = TokenStream::new(&mut self.connection).try_unfold();
328
329 let columns = token_stream
330 .try_fold(None, |mut columns, token| async move {
331 if let ReceivedToken::NewResultset(metadata) = token {
332 columns = Some(metadata.columns.clone());
333 };
334
335 Ok(columns)
336 })
337 .await?;
338
339 // now start bulk upload
340 let columns: Vec<_> = columns
341 .ok_or_else(|| {
342 crate::Error::Protocol("expecting column metadata from query but not found".into())
343 })?
344 .into_iter()
345 .filter(|column| column.base.flags.contains(ColumnFlag::Updateable))
346 .collect();
347
348 Ok(BulkLoadColumns::new(columns))
349 }
350
351 /// Starts a bulk insert using previously discovered target columns.
352 pub async fn bulk_insert_with_columns<'a>(
353 &'a mut self,
354 table: &str,
355 columns: BulkLoadColumns<'a>,
356 ) -> crate::Result<BulkLoadRequest<'a, S>> {
357 let columns = columns.into_inner();
358
359 self.connection.flush_stream().await?;
360 let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
361 let query = format!("INSERT BULK {} ({})", table, col_data);
362
363 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
364 let id = self.connection.context_mut().next_packet_id();
365
366 self.connection.send(PacketHeader::batch(id), req).await?;
367
368 let ts = TokenStream::new(&mut self.connection);
369 ts.flush_done().await?;
370
371 BulkLoadRequest::new(&mut self.connection, columns)
372 }
373
374 /// Closes this database connection explicitly.
375 pub async fn close(self) -> crate::Result<()> {
376 self.connection.close().await
377 }
378
379 pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
380 vec![
381 RpcParam {
382 name: Cow::Borrowed("stmt"),
383 flags: BitFlags::empty(),
384 value: ColumnData::String(Some(query.into())),
385 },
386 RpcParam {
387 name: Cow::Borrowed("params"),
388 flags: BitFlags::empty(),
389 value: ColumnData::I32(Some(0)),
390 },
391 ]
392 }
393
394 pub(crate) async fn rpc_perform_query<'a, 'b>(
395 &'a mut self,
396 proc_id: RpcProcId,
397 mut rpc_params: Vec<RpcParam<'b>>,
398 params: impl Iterator<Item = ColumnData<'b>>,
399 ) -> crate::Result<()>
400 where
401 'a: 'b,
402 {
403 let mut param_str = String::new();
404
405 for (i, param) in params.enumerate() {
406 if i > 0 {
407 param_str.push(',')
408 }
409 param_str.push_str(&format!("@P{} ", i + 1));
410 param_str.push_str(¶m.type_name());
411
412 rpc_params.push(RpcParam {
413 name: Cow::Owned(format!("@P{}", i + 1)),
414 flags: BitFlags::empty(),
415 value: param,
416 });
417 }
418
419 if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") {
420 params.value = ColumnData::String(Some(param_str.into()));
421 }
422
423 let req = TokenRpcRequest::new(
424 proc_id,
425 rpc_params,
426 self.connection.context().transaction_descriptor(),
427 );
428
429 let id = self.connection.context_mut().next_packet_id();
430 self.connection.send(PacketHeader::rpc(id), req).await?;
431
432 Ok(())
433 }
434}