1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7 error::{MutationError, RetrievalError, StateError},
8 property::backend::backend_from_string,
9 storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::{Attestation, AttestationSet, Attested, EntityState, EventId, OperationSet, State, StateBuffers};
12
13use futures_util::{pin_mut, TryStreamExt};
14
15pub mod sql_builder;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, EntityId, Event};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{debug, error, info, warn};
25
26pub struct Postgres {
27 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31 pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool }) }
32
33 pub async fn open(uri: &str) -> anyhow::Result<Self> {
34 let manager = PostgresConnectionManager::new_from_stringlike(uri, NoTls)?;
35 let pool = bb8::Pool::builder().build(manager).await?;
36 Self::new(pool)
37 }
38
39 pub fn sane_name(collection: &str) -> bool {
42 for char in collection.chars() {
43 match char {
44 char if char.is_alphanumeric() => {}
45 char if char.is_numeric() => {}
46 '_' | '.' | ':' => {}
47 _ => return false,
48 }
49 }
50
51 true
52 }
53}
54
55#[async_trait]
56impl StorageEngine for Postgres {
57 type Value = PGValue;
58
59 async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
60 if !Postgres::sane_name(collection_id.as_str()) {
61 return Err(RetrievalError::InvalidBucketName);
62 }
63
64 let mut client = self.pool.get().await.map_err(RetrievalError::storage)?;
65
66 let schema = client.query_one("SELECT current_database()", &[]).await.map_err(RetrievalError::storage)?;
68 let schema = schema.get("current_database");
69
70 let bucket = PostgresBucket {
71 pool: self.pool.clone(),
72 schema,
73 collection_id: collection_id.clone(),
74 columns: Arc::new(RwLock::new(Vec::new())),
75 };
76
77 bucket.create_state_table(&mut client).await?;
79 bucket.create_event_table(&mut client).await?;
80 bucket.rebuild_columns_cache(&mut client).await?;
81
82 Ok(Arc::new(bucket))
83 }
84
85 async fn delete_all_collections(&self) -> Result<bool, MutationError> {
86 let mut client = self.pool.get().await.map_err(|err| MutationError::General(Box::new(err)))?;
87
88 let query = r#"
90 SELECT table_name
91 FROM information_schema.tables
92 WHERE table_schema = 'public'
93 "#;
94
95 let rows = client.query(query, &[]).await.map_err(|err| MutationError::General(Box::new(err)))?;
96 if rows.is_empty() {
97 return Ok(false);
98 }
99
100 let transaction = client.transaction().await.map_err(|err| MutationError::General(Box::new(err)))?;
102
103 for row in rows {
105 let table_name: String = row.get("table_name");
106 let drop_query = format!(r#"DROP TABLE IF EXISTS "{}""#, table_name);
107 transaction.execute(&drop_query, &[]).await.map_err(|err| MutationError::General(Box::new(err)))?;
108 }
109
110 transaction.commit().await.map_err(|err| MutationError::General(Box::new(err)))?;
112
113 Ok(true)
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct PostgresColumn {
119 pub name: String,
120 pub is_nullable: bool,
121 pub data_type: String,
122}
123
124pub struct PostgresBucket {
125 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
126 collection_id: CollectionId,
127 schema: String,
128 columns: Arc<RwLock<Vec<PostgresColumn>>>,
129}
130
131impl PostgresBucket {
132 fn state_table(&self) -> String { self.collection_id.as_str().to_string() }
133
134 pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
135
136 pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
138 debug!("PostgresBucket({}).rebuild_columns_cache", self.collection_id);
139 let column_query =
140 r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#
141 .to_string();
142 let mut new_columns = Vec::new();
143 debug!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &self.schema, &self.collection_id.as_str());
144 let rows = client
145 .query(&column_query, &[&self.schema, &self.collection_id.as_str()])
146 .await
147 .map_err(|err| StateError::DDLError(Box::new(err)))?;
148 for row in rows {
149 let is_nullable: String = row.get("is_nullable");
150 new_columns.push(PostgresColumn {
151 name: row.get("column_name"),
152 is_nullable: is_nullable.eq("YES"),
153 data_type: row.get("data_type"),
154 })
155 }
156
157 let mut columns = self.columns.write().unwrap();
158 *columns = new_columns;
159 drop(columns);
160
161 Ok(())
162 }
163
164 pub fn existing_columns(&self) -> Vec<String> {
165 let columns = self.columns.read().unwrap();
166 columns.iter().map(|column| column.name.clone()).collect()
167 }
168
169 pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
170 let columns = self.columns.read().unwrap();
171 columns.iter().find(|column| column.name == *column_name).cloned()
172 }
173
174 pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
175
176 pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
177 let create_query = format!(
178 r#"CREATE TABLE IF NOT EXISTS "{}"(
179 "id" character(43) PRIMARY KEY,
180 "entity_id" character(22),
181 "operations" bytea,
182 "parent" character(43)[],
183 "attestations" bytea
184 )"#,
185 self.event_table()
186 );
187
188 debug!("{create_query}");
189 client.execute(&create_query, &[]).await.map_err(|e| StateError::DDLError(Box::new(e)))?;
190 Ok(())
191 }
192
193 pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
194 let create_query = format!(
195 r#"CREATE TABLE IF NOT EXISTS "{}"(
196 "id" character(22) PRIMARY KEY,
197 "state_buffer" BYTEA,
198 "head" character(43)[],
199 "attestations" BYTEA[]
200 )"#,
201 self.state_table()
202 );
203
204 debug!("{create_query}");
205 match client.execute(&create_query, &[]).await {
206 Ok(_) => Ok(()),
207 Err(err) => {
208 error!("Error: {}", err);
209 Err(StateError::DDLError(Box::new(err)))
210 }
211 }
212 }
213
214 pub async fn add_missing_columns(
215 &self,
216 client: &mut tokio_postgres::Client,
217 missing: Vec<(String, &'static str)>, ) -> Result<(), StateError> {
219 for (column, datatype) in missing {
220 if Postgres::sane_name(&column) {
221 let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
222 info!("PostgresBucket({}).add_missing_columns: {}", self.collection_id, alter_query);
223 match client.execute(&alter_query, &[]).await {
224 Ok(_) => {}
225 Err(err) => {
226 warn!("Error adding column: {} to table: {} - rebuilding columns cache", err, self.state_table());
227 self.rebuild_columns_cache(client).await?;
228 return Err(StateError::DDLError(Box::new(err)));
229 }
230 }
231 }
232 }
233
234 self.rebuild_columns_cache(client).await?;
235 Ok(())
236 }
237}
238
239#[async_trait]
240impl StorageCollection for PostgresBucket {
241 async fn set_state(&self, state: Attested<EntityState>) -> Result<bool, MutationError> {
242 let state_buffers = bincode::serialize(&state.payload.state.state_buffers)?;
243 let attestations: Vec<Vec<u8>> = state.attestations.iter().map(bincode::serialize).collect::<Result<Vec<_>, _>>()?;
244 let id = state.payload.entity_id;
245
246 if state.payload.state.head.is_empty() {
248 warn!("Warning: Empty head detected for entity {}", id);
249 }
250
251 let mut client = self.pool.get().await.map_err(|err| MutationError::General(err.into()))?;
252
253 let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned(), "attestations".to_owned()];
254 let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
255 params.push(&id);
256 params.push(&state_buffers);
257 params.push(&state.payload.state.head);
258 params.push(&attestations);
259
260 let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
261 let mut seen_properties = std::collections::HashSet::new();
262
263 for (name, state_buffer) in state.payload.state.state_buffers.iter() {
265 let backend = backend_from_string(name, Some(state_buffer))?;
266 for (column, value) in backend.property_values() {
267 if !seen_properties.insert(column.clone()) {
268 continue;
273 }
274
275 let pg_value: Option<PGValue> = value.map(|value| value.into());
276 if !self.has_column(&column) {
277 if let Some(ref pg_value) = pg_value {
279 self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
280 } else {
281 continue;
285 }
286 }
287
288 materialized.push((column.clone(), pg_value));
289 }
290 }
291
292 for (name, parameter) in &materialized {
293 columns.push(name.clone());
294
295 match ¶meter {
296 Some(value) => match value {
297 PGValue::CharacterVarying(string) => params.push(string),
298 PGValue::SmallInt(number) => params.push(number),
299 PGValue::Integer(number) => params.push(number),
300 PGValue::BigInt(number) => params.push(number),
301 PGValue::DoublePrecision(float) => params.push(float),
302 PGValue::Bytea(bytes) => params.push(bytes),
303 PGValue::Boolean(bool) => params.push(bool),
304 },
305 None => params.push(&UntypedNull),
306 }
307 }
308
309 let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
310 let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
311 let columns_update_str = columns
312 .iter()
313 .enumerate()
314 .skip(1) .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
316 .collect::<Vec<String>>()
317 .join(", ");
318
319 let query = format!(
321 r#"WITH old_state AS (
322 SELECT "head" FROM "{0}" WHERE "id" = $1
323 )
324 INSERT INTO "{0}"({1}) VALUES({2})
325 ON CONFLICT("id") DO UPDATE SET {3}
326 RETURNING (SELECT "head" FROM old_state) as old_head"#,
327 self.state_table(),
328 columns_str,
329 values_str,
330 columns_update_str
331 );
332
333 debug!("PostgresBucket({}).set_state: {}", self.collection_id, query);
334 let row = match client.query_one(&query, params.as_slice()).await {
335 Ok(row) => row,
336 Err(err) => {
337 let kind = error_kind(&err);
338 if let ErrorKind::UndefinedTable { table } = kind {
339 if table == self.state_table() {
340 self.create_state_table(&mut client).await?;
341 return self.set_state(state).await; }
343 }
344
345 return Err(StateError::DDLError(Box::new(err)).into());
346 }
347 };
348
349 let old_head: Option<Clock> = row.get("old_head");
351 let changed = match old_head {
352 None => true, Some(old_head) => old_head != state.payload.state.head,
354 };
355
356 debug!("PostgresBucket({}).set_state: Changed: {}", self.collection_id, changed);
357 Ok(changed)
358 }
359
360 async fn get_state(&self, id: EntityId) -> Result<Attested<EntityState>, RetrievalError> {
361 let query = format!(r#"SELECT "id", "state_buffer", "head", "attestations" FROM "{}" WHERE "id" = $1"#, self.state_table());
363
364 let mut client = match self.pool.get().await {
365 Ok(client) => client,
366 Err(err) => {
367 return Err(RetrievalError::StorageError(err.into()));
368 }
369 };
370
371 debug!("PostgresBucket({}).get_state: {}", self.collection_id, query);
372 let rows = match client.query(&query, &[&id]).await {
373 Ok(rows) => rows,
374 Err(err) => {
375 let kind = error_kind(&err);
376 if let ErrorKind::UndefinedTable { table } = kind {
377 if table == self.state_table() {
378 self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
379 return Err(RetrievalError::EntityNotFound(id));
380 }
381 }
382 return Err(RetrievalError::StorageError(err.into()));
383 }
384 };
385
386 let row = match rows.into_iter().next() {
387 Some(row) => row,
388 None => return Err(RetrievalError::EntityNotFound(id)),
389 };
390
391 debug!("PostgresBucket({}).get_state: Row: {:?}", self.collection_id, row);
392 let row_id: EntityId = row.try_get("id").map_err(RetrievalError::storage)?;
393 assert_eq!(row_id, id);
394
395 let serialized_buffers: Vec<u8> = row.try_get("state_buffer").map_err(RetrievalError::storage)?;
396 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers).map_err(RetrievalError::storage)?;
397 let head: Clock = row.try_get("head").map_err(RetrievalError::storage)?;
398 let attestation_bytes: Vec<Vec<u8>> = row.try_get("attestations").map_err(RetrievalError::storage)?;
399 let attestations = attestation_bytes
400 .into_iter()
401 .map(|bytes| bincode::deserialize(&bytes))
402 .collect::<Result<Vec<Attestation>, _>>()
403 .map_err(RetrievalError::storage)?;
404
405 Ok(Attested {
406 payload: EntityState {
407 entity_id: id,
408 collection: self.collection_id.clone(),
409 state: State { state_buffers: StateBuffers(state_buffers), head },
410 },
411 attestations: AttestationSet(attestations),
412 })
413 }
414
415 async fn fetch_states(&self, selection: &ankql::ast::Selection) -> Result<Vec<Attested<EntityState>>, RetrievalError> {
416 debug!("fetch_states: {:?}", selection);
417 let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
418
419 let mut results = Vec::new();
420 let mut builder = SqlBuilder::with_fields(vec!["id", "state_buffer", "head", "attestations"]);
421 builder.table_name(self.state_table());
422 builder.selection(selection)?;
423
424 let (sql, args) = builder.build()?;
425 debug!("PostgresBucket({}).fetch_states: SQL: {} with args: {:?}", self.collection_id, sql, args);
426
427 let stream = match client.query_raw(&sql, args).await {
428 Ok(stream) => stream,
429 Err(err) => {
430 let kind = error_kind(&err);
431 match kind {
432 ErrorKind::UndefinedTable { table } => {
433 if table == self.state_table() {
434 return Ok(Vec::new());
436 }
437 }
438 ErrorKind::UndefinedColumn { table, column } => {
439 debug!("Undefined column: {} in table: {:?}, {}", column, table, self.state_table());
442 let new_selection = selection.assume_null(&[column]);
443 return self.fetch_states(&new_selection).await;
444 }
445 _ => {}
446 }
447
448 return Err(RetrievalError::StorageError(err.into()));
449 }
450 };
451 pin_mut!(stream);
452
453 while let Some(row) = stream.try_next().await.map_err(RetrievalError::storage)? {
454 let id: EntityId = row.try_get(0).map_err(RetrievalError::storage)?;
455 let state_buffer: Vec<u8> = row.try_get(1).map_err(RetrievalError::storage)?;
456 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer).map_err(RetrievalError::storage)?;
457 let head: Clock = row.try_get("head").map_err(RetrievalError::storage)?;
458 let attestation_bytes: Vec<Vec<u8>> = row.try_get("attestations").map_err(RetrievalError::storage)?;
459 let attestations = attestation_bytes
460 .into_iter()
461 .map(|bytes| bincode::deserialize(&bytes))
462 .collect::<Result<Vec<Attestation>, _>>()
463 .map_err(RetrievalError::storage)?;
464
465 results.push(Attested {
466 payload: EntityState {
467 entity_id: id,
468 collection: self.collection_id.clone(),
469 state: State { state_buffers: StateBuffers(state_buffers), head },
470 },
471 attestations: AttestationSet(attestations),
472 });
473 }
474
475 Ok(results)
476 }
477
478 async fn add_event(&self, entity_event: &Attested<Event>) -> Result<bool, MutationError> {
479 let operations = bincode::serialize(&entity_event.payload.operations)?;
480 let attestations = bincode::serialize(&entity_event.attestations)?;
481
482 let query = format!(
483 r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent", "attestations") VALUES($1, $2, $3, $4, $5)
484 ON CONFLICT ("id") DO NOTHING"#,
485 self.event_table(),
486 );
487
488 let mut client = self.pool.get().await.map_err(|err| MutationError::General(err.into()))?;
489 debug!("PostgresBucket({}).add_event: {}", self.collection_id, query);
490 let affected = match client
491 .execute(
492 &query,
493 &[&entity_event.payload.id(), &entity_event.payload.entity_id, &operations, &entity_event.payload.parent, &attestations],
494 )
495 .await
496 {
497 Ok(affected) => affected,
498 Err(err) => {
499 let kind = error_kind(&err);
500 match kind {
501 ErrorKind::UndefinedTable { table } => {
502 if table == self.event_table() {
503 self.create_event_table(&mut client).await?;
504 return self.add_event(entity_event).await; }
506 }
507 _ => {
508 error!("PostgresBucket({}).add_event: Error: {:?}", self.collection_id, err);
509 }
510 }
511
512 return Err(StateError::DMLError(Box::new(err)).into());
513 }
514 };
515
516 Ok(affected > 0)
517 }
518
519 async fn get_events(&self, event_ids: Vec<EventId>) -> Result<Vec<Attested<Event>>, RetrievalError> {
520 if event_ids.is_empty() {
521 return Ok(Vec::new());
522 }
523
524 let query = format!(
525 r#"SELECT "id", "entity_id", "operations", "parent", "attestations" FROM "{0}" WHERE "id" = ANY($1)"#,
526 self.event_table(),
527 );
528
529 let client = self.pool.get().await.map_err(RetrievalError::storage)?;
530 let rows = match client.query(&query, &[&event_ids]).await {
531 Ok(rows) => rows,
532 Err(err) => {
533 let kind = error_kind(&err);
534 match kind {
535 ErrorKind::UndefinedTable { table } if table == self.event_table() => return Ok(Vec::new()),
536 _ => return Err(RetrievalError::storage(err)),
537 }
538 }
539 };
540
541 let mut events = Vec::new();
542 for row in rows {
543 let entity_id: EntityId = row.try_get("entity_id").map_err(RetrievalError::storage)?;
544 let operations: OperationSet = row.try_get("operations").map_err(RetrievalError::storage)?;
545 let parent: Clock = row.try_get("parent").map_err(RetrievalError::storage)?;
546 let attestations_binary: Vec<u8> = row.try_get("attestations").map_err(RetrievalError::storage)?;
547 let attestations: Vec<Attestation> = bincode::deserialize(&attestations_binary).map_err(RetrievalError::storage)?;
548
549 let event = Attested {
550 payload: Event { collection: self.collection_id.clone(), entity_id, operations, parent },
551 attestations: AttestationSet(attestations),
552 };
553 events.push(event);
554 }
555 Ok(events)
556 }
557
558 async fn dump_entity_events(&self, entity_id: EntityId) -> Result<Vec<Attested<Event>>, ankurah_core::error::RetrievalError> {
559 let query =
560 format!(r#"SELECT "id", "operations", "parent", "attestations" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
561
562 let client = self.pool.get().await.map_err(RetrievalError::storage)?;
563 debug!("PostgresBucket({}).get_events: {}", self.collection_id, query);
564 let rows = match client.query(&query, &[&entity_id]).await {
565 Ok(rows) => rows,
566 Err(err) => {
567 let kind = error_kind(&err);
568 if let ErrorKind::UndefinedTable { table } = kind {
569 if table == self.event_table() {
570 return Ok(Vec::new());
571 }
572 }
573
574 return Err(RetrievalError::storage(err));
575 }
576 };
577
578 let mut events = Vec::new();
579 for row in rows {
580 let operations_binary: Vec<u8> = row.try_get("operations").map_err(RetrievalError::storage)?;
582 let operations = bincode::deserialize(&operations_binary).map_err(RetrievalError::storage)?;
583 let parent: Clock = row.try_get("parent").map_err(RetrievalError::storage)?;
584 let attestations_binary: Vec<u8> = row.try_get("attestations").map_err(RetrievalError::storage)?;
585 let attestations: Vec<Attestation> = bincode::deserialize(&attestations_binary).map_err(RetrievalError::storage)?;
586
587 events.push(Attested {
588 payload: Event { collection: self.collection_id.clone(), entity_id, operations, parent },
589 attestations: AttestationSet(attestations),
590 });
591 }
592
593 Ok(events)
594 }
595}
596
597#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
601pub enum ErrorKind {
602 RowCount,
603 UndefinedTable { table: String },
604 UndefinedColumn { table: Option<String>, column: String },
605 Unknown,
606 PostgresError(String),
607}
608
609pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
610 let string = err.as_db_error().map(|e| e.message()).unwrap_or_default().trim().to_owned();
611 let _db_error = err.as_db_error();
612 let sql_code = err.code().cloned();
613
614 let err_string = err.to_string();
616 if err_string.contains("query returned an unexpected number of rows") || string == "query returned an unexpected number of rows" {
617 return ErrorKind::RowCount;
618 }
619
620 debug!("postgres error: {:?}", err);
627
628 let quote_indices = |s: &str| {
629 let mut quotes = Vec::new();
630 for (index, char) in s.char_indices() {
631 if char == '"' {
632 quotes.push(index)
633 }
634 }
635 quotes
636 };
637
638 match sql_code {
639 Some(SqlState::UNDEFINED_TABLE) => {
640 let quotes = quote_indices(&string);
642 if quotes.len() >= 2 {
643 let table = &string[quotes[0] + 1..quotes[1]];
644 ErrorKind::UndefinedTable { table: table.to_owned() }
645 } else {
646 ErrorKind::PostgresError(string.clone())
647 }
648 }
649 Some(SqlState::UNDEFINED_COLUMN) => {
650 let quotes = quote_indices(&string);
654 if quotes.len() >= 2 {
655 let column = string[quotes[0] + 1..quotes[1]].to_owned();
656
657 let table = if quotes.len() >= 4 {
658 Some(string[quotes[2] + 1..quotes[3]].to_owned())
660 } else {
661 None
663 };
664
665 ErrorKind::UndefinedColumn { table, column }
666 } else {
667 ErrorKind::PostgresError(string.clone())
668 }
669 }
670 _ => ErrorKind::Unknown,
671 }
672}
673
674#[allow(unused)]
675pub struct MissingMaterialized {
676 pub name: String,
677}
678
679use bytes::BytesMut;
680use tokio_postgres::types::{to_sql_checked, IsNull, Type};
681
682use crate::sql_builder::SqlBuilder;
683
684#[derive(Debug)]
685struct UntypedNull;
686
687impl ToSql for UntypedNull {
688 fn to_sql(&self, _ty: &Type, _out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> { Ok(IsNull::Yes) }
689
690 fn accepts(_ty: &Type) -> bool {
691 true }
693
694 to_sql_checked!();
695}