1use std::{
2 collections::{BTreeMap, BTreeSet},
3 sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7 error::RetrievalError,
8 property::Backends,
9 storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::State;
12
13use futures_util::TryStreamExt;
14
15pub mod predicate;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, Event, ID};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{error, info, warn};
25
26pub struct Postgres {
27 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31 pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool: pool }) }
32
33 pub fn sane_name(collection: &str) -> bool {
36 for char in collection.chars() {
37 match char {
38 char if char.is_alphanumeric() => {}
39 char if char.is_numeric() => {}
40 '_' | '.' | ':' => {}
41 _ => return false,
42 }
43 }
44
45 true
46 }
47}
48
49#[async_trait]
50impl StorageEngine for Postgres {
51 type Value = PGValue;
52
53 async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
54 if !Postgres::sane_name(collection_id.as_str()) {
55 return Err(RetrievalError::InvalidBucketName);
56 }
57
58 let bucket =
59 PostgresBucket { pool: self.pool.clone(), collection_id: collection_id.clone(), columns: Arc::new(RwLock::new(Vec::new())) };
60
61 let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
63 bucket.create_event_table(&mut client).await?;
64 bucket.create_state_table(&mut client).await?;
65 bucket.rebuild_columns_cache(&mut client).await?;
66
67 Ok(Arc::new(bucket))
68 }
69}
70
71#[derive(Clone, Debug)]
72pub struct PostgresColumn {
73 pub name: String,
74 pub is_nullable: bool,
75 pub data_type: String,
76}
77
78pub struct PostgresBucket {
79 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
80 collection_id: CollectionId,
81
82 columns: Arc<RwLock<Vec<PostgresColumn>>>,
83}
84
85impl PostgresBucket {
86 fn state_table(&self) -> String { format!("{}", self.collection_id.as_str()) }
87
88 pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
89
90 pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
92 let schema = "ankurah";
93
94 let column_query = format!(
95 r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#,
96 );
97 let mut new_columns = Vec::new();
98 info!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &schema, &self.collection_id.as_str());
99 let rows = client.query(&column_query, &[&schema, &self.collection_id.as_str()]).await?;
100 for row in rows {
101 let is_nullable: String = row.get("is_nullable");
102 new_columns.push(PostgresColumn {
103 name: row.get("column_name"),
104 is_nullable: is_nullable.eq("YES"),
105 data_type: row.get("data_type"),
106 })
107 }
108
109 let mut columns = self.columns.write().unwrap();
110 *columns = new_columns;
111 drop(columns);
112
113 Ok(())
114 }
115
116 pub fn existing_columns(&self) -> Vec<String> {
117 let columns = self.columns.read().unwrap();
118 columns.iter().map(|column| column.name.clone()).collect()
119 }
120
121 pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
122 let columns = self.columns.read().unwrap();
123 columns.iter().find(|column| column.name == *column_name).cloned()
124 }
125
126 pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
127
128 pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
129 let create_query = format!(
130 r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "entity_id" UUID, "operations" bytea, "parent" UUID[])"#,
131 self.event_table()
132 );
133
134 info!("Applying DDL: {}", create_query);
135 client.execute(&create_query, &[]).await?;
136 Ok(())
137 }
138
139 pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
140 let create_query =
141 format!(r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "state_buffer" BYTEA, "head" UUID[])"#, self.state_table());
142
143 info!("Applying DDL: {}", create_query);
144 match client.execute(&create_query, &[]).await {
145 Ok(_) => Ok(()),
146 Err(err) => {
147 info!("Error: {}", err);
148 Err(err.into())
152 }
154 }
155 }
156
157 pub async fn add_missing_columns(
158 &self,
159 client: &mut tokio_postgres::Client,
160 missing: Vec<(String, &'static str)>, ) -> anyhow::Result<()> {
162 for (column, datatype) in missing {
163 if Postgres::sane_name(&column) {
164 let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
165 info!("Running: {}", alter_query);
166 client.execute(&alter_query, &[]).await?;
167 }
168 }
169
170 self.rebuild_columns_cache(client).await?;
171 Ok(())
172 }
173}
174
175#[async_trait]
176impl StorageCollection for PostgresBucket {
177 async fn set_state(&self, id: ID, state: &State) -> anyhow::Result<bool> {
178 let state_buffers = bincode::serialize(&state.state_buffers)?;
179 let ulid: ulid::Ulid = id.into();
180 let uuid: uuid::Uuid = ulid.into();
181
182 let head_uuids: Vec<uuid::Uuid> = (&state.head).into();
183
184 if head_uuids.is_empty() {
186 warn!("Warning: Empty head detected for entity {}", id);
187 }
188
189 let mut client = self.pool.get().await?;
190
191 let backends = Backends::from_state_buffers(state)?;
192 let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned()];
193 let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
194 params.push(&uuid);
195 params.push(&state_buffers);
196 params.push(&head_uuids);
197
198 let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
199 for (column, value) in backends.property_values() {
200 let pg_value: Option<PGValue> = value.map(|value| value.into());
201 if !self.has_column(&column) {
202 if let Some(ref pg_value) = pg_value {
204 self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
205 } else {
206 continue;
210 }
211 }
212
213 materialized.push((column.clone(), pg_value));
214 }
215
216 for (name, parameter) in &materialized {
217 columns.push(name.clone());
218
219 match ¶meter {
220 Some(value) => match value {
221 PGValue::CharacterVarying(string) => params.push(string),
222 PGValue::SmallInt(number) => params.push(number),
223 PGValue::Integer(number) => params.push(number),
224 PGValue::BigInt(number) => params.push(number),
225 PGValue::Bytea(bytes) => params.push(bytes),
226 },
227 None => params.push(&None::<i32>),
228 }
229 }
230
231 let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
232 let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
233 let columns_update_str = columns
234 .iter()
235 .enumerate()
236 .skip(1) .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
238 .collect::<Vec<String>>()
239 .join(", ");
240
241 let query = format!(
243 r#"WITH old_state AS (
244 SELECT "head" FROM "{0}" WHERE "id" = $1
245 )
246 INSERT INTO "{0}"({1}) VALUES({2})
247 ON CONFLICT("id") DO UPDATE SET {3}
248 RETURNING (SELECT "head" FROM old_state) as old_head"#,
249 self.state_table(),
250 columns_str,
251 values_str,
252 columns_update_str
253 );
254
255 info!("Querying: {}", query);
256 let row = match client.query_one(&query, params.as_slice()).await {
257 Ok(row) => row,
258 Err(err) => {
259 let kind = error_kind(&err);
260 match kind {
261 ErrorKind::UndefinedTable { table } => {
262 if table == self.state_table() {
263 self.create_state_table(&mut *client).await?;
264 return self.set_state(id, state).await; }
266 }
267 _ => {}
268 }
269
270 return Err(err.into());
271 }
272 };
273
274 let old_head: Option<Vec<uuid::Uuid>> = row.get("old_head");
276 let changed = match old_head {
277 None => true, Some(old_head) => {
279 let old_clock: Clock = old_head.into();
280 old_clock != state.head
281 }
282 };
283
284 info!("Changed: {}", changed);
285 Ok(changed)
286 }
287
288 async fn get_state(&self, id: ID) -> Result<State, RetrievalError> {
289 let ulid: ulid::Ulid = id.into();
290 let uuid: uuid::Uuid = ulid.into();
291
292 let query = format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE "id" = $1"#, self.state_table());
294
295 let mut client = match self.pool.get().await {
296 Ok(client) => client,
297 Err(err) => {
298 return Err(RetrievalError::StorageError(err.into()));
299 }
300 };
301
302 info!("Getting state: {}", query);
303 let row = match client.query_one(&query, &[&uuid]).await {
304 Ok(row) => row,
305 Err(err) => {
306 let kind = error_kind(&err);
307 match kind {
308 ErrorKind::RowCount => {
309 return Err(RetrievalError::NotFound(id));
310 }
311 ErrorKind::UndefinedTable { table } => {
312 if table == self.state_table() {
313 self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
314 return Err(RetrievalError::NotFound(id));
315 }
316 }
317 _ => {}
318 }
319
320 return Err(RetrievalError::StorageError(err.into()));
321 }
322 };
323
324 info!("Row: {:?}", row);
325 let row_id: uuid::Uuid = row.get("id");
326 assert_eq!(row_id, uuid);
327
328 let serialized_buffers: Vec<u8> = row.get("state_buffer");
329 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers)?;
330
331 Ok(State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>("head").into() })
332 }
333
334 async fn fetch_states(&self, predicate: &ankql::ast::Predicate) -> Result<Vec<(ID, State)>, RetrievalError> {
335 println!("Fetching states for: {:?}", predicate);
336 let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
337
338 let mut results = Vec::new();
339
340 let mut ankql_sql = predicate::Sql::new();
341 ankql_sql.predicate(&predicate);
342
343 let (sql, args) = ankql_sql.collapse();
344
345 let filtered_query = if !sql.is_empty() {
346 format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE {}"#, self.state_table(), sql,)
347 } else {
348 format!(r#"SELECT "id", "state_buffer", "head" FROM "{}""#, self.state_table())
349 };
350
351 info!("SQL: {} with args: {:?}", filtered_query, args);
352
353 let rows = match client.query_raw(&filtered_query, args).await {
354 Ok(stream) => match stream.try_collect::<Vec<_>>().await {
355 Ok(rows) => rows,
356 Err(err) => return Err(RetrievalError::StorageError(err.into())),
357 },
358 Err(err) => {
359 let kind = error_kind(&err);
360 match kind {
361 ErrorKind::UndefinedTable { table } => {
362 if table == self.state_table() {
363 return Ok(Vec::new());
365 }
366 }
367 ErrorKind::UndefinedColumn { table, column } => {
368 println!("Undefined column: {} in table: {:?}", column, table);
369 match table {
370 Some(table) if table == self.state_table() => {
371 return self.fetch_states(&predicate.assume_null(&[column])).await;
373 }
374 None => {
375 return self.fetch_states(&predicate.assume_null(&[column])).await;
376 }
377 _ => {}
378 }
379 }
380 _ => {}
381 }
382
383 return Err(RetrievalError::StorageError(err.into()));
384 }
385 };
386
387 for row in rows {
388 let uuid: uuid::Uuid = row.get(0);
389 let state_buffer: Vec<u8> = row.get(1);
390 let id = ID::from_ulid(ulid::Ulid::from(uuid));
391
392 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer)?;
393
394 let entity_state = State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>(2).into() };
395
396 results.push((id, entity_state));
397 }
398
399 Ok(results)
400 }
401
402 async fn add_event(&self, entity_event: &Event) -> anyhow::Result<bool> {
409 let event_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.id));
410 let entity_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.entity_id));
411 let operations = bincode::serialize(&entity_event.operations)?;
412 let parent_uuids: Vec<uuid::Uuid> = (&entity_event.parent).into();
413
414 let query = format!(r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent") VALUES($1, $2, $3, $4)"#, self.event_table(),);
418
419 let mut client = self.pool.get().await?;
420 info!("Running: {}", query);
421 let affected = match client.execute(&query, &[&event_id, &entity_id, &operations, &parent_uuids]).await {
422 Ok(affected) => affected,
423 Err(err) => {
424 let kind = error_kind(&err);
425 match kind {
426 ErrorKind::UndefinedTable { table } => {
427 if table == self.event_table() {
428 self.create_event_table(&mut *client).await?;
429 return self.add_event(entity_event).await; }
431 }
432 _ => {}
433 }
434
435 return Err(err.into());
436 }
437 };
438
439 Ok(affected > 0)
440 }
441
442 async fn get_events(&self, entity_id: ID) -> Result<Vec<Event>, ankurah_core::error::RetrievalError> {
443 let query = format!(r#"SELECT "id", "operations", "parent" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
444
445 let entity_uuid = uuid::Uuid::from(ulid::Ulid::from(entity_id));
446
447 let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
448 info!("Running: {}", query);
449 let rows = match client.query(&query, &[&entity_uuid]).await {
450 Ok(rows) => rows,
451 Err(err) => {
452 let kind = error_kind(&err);
453 match kind {
454 ErrorKind::UndefinedTable { table } => {
455 if table == self.event_table() {
456 self.create_event_table(&mut *client).await?;
457 return Ok(Vec::new());
458 }
459 }
460 _ => {}
461 }
462
463 return Err(RetrievalError::storage(err));
464 }
465 };
466
467 let mut events = Vec::new();
468 for row in rows {
469 let event_id: uuid::Uuid = row.get("id");
470 let event_id = ID::from_ulid(event_id.into());
471 let operations_binary: Vec<u8> = row.get("operations");
472 let operations = bincode::deserialize(&operations_binary)?;
473 let parent: Vec<uuid::Uuid> = row.get("parent");
474 let parent = parent.into_iter().map(|uuid| ID::from_ulid(uuid.into())).collect::<BTreeSet<_>>();
475 let clock = Clock::new(parent);
476
477 events.push(Event {
478 id: event_id,
479 collection: self.collection_id.clone(),
480 entity_id: entity_id,
481 operations: operations,
482 parent: clock,
483 })
484 }
485
486 Ok(events)
487 }
488}
489
490#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
494pub enum ErrorKind {
495 RowCount,
496 UndefinedTable { table: String },
497 UndefinedColumn { table: Option<String>, column: String },
498 Unknown,
499}
500
501pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
502 let string = err.to_string().trim().to_owned();
503 let _db_error = err.as_db_error();
504 let sql_code = err.code().cloned();
505
506 if string == "query returned an unexpected number of rows" {
507 return ErrorKind::RowCount;
508 }
509
510 let quote_indices = |s: &str| {
518 let mut quotes = Vec::new();
519 for (index, char) in s.char_indices() {
520 match char {
521 '"' => quotes.push(index),
522 _ => {}
523 }
524 }
525 quotes
526 };
527
528 match sql_code {
529 Some(SqlState::UNDEFINED_TABLE) => {
530 let quotes = quote_indices(&string);
532 let table = &string[quotes[0] + 1..quotes[1]];
533 ErrorKind::UndefinedTable { table: table.to_owned() }
534 }
535 Some(SqlState::UNDEFINED_COLUMN) => {
536 let quotes = quote_indices(&string);
540 let column = string[quotes[0] + 1..quotes[1]].to_owned();
541
542 println!("quotes: {:?}", quotes);
543 println!("column: {}", column);
544
545 let table = if quotes.len() >= 4 {
546 Some(string[quotes[2] + 1..quotes[3]].to_owned())
548 } else {
549 None
551 };
552
553 ErrorKind::UndefinedColumn { table, column }
554 }
555 _ => ErrorKind::Unknown,
556 }
557}
558
559#[allow(unused)]
560pub struct MissingMaterialized {
561 pub name: String,
562}