ankurah_storage_postgres/
lib.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7    error::RetrievalError,
8    property::Backends,
9    storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::State;
12
13use futures_util::TryStreamExt;
14
15pub mod predicate;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, Event, ID};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{error, info, warn};
25
26pub struct Postgres {
27    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31    pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool: pool }) }
32
33    // TODO: newtype this to `BucketName(&str)` with a constructor that
34    // only accepts a subset of characters.
35    pub fn sane_name(collection: &str) -> bool {
36        for char in collection.chars() {
37            match char {
38                char if char.is_alphanumeric() => {}
39                char if char.is_numeric() => {}
40                '_' | '.' | ':' => {}
41                _ => return false,
42            }
43        }
44
45        true
46    }
47}
48
49#[async_trait]
50impl StorageEngine for Postgres {
51    type Value = PGValue;
52
53    async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
54        if !Postgres::sane_name(collection_id.as_str()) {
55            return Err(RetrievalError::InvalidBucketName);
56        }
57
58        let bucket =
59            PostgresBucket { pool: self.pool.clone(), collection_id: collection_id.clone(), columns: Arc::new(RwLock::new(Vec::new())) };
60
61        // Try to create the table if it doesn't exist
62        let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
63        bucket.create_event_table(&mut client).await?;
64        bucket.create_state_table(&mut client).await?;
65        bucket.rebuild_columns_cache(&mut client).await?;
66
67        Ok(Arc::new(bucket))
68    }
69}
70
71#[derive(Clone, Debug)]
72pub struct PostgresColumn {
73    pub name: String,
74    pub is_nullable: bool,
75    pub data_type: String,
76}
77
78pub struct PostgresBucket {
79    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
80    collection_id: CollectionId,
81
82    columns: Arc<RwLock<Vec<PostgresColumn>>>,
83}
84
85impl PostgresBucket {
86    fn state_table(&self) -> String { format!("{}", self.collection_id.as_str()) }
87
88    pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
89
90    /// Rebuild the cache of columns in the table.
91    pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
92        let schema = "ankurah";
93
94        let column_query = format!(
95            r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#,
96        );
97        let mut new_columns = Vec::new();
98        info!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &schema, &self.collection_id.as_str());
99        let rows = client.query(&column_query, &[&schema, &self.collection_id.as_str()]).await?;
100        for row in rows {
101            let is_nullable: String = row.get("is_nullable");
102            new_columns.push(PostgresColumn {
103                name: row.get("column_name"),
104                is_nullable: is_nullable.eq("YES"),
105                data_type: row.get("data_type"),
106            })
107        }
108
109        let mut columns = self.columns.write().unwrap();
110        *columns = new_columns;
111        drop(columns);
112
113        Ok(())
114    }
115
116    pub fn existing_columns(&self) -> Vec<String> {
117        let columns = self.columns.read().unwrap();
118        columns.iter().map(|column| column.name.clone()).collect()
119    }
120
121    pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
122        let columns = self.columns.read().unwrap();
123        columns.iter().find(|column| column.name == *column_name).cloned()
124    }
125
126    pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
127
128    pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
129        let create_query = format!(
130            r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "entity_id" UUID, "operations" bytea, "parent" UUID[])"#,
131            self.event_table()
132        );
133
134        info!("Applying DDL: {}", create_query);
135        client.execute(&create_query, &[]).await?;
136        Ok(())
137    }
138
139    pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
140        let create_query =
141            format!(r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "state_buffer" BYTEA, "head" UUID[])"#, self.state_table());
142
143        info!("Applying DDL: {}", create_query);
144        match client.execute(&create_query, &[]).await {
145            Ok(_) => Ok(()),
146            Err(err) => {
147                info!("Error: {}", err);
148                // if err.code() == Some(SqlState::UNIQUE_VIOLATION) {
149                //     Ok(())
150                // } else {
151                Err(err.into())
152                // }
153            }
154        }
155    }
156
157    pub async fn add_missing_columns(
158        &self,
159        client: &mut tokio_postgres::Client,
160        missing: Vec<(String, &'static str)>, // column name, datatype
161    ) -> anyhow::Result<()> {
162        for (column, datatype) in missing {
163            if Postgres::sane_name(&column) {
164                let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
165                info!("Running: {}", alter_query);
166                client.execute(&alter_query, &[]).await?;
167            }
168        }
169
170        self.rebuild_columns_cache(client).await?;
171        Ok(())
172    }
173}
174
175#[async_trait]
176impl StorageCollection for PostgresBucket {
177    async fn set_state(&self, id: ID, state: &State) -> anyhow::Result<bool> {
178        let state_buffers = bincode::serialize(&state.state_buffers)?;
179        let ulid: ulid::Ulid = id.into();
180        let uuid: uuid::Uuid = ulid.into();
181
182        let head_uuids: Vec<uuid::Uuid> = (&state.head).into();
183
184        // Ensure head is not empty for new records
185        if head_uuids.is_empty() {
186            warn!("Warning: Empty head detected for entity {}", id);
187        }
188
189        let mut client = self.pool.get().await?;
190
191        let backends = Backends::from_state_buffers(state)?;
192        let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned()];
193        let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
194        params.push(&uuid);
195        params.push(&state_buffers);
196        params.push(&head_uuids);
197
198        let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
199        for (column, value) in backends.property_values() {
200            let pg_value: Option<PGValue> = value.map(|value| value.into());
201            if !self.has_column(&column) {
202                // We don't have the column yet and we know the type.
203                if let Some(ref pg_value) = pg_value {
204                    self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
205                } else {
206                    // The column doesn't exist yet and we don't have a value.
207                    // This means the entire column is already null/none so we
208                    // don't need to set anything.
209                    continue;
210                }
211            }
212
213            materialized.push((column.clone(), pg_value));
214        }
215
216        for (name, parameter) in &materialized {
217            columns.push(name.clone());
218
219            match &parameter {
220                Some(value) => match value {
221                    PGValue::CharacterVarying(string) => params.push(string),
222                    PGValue::SmallInt(number) => params.push(number),
223                    PGValue::Integer(number) => params.push(number),
224                    PGValue::BigInt(number) => params.push(number),
225                    PGValue::Bytea(bytes) => params.push(bytes),
226                },
227                None => params.push(&None::<i32>),
228            }
229        }
230
231        let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
232        let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
233        let columns_update_str = columns
234            .iter()
235            .enumerate()
236            .skip(1) // Skip "id"
237            .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
238            .collect::<Vec<String>>()
239            .join(", ");
240
241        // be careful with sql injection via bucket name
242        let query = format!(
243            r#"WITH old_state AS (
244                SELECT "head" FROM "{0}" WHERE "id" = $1
245            )
246            INSERT INTO "{0}"({1}) VALUES({2})
247            ON CONFLICT("id") DO UPDATE SET {3}
248            RETURNING (SELECT "head" FROM old_state) as old_head"#,
249            self.state_table(),
250            columns_str,
251            values_str,
252            columns_update_str
253        );
254
255        info!("Querying: {}", query);
256        let row = match client.query_one(&query, params.as_slice()).await {
257            Ok(row) => row,
258            Err(err) => {
259                let kind = error_kind(&err);
260                match kind {
261                    ErrorKind::UndefinedTable { table } => {
262                        if table == self.state_table() {
263                            self.create_state_table(&mut *client).await?;
264                            return self.set_state(id, state).await; // retry
265                        }
266                    }
267                    _ => {}
268                }
269
270                return Err(err.into());
271            }
272        };
273
274        // If this is a new entity (no old_head), or if the heads are different, return true
275        let old_head: Option<Vec<uuid::Uuid>> = row.get("old_head");
276        let changed = match old_head {
277            None => true, // New entity
278            Some(old_head) => {
279                let old_clock: Clock = old_head.into();
280                old_clock != state.head
281            }
282        };
283
284        info!("Changed: {}", changed);
285        Ok(changed)
286    }
287
288    async fn get_state(&self, id: ID) -> Result<State, RetrievalError> {
289        let ulid: ulid::Ulid = id.into();
290        let uuid: uuid::Uuid = ulid.into();
291
292        // be careful with sql injection via bucket name
293        let query = format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE "id" = $1"#, self.state_table());
294
295        let mut client = match self.pool.get().await {
296            Ok(client) => client,
297            Err(err) => {
298                return Err(RetrievalError::StorageError(err.into()));
299            }
300        };
301
302        info!("Getting state: {}", query);
303        let row = match client.query_one(&query, &[&uuid]).await {
304            Ok(row) => row,
305            Err(err) => {
306                let kind = error_kind(&err);
307                match kind {
308                    ErrorKind::RowCount => {
309                        return Err(RetrievalError::NotFound(id));
310                    }
311                    ErrorKind::UndefinedTable { table } => {
312                        if table == self.state_table() {
313                            self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
314                            return Err(RetrievalError::NotFound(id));
315                        }
316                    }
317                    _ => {}
318                }
319
320                return Err(RetrievalError::StorageError(err.into()));
321            }
322        };
323
324        info!("Row: {:?}", row);
325        let row_id: uuid::Uuid = row.get("id");
326        assert_eq!(row_id, uuid);
327
328        let serialized_buffers: Vec<u8> = row.get("state_buffer");
329        let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers)?;
330
331        Ok(State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>("head").into() })
332    }
333
334    async fn fetch_states(&self, predicate: &ankql::ast::Predicate) -> Result<Vec<(ID, State)>, RetrievalError> {
335        println!("Fetching states for: {:?}", predicate);
336        let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
337
338        let mut results = Vec::new();
339
340        let mut ankql_sql = predicate::Sql::new();
341        ankql_sql.predicate(&predicate);
342
343        let (sql, args) = ankql_sql.collapse();
344
345        let filtered_query = if !sql.is_empty() {
346            format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE {}"#, self.state_table(), sql,)
347        } else {
348            format!(r#"SELECT "id", "state_buffer", "head" FROM "{}""#, self.state_table())
349        };
350
351        info!("SQL: {} with args: {:?}", filtered_query, args);
352
353        let rows = match client.query_raw(&filtered_query, args).await {
354            Ok(stream) => match stream.try_collect::<Vec<_>>().await {
355                Ok(rows) => rows,
356                Err(err) => return Err(RetrievalError::StorageError(err.into())),
357            },
358            Err(err) => {
359                let kind = error_kind(&err);
360                match kind {
361                    ErrorKind::UndefinedTable { table } => {
362                        if table == self.state_table() {
363                            // Table doesn't exist yet, return empty results
364                            return Ok(Vec::new());
365                        }
366                    }
367                    ErrorKind::UndefinedColumn { table, column } => {
368                        println!("Undefined column: {} in table: {:?}", column, table);
369                        match table {
370                            Some(table) if table == self.state_table() => {
371                                // Modify the predicate treating this column as NULL and retry
372                                return self.fetch_states(&predicate.assume_null(&[column])).await;
373                            }
374                            None => {
375                                return self.fetch_states(&predicate.assume_null(&[column])).await;
376                            }
377                            _ => {}
378                        }
379                    }
380                    _ => {}
381                }
382
383                return Err(RetrievalError::StorageError(err.into()));
384            }
385        };
386
387        for row in rows {
388            let uuid: uuid::Uuid = row.get(0);
389            let state_buffer: Vec<u8> = row.get(1);
390            let id = ID::from_ulid(ulid::Ulid::from(uuid));
391
392            let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer)?;
393
394            let entity_state = State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>(2).into() };
395
396            results.push((id, entity_state));
397        }
398
399        Ok(results)
400    }
401
402    /// Postgres Event Table:
403    /// {bucket_name}_event
404    /// event_id uuid, // `ID`/`ULID`
405    /// entity_id uuid, // `ID`/`ULID`
406    /// operations bytea, // `Vec<Operation>`
407    /// clock bytea, // `Clock`
408    async fn add_event(&self, entity_event: &Event) -> anyhow::Result<bool> {
409        let event_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.id));
410        let entity_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.entity_id));
411        let operations = bincode::serialize(&entity_event.operations)?;
412        let parent_uuids: Vec<uuid::Uuid> = (&entity_event.parent).into();
413
414        // Does it even matter if this conflicts?
415        // One peers event should match any duplicates, so taking the first
416        // event we receive from a peer should be fine.
417        let query = format!(r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent") VALUES($1, $2, $3, $4)"#, self.event_table(),);
418
419        let mut client = self.pool.get().await?;
420        info!("Running: {}", query);
421        let affected = match client.execute(&query, &[&event_id, &entity_id, &operations, &parent_uuids]).await {
422            Ok(affected) => affected,
423            Err(err) => {
424                let kind = error_kind(&err);
425                match kind {
426                    ErrorKind::UndefinedTable { table } => {
427                        if table == self.event_table() {
428                            self.create_event_table(&mut *client).await?;
429                            return self.add_event(entity_event).await; // retry
430                        }
431                    }
432                    _ => {}
433                }
434
435                return Err(err.into());
436            }
437        };
438
439        Ok(affected > 0)
440    }
441
442    async fn get_events(&self, entity_id: ID) -> Result<Vec<Event>, ankurah_core::error::RetrievalError> {
443        let query = format!(r#"SELECT "id", "operations", "parent" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
444
445        let entity_uuid = uuid::Uuid::from(ulid::Ulid::from(entity_id));
446
447        let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
448        info!("Running: {}", query);
449        let rows = match client.query(&query, &[&entity_uuid]).await {
450            Ok(rows) => rows,
451            Err(err) => {
452                let kind = error_kind(&err);
453                match kind {
454                    ErrorKind::UndefinedTable { table } => {
455                        if table == self.event_table() {
456                            self.create_event_table(&mut *client).await?;
457                            return Ok(Vec::new());
458                        }
459                    }
460                    _ => {}
461                }
462
463                return Err(RetrievalError::storage(err));
464            }
465        };
466
467        let mut events = Vec::new();
468        for row in rows {
469            let event_id: uuid::Uuid = row.get("id");
470            let event_id = ID::from_ulid(event_id.into());
471            let operations_binary: Vec<u8> = row.get("operations");
472            let operations = bincode::deserialize(&operations_binary)?;
473            let parent: Vec<uuid::Uuid> = row.get("parent");
474            let parent = parent.into_iter().map(|uuid| ID::from_ulid(uuid.into())).collect::<BTreeSet<_>>();
475            let clock = Clock::new(parent);
476
477            events.push(Event {
478                id: event_id,
479                collection: self.collection_id.clone(),
480                entity_id: entity_id,
481                operations: operations,
482                parent: clock,
483            })
484        }
485
486        Ok(events)
487    }
488}
489
490// Some hacky shit because rust-postgres doesn't let us ask for the error kind
491// TODO: remove this when https://github.com/sfackler/rust-postgres/pull/1185
492//       gets merged
493#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
494pub enum ErrorKind {
495    RowCount,
496    UndefinedTable { table: String },
497    UndefinedColumn { table: Option<String>, column: String },
498    Unknown,
499}
500
501pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
502    let string = err.to_string().trim().to_owned();
503    let _db_error = err.as_db_error();
504    let sql_code = err.code().cloned();
505
506    if string == "query returned an unexpected number of rows" {
507        return ErrorKind::RowCount;
508    }
509
510    // Useful for adding new errors
511    // error!("postgres error: {:?}", err);
512    // error!("db_err: {:?}", err.as_db_error());
513    // error!("sql_code: {:?}", err.code());
514    // error!("err: {:?}", err);
515    // error!("err: {:?}", err.to_string());
516
517    let quote_indices = |s: &str| {
518        let mut quotes = Vec::new();
519        for (index, char) in s.char_indices() {
520            match char {
521                '"' => quotes.push(index),
522                _ => {}
523            }
524        }
525        quotes
526    };
527
528    match sql_code {
529        Some(SqlState::UNDEFINED_TABLE) => {
530            // relation "album" does not exist
531            let quotes = quote_indices(&string);
532            let table = &string[quotes[0] + 1..quotes[1]];
533            ErrorKind::UndefinedTable { table: table.to_owned() }
534        }
535        Some(SqlState::UNDEFINED_COLUMN) => {
536            // Handle both formats:
537            // "column "name" of relation "album" does not exist"
538            // "column "status" does not exist"
539            let quotes = quote_indices(&string);
540            let column = string[quotes[0] + 1..quotes[1]].to_owned();
541
542            println!("quotes: {:?}", quotes);
543            println!("column: {}", column);
544
545            let table = if quotes.len() >= 4 {
546                // Full format with table name
547                Some(string[quotes[2] + 1..quotes[3]].to_owned())
548            } else {
549                // Short format without table name, use empty string
550                None
551            };
552
553            ErrorKind::UndefinedColumn { table, column }
554        }
555        _ => ErrorKind::Unknown,
556    }
557}
558
559#[allow(unused)]
560pub struct MissingMaterialized {
561    pub name: String,
562}