refinery_core/drivers/
tiberius.rs

1use crate::traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction};
2use crate::Migration;
3
4use async_trait::async_trait;
5use futures::{
6    io::{AsyncRead, AsyncWrite},
7    TryStreamExt,
8};
9use tiberius::{error::Error, Client, QueryItem};
10use time::format_description::well_known::Rfc3339;
11use time::OffsetDateTime;
12
13async fn query_applied_migrations<S: AsyncRead + AsyncWrite + Unpin + Send>(
14    client: &mut Client<S>,
15    query: &str,
16) -> Result<Vec<Migration>, Error> {
17    let mut rows = client.simple_query(query).await?;
18    let mut applied = Vec::new();
19    // Unfortunately too many unwraps as `Row::get` maps to Option<T> instead of T
20    while let Some(item) = rows.try_next().await? {
21        if let QueryItem::Row(row) = item {
22            let version = row.get::<i32, usize>(0).unwrap();
23            let applied_on: &str = row.get::<&str, usize>(2).unwrap();
24            // Safe to call unwrap, as we stored it in RFC3339 format on the database
25            let applied_on = OffsetDateTime::parse(applied_on, &Rfc3339).unwrap();
26            let checksum: String = row.get::<&str, usize>(3).unwrap().to_string();
27
28            applied.push(Migration::applied(
29                version,
30                row.get::<&str, usize>(1).unwrap().to_string(),
31                applied_on,
32                checksum
33                    .parse::<u64>()
34                    .expect("checksum must be a valid u64"),
35            ));
36        }
37    }
38
39    Ok(applied)
40}
41
42#[async_trait]
43impl<S> AsyncTransaction for Client<S>
44where
45    S: AsyncRead + AsyncWrite + Unpin + Send,
46{
47    type Error = Error;
48
49    async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
50        // Tiberius doesn't support transactions, see https://github.com/prisma/tiberius/issues/28
51        self.simple_query("BEGIN TRAN T1;").await?;
52        let mut count = 0;
53        for query in queries {
54            // Drop the returning `QueryStream<'a>` to avoid compiler complaning regarding lifetimes
55            if let Err(err) = self.simple_query(*query).await.map(drop) {
56                if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
57                    log::error!("could not ROLLBACK transaction, {}", err);
58                }
59                return Err(err);
60            }
61            count += 1;
62        }
63        self.simple_query("COMMIT TRAN T1").await?;
64        Ok(count as usize)
65    }
66}
67
68#[async_trait]
69impl<S> AsyncQuery<Vec<Migration>> for Client<S>
70where
71    S: AsyncRead + AsyncWrite + Unpin + Send,
72{
73    async fn query(
74        &mut self,
75        query: &str,
76    ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
77        let applied = query_applied_migrations(self, query).await?;
78        Ok(applied)
79    }
80}
81
82impl<S> AsyncMigrate for Client<S>
83where
84    S: AsyncRead + AsyncWrite + Unpin + Send,
85{
86    fn assert_migrations_table_query(migration_table_name: &str) -> String {
87        format!(
88            "IF NOT EXISTS(SELECT 1 FROM sys.Tables WHERE  Name = N'{table_name}')
89         BEGIN
90           CREATE TABLE {table_name}(
91             version INT PRIMARY KEY,
92             name VARCHAR(255),
93             applied_on VARCHAR(255),
94             checksum VARCHAR(255));
95         END",
96            table_name = migration_table_name
97        )
98    }
99}