1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use bb8::Pool;
use bb8_tiberius::ConnectionManager;
use chrono::{DateTime, Utc};
use eventful::{JournalStore, Snapshot, SnapshotStore};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fmt::Debug;
use tiberius::Row;

#[derive(Clone)]
pub struct SqlServerJournalStore {
    pub pool: Pool<ConnectionManager>,
}

#[derive(Clone)]
pub struct SqlServerSnapshotStore {
    pub pool: Pool<ConnectionManager>,
}

#[async_trait]
impl JournalStore for SqlServerJournalStore {
    async fn get_initial_offset(&self) -> anyhow::Result<u64> {
        const SQL: &str = "select max(offset) as max_offset from journal;";

        tracing::info!("Getting initial sequence number");

        let mut conn = self.pool.get().await?;
        let rows: Vec<Row> = conn.query(SQL, &[]).await?.into_first_result().await?;

        if rows.len() != 1 {
            return Err(anyhow!("Expected 1 row, got {}", rows.len()));
        }

        let max_offset: i64 = rows[0].try_get(0).context("failed to read max_offset")?.unwrap_or(0);

        tracing::info!("Got {} initial sequence number", max_offset);

        Ok(max_offset as u64)
    }

    // TODO - this should stream events directly from the database rather than returning a vector of all events
    async fn get_events_from_journal(&self, offset: u64) -> anyhow::Result<Vec<(u64, Vec<u8>)>> {
        const SQL: &str = "select offset,message from journal where journal.offset >= @p1 order by offset asc;";
        let mut conn = self.pool.get().await?;
        let offset = offset as i64;
        let rows: Vec<Row> = conn.query(SQL, &[&offset]).await?.into_first_result().await?;
        tracing::debug!("get_events_from_journal from offset {}, got {} rows", offset, rows.len());

        let mut r = vec![];

        for row in rows {
            let offset: i64 = row.try_get(0).context("failed to read offset")?.expect("offset should not be null");
            let data: &[u8] = row.try_get(1).context("failed to read data")?.expect("data should not be null");
            r.push((offset as u64, data.to_vec()));
        }

        Ok(r)
    }

    async fn load_entity_events(&self, entity_type_name: &str, persistence_id: &str, offset: u64) -> anyhow::Result<Vec<Vec<u8>>> {
        const SQL: &str = "select message from journal where entity_type = @P1 and persistence_id = @P2 and offset >= @P3 order by offset asc;";
        let mut conn = self.pool.get().await?;
        let sql_compatible_offset = offset as i64;
        let rows: Vec<Row> = conn.query(SQL, &[&entity_type_name, &persistence_id, &sql_compatible_offset]).await?.into_first_result().await?;
        tracing::debug!("load_entity {} ({}), got {} rows from journal", entity_type_name, persistence_id, rows.len());
        if rows.is_empty() {
            Ok(vec![])
        } else {
            let mut v = vec![];
            for row in rows {
                let message: &[u8] = row.try_get(0).context("failed to read message")?.unwrap();
                let message = message.to_vec();
                v.push(message);
            }

            Ok(v)
        }
    }

    async fn persist_event_to_journal(&self, entity_type_name: &str, event_type: &str, event_date: &DateTime<Utc>, bytes: &[u8], persistence_id: &str, offset: u64) -> anyhow::Result<()> {
        tracing::trace!("persist_event_to_journal: {:?} bytes", bytes.len());

        const SQL: &str = "insert into journal(offset,entity_type,persistence_id,event_type,message,event_date) values (@p1,@p2,@p3,@p4,@p5,@p6);";

        let mut conn = self.pool.get().await?;
        let offset = offset as i64;
        let results = conn.execute(SQL, &[&offset, &entity_type_name, &persistence_id, &event_type, &bytes, event_date]).await?;

        tracing::trace!("persist_event_to_journal rows_affected: {:?}", results.rows_affected());

        Ok(())
    }
}

#[async_trait]
impl SnapshotStore for SqlServerSnapshotStore {
    async fn read_snapshot<'de, S>(&self, name: &str) -> anyhow::Result<Option<Snapshot<S>>>
    where
        S: DeserializeOwned,
    {
        const SQL: &str = "select offset, value from snapshot where name = @P1;";
        let mut conn = self.pool.get().await?;
        let rows: Vec<Row> = conn.query(SQL, &[&name]).await?.into_first_result().await?;
        tracing::debug!("read_snapshot {}, got {} rows", name, rows.len());
        if rows.is_empty() {
            Ok(None)
        } else {
            let offset: i64 = rows[0].try_get(0).context("failed to read snapshot offset value from data row")?.unwrap();
            let bytes: &[u8] = rows[0].try_get(1).context("failed to read snapshot value from data row")?.unwrap();
            let json = serde_json::from_slice::<S>(bytes).context("failed to deserialise snapshot from json")?;
            Ok(Some(Snapshot { offset: offset as u64, value: json }))
        }
    }

    async fn write_snapshot<S>(&self, name: &str, offset: u64, value: &S) -> anyhow::Result<()>
    where
        S: Debug + Serialize + Sync,
    {
        tracing::trace!("write_snapshot at offset: {} for value: {:?}", offset, value);

        const SQL: &str = "merge snapshot as s
using (select @p1 name, @p2 offset, @p3 value) as p
on s.name = p.name
when matched then
    update
    set offset = p.offset, value = p.value
when not matched then
    insert (name, offset, value)
    values (@p1, @p2, @p3);";

        let mut conn = self.pool.get().await?;
        let json = serde_json::to_string(value).context("failed to serialise snapshot to json")?;
        let bytes = json.as_bytes();
        let sql_compatible_offset = offset as i64;
        let results = conn.execute(SQL, &[&name, &sql_compatible_offset, &bytes]).await?;

        tracing::trace!("write_snapshot rows_affected: {:?}", results.rows_affected());

        Ok(())
    }
}

pub async fn build_connection(connection_string: &str) -> Result<Pool<ConnectionManager>, bb8_tiberius::Error> {
    tracing::info!("CONN: {}", connection_string);
    let mgr = ConnectionManager::build(connection_string)?;
    Pool::builder().connection_timeout(std::time::Duration::from_millis(15000)).max_size(6).build(mgr).await
}