use std::borrow::Cow;
use either::Either;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::FutureExt;
use serde::Deserialize;
use serde_json::Value;
use sqlx::database::HasStatement;
use sqlx::{Database, Describe, Execute, Executor, PgPool, Postgres, Transaction};
use tracing::trace;
use typed_builder::TypedBuilder;
use uuid::Uuid;
use crate::message::{DeserializeMessage, GenericMessage, Message, MetadataRef};
use crate::Result;
macro_rules! message_db_fn {
($s:literal) => {
concat!(
r#"
SELECT
id,
stream_name,
"type",
"position",
global_position,
data::jsonb,
metadata::jsonb,
time
FROM "#,
$s
)
};
}
pub type MessageStoreTransaction<'a> = Transaction<'a, Postgres>;
#[derive(Clone, Debug)]
pub struct MessageStore {
pool: PgPool,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, TypedBuilder)]
pub struct WriteMessageOpts<'a> {
#[builder(default, setter(strip_option))]
id: Option<&'a str>,
#[builder(default, setter(strip_option))]
metadata: Option<MetadataRef<'a>>,
#[builder(default, setter(strip_option))]
expected_version: Option<i64>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, TypedBuilder)]
pub struct GetStreamMessagesOpts<'a> {
#[builder(default, setter(strip_option))]
position: Option<i64>,
#[builder(default, setter(strip_option))]
batch_size: Option<i64>,
#[builder(default, setter(strip_option))]
condition: Option<&'a str>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, TypedBuilder)]
pub struct GetCategoryMessagesOpts<'a> {
#[builder(default, setter(strip_option))]
pub(crate) position: Option<i64>,
#[builder(default, setter(strip_option))]
pub(crate) batch_size: Option<i64>,
#[builder(default, setter(strip_option))]
pub(crate) correlation: Option<&'a str>,
#[builder(default, setter(strip_option))]
pub(crate) consumer_group_member: Option<i64>,
#[builder(default, setter(strip_option))]
pub(crate) consumer_group_size: Option<i64>,
#[builder(default, setter(strip_option))]
pub(crate) condition: Option<&'a str>,
}
impl MessageStore {
pub async fn connect(url: &str) -> Result<Self> {
Ok(MessageStore {
pool: PgPool::connect(url).await?,
})
}
pub fn transaction<'a, F, R>(&'a self, callback: F) -> BoxFuture<'a, Result<R>>
where
for<'c> F:
'a + FnOnce(&'c mut MessageStoreTransaction<'a>) -> BoxFuture<'c, Result<R>> + Send,
R: Send,
{
async move {
let mut tx = self.pool.begin().await?;
let result = callback(&mut tx).await?;
tx.commit().await?;
Ok(result)
}
.boxed()
}
pub async fn write_message<'e, 'c: 'e, E>(
executor: E,
stream_name: &str,
msg_type: &str,
data: &Value,
opts: &WriteMessageOpts<'_>,
) -> Result<i64>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let id = opts
.id
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(Uuid::new_v4().to_string()));
let metadata = opts
.metadata
.as_ref()
.map(serde_json::to_value)
.transpose()
.unwrap();
let position =
sqlx::query_scalar("SELECT message_store.write_message($1, $2, $3, $4, $5, $6)")
.bind(&id)
.bind(stream_name)
.bind(msg_type)
.bind(data)
.bind(metadata)
.bind(opts.expected_version)
.fetch_one(executor)
.await?;
trace!(%id, %stream_name, %msg_type, %position, "wrote message");
Ok(position)
}
pub async fn write_messages(
&self,
stream_name: &str,
messages: &[(&str, &Value, &WriteMessageOpts<'_>)],
) -> Result<i64> {
self.transaction(|tx| {
async move {
let mut version = -1;
for (msg_type, data, opts) in messages {
version =
MessageStore::write_message(&mut *tx, stream_name, msg_type, data, opts)
.await?;
}
Ok(version)
}
.boxed()
})
.await
}
pub async fn get_stream_messages<'e, 'c: 'e, T, E>(
executor: E,
stream_name: &str,
opts: &GetStreamMessagesOpts<'_>,
) -> Result<Vec<Message<T>>>
where
T: for<'de> Deserialize<'de>,
E: 'e + Executor<'c, Database = Postgres>,
{
let messages: Vec<GenericMessage> = sqlx::query_as(message_db_fn!(
"message_store.get_stream_messages($1, $2, $3, $4)"
))
.bind(stream_name)
.bind(opts.position)
.bind(opts.batch_size)
.bind(opts.condition)
.fetch_all(executor)
.await?;
messages.deserialize_messages()
}
pub async fn get_category_messages<'e, 'c: 'e, T, E>(
executor: E,
category_name: &str,
opts: &GetCategoryMessagesOpts<'_>,
) -> Result<Vec<Message<T>>>
where
T: for<'de> Deserialize<'de>,
E: 'e + Executor<'c, Database = Postgres>,
{
let messages: Vec<GenericMessage> = sqlx::query_as(message_db_fn!(
"message_store.get_category_messages($1, $2, $3, $4, $5, $6, $7)"
))
.bind(category_name)
.bind(opts.position)
.bind(opts.batch_size)
.bind(opts.correlation)
.bind(opts.consumer_group_member)
.bind(opts.consumer_group_size)
.bind(opts.condition)
.fetch_all(executor)
.await?;
messages.deserialize_messages()
}
pub async fn get_last_stream_message<'e, 'c: 'e, T, E>(
executor: E,
stream_name: &str,
msg_type: Option<&str>,
) -> Result<Option<Message<T>>>
where
T: for<'de> Deserialize<'de>,
E: 'e + Executor<'c, Database = Postgres>,
{
let message: Option<GenericMessage> = sqlx::query_as(message_db_fn!(
"message_store.get_last_stream_message($1, $2)"
))
.bind(stream_name)
.bind(msg_type)
.fetch_optional(executor)
.await?;
message.deserialize_messages()
}
pub async fn stream_version<'e, 'c: 'e, E>(
executor: E,
stream_name: &str,
) -> Result<Option<i64>>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let version = sqlx::query_scalar("SELECT * FROM message_store.stream_version($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(version)
}
pub async fn id<'e, 'c: 'e, E>(executor: E, stream_name: &str) -> Result<Option<String>>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let id = sqlx::query_scalar("SELECT * FROM message_store.id($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(id)
}
pub async fn cardinal_id<'e, 'c: 'e, E>(
executor: E,
stream_name: &str,
) -> Result<Option<String>>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let id = sqlx::query_scalar("SELECT * FROM message_store.cardinal_id($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(id)
}
pub async fn category<'e, 'c: 'e, E>(executor: E, stream_name: &str) -> Result<String>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let category = sqlx::query_scalar("SELECT * FROM message_store.category($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(category)
}
pub async fn is_category<'e, 'c: 'e, E>(executor: E, stream_name: &str) -> Result<bool>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let is_category = sqlx::query_scalar("SELECT * FROM message_store.is_category($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(is_category)
}
pub async fn acquire_lock<'e, 'c: 'e, E>(executor: E, stream_name: &str) -> Result<i64>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let lock = sqlx::query_scalar("SELECT * FROM message_store.acquire_lock($1)")
.bind(stream_name)
.fetch_one(executor)
.await?;
Ok(lock)
}
pub async fn hash_64<'e, 'c: 'e, E>(executor: E, value: &str) -> Result<i64>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let hash = sqlx::query_scalar("SELECT * FROM message_store.hash_64($1)")
.bind(value)
.fetch_one(executor)
.await?;
Ok(hash)
}
pub async fn message_store_version<'e, 'c: 'e, E>(executor: E) -> Result<String>
where
E: 'e + Executor<'c, Database = Postgres>,
{
let version = sqlx::query_scalar("SELECT * FROM message_store.message_store_version()")
.fetch_one(executor)
.await?;
Ok(version)
}
}
impl<'c> Executor<'c> for &MessageStore {
type Database = Postgres;
fn fetch_many<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxStream<
'e,
Result<
Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
sqlx::Error,
>,
>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.pool.fetch_many(query)
}
fn fetch_optional<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, sqlx::Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.pool.fetch_optional(query)
}
fn prepare_with<'e, 'q: 'e>(
self,
sql: &'q str,
parameters: &'e [<Self::Database as Database>::TypeInfo],
) -> BoxFuture<'e, Result<<Self::Database as HasStatement<'q>>::Statement, sqlx::Error>>
where
'c: 'e,
{
self.pool.prepare_with(sql, parameters)
}
fn describe<'e, 'q: 'e>(
self,
sql: &'q str,
) -> BoxFuture<'e, Result<Describe<Self::Database>, sqlx::Error>>
where
'c: 'e,
{
self.pool.describe(sql)
}
}