use std::{any::TypeId, sync::Arc};
use serde::{Serialize, de::DeserializeOwned};
use crate::registry::StateRegistryError;
#[derive(thiserror::Error, Debug)]
pub enum RepositoryError {
#[error("Internal error: {0}")]
Internal(String),
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error(transparent)]
Database(#[from] crate::sdk_managed::DatabaseError),
#[error(transparent)]
StateRegistry(#[from] StateRegistryError),
}
pub trait RepositoryOption<V: RepositoryItem> {
fn require(&self) -> Result<&Arc<dyn Repository<V>>, RepositoryError>;
}
impl<V: RepositoryItem> RepositoryOption<V> for Option<Arc<dyn Repository<V>>> {
fn require(&self) -> Result<&Arc<dyn Repository<V>>, RepositoryError> {
self.as_ref()
.ok_or(StateRegistryError::DatabaseNotInitialized.into())
}
}
#[async_trait::async_trait]
pub trait Repository<V: RepositoryItem>: Send + Sync {
async fn get(&self, key: V::Key) -> Result<Option<V>, RepositoryError>;
async fn list(&self) -> Result<Vec<V>, RepositoryError>;
async fn set(&self, key: V::Key, value: V) -> Result<(), RepositoryError>;
async fn set_bulk(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError>;
async fn remove(&self, key: V::Key) -> Result<(), RepositoryError>;
async fn remove_bulk(&self, keys: Vec<V::Key>) -> Result<(), RepositoryError>;
async fn remove_all(&self) -> Result<(), RepositoryError>;
async fn replace_all(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError> {
self.remove_all().await?;
self.set_bulk(values).await
}
}
pub trait RepositoryItem: Internal + Serialize + DeserializeOwned + Send + Sync + 'static {
const NAME: &'static str;
type Key: ToString + Send + Sync + 'static;
fn type_id() -> TypeId {
TypeId::of::<Self>()
}
fn data() -> RepositoryItemData {
RepositoryItemData::new::<Self>()
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy)]
pub struct RepositoryItemData {
type_id: TypeId,
name: &'static str,
}
impl RepositoryItemData {
pub fn new<T: RepositoryItem>() -> Self {
Self {
type_id: TypeId::of::<T>(),
name: T::NAME,
}
}
pub fn type_id(&self) -> TypeId {
self.type_id
}
pub fn name(&self) -> &'static str {
self.name
}
}
pub const fn validate_registry_name(name: &str) -> bool {
let bytes = name.as_bytes();
let mut i = 0;
while i < bytes.len() {
let byte = bytes[i];
if !((byte >= b'a' && byte <= b'z') || (byte >= b'A' && byte <= b'Z') || byte == b'_') {
return false;
}
i += 1;
}
true
}
#[derive(Debug, Clone)]
pub struct RepositoryMigrations {
pub(crate) steps: Vec<RepositoryMigrationStep>,
#[allow(dead_code)]
pub(crate) version: u32,
}
#[derive(Debug, Clone, Copy)]
pub enum RepositoryMigrationStep {
Add(RepositoryItemData),
Remove(RepositoryItemData),
}
impl RepositoryMigrations {
pub fn new(steps: Vec<RepositoryMigrationStep>) -> Self {
Self {
version: steps.len() as u32,
steps,
}
}
pub fn into_repository_items(self) -> Vec<RepositoryItemData> {
let mut map = std::collections::HashMap::new();
for step in self.steps {
match step {
RepositoryMigrationStep::Add(data) => {
map.insert(data.type_id, data);
}
RepositoryMigrationStep::Remove(data) => {
map.remove(&data.type_id);
}
}
}
map.into_values().collect()
}
}
#[macro_export]
macro_rules! register_repository_item {
($keyty:ty => $ty:ty, $name:literal) => {
const _: () = {
impl $crate::repository::___internal::Internal for $ty {}
impl $crate::repository::RepositoryItem for $ty {
const NAME: &'static str = $name;
type Key = $keyty;
}
assert!(
$crate::repository::validate_registry_name($name),
concat!(
"Repository name '",
$name,
"' must contain only alphabetic characters and underscores"
)
)
};
};
}
#[doc(hidden)]
pub mod ___internal {
pub trait Internal {}
}
pub(crate) use ___internal::Internal;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_name() {
assert!(validate_registry_name("valid"));
assert!(validate_registry_name("Valid_Name"));
assert!(!validate_registry_name("Invalid-Name"));
assert!(!validate_registry_name("Invalid Name"));
assert!(!validate_registry_name("Invalid.Name"));
assert!(!validate_registry_name("Invalid123"));
}
}