1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7 error::{MutationError, RetrievalError, StateError},
8 property::Backends,
9 storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::{Attestation, AttestationSet, Attested, EntityState, EventId, OperationSet, State, StateBuffers};
12
13use futures_util::TryStreamExt;
14
15pub mod predicate;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, EntityId, Event};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{debug, error, info, warn};
25
26pub struct Postgres {
27 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31 pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool }) }
32
33 pub fn sane_name(collection: &str) -> bool {
36 for char in collection.chars() {
37 match char {
38 char if char.is_alphanumeric() => {}
39 char if char.is_numeric() => {}
40 '_' | '.' | ':' => {}
41 _ => return false,
42 }
43 }
44
45 true
46 }
47}
48
49#[async_trait]
50impl StorageEngine for Postgres {
51 type Value = PGValue;
52
53 async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
54 if !Postgres::sane_name(collection_id.as_str()) {
55 return Err(RetrievalError::InvalidBucketName);
56 }
57
58 let mut client = self.pool.get().await.map_err(RetrievalError::storage)?;
59
60 let schema = client.query_one("SELECT current_database()", &[]).await.map_err(RetrievalError::storage)?;
62 let schema = schema.get("current_database");
63
64 let bucket = PostgresBucket {
65 pool: self.pool.clone(),
66 schema,
67 collection_id: collection_id.clone(),
68 columns: Arc::new(RwLock::new(Vec::new())),
69 };
70
71 bucket.create_state_table(&mut client).await?;
73 bucket.create_event_table(&mut client).await?;
74 bucket.rebuild_columns_cache(&mut client).await?;
75
76 Ok(Arc::new(bucket))
77 }
78
79 async fn delete_all_collections(&self) -> Result<bool, MutationError> {
80 let mut client = self.pool.get().await.map_err(|err| MutationError::General(Box::new(err)))?;
81
82 let query = r#"
84 SELECT table_name
85 FROM information_schema.tables
86 WHERE table_schema = 'public'
87 "#;
88
89 let rows = client.query(query, &[]).await.map_err(|err| MutationError::General(Box::new(err)))?;
90 if rows.is_empty() {
91 return Ok(false);
92 }
93
94 let transaction = client.transaction().await.map_err(|err| MutationError::General(Box::new(err)))?;
96
97 for row in rows {
99 let table_name: String = row.get("table_name");
100 let drop_query = format!(r#"DROP TABLE IF EXISTS "{}""#, table_name);
101 transaction.execute(&drop_query, &[]).await.map_err(|err| MutationError::General(Box::new(err)))?;
102 }
103
104 transaction.commit().await.map_err(|err| MutationError::General(Box::new(err)))?;
106
107 Ok(true)
108 }
109}
110
111#[derive(Clone, Debug)]
112pub struct PostgresColumn {
113 pub name: String,
114 pub is_nullable: bool,
115 pub data_type: String,
116}
117
118pub struct PostgresBucket {
119 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
120 collection_id: CollectionId,
121 schema: String,
122 columns: Arc<RwLock<Vec<PostgresColumn>>>,
123}
124
125impl PostgresBucket {
126 fn state_table(&self) -> String { self.collection_id.as_str().to_string() }
127
128 pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
129
130 pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
132 debug!("PostgresBucket({}).rebuild_columns_cache", self.collection_id);
133 let column_query =
134 r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#
135 .to_string();
136 let mut new_columns = Vec::new();
137 debug!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &self.schema, &self.collection_id.as_str());
138 let rows = client
139 .query(&column_query, &[&self.schema, &self.collection_id.as_str()])
140 .await
141 .map_err(|err| StateError::DDLError(Box::new(err)))?;
142 for row in rows {
143 let is_nullable: String = row.get("is_nullable");
144 new_columns.push(PostgresColumn {
145 name: row.get("column_name"),
146 is_nullable: is_nullable.eq("YES"),
147 data_type: row.get("data_type"),
148 })
149 }
150
151 let mut columns = self.columns.write().unwrap();
152 *columns = new_columns;
153 drop(columns);
154
155 Ok(())
156 }
157
158 pub fn existing_columns(&self) -> Vec<String> {
159 let columns = self.columns.read().unwrap();
160 columns.iter().map(|column| column.name.clone()).collect()
161 }
162
163 pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
164 let columns = self.columns.read().unwrap();
165 columns.iter().find(|column| column.name == *column_name).cloned()
166 }
167
168 pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
169
170 pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
171 let create_query = format!(
172 r#"CREATE TABLE IF NOT EXISTS "{}"(
173 "id" character(43) PRIMARY KEY,
174 "entity_id" character(22),
175 "operations" bytea,
176 "parent" character(43)[],
177 "attestations" bytea
178 )"#,
179 self.event_table()
180 );
181
182 debug!("{create_query}");
183 client.execute(&create_query, &[]).await.map_err(|e| StateError::DDLError(Box::new(e)))?;
184 Ok(())
185 }
186
187 pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> Result<(), StateError> {
188 let create_query = format!(
189 r#"CREATE TABLE IF NOT EXISTS "{}"(
190 "id" character(22) PRIMARY KEY,
191 "state_buffer" BYTEA,
192 "head" character(43)[],
193 "attestations" BYTEA[]
194 )"#,
195 self.state_table()
196 );
197
198 debug!("{create_query}");
199 match client.execute(&create_query, &[]).await {
200 Ok(_) => Ok(()),
201 Err(err) => {
202 error!("Error: {}", err);
203 Err(StateError::DDLError(Box::new(err)))
204 }
205 }
206 }
207
208 pub async fn add_missing_columns(
209 &self,
210 client: &mut tokio_postgres::Client,
211 missing: Vec<(String, &'static str)>, ) -> Result<(), StateError> {
213 for (column, datatype) in missing {
214 if Postgres::sane_name(&column) {
215 let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
216 info!("PostgresBucket({}).add_missing_columns: {}", self.collection_id, alter_query);
217 match client.execute(&alter_query, &[]).await {
218 Ok(_) => {}
219 Err(err) => {
220 warn!("Error adding column: {} to table: {} - rebuilding columns cache", err, self.state_table());
221 self.rebuild_columns_cache(client).await?;
222 return Err(StateError::DDLError(Box::new(err)));
223 }
224 }
225 }
226 }
227
228 self.rebuild_columns_cache(client).await?;
229 Ok(())
230 }
231}
232
233#[async_trait]
234impl StorageCollection for PostgresBucket {
235 async fn set_state(&self, state: Attested<EntityState>) -> Result<bool, MutationError> {
236 let state_buffers = bincode::serialize(&state.payload.state.state_buffers)?;
237 let attestations: Vec<Vec<u8>> = state.attestations.iter().map(bincode::serialize).collect::<Result<Vec<_>, _>>()?;
238 let id = state.payload.entity_id;
239
240 if state.payload.state.head.is_empty() {
242 warn!("Warning: Empty head detected for entity {}", id);
243 }
244
245 let mut client = self.pool.get().await.map_err(|err| MutationError::General(err.into()))?;
246
247 let backends = Backends::from_state_buffers(&state.payload.state.state_buffers)?;
248 let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned(), "attestations".to_owned()];
249 let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
250 params.push(&id);
251 params.push(&state_buffers);
252 params.push(&state.payload.state.head);
253 params.push(&attestations);
254
255 let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
256 for (column, value) in backends.property_values() {
257 let pg_value: Option<PGValue> = value.map(|value| value.into());
258 if !self.has_column(&column) {
259 if let Some(ref pg_value) = pg_value {
261 self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
262 } else {
263 continue;
267 }
268 }
269
270 materialized.push((column.clone(), pg_value));
271 }
272
273 for (name, parameter) in &materialized {
274 columns.push(name.clone());
275
276 match ¶meter {
277 Some(value) => match value {
278 PGValue::CharacterVarying(string) => params.push(string),
279 PGValue::SmallInt(number) => params.push(number),
280 PGValue::Integer(number) => params.push(number),
281 PGValue::BigInt(number) => params.push(number),
282 PGValue::Bytea(bytes) => params.push(bytes),
283 PGValue::Boolean(bool) => params.push(bool),
284 },
285 None => params.push(&UntypedNull),
286 }
287 }
288
289 let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
290 let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
291 let columns_update_str = columns
292 .iter()
293 .enumerate()
294 .skip(1) .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
296 .collect::<Vec<String>>()
297 .join(", ");
298
299 let query = format!(
301 r#"WITH old_state AS (
302 SELECT "head" FROM "{0}" WHERE "id" = $1
303 )
304 INSERT INTO "{0}"({1}) VALUES({2})
305 ON CONFLICT("id") DO UPDATE SET {3}
306 RETURNING (SELECT "head" FROM old_state) as old_head"#,
307 self.state_table(),
308 columns_str,
309 values_str,
310 columns_update_str
311 );
312
313 debug!("PostgresBucket({}).set_state: {}", self.collection_id, query);
314 let row = match client.query_one(&query, params.as_slice()).await {
315 Ok(row) => row,
316 Err(err) => {
317 let kind = error_kind(&err);
318 if let ErrorKind::UndefinedTable { table } = kind {
319 if table == self.state_table() {
320 self.create_state_table(&mut client).await?;
321 return self.set_state(state).await; }
323 }
324
325 return Err(StateError::DDLError(Box::new(err)).into());
326 }
327 };
328
329 let old_head: Option<Clock> = row.get("old_head");
331 let changed = match old_head {
332 None => true, Some(old_head) => old_head != state.payload.state.head,
334 };
335
336 debug!("PostgresBucket({}).set_state: Changed: {}", self.collection_id, changed);
337 Ok(changed)
338 }
339
340 async fn get_state(&self, id: EntityId) -> Result<Attested<EntityState>, RetrievalError> {
341 let query = format!(r#"SELECT "id", "state_buffer", "head", "attestations" FROM "{}" WHERE "id" = $1"#, self.state_table());
343
344 let mut client = match self.pool.get().await {
345 Ok(client) => client,
346 Err(err) => {
347 return Err(RetrievalError::StorageError(err.into()));
348 }
349 };
350
351 debug!("PostgresBucket({}).get_state: {}", self.collection_id, query);
352 let row = match client.query_one(&query, &[&id]).await {
353 Ok(row) => row,
354 Err(err) => {
355 let kind = error_kind(&err);
356 match kind {
357 ErrorKind::RowCount => {
358 return Err(RetrievalError::EntityNotFound(id));
359 }
360 ErrorKind::UndefinedTable { table } => {
361 if table == self.state_table() {
362 self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
363 return Err(RetrievalError::EntityNotFound(id));
364 }
365 }
366 _ => {}
367 }
368
369 return Err(RetrievalError::StorageError(err.into()));
370 }
371 };
372
373 debug!("PostgresBucket({}).get_state: Row: {:?}", self.collection_id, row);
374 let row_id: EntityId = row.try_get("id").map_err(RetrievalError::storage)?;
375 assert_eq!(row_id, id);
376
377 let serialized_buffers: Vec<u8> = row.try_get("state_buffer").map_err(RetrievalError::storage)?;
378 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers).map_err(RetrievalError::storage)?;
379 let head: Clock = row.try_get("head").map_err(RetrievalError::storage)?;
380 let attestation_bytes: Vec<Vec<u8>> = row.try_get("attestations").map_err(RetrievalError::storage)?;
381 let attestations = attestation_bytes
382 .into_iter()
383 .map(|bytes| bincode::deserialize(&bytes))
384 .collect::<Result<Vec<Attestation>, _>>()
385 .map_err(RetrievalError::storage)?;
386
387 Ok(Attested {
388 payload: EntityState {
389 entity_id: id,
390 collection: self.collection_id.clone(),
391 state: State { state_buffers: StateBuffers(state_buffers), head },
392 },
393 attestations: AttestationSet(attestations),
394 })
395 }
396
397 async fn fetch_states(&self, predicate: &ankql::ast::Predicate) -> Result<Vec<Attested<EntityState>>, RetrievalError> {
398 debug!("fetch_states: {:?}", predicate);
399 let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
400
401 let mut results = Vec::new();
402
403 let mut ankql_sql = predicate::Sql::new();
404 ankql_sql.predicate(predicate).map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
405
406 let (sql, args) = ankql_sql.collapse();
407
408 let filtered_query = if !sql.is_empty() {
409 format!(r#"SELECT "id", "state_buffer", "head", "attestations" FROM "{}" WHERE {}"#, self.state_table(), sql,)
410 } else {
411 format!(r#"SELECT "id", "state_buffer", "head", "attestations" FROM "{}""#, self.state_table())
412 };
413
414 debug!("PostgresBucket({}).fetch_states: SQL: {} with args: {:?}", self.collection_id, filtered_query, args);
415
416 let rows = match client.query_raw(&filtered_query, args).await {
417 Ok(stream) => match stream.try_collect::<Vec<_>>().await {
418 Ok(rows) => rows,
419 Err(err) => return Err(RetrievalError::StorageError(err.into())),
420 },
421 Err(err) => {
422 let kind = error_kind(&err);
423 match kind {
424 ErrorKind::UndefinedTable { table } => {
425 if table == self.state_table() {
426 return Ok(Vec::new());
428 }
429 }
430 ErrorKind::UndefinedColumn { table, column } => {
431 debug!("Undefined column: {} in table: {:?}, {}", column, table, self.state_table());
433 match table {
434 Some(table) if table == self.state_table() => {
435 return self.fetch_states(&predicate.assume_null(&[column])).await;
437 }
438 None => {
439 return self.fetch_states(&predicate.assume_null(&[column])).await;
440 }
441 _ => {}
442 }
443 }
444 _ => {}
445 }
446
447 return Err(RetrievalError::StorageError(err.into()));
448 }
449 };
450
451 for row in rows {
452 let id: EntityId = row.try_get(0).map_err(RetrievalError::storage)?;
453 let state_buffer: Vec<u8> = row.try_get(1).map_err(RetrievalError::storage)?;
454 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer).map_err(RetrievalError::storage)?;
455 let head: Clock = row.try_get("head").map_err(RetrievalError::storage)?;
456 let attestation_bytes: Vec<Vec<u8>> = row.try_get("attestations").map_err(RetrievalError::storage)?;
457 let attestations = attestation_bytes
458 .into_iter()
459 .map(|bytes| bincode::deserialize(&bytes))
460 .collect::<Result<Vec<Attestation>, _>>()
461 .map_err(RetrievalError::storage)?;
462
463 results.push(Attested {
464 payload: EntityState {
465 entity_id: id,
466 collection: self.collection_id.clone(),
467 state: State { state_buffers: StateBuffers(state_buffers), head },
468 },
469 attestations: AttestationSet(attestations),
470 });
471 }
472
473 Ok(results)
474 }
475
476 async fn add_event(&self, entity_event: &Attested<Event>) -> Result<bool, MutationError> {
477 let operations = bincode::serialize(&entity_event.payload.operations)?;
478 let attestations = bincode::serialize(&entity_event.attestations)?;
479
480 let query = format!(
481 r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent", "attestations") VALUES($1, $2, $3, $4, $5)"#,
482 self.event_table(),
483 );
484
485 let mut client = self.pool.get().await.map_err(|err| MutationError::General(err.into()))?;
486 debug!("PostgresBucket({}).add_event: {}", self.collection_id, query);
487 let affected = match client
488 .execute(
489 &query,
490 &[&entity_event.payload.id(), &entity_event.payload.entity_id, &operations, &entity_event.payload.parent, &attestations],
491 )
492 .await
493 {
494 Ok(affected) => affected,
495 Err(err) => {
496 let kind = error_kind(&err);
497 match kind {
498 ErrorKind::UndefinedTable { table } => {
499 if table == self.event_table() {
500 self.create_event_table(&mut client).await?;
501 return self.add_event(entity_event).await; }
503 }
504 _ => {
505 error!("PostgresBucket({}).add_event: Error: {:?}", self.collection_id, err);
506 }
507 }
508
509 return Err(StateError::DMLError(Box::new(err)).into());
510 }
511 };
512
513 Ok(affected > 0)
514 }
515
516 async fn get_events(&self, event_ids: Vec<EventId>) -> Result<Vec<Attested<Event>>, RetrievalError> {
517 if event_ids.is_empty() {
518 return Ok(Vec::new());
519 }
520
521 let query = format!(
522 r#"SELECT "id", "entity_id", "operations", "parent", "attestations" FROM "{0}" WHERE "id" = ANY($1)"#,
523 self.event_table(),
524 );
525
526 let client = self.pool.get().await.map_err(RetrievalError::storage)?;
527 let rows = match client.query(&query, &[&event_ids]).await {
528 Ok(rows) => rows,
529 Err(err) => {
530 let kind = error_kind(&err);
531 match kind {
532 ErrorKind::UndefinedTable { table } if table == self.event_table() => return Ok(Vec::new()),
533 _ => return Err(RetrievalError::storage(err)),
534 }
535 }
536 };
537
538 let mut events = Vec::new();
539 for row in rows {
540 let entity_id: EntityId = row.try_get("entity_id").map_err(RetrievalError::storage)?;
541 let operations: OperationSet = row.try_get("operations").map_err(RetrievalError::storage)?;
542 let parent: Clock = row.try_get("parent").map_err(RetrievalError::storage)?;
543 let attestations_binary: Vec<u8> = row.try_get("attestations").map_err(RetrievalError::storage)?;
544 let attestations: Vec<Attestation> = bincode::deserialize(&attestations_binary).map_err(RetrievalError::storage)?;
545
546 let event = Attested {
547 payload: Event { collection: self.collection_id.clone(), entity_id, operations, parent },
548 attestations: AttestationSet(attestations),
549 };
550 events.push(event);
551 }
552 Ok(events)
553 }
554
555 async fn dump_entity_events(&self, entity_id: EntityId) -> Result<Vec<Attested<Event>>, ankurah_core::error::RetrievalError> {
556 let query =
557 format!(r#"SELECT "id", "operations", "parent", "attestations" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
558
559 let client = self.pool.get().await.map_err(RetrievalError::storage)?;
560 debug!("PostgresBucket({}).get_events: {}", self.collection_id, query);
561 let rows = match client.query(&query, &[&entity_id]).await {
562 Ok(rows) => rows,
563 Err(err) => {
564 let kind = error_kind(&err);
565 if let ErrorKind::UndefinedTable { table } = kind {
566 if table == self.event_table() {
567 return Ok(Vec::new());
568 }
569 }
570
571 return Err(RetrievalError::storage(err));
572 }
573 };
574
575 let mut events = Vec::new();
576 for row in rows {
577 let operations_binary: Vec<u8> = row.try_get("operations").map_err(RetrievalError::storage)?;
579 let operations = bincode::deserialize(&operations_binary).map_err(RetrievalError::storage)?;
580 let parent: Clock = row.try_get("parent").map_err(RetrievalError::storage)?;
581 let attestations_binary: Vec<u8> = row.try_get("attestations").map_err(RetrievalError::storage)?;
582 let attestations: Vec<Attestation> = bincode::deserialize(&attestations_binary).map_err(RetrievalError::storage)?;
583
584 events.push(Attested {
585 payload: Event { collection: self.collection_id.clone(), entity_id, operations, parent },
586 attestations: AttestationSet(attestations),
587 });
588 }
589
590 Ok(events)
591 }
592}
593
594#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
598pub enum ErrorKind {
599 RowCount,
600 UndefinedTable { table: String },
601 UndefinedColumn { table: Option<String>, column: String },
602 Unknown,
603}
604
605pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
606 let string = err.to_string().trim().to_owned();
607 let _db_error = err.as_db_error();
608 let sql_code = err.code().cloned();
609
610 if string == "query returned an unexpected number of rows" {
611 return ErrorKind::RowCount;
612 }
613
614 let quote_indices = |s: &str| {
622 let mut quotes = Vec::new();
623 for (index, char) in s.char_indices() {
624 if char == '"' {
625 quotes.push(index)
626 }
627 }
628 quotes
629 };
630
631 match sql_code {
632 Some(SqlState::UNDEFINED_TABLE) => {
633 let quotes = quote_indices(&string);
635 let table = &string[quotes[0] + 1..quotes[1]];
636 ErrorKind::UndefinedTable { table: table.to_owned() }
637 }
638 Some(SqlState::UNDEFINED_COLUMN) => {
639 let quotes = quote_indices(&string);
643 let column = string[quotes[0] + 1..quotes[1]].to_owned();
644
645 let table = if quotes.len() >= 4 {
646 Some(string[quotes[2] + 1..quotes[3]].to_owned())
648 } else {
649 None
651 };
652
653 ErrorKind::UndefinedColumn { table, column }
654 }
655 _ => ErrorKind::Unknown,
656 }
657}
658
659#[allow(unused)]
660pub struct MissingMaterialized {
661 pub name: String,
662}
663
664use bytes::BytesMut;
665use tokio_postgres::types::{to_sql_checked, IsNull, Type};
666
667#[derive(Debug)]
668struct UntypedNull;
669
670impl ToSql for UntypedNull {
671 fn to_sql(&self, _ty: &Type, _out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> { Ok(IsNull::Yes) }
672
673 fn accepts(_ty: &Type) -> bool {
674 true }
676
677 to_sql_checked!();
678}