ankurah_storage_postgres/
lib.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7    error::RetrievalError,
8    property::Backends,
9    storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::State;
12
13use futures_util::TryStreamExt;
14
15pub mod predicate;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, Event, ID};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{error, info, warn};
25
26pub struct Postgres {
27    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31    pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool: pool }) }
32
33    // TODO: newtype this to `BucketName(&str)` with a constructor that
34    // only accepts a subset of characters.
35    pub fn sane_name(collection: &str) -> bool {
36        for char in collection.chars() {
37            match char {
38                char if char.is_alphanumeric() => {}
39                char if char.is_numeric() => {}
40                '_' | '.' | ':' => {}
41                _ => return false,
42            }
43        }
44
45        true
46    }
47}
48
49#[async_trait]
50impl StorageEngine for Postgres {
51    type Value = PGValue;
52
53    async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
54        if !Postgres::sane_name(collection_id.as_str()) {
55            return Err(RetrievalError::InvalidBucketName);
56        }
57
58        let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
59
60        // get the current schema from the database
61        let schema = client.query_one("SELECT current_database()", &[]).await.map_err(|err| RetrievalError::storage(err))?;
62        let schema = schema.get("current_database");
63
64        let bucket = PostgresBucket {
65            pool: self.pool.clone(),
66            schema,
67            collection_id: collection_id.clone(),
68            columns: Arc::new(RwLock::new(Vec::new())),
69        };
70
71        // Try to create the table if it doesn't exist
72        bucket.create_event_table(&mut client).await?;
73        bucket.create_state_table(&mut client).await?;
74        bucket.rebuild_columns_cache(&mut client).await?;
75
76        Ok(Arc::new(bucket))
77    }
78}
79
80#[derive(Clone, Debug)]
81pub struct PostgresColumn {
82    pub name: String,
83    pub is_nullable: bool,
84    pub data_type: String,
85}
86
87pub struct PostgresBucket {
88    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
89    collection_id: CollectionId,
90    schema: String,
91    columns: Arc<RwLock<Vec<PostgresColumn>>>,
92}
93
94impl PostgresBucket {
95    fn state_table(&self) -> String { format!("{}", self.collection_id.as_str()) }
96
97    pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
98
99    /// Rebuild the cache of columns in the table.
100    pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
101        let column_query = format!(
102            r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#,
103        );
104        let mut new_columns = Vec::new();
105        info!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &self.schema, &self.collection_id.as_str());
106        let rows = client.query(&column_query, &[&self.schema, &self.collection_id.as_str()]).await?;
107        for row in rows {
108            let is_nullable: String = row.get("is_nullable");
109            new_columns.push(PostgresColumn {
110                name: row.get("column_name"),
111                is_nullable: is_nullable.eq("YES"),
112                data_type: row.get("data_type"),
113            })
114        }
115
116        let mut columns = self.columns.write().unwrap();
117        *columns = new_columns;
118        drop(columns);
119
120        Ok(())
121    }
122
123    pub fn existing_columns(&self) -> Vec<String> {
124        let columns = self.columns.read().unwrap();
125        columns.iter().map(|column| column.name.clone()).collect()
126    }
127
128    pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
129        let columns = self.columns.read().unwrap();
130        columns.iter().find(|column| column.name == *column_name).cloned()
131    }
132
133    pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
134
135    pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
136        let create_query = format!(
137            r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "entity_id" UUID, "operations" bytea, "parent" UUID[])"#,
138            self.event_table()
139        );
140
141        info!("Applying DDL: {}", create_query);
142        client.execute(&create_query, &[]).await?;
143        Ok(())
144    }
145
146    pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
147        let create_query =
148            format!(r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "state_buffer" BYTEA, "head" UUID[])"#, self.state_table());
149
150        info!("Applying DDL: {}", create_query);
151        match client.execute(&create_query, &[]).await {
152            Ok(_) => Ok(()),
153            Err(err) => {
154                info!("Error: {}", err);
155                // if err.code() == Some(SqlState::UNIQUE_VIOLATION) {
156                //     Ok(())
157                // } else {
158                Err(err.into())
159                // }
160            }
161        }
162    }
163
164    pub async fn add_missing_columns(
165        &self,
166        client: &mut tokio_postgres::Client,
167        missing: Vec<(String, &'static str)>, // column name, datatype
168    ) -> anyhow::Result<()> {
169        for (column, datatype) in missing {
170            if Postgres::sane_name(&column) {
171                let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
172                info!("Running: {}", alter_query);
173                match client.execute(&alter_query, &[]).await {
174                    Ok(_) => {}
175                    Err(err) => {
176                        warn!("Error adding column: {} to table: {} - rebuilding columns cache", err, self.state_table());
177                        self.rebuild_columns_cache(client).await?;
178                        return Err(err.into());
179                    }
180                }
181            }
182        }
183
184        self.rebuild_columns_cache(client).await?;
185        Ok(())
186    }
187}
188
189#[async_trait]
190impl StorageCollection for PostgresBucket {
191    async fn set_state(&self, id: ID, state: &State) -> anyhow::Result<bool> {
192        let state_buffers = bincode::serialize(&state.state_buffers)?;
193        let ulid: ulid::Ulid = id.into();
194        let uuid: uuid::Uuid = ulid.into();
195
196        let head_uuids: Vec<uuid::Uuid> = (&state.head).into();
197
198        // Ensure head is not empty for new records
199        if head_uuids.is_empty() {
200            warn!("Warning: Empty head detected for entity {}", id);
201        }
202
203        let mut client = self.pool.get().await?;
204
205        let backends = Backends::from_state_buffers(state)?;
206        let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned()];
207        let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
208        params.push(&uuid);
209        params.push(&state_buffers);
210        params.push(&head_uuids);
211
212        let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
213        for (column, value) in backends.property_values() {
214            let pg_value: Option<PGValue> = value.map(|value| value.into());
215            if !self.has_column(&column) {
216                // We don't have the column yet and we know the type.
217                if let Some(ref pg_value) = pg_value {
218                    self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
219                } else {
220                    // The column doesn't exist yet and we don't have a value.
221                    // This means the entire column is already null/none so we
222                    // don't need to set anything.
223                    continue;
224                }
225            }
226
227            materialized.push((column.clone(), pg_value));
228        }
229
230        for (name, parameter) in &materialized {
231            columns.push(name.clone());
232
233            match &parameter {
234                Some(value) => match value {
235                    PGValue::CharacterVarying(string) => params.push(string),
236                    PGValue::SmallInt(number) => params.push(number),
237                    PGValue::Integer(number) => params.push(number),
238                    PGValue::BigInt(number) => params.push(number),
239                    PGValue::Bytea(bytes) => params.push(bytes),
240                },
241                None => params.push(&None::<i32>),
242            }
243        }
244
245        let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
246        let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
247        let columns_update_str = columns
248            .iter()
249            .enumerate()
250            .skip(1) // Skip "id"
251            .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
252            .collect::<Vec<String>>()
253            .join(", ");
254
255        // be careful with sql injection via bucket name
256        let query = format!(
257            r#"WITH old_state AS (
258                SELECT "head" FROM "{0}" WHERE "id" = $1
259            )
260            INSERT INTO "{0}"({1}) VALUES({2})
261            ON CONFLICT("id") DO UPDATE SET {3}
262            RETURNING (SELECT "head" FROM old_state) as old_head"#,
263            self.state_table(),
264            columns_str,
265            values_str,
266            columns_update_str
267        );
268
269        info!("Querying: {}", query);
270        let row = match client.query_one(&query, params.as_slice()).await {
271            Ok(row) => row,
272            Err(err) => {
273                let kind = error_kind(&err);
274                match kind {
275                    ErrorKind::UndefinedTable { table } => {
276                        if table == self.state_table() {
277                            self.create_state_table(&mut *client).await?;
278                            return self.set_state(id, state).await; // retry
279                        }
280                    }
281                    _ => {}
282                }
283
284                return Err(err.into());
285            }
286        };
287
288        // If this is a new entity (no old_head), or if the heads are different, return true
289        let old_head: Option<Vec<uuid::Uuid>> = row.get("old_head");
290        let changed = match old_head {
291            None => true, // New entity
292            Some(old_head) => {
293                let old_clock: Clock = old_head.into();
294                old_clock != state.head
295            }
296        };
297
298        info!("Changed: {}", changed);
299        Ok(changed)
300    }
301
302    async fn get_state(&self, id: ID) -> Result<State, RetrievalError> {
303        let ulid: ulid::Ulid = id.into();
304        let uuid: uuid::Uuid = ulid.into();
305
306        // be careful with sql injection via bucket name
307        let query = format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE "id" = $1"#, self.state_table());
308
309        let mut client = match self.pool.get().await {
310            Ok(client) => client,
311            Err(err) => {
312                return Err(RetrievalError::StorageError(err.into()));
313            }
314        };
315
316        info!("Getting state: {}", query);
317        let row = match client.query_one(&query, &[&uuid]).await {
318            Ok(row) => row,
319            Err(err) => {
320                let kind = error_kind(&err);
321                match kind {
322                    ErrorKind::RowCount => {
323                        return Err(RetrievalError::NotFound(id));
324                    }
325                    ErrorKind::UndefinedTable { table } => {
326                        if table == self.state_table() {
327                            self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
328                            return Err(RetrievalError::NotFound(id));
329                        }
330                    }
331                    _ => {}
332                }
333
334                return Err(RetrievalError::StorageError(err.into()));
335            }
336        };
337
338        info!("Row: {:?}", row);
339        let row_id: uuid::Uuid = row.get("id");
340        assert_eq!(row_id, uuid);
341
342        let serialized_buffers: Vec<u8> = row.get("state_buffer");
343        let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers)?;
344
345        Ok(State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>("head").into() })
346    }
347
348    async fn fetch_states(&self, predicate: &ankql::ast::Predicate) -> Result<Vec<(ID, State)>, RetrievalError> {
349        println!("Fetching states for: {:?}", predicate);
350        let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
351
352        let mut results = Vec::new();
353
354        let mut ankql_sql = predicate::Sql::new();
355        ankql_sql.predicate(&predicate);
356
357        let (sql, args) = ankql_sql.collapse();
358
359        let filtered_query = if !sql.is_empty() {
360            format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE {}"#, self.state_table(), sql,)
361        } else {
362            format!(r#"SELECT "id", "state_buffer", "head" FROM "{}""#, self.state_table())
363        };
364
365        info!("SQL: {} with args: {:?}", filtered_query, args);
366
367        let rows = match client.query_raw(&filtered_query, args).await {
368            Ok(stream) => match stream.try_collect::<Vec<_>>().await {
369                Ok(rows) => rows,
370                Err(err) => return Err(RetrievalError::StorageError(err.into())),
371            },
372            Err(err) => {
373                let kind = error_kind(&err);
374                match kind {
375                    ErrorKind::UndefinedTable { table } => {
376                        if table == self.state_table() {
377                            // Table doesn't exist yet, return empty results
378                            return Ok(Vec::new());
379                        }
380                    }
381                    ErrorKind::UndefinedColumn { table, column } => {
382                        println!("Undefined column: {} in table: {:?}", column, table);
383                        match table {
384                            Some(table) if table == self.state_table() => {
385                                // Modify the predicate treating this column as NULL and retry
386                                return self.fetch_states(&predicate.assume_null(&[column])).await;
387                            }
388                            None => {
389                                return self.fetch_states(&predicate.assume_null(&[column])).await;
390                            }
391                            _ => {}
392                        }
393                    }
394                    _ => {}
395                }
396
397                return Err(RetrievalError::StorageError(err.into()));
398            }
399        };
400
401        for row in rows {
402            let uuid: uuid::Uuid = row.get(0);
403            let state_buffer: Vec<u8> = row.get(1);
404            let id = ID::from_ulid(ulid::Ulid::from(uuid));
405
406            let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer)?;
407
408            let entity_state = State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>(2).into() };
409
410            results.push((id, entity_state));
411        }
412
413        Ok(results)
414    }
415
416    /// Postgres Event Table:
417    /// {bucket_name}_event
418    /// event_id uuid, // `ID`/`ULID`
419    /// entity_id uuid, // `ID`/`ULID`
420    /// operations bytea, // `Vec<Operation>`
421    /// clock bytea, // `Clock`
422    async fn add_event(&self, entity_event: &Event) -> anyhow::Result<bool> {
423        let event_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.id));
424        let entity_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.entity_id));
425        let operations = bincode::serialize(&entity_event.operations)?;
426        let parent_uuids: Vec<uuid::Uuid> = (&entity_event.parent).into();
427
428        // Does it even matter if this conflicts?
429        // One peers event should match any duplicates, so taking the first
430        // event we receive from a peer should be fine.
431        let query = format!(r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent") VALUES($1, $2, $3, $4)"#, self.event_table(),);
432
433        let mut client = self.pool.get().await?;
434        info!("Running: {}", query);
435        let affected = match client.execute(&query, &[&event_id, &entity_id, &operations, &parent_uuids]).await {
436            Ok(affected) => affected,
437            Err(err) => {
438                let kind = error_kind(&err);
439                match kind {
440                    ErrorKind::UndefinedTable { table } => {
441                        if table == self.event_table() {
442                            self.create_event_table(&mut *client).await?;
443                            return self.add_event(entity_event).await; // retry
444                        }
445                    }
446                    _ => {}
447                }
448
449                return Err(err.into());
450            }
451        };
452
453        Ok(affected > 0)
454    }
455
456    async fn get_events(&self, entity_id: ID) -> Result<Vec<Event>, ankurah_core::error::RetrievalError> {
457        let query = format!(r#"SELECT "id", "operations", "parent" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
458
459        let entity_uuid = uuid::Uuid::from(ulid::Ulid::from(entity_id));
460
461        let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
462        info!("Running: {}", query);
463        let rows = match client.query(&query, &[&entity_uuid]).await {
464            Ok(rows) => rows,
465            Err(err) => {
466                let kind = error_kind(&err);
467                match kind {
468                    ErrorKind::UndefinedTable { table } => {
469                        if table == self.event_table() {
470                            self.create_event_table(&mut *client).await?;
471                            return Ok(Vec::new());
472                        }
473                    }
474                    _ => {}
475                }
476
477                return Err(RetrievalError::storage(err));
478            }
479        };
480
481        let mut events = Vec::new();
482        for row in rows {
483            let event_id: uuid::Uuid = row.get("id");
484            let event_id = ID::from_ulid(event_id.into());
485            let operations_binary: Vec<u8> = row.get("operations");
486            let operations = bincode::deserialize(&operations_binary)?;
487            let parent: Vec<uuid::Uuid> = row.get("parent");
488            let parent = parent.into_iter().map(|uuid| ID::from_ulid(uuid.into())).collect::<BTreeSet<_>>();
489            let clock = Clock::new(parent);
490
491            events.push(Event {
492                id: event_id,
493                collection: self.collection_id.clone(),
494                entity_id: entity_id,
495                operations: operations,
496                parent: clock,
497            })
498        }
499
500        Ok(events)
501    }
502}
503
504// Some hacky shit because rust-postgres doesn't let us ask for the error kind
505// TODO: remove this when https://github.com/sfackler/rust-postgres/pull/1185
506//       gets merged
507#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
508pub enum ErrorKind {
509    RowCount,
510    UndefinedTable { table: String },
511    UndefinedColumn { table: Option<String>, column: String },
512    Unknown,
513}
514
515pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
516    let string = err.to_string().trim().to_owned();
517    let _db_error = err.as_db_error();
518    let sql_code = err.code().cloned();
519
520    if string == "query returned an unexpected number of rows" {
521        return ErrorKind::RowCount;
522    }
523
524    // Useful for adding new errors
525    // error!("postgres error: {:?}", err);
526    // error!("db_err: {:?}", err.as_db_error());
527    // error!("sql_code: {:?}", err.code());
528    // error!("err: {:?}", err);
529    // error!("err: {:?}", err.to_string());
530
531    let quote_indices = |s: &str| {
532        let mut quotes = Vec::new();
533        for (index, char) in s.char_indices() {
534            match char {
535                '"' => quotes.push(index),
536                _ => {}
537            }
538        }
539        quotes
540    };
541
542    match sql_code {
543        Some(SqlState::UNDEFINED_TABLE) => {
544            // relation "album" does not exist
545            let quotes = quote_indices(&string);
546            let table = &string[quotes[0] + 1..quotes[1]];
547            ErrorKind::UndefinedTable { table: table.to_owned() }
548        }
549        Some(SqlState::UNDEFINED_COLUMN) => {
550            // Handle both formats:
551            // "column "name" of relation "album" does not exist"
552            // "column "status" does not exist"
553            let quotes = quote_indices(&string);
554            let column = string[quotes[0] + 1..quotes[1]].to_owned();
555
556            println!("quotes: {:?}", quotes);
557            println!("column: {}", column);
558
559            let table = if quotes.len() >= 4 {
560                // Full format with table name
561                Some(string[quotes[2] + 1..quotes[3]].to_owned())
562            } else {
563                // Short format without table name, use empty string
564                None
565            };
566
567            ErrorKind::UndefinedColumn { table, column }
568        }
569        _ => ErrorKind::Unknown,
570    }
571}
572
573#[allow(unused)]
574pub struct MissingMaterialized {
575    pub name: String,
576}