refinery_core/drivers/
tiberius.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use crate::traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction};
use crate::Migration;

use async_trait::async_trait;
use futures::{
    io::{AsyncRead, AsyncWrite},
    TryStreamExt,
};
use tiberius::{error::Error, Client, QueryItem};
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;

async fn query_applied_migrations<S: AsyncRead + AsyncWrite + Unpin + Send>(
    client: &mut Client<S>,
    query: &str,
) -> Result<Vec<Migration>, Error> {
    let mut rows = client.simple_query(query).await?;
    let mut applied = Vec::new();
    // Unfortunately too many unwraps as `Row::get` maps to Option<T> instead of T
    while let Some(item) = rows.try_next().await? {
        if let QueryItem::Row(row) = item {
            let version = row.get::<i32, usize>(0).unwrap();
            let applied_on: &str = row.get::<&str, usize>(2).unwrap();
            // Safe to call unwrap, as we stored it in RFC3339 format on the database
            let applied_on = OffsetDateTime::parse(applied_on, &Rfc3339).unwrap();
            let checksum: String = row.get::<&str, usize>(3).unwrap().to_string();

            applied.push(Migration::applied(
                version,
                row.get::<&str, usize>(1).unwrap().to_string(),
                applied_on,
                checksum
                    .parse::<u64>()
                    .expect("checksum must be a valid u64"),
            ));
        }
    }

    Ok(applied)
}

#[async_trait]
impl<S> AsyncTransaction for Client<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    type Error = Error;

    async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
        // Tiberius doesn't support transactions, see https://github.com/prisma/tiberius/issues/28
        self.simple_query("BEGIN TRAN T1;").await?;
        let mut count = 0;
        for query in queries {
            // Drop the returning `QueryStream<'a>` to avoid compiler complaning regarding lifetimes
            if let Err(err) = self.simple_query(*query).await.map(drop) {
                if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
                    log::error!("could not ROLLBACK transaction, {}", err);
                }
                return Err(err);
            }
            count += 1;
        }
        self.simple_query("COMMIT TRAN T1").await?;
        Ok(count as usize)
    }
}

#[async_trait]
impl<S> AsyncQuery<Vec<Migration>> for Client<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    async fn query(
        &mut self,
        query: &str,
    ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
        let applied = query_applied_migrations(self, query).await?;
        Ok(applied)
    }
}

impl<S> AsyncMigrate for Client<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    fn assert_migrations_table_query(migration_table_name: &str) -> String {
        format!(
            "IF NOT EXISTS(SELECT 1 FROM sys.Tables WHERE  Name = N'{table_name}')
         BEGIN
           CREATE TABLE {table_name}(
             version INT PRIMARY KEY,
             name VARCHAR(255),
             applied_on VARCHAR(255),
             checksum VARCHAR(255));
         END",
            table_name = migration_table_name
        )
    }
}