refinery_core/drivers/
tiberius.rs

1use crate::traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction};
2use crate::Migration;
3
4use async_trait::async_trait;
5use futures::{
6    io::{AsyncRead, AsyncWrite},
7    TryStreamExt,
8};
9use tiberius::{error::Error, Client, QueryItem};
10use time::format_description::well_known::Rfc3339;
11use time::OffsetDateTime;
12
13async fn query_applied_migrations<S: AsyncRead + AsyncWrite + Unpin + Send>(
14    client: &mut Client<S>,
15    query: &str,
16) -> Result<Vec<Migration>, Error> {
17    let mut rows = client.simple_query(query).await?;
18    let mut applied = Vec::new();
19    // Unfortunately too many unwraps as `Row::get` maps to Option<T> instead of T
20    while let Some(item) = rows.try_next().await? {
21        if let QueryItem::Row(row) = item {
22            let version = row.get::<i32, usize>(0).unwrap();
23            let applied_on: &str = row.get::<&str, usize>(2).unwrap();
24            // Safe to call unwrap, as we stored it in RFC3339 format on the database
25            let applied_on = OffsetDateTime::parse(applied_on, &Rfc3339).unwrap();
26            let checksum: String = row.get::<&str, usize>(3).unwrap().to_string();
27
28            applied.push(Migration::applied(
29                version,
30                row.get::<&str, usize>(1).unwrap().to_string(),
31                applied_on,
32                checksum
33                    .parse::<u64>()
34                    .expect("checksum must be a valid u64"),
35            ));
36        }
37    }
38
39    Ok(applied)
40}
41
42#[async_trait]
43impl<S> AsyncTransaction for Client<S>
44where
45    S: AsyncRead + AsyncWrite + Unpin + Send,
46{
47    type Error = Error;
48
49    async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
50        &mut self,
51        queries: T,
52    ) -> Result<usize, Self::Error> {
53        // Tiberius doesn't support transactions, see https://github.com/prisma/tiberius/issues/28
54        self.simple_query("BEGIN TRAN T1;").await?;
55        let mut count = 0;
56        for query in queries {
57            // Drop the returning `QueryStream<'a>` to avoid compiler complaning regarding lifetimes
58            if let Err(err) = self.simple_query(query).await.map(drop) {
59                if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
60                    log::error!("could not ROLLBACK transaction, {}", err);
61                }
62                return Err(err);
63            }
64            count += 1;
65        }
66        self.simple_query("COMMIT TRAN T1").await?;
67        Ok(count as usize)
68    }
69}
70
71#[async_trait]
72impl<S> AsyncQuery<Vec<Migration>> for Client<S>
73where
74    S: AsyncRead + AsyncWrite + Unpin + Send,
75{
76    async fn query(
77        &mut self,
78        query: &str,
79    ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
80        let applied = query_applied_migrations(self, query).await?;
81        Ok(applied)
82    }
83}
84
85impl<S> AsyncMigrate for Client<S>
86where
87    S: AsyncRead + AsyncWrite + Unpin + Send,
88{
89    fn assert_migrations_table_query(migration_table_name: &str) -> String {
90        format!(
91            "IF NOT EXISTS(SELECT 1 FROM sys.Tables WHERE  Name = N'{table_name}')
92         BEGIN
93           CREATE TABLE {table_name}(
94             version INT PRIMARY KEY,
95             name VARCHAR(255),
96             applied_on VARCHAR(255),
97             checksum VARCHAR(255));
98         END",
99            table_name = migration_table_name
100        )
101    }
102}