1use std::{
2 collections::{BTreeMap, BTreeSet},
3 sync::{Arc, RwLock},
4};
5
6use ankurah_core::{
7 error::RetrievalError,
8 property::Backends,
9 storage::{StorageCollection, StorageEngine},
10};
11use ankurah_proto::State;
12
13use futures_util::TryStreamExt;
14
15pub mod predicate;
16pub mod value;
17
18use value::PGValue;
19
20use ankurah_proto::{Clock, CollectionId, Event, ID};
21use async_trait::async_trait;
22use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager};
23use tokio_postgres::{error::SqlState, types::ToSql};
24use tracing::{error, info, warn};
25
26pub struct Postgres {
27 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
28}
29
30impl Postgres {
31 pub fn new(pool: bb8::Pool<PostgresConnectionManager<NoTls>>) -> anyhow::Result<Self> { Ok(Self { pool: pool }) }
32
33 pub fn sane_name(collection: &str) -> bool {
36 for char in collection.chars() {
37 match char {
38 char if char.is_alphanumeric() => {}
39 char if char.is_numeric() => {}
40 '_' | '.' | ':' => {}
41 _ => return false,
42 }
43 }
44
45 true
46 }
47}
48
49#[async_trait]
50impl StorageEngine for Postgres {
51 type Value = PGValue;
52
53 async fn collection(&self, collection_id: &CollectionId) -> Result<std::sync::Arc<dyn StorageCollection>, RetrievalError> {
54 if !Postgres::sane_name(collection_id.as_str()) {
55 return Err(RetrievalError::InvalidBucketName);
56 }
57
58 let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
59
60 let schema = client.query_one("SELECT current_database()", &[]).await.map_err(|err| RetrievalError::storage(err))?;
62 let schema = schema.get("current_database");
63
64 let bucket = PostgresBucket {
65 pool: self.pool.clone(),
66 schema,
67 collection_id: collection_id.clone(),
68 columns: Arc::new(RwLock::new(Vec::new())),
69 };
70
71 bucket.create_event_table(&mut client).await?;
73 bucket.create_state_table(&mut client).await?;
74 bucket.rebuild_columns_cache(&mut client).await?;
75
76 Ok(Arc::new(bucket))
77 }
78}
79
80#[derive(Clone, Debug)]
81pub struct PostgresColumn {
82 pub name: String,
83 pub is_nullable: bool,
84 pub data_type: String,
85}
86
87pub struct PostgresBucket {
88 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
89 collection_id: CollectionId,
90 schema: String,
91 columns: Arc<RwLock<Vec<PostgresColumn>>>,
92}
93
94impl PostgresBucket {
95 fn state_table(&self) -> String { format!("{}", self.collection_id.as_str()) }
96
97 pub fn event_table(&self) -> String { format!("{}_event", self.collection_id.as_str()) }
98
99 pub async fn rebuild_columns_cache(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
101 let column_query = format!(
102 r#"SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_catalog = $1 AND table_name = $2;"#,
103 );
104 let mut new_columns = Vec::new();
105 info!("Querying existing columns: {:?}, [{:?}, {:?}]", column_query, &self.schema, &self.collection_id.as_str());
106 let rows = client.query(&column_query, &[&self.schema, &self.collection_id.as_str()]).await?;
107 for row in rows {
108 let is_nullable: String = row.get("is_nullable");
109 new_columns.push(PostgresColumn {
110 name: row.get("column_name"),
111 is_nullable: is_nullable.eq("YES"),
112 data_type: row.get("data_type"),
113 })
114 }
115
116 let mut columns = self.columns.write().unwrap();
117 *columns = new_columns;
118 drop(columns);
119
120 Ok(())
121 }
122
123 pub fn existing_columns(&self) -> Vec<String> {
124 let columns = self.columns.read().unwrap();
125 columns.iter().map(|column| column.name.clone()).collect()
126 }
127
128 pub fn column(&self, column_name: &String) -> Option<PostgresColumn> {
129 let columns = self.columns.read().unwrap();
130 columns.iter().find(|column| column.name == *column_name).cloned()
131 }
132
133 pub fn has_column(&self, column_name: &String) -> bool { self.column(column_name).is_some() }
134
135 pub async fn create_event_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
136 let create_query = format!(
137 r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "entity_id" UUID, "operations" bytea, "parent" UUID[])"#,
138 self.event_table()
139 );
140
141 info!("Applying DDL: {}", create_query);
142 client.execute(&create_query, &[]).await?;
143 Ok(())
144 }
145
146 pub async fn create_state_table(&self, client: &mut tokio_postgres::Client) -> anyhow::Result<()> {
147 let create_query =
148 format!(r#"CREATE TABLE IF NOT EXISTS "{}"("id" UUID UNIQUE, "state_buffer" BYTEA, "head" UUID[])"#, self.state_table());
149
150 info!("Applying DDL: {}", create_query);
151 match client.execute(&create_query, &[]).await {
152 Ok(_) => Ok(()),
153 Err(err) => {
154 info!("Error: {}", err);
155 Err(err.into())
159 }
161 }
162 }
163
164 pub async fn add_missing_columns(
165 &self,
166 client: &mut tokio_postgres::Client,
167 missing: Vec<(String, &'static str)>, ) -> anyhow::Result<()> {
169 for (column, datatype) in missing {
170 if Postgres::sane_name(&column) {
171 let alter_query = format!(r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#, self.state_table(), column, datatype,);
172 info!("Running: {}", alter_query);
173 match client.execute(&alter_query, &[]).await {
174 Ok(_) => {}
175 Err(err) => {
176 warn!("Error adding column: {} to table: {} - rebuilding columns cache", err, self.state_table());
177 self.rebuild_columns_cache(client).await?;
178 return Err(err.into());
179 }
180 }
181 }
182 }
183
184 self.rebuild_columns_cache(client).await?;
185 Ok(())
186 }
187}
188
189#[async_trait]
190impl StorageCollection for PostgresBucket {
191 async fn set_state(&self, id: ID, state: &State) -> anyhow::Result<bool> {
192 let state_buffers = bincode::serialize(&state.state_buffers)?;
193 let ulid: ulid::Ulid = id.into();
194 let uuid: uuid::Uuid = ulid.into();
195
196 let head_uuids: Vec<uuid::Uuid> = (&state.head).into();
197
198 if head_uuids.is_empty() {
200 warn!("Warning: Empty head detected for entity {}", id);
201 }
202
203 let mut client = self.pool.get().await?;
204
205 let backends = Backends::from_state_buffers(state)?;
206 let mut columns: Vec<String> = vec!["id".to_owned(), "state_buffer".to_owned(), "head".to_owned()];
207 let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new();
208 params.push(&uuid);
209 params.push(&state_buffers);
210 params.push(&head_uuids);
211
212 let mut materialized: Vec<(String, Option<PGValue>)> = Vec::new();
213 for (column, value) in backends.property_values() {
214 let pg_value: Option<PGValue> = value.map(|value| value.into());
215 if !self.has_column(&column) {
216 if let Some(ref pg_value) = pg_value {
218 self.add_missing_columns(&mut client, vec![(column.clone(), pg_value.postgres_type())]).await?;
219 } else {
220 continue;
224 }
225 }
226
227 materialized.push((column.clone(), pg_value));
228 }
229
230 for (name, parameter) in &materialized {
231 columns.push(name.clone());
232
233 match ¶meter {
234 Some(value) => match value {
235 PGValue::CharacterVarying(string) => params.push(string),
236 PGValue::SmallInt(number) => params.push(number),
237 PGValue::Integer(number) => params.push(number),
238 PGValue::BigInt(number) => params.push(number),
239 PGValue::Bytea(bytes) => params.push(bytes),
240 },
241 None => params.push(&None::<i32>),
242 }
243 }
244
245 let columns_str = columns.iter().map(|name| format!("\"{}\"", name)).collect::<Vec<String>>().join(", ");
246 let values_str = params.iter().enumerate().map(|(index, _)| format!("${}", index + 1)).collect::<Vec<String>>().join(", ");
247 let columns_update_str = columns
248 .iter()
249 .enumerate()
250 .skip(1) .map(|(index, name)| format!("\"{}\" = ${}", name, index + 1))
252 .collect::<Vec<String>>()
253 .join(", ");
254
255 let query = format!(
257 r#"WITH old_state AS (
258 SELECT "head" FROM "{0}" WHERE "id" = $1
259 )
260 INSERT INTO "{0}"({1}) VALUES({2})
261 ON CONFLICT("id") DO UPDATE SET {3}
262 RETURNING (SELECT "head" FROM old_state) as old_head"#,
263 self.state_table(),
264 columns_str,
265 values_str,
266 columns_update_str
267 );
268
269 info!("Querying: {}", query);
270 let row = match client.query_one(&query, params.as_slice()).await {
271 Ok(row) => row,
272 Err(err) => {
273 let kind = error_kind(&err);
274 match kind {
275 ErrorKind::UndefinedTable { table } => {
276 if table == self.state_table() {
277 self.create_state_table(&mut *client).await?;
278 return self.set_state(id, state).await; }
280 }
281 _ => {}
282 }
283
284 return Err(err.into());
285 }
286 };
287
288 let old_head: Option<Vec<uuid::Uuid>> = row.get("old_head");
290 let changed = match old_head {
291 None => true, Some(old_head) => {
293 let old_clock: Clock = old_head.into();
294 old_clock != state.head
295 }
296 };
297
298 info!("Changed: {}", changed);
299 Ok(changed)
300 }
301
302 async fn get_state(&self, id: ID) -> Result<State, RetrievalError> {
303 let ulid: ulid::Ulid = id.into();
304 let uuid: uuid::Uuid = ulid.into();
305
306 let query = format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE "id" = $1"#, self.state_table());
308
309 let mut client = match self.pool.get().await {
310 Ok(client) => client,
311 Err(err) => {
312 return Err(RetrievalError::StorageError(err.into()));
313 }
314 };
315
316 info!("Getting state: {}", query);
317 let row = match client.query_one(&query, &[&uuid]).await {
318 Ok(row) => row,
319 Err(err) => {
320 let kind = error_kind(&err);
321 match kind {
322 ErrorKind::RowCount => {
323 return Err(RetrievalError::NotFound(id));
324 }
325 ErrorKind::UndefinedTable { table } => {
326 if table == self.state_table() {
327 self.create_state_table(&mut client).await.map_err(|e| RetrievalError::StorageError(e.into()))?;
328 return Err(RetrievalError::NotFound(id));
329 }
330 }
331 _ => {}
332 }
333
334 return Err(RetrievalError::StorageError(err.into()));
335 }
336 };
337
338 info!("Row: {:?}", row);
339 let row_id: uuid::Uuid = row.get("id");
340 assert_eq!(row_id, uuid);
341
342 let serialized_buffers: Vec<u8> = row.get("state_buffer");
343 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&serialized_buffers)?;
344
345 Ok(State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>("head").into() })
346 }
347
348 async fn fetch_states(&self, predicate: &ankql::ast::Predicate) -> Result<Vec<(ID, State)>, RetrievalError> {
349 println!("Fetching states for: {:?}", predicate);
350 let client = self.pool.get().await.map_err(|err| RetrievalError::StorageError(Box::new(err)))?;
351
352 let mut results = Vec::new();
353
354 let mut ankql_sql = predicate::Sql::new();
355 ankql_sql.predicate(&predicate);
356
357 let (sql, args) = ankql_sql.collapse();
358
359 let filtered_query = if !sql.is_empty() {
360 format!(r#"SELECT "id", "state_buffer", "head" FROM "{}" WHERE {}"#, self.state_table(), sql,)
361 } else {
362 format!(r#"SELECT "id", "state_buffer", "head" FROM "{}""#, self.state_table())
363 };
364
365 info!("SQL: {} with args: {:?}", filtered_query, args);
366
367 let rows = match client.query_raw(&filtered_query, args).await {
368 Ok(stream) => match stream.try_collect::<Vec<_>>().await {
369 Ok(rows) => rows,
370 Err(err) => return Err(RetrievalError::StorageError(err.into())),
371 },
372 Err(err) => {
373 let kind = error_kind(&err);
374 match kind {
375 ErrorKind::UndefinedTable { table } => {
376 if table == self.state_table() {
377 return Ok(Vec::new());
379 }
380 }
381 ErrorKind::UndefinedColumn { table, column } => {
382 println!("Undefined column: {} in table: {:?}", column, table);
383 match table {
384 Some(table) if table == self.state_table() => {
385 return self.fetch_states(&predicate.assume_null(&[column])).await;
387 }
388 None => {
389 return self.fetch_states(&predicate.assume_null(&[column])).await;
390 }
391 _ => {}
392 }
393 }
394 _ => {}
395 }
396
397 return Err(RetrievalError::StorageError(err.into()));
398 }
399 };
400
401 for row in rows {
402 let uuid: uuid::Uuid = row.get(0);
403 let state_buffer: Vec<u8> = row.get(1);
404 let id = ID::from_ulid(ulid::Ulid::from(uuid));
405
406 let state_buffers: BTreeMap<String, Vec<u8>> = bincode::deserialize(&state_buffer)?;
407
408 let entity_state = State { state_buffers, head: row.get::<_, Vec<uuid::Uuid>>(2).into() };
409
410 results.push((id, entity_state));
411 }
412
413 Ok(results)
414 }
415
416 async fn add_event(&self, entity_event: &Event) -> anyhow::Result<bool> {
423 let event_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.id));
424 let entity_id = uuid::Uuid::from(ulid::Ulid::from(entity_event.entity_id));
425 let operations = bincode::serialize(&entity_event.operations)?;
426 let parent_uuids: Vec<uuid::Uuid> = (&entity_event.parent).into();
427
428 let query = format!(r#"INSERT INTO "{0}"("id", "entity_id", "operations", "parent") VALUES($1, $2, $3, $4)"#, self.event_table(),);
432
433 let mut client = self.pool.get().await?;
434 info!("Running: {}", query);
435 let affected = match client.execute(&query, &[&event_id, &entity_id, &operations, &parent_uuids]).await {
436 Ok(affected) => affected,
437 Err(err) => {
438 let kind = error_kind(&err);
439 match kind {
440 ErrorKind::UndefinedTable { table } => {
441 if table == self.event_table() {
442 self.create_event_table(&mut *client).await?;
443 return self.add_event(entity_event).await; }
445 }
446 _ => {}
447 }
448
449 return Err(err.into());
450 }
451 };
452
453 Ok(affected > 0)
454 }
455
456 async fn get_events(&self, entity_id: ID) -> Result<Vec<Event>, ankurah_core::error::RetrievalError> {
457 let query = format!(r#"SELECT "id", "operations", "parent" FROM "{0}" WHERE "entity_id" = $1"#, self.event_table(),);
458
459 let entity_uuid = uuid::Uuid::from(ulid::Ulid::from(entity_id));
460
461 let mut client = self.pool.get().await.map_err(|err| RetrievalError::storage(err))?;
462 info!("Running: {}", query);
463 let rows = match client.query(&query, &[&entity_uuid]).await {
464 Ok(rows) => rows,
465 Err(err) => {
466 let kind = error_kind(&err);
467 match kind {
468 ErrorKind::UndefinedTable { table } => {
469 if table == self.event_table() {
470 self.create_event_table(&mut *client).await?;
471 return Ok(Vec::new());
472 }
473 }
474 _ => {}
475 }
476
477 return Err(RetrievalError::storage(err));
478 }
479 };
480
481 let mut events = Vec::new();
482 for row in rows {
483 let event_id: uuid::Uuid = row.get("id");
484 let event_id = ID::from_ulid(event_id.into());
485 let operations_binary: Vec<u8> = row.get("operations");
486 let operations = bincode::deserialize(&operations_binary)?;
487 let parent: Vec<uuid::Uuid> = row.get("parent");
488 let parent = parent.into_iter().map(|uuid| ID::from_ulid(uuid.into())).collect::<BTreeSet<_>>();
489 let clock = Clock::new(parent);
490
491 events.push(Event {
492 id: event_id,
493 collection: self.collection_id.clone(),
494 entity_id: entity_id,
495 operations: operations,
496 parent: clock,
497 })
498 }
499
500 Ok(events)
501 }
502}
503
504#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
508pub enum ErrorKind {
509 RowCount,
510 UndefinedTable { table: String },
511 UndefinedColumn { table: Option<String>, column: String },
512 Unknown,
513}
514
515pub fn error_kind(err: &tokio_postgres::Error) -> ErrorKind {
516 let string = err.to_string().trim().to_owned();
517 let _db_error = err.as_db_error();
518 let sql_code = err.code().cloned();
519
520 if string == "query returned an unexpected number of rows" {
521 return ErrorKind::RowCount;
522 }
523
524 let quote_indices = |s: &str| {
532 let mut quotes = Vec::new();
533 for (index, char) in s.char_indices() {
534 match char {
535 '"' => quotes.push(index),
536 _ => {}
537 }
538 }
539 quotes
540 };
541
542 match sql_code {
543 Some(SqlState::UNDEFINED_TABLE) => {
544 let quotes = quote_indices(&string);
546 let table = &string[quotes[0] + 1..quotes[1]];
547 ErrorKind::UndefinedTable { table: table.to_owned() }
548 }
549 Some(SqlState::UNDEFINED_COLUMN) => {
550 let quotes = quote_indices(&string);
554 let column = string[quotes[0] + 1..quotes[1]].to_owned();
555
556 println!("quotes: {:?}", quotes);
557 println!("column: {}", column);
558
559 let table = if quotes.len() >= 4 {
560 Some(string[quotes[2] + 1..quotes[3]].to_owned())
562 } else {
563 None
565 };
566
567 ErrorKind::UndefinedColumn { table, column }
568 }
569 _ => ErrorKind::Unknown,
570 }
571}
572
573#[allow(unused)]
574pub struct MissingMaterialized {
575 pub name: String,
576}