refinery_core/drivers/
tiberius.rs

1use crate::traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction};
2use crate::util::SchemaVersion;
3use crate::Migration;
4
5use async_trait::async_trait;
6use futures::{
7    io::{AsyncRead, AsyncWrite},
8    TryStreamExt,
9};
10use tiberius::{error::Error, Client, QueryItem};
11use time::format_description::well_known::Rfc3339;
12use time::OffsetDateTime;
13
14async fn query_applied_migrations<S: AsyncRead + AsyncWrite + Unpin + Send>(
15    client: &mut Client<S>,
16    query: &str,
17) -> Result<Vec<Migration>, Error> {
18    let mut rows = client.simple_query(query).await?;
19    let mut applied = Vec::new();
20    // Unfortunately too many unwraps as `Row::get` maps to Option<T> instead of T
21    while let Some(item) = rows.try_next().await? {
22        if let QueryItem::Row(row) = item {
23            let version = row.get::<SchemaVersion, usize>(0).unwrap();
24            let applied_on: &str = row.get::<&str, usize>(2).unwrap();
25            // Safe to call unwrap, as we stored it in RFC3339 format on the database
26            let applied_on = OffsetDateTime::parse(applied_on, &Rfc3339).unwrap();
27            let checksum: String = row.get::<&str, usize>(3).unwrap().to_string();
28
29            applied.push(Migration::applied(
30                version,
31                row.get::<&str, usize>(1).unwrap().to_string(),
32                applied_on,
33                checksum
34                    .parse::<u64>()
35                    .expect("checksum must be a valid u64"),
36            ));
37        }
38    }
39
40    Ok(applied)
41}
42
43#[async_trait]
44impl<S> AsyncTransaction for Client<S>
45where
46    S: AsyncRead + AsyncWrite + Unpin + Send,
47{
48    type Error = Error;
49
50    async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
51        &mut self,
52        queries: T,
53    ) -> Result<usize, Self::Error> {
54        // Tiberius doesn't support transactions, see https://github.com/prisma/tiberius/issues/28
55        self.simple_query("BEGIN TRAN T1;").await?;
56        let mut count = 0;
57        for query in queries {
58            // Drop the returning `QueryStream<'a>` to avoid compiler complaning regarding lifetimes
59            if let Err(err) = self.simple_query(query).await.map(drop) {
60                if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
61                    log::error!("could not ROLLBACK transaction, {}", err);
62                }
63                return Err(err);
64            }
65            count += 1;
66        }
67        self.simple_query("COMMIT TRAN T1").await?;
68        Ok(count as usize)
69    }
70}
71
72#[async_trait]
73impl<S> AsyncQuery<Vec<Migration>> for Client<S>
74where
75    S: AsyncRead + AsyncWrite + Unpin + Send,
76{
77    async fn query(
78        &mut self,
79        query: &str,
80    ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
81        let applied = query_applied_migrations(self, query).await?;
82        Ok(applied)
83    }
84}
85
86impl<S> AsyncMigrate for Client<S>
87where
88    S: AsyncRead + AsyncWrite + Unpin + Send,
89{
90    fn assert_migrations_table_query(migration_table_name: &str) -> String {
91        format!(
92            "IF NOT EXISTS(SELECT 1 FROM sys.Tables WHERE  Name = N'{migration_table_name}')
93         BEGIN
94           CREATE TABLE {migration_table_name}(
95             version INT PRIMARY KEY,
96             name VARCHAR(255),
97             applied_on VARCHAR(255),
98             checksum VARCHAR(255));
99         END"
100        )
101    }
102}