refinery_core/drivers/
tiberius.rs1use crate::traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction};
2use crate::Migration;
3
4use async_trait::async_trait;
5use futures::{
6 io::{AsyncRead, AsyncWrite},
7 TryStreamExt,
8};
9use tiberius::{error::Error, Client, QueryItem};
10use time::format_description::well_known::Rfc3339;
11use time::OffsetDateTime;
12
13async fn query_applied_migrations<S: AsyncRead + AsyncWrite + Unpin + Send>(
14 client: &mut Client<S>,
15 query: &str,
16) -> Result<Vec<Migration>, Error> {
17 let mut rows = client.simple_query(query).await?;
18 let mut applied = Vec::new();
19 while let Some(item) = rows.try_next().await? {
21 if let QueryItem::Row(row) = item {
22 let version = row.get::<i32, usize>(0).unwrap();
23 let applied_on: &str = row.get::<&str, usize>(2).unwrap();
24 let applied_on = OffsetDateTime::parse(applied_on, &Rfc3339).unwrap();
26 let checksum: String = row.get::<&str, usize>(3).unwrap().to_string();
27
28 applied.push(Migration::applied(
29 version,
30 row.get::<&str, usize>(1).unwrap().to_string(),
31 applied_on,
32 checksum
33 .parse::<u64>()
34 .expect("checksum must be a valid u64"),
35 ));
36 }
37 }
38
39 Ok(applied)
40}
41
42#[async_trait]
43impl<S> AsyncTransaction for Client<S>
44where
45 S: AsyncRead + AsyncWrite + Unpin + Send,
46{
47 type Error = Error;
48
49 async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
50 self.simple_query("BEGIN TRAN T1;").await?;
52 let mut count = 0;
53 for query in queries {
54 if let Err(err) = self.simple_query(*query).await.map(drop) {
56 if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
57 log::error!("could not ROLLBACK transaction, {}", err);
58 }
59 return Err(err);
60 }
61 count += 1;
62 }
63 self.simple_query("COMMIT TRAN T1").await?;
64 Ok(count as usize)
65 }
66}
67
68#[async_trait]
69impl<S> AsyncQuery<Vec<Migration>> for Client<S>
70where
71 S: AsyncRead + AsyncWrite + Unpin + Send,
72{
73 async fn query(
74 &mut self,
75 query: &str,
76 ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
77 let applied = query_applied_migrations(self, query).await?;
78 Ok(applied)
79 }
80}
81
82impl<S> AsyncMigrate for Client<S>
83where
84 S: AsyncRead + AsyncWrite + Unpin + Send,
85{
86 fn assert_migrations_table_query(migration_table_name: &str) -> String {
87 format!(
88 "IF NOT EXISTS(SELECT 1 FROM sys.Tables WHERE Name = N'{table_name}')
89 BEGIN
90 CREATE TABLE {table_name}(
91 version INT PRIMARY KEY,
92 name VARCHAR(255),
93 applied_on VARCHAR(255),
94 checksum VARCHAR(255));
95 END",
96 table_name = migration_table_name
97 )
98 }
99}