#![deny(unsafe_op_in_unsafe_fn)]
use std::any::TypeId;
use std::collections::{btree_map, BTreeMap};
use std::fmt;
use std::marker::PhantomData;
use std::sync::Mutex;
use std::time::Duration;
use sys::Buffer;
use crate::sats::db::def::{ColumnDef, ConstraintDef, IndexDef, SequenceDef, TableDef};
use crate::timestamp::with_timestamp_set;
use crate::{sys, ReducerContext, ScheduleToken, SpacetimeType, TableType, Timestamp};
use spacetimedb_lib::de::{self, Deserialize, SeqProductAccess};
use spacetimedb_lib::sats::db::auth::{StAccess, StTableType};
use spacetimedb_lib::sats::typespace::TypespaceBuilder;
use spacetimedb_lib::sats::{impl_deserialize, impl_serialize, AlgebraicType, AlgebraicTypeRef, ProductTypeElement};
use spacetimedb_lib::ser::{Serialize, SerializeSeqProduct};
use spacetimedb_lib::{bsatn, Address, Identity, MiscModuleExport, ModuleDef, ReducerDef, TableDesc, TypeAlias};
use spacetimedb_primitives::*;
pub use once_cell::sync::{Lazy, OnceCell};
pub fn invoke_reducer<'a, A: Args<'a>, T>(
reducer: impl Reducer<'a, A, T>,
sender: Buffer,
client_address: Buffer,
timestamp: u64,
args: &'a [u8],
epilogue: impl FnOnce(Result<(), &str>),
) -> Buffer {
let ctx = assemble_context(sender, timestamp, client_address);
let SerDeArgs(args) = bsatn::from_slice(args).expect("unable to decode args");
let res = with_timestamp_set(ctx.timestamp, || {
let res: Result<(), Box<str>> = reducer.invoke(ctx, args);
epilogue(res.as_ref().map(|()| ()).map_err(|e| &**e));
res
});
cvt_result(res)
}
pub fn create_index(index_name: &str, table_id: TableId, index_type: sys::raw::IndexType, col_ids: Vec<u8>) -> Buffer {
let result = sys::create_index(index_name, table_id, index_type as u8, &col_ids);
cvt_result(result.map_err(cvt_errno))
}
fn assemble_context(sender: Buffer, timestamp: u64, client_address: Buffer) -> ReducerContext {
let sender = Identity::from_byte_array(sender.read_array::<32>());
let timestamp = Timestamp::UNIX_EPOCH + Duration::from_micros(timestamp);
let address = Address::from_arr(&client_address.read_array::<16>());
let address = if address == Address::__DUMMY {
None
} else {
Some(address)
};
ReducerContext {
sender,
timestamp,
address,
}
}
fn cvt_errno(errno: sys::Errno) -> Box<str> {
let message = format!("{errno}");
message.into_boxed_str()
}
fn cvt_result(res: Result<(), Box<str>>) -> Buffer {
match res {
Ok(()) => Buffer::INVALID,
Err(errmsg) => Buffer::alloc(errmsg.as_bytes()),
}
}
pub trait Reducer<'de, A: Args<'de>, T> {
fn invoke(&self, ctx: ReducerContext, args: A) -> Result<(), Box<str>>;
}
pub trait ReducerInfo {
const NAME: &'static str;
const ARG_NAMES: &'static [Option<&'static str>];
const INVOKE: ReducerFn;
}
pub trait RepeaterInfo: ReducerInfo {
const REPEAT_INTERVAL: Duration;
}
pub trait Args<'de>: Sized {
const LEN: usize;
fn visit_seq_product<A: SeqProductAccess<'de>>(prod: A) -> Result<Self, A::Error>;
fn serialize_seq_product<S: SerializeSeqProduct>(&self, prod: &mut S) -> Result<(), S::Error>;
fn schema<I: ReducerInfo>(typespace: &mut impl TypespaceBuilder) -> ReducerDef;
}
pub trait ScheduleArgs<'de>: Sized {
type Args: Args<'de>;
fn into_args(self) -> Self::Args;
}
impl<'de, T: Args<'de>> ScheduleArgs<'de> for T {
type Args = Self;
fn into_args(self) -> Self::Args {
self
}
}
pub trait ReducerResult {
fn into_result(self) -> Result<(), Box<str>>;
}
impl ReducerResult for () {
#[inline]
fn into_result(self) -> Result<(), Box<str>> {
Ok(self)
}
}
impl<E: fmt::Debug> ReducerResult for Result<(), E> {
#[inline]
fn into_result(self) -> Result<(), Box<str>> {
self.map_err(|e| format!("{e:?}").into())
}
}
pub trait ReducerArg<'de> {}
impl<'de, T: Deserialize<'de>> ReducerArg<'de> for T {}
impl ReducerArg<'_> for ReducerContext {}
pub fn assert_reducer_arg<'de, T: ReducerArg<'de>>() {}
pub fn assert_reducer_ret<T: ReducerResult>() {}
pub const fn assert_table<T: TableType>() {}
pub struct ContextArg;
pub struct NoContextArg;
struct ArgsVisitor<A> {
_marker: PhantomData<A>,
}
impl<'de, A: Args<'de>> de::ProductVisitor<'de> for ArgsVisitor<A> {
type Output = A;
fn product_name(&self) -> Option<&str> {
None
}
fn product_len(&self) -> usize {
A::LEN
}
fn product_kind(&self) -> de::ProductKind {
de::ProductKind::ReducerArgs
}
fn visit_seq_product<Acc: SeqProductAccess<'de>>(self, prod: Acc) -> Result<Self::Output, Acc::Error> {
A::visit_seq_product(prod)
}
fn visit_named_product<Acc: de::NamedProductAccess<'de>>(self, _prod: Acc) -> Result<Self::Output, Acc::Error> {
Err(de::Error::custom("named products not supported"))
}
}
macro_rules! impl_reducer {
($($T1:ident $(, $T:ident)*)?) => {
impl_reducer!(@impl $($T1 $(, $T)*)?);
$(impl_reducer!($($T),*);)?
};
(@impl $($T:ident),*) => {
impl<'de, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Args<'de> for ($($T,)*) {
const LEN: usize = impl_reducer!(@count $($T)*);
#[allow(non_snake_case)]
#[allow(unused)]
fn visit_seq_product<Acc: SeqProductAccess<'de>>(mut prod: Acc) -> Result<Self, Acc::Error> {
let vis = ArgsVisitor { _marker: PhantomData::<Self> };
let i = 0;
$(
let $T = prod.next_element::<$T>()?.ok_or_else(|| de::Error::missing_field(i, None, &vis))?;
let i = i + 1;
)*
Ok(($($T,)*))
}
fn serialize_seq_product<Ser: SerializeSeqProduct>(&self, _prod: &mut Ser) -> Result<(), Ser::Error> {
#[allow(non_snake_case)]
let ($($T,)*) = self;
$(_prod.serialize_element($T)?;)*
Ok(())
}
#[inline]
fn schema<Info: ReducerInfo>(_typespace: &mut impl TypespaceBuilder) -> ReducerDef {
#[allow(non_snake_case, irrefutable_let_patterns)]
let [.., $($T),*] = Info::ARG_NAMES else { panic!() };
ReducerDef {
name: Info::NAME.into(),
args: vec![
$(ProductTypeElement {
name: $T.map(str::to_owned),
algebraic_type: <$T>::make_type(_typespace),
}),*
],
}
}
}
impl<'de, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> ScheduleArgs<'de> for (ReducerContext, $($T,)*) {
type Args = ($($T,)*);
#[allow(clippy::unused_unit)]
fn into_args(self) -> Self::Args {
#[allow(non_snake_case)]
let (_ctx, $($T,)*) = self;
($($T,)*)
}
}
impl<'de, Func, Ret, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Reducer<'de, ($($T,)*), ContextArg> for Func
where
Func: Fn(ReducerContext, $($T),*) -> Ret,
Ret: ReducerResult
{
fn invoke(&self, ctx: ReducerContext, args: ($($T,)*)) -> Result<(), Box<str>> {
#[allow(non_snake_case)]
let ($($T,)*) = args;
self(ctx, $($T),*).into_result()
}
}
impl<'de, Func, Ret, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Reducer<'de, ($($T,)*), NoContextArg> for Func
where
Func: Fn($($T),*) -> Ret,
Ret: ReducerResult
{
fn invoke(&self, _ctx: ReducerContext, args: ($($T,)*)) -> Result<(), Box<str>> {
#[allow(non_snake_case)]
let ($($T,)*) = args;
self($($T),*).into_result()
}
}
};
(@count $($T:ident)*) => {
0 $(+ impl_reducer!(@drop $T 1))*
};
(@drop $a:tt $b:tt) => { $b };
}
impl_reducer!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z, AA, AB, AC, AD, AE, AF);
struct SerDeArgs<A>(A);
impl_deserialize!(
[A: Args<'de>] SerDeArgs<A>,
de => de.deserialize_product(ArgsVisitor { _marker: PhantomData }).map(Self)
);
impl_serialize!(['de, A: Args<'de>] SerDeArgs<A>, (self, ser) => {
let mut prod = ser.serialize_seq_product(A::LEN)?;
self.0.serialize_seq_product(&mut prod)?;
prod.end()
});
#[track_caller]
pub fn schedule_in(duration: Duration) -> Timestamp {
Timestamp::now()
.checked_add(duration)
.unwrap_or_else(|| panic!("{duration:?} is too far into the future to schedule"))
}
pub fn schedule<'de, R: ReducerInfo>(time: Timestamp, args: impl ScheduleArgs<'de>) -> ScheduleToken<R> {
let arg_bytes = bsatn::to_vec(&SerDeArgs(args.into_args())).unwrap();
let id = sys::schedule(R::NAME, &arg_bytes, time.micros_since_epoch);
ScheduleToken::new(id)
}
pub fn schedule_repeater<A: RepeaterArgs, T, I: RepeaterInfo>(_reducer: impl for<'de> Reducer<'de, A, T>) {
let time = schedule_in(I::REPEAT_INTERVAL);
let args = bsatn::to_vec(&SerDeArgs(A::get_now())).unwrap();
sys::schedule(I::NAME, &args, time.micros_since_epoch);
}
pub trait RepeaterArgs: for<'de> Args<'de> {
fn get_now() -> Self;
}
impl RepeaterArgs for () {
fn get_now() -> Self {}
}
impl RepeaterArgs for (Timestamp,) {
fn get_now() -> Self {
(Timestamp::now(),)
}
}
fn register_describer(f: fn(&mut ModuleBuilder)) {
DESCRIBERS.lock().unwrap().push(f)
}
pub fn register_reftype<T: SpacetimeType>() {
register_describer(|module| {
T::make_type(module);
})
}
pub fn register_table<T: TableType>() {
register_describer(|module| {
let data = *T::make_type(module).as_ref().unwrap();
let columns = module
.module
.typespace
.with_type(&data)
.resolve_refs()
.and_then(|x| {
if let Ok(x) = x.into_product() {
let cols: Vec<ColumnDef> = x.into();
Some(cols)
} else {
None
}
})
.expect("Fail to retrieve the columns from the module");
let indexes: Vec<_> = T::INDEXES.iter().copied().map(Into::into).collect();
let constraints: Vec<_> = T::COLUMN_ATTRS
.iter()
.enumerate()
.map(|(col_pos, x)| {
let col = &columns[col_pos];
let kind = match (*x).try_into() {
Ok(x) => x,
Err(_) => Constraints::unset(),
};
ConstraintDef::for_column(T::TABLE_NAME, &col.col_name, kind, ColList::new(col_pos.into()))
})
.collect();
let sequences: Vec<_> = T::COLUMN_ATTRS
.iter()
.enumerate()
.filter_map(|(col_pos, x)| {
let col = &columns[col_pos];
if x.kind() == AttributeKind::AUTO_INC {
Some(SequenceDef::for_column(T::TABLE_NAME, &col.col_name, col_pos.into()))
} else {
None
}
})
.collect();
let schema = TableDef::new(T::TABLE_NAME.into(), columns)
.with_type(StTableType::User)
.with_access(StAccess::for_name(T::TABLE_NAME))
.with_constraints(constraints)
.with_sequences(sequences)
.with_indexes(indexes);
let schema = TableDesc { schema, data };
module.module.tables.push(schema)
})
}
impl From<crate::IndexDesc<'_>> for IndexDef {
fn from(index: crate::IndexDesc<'_>) -> IndexDef {
let Ok(columns) = index
.col_ids
.iter()
.map(|x| (*x).into())
.collect::<ColListBuilder>()
.build()
else {
panic!("Need at least one column in IndexDesc for index `{}`", index.name);
};
IndexDef {
index_name: index.name.to_string(),
is_unique: false,
index_type: index.ty,
columns,
}
}
}
pub fn register_reducer<'a, A: Args<'a>, T, I: ReducerInfo>(_: impl Reducer<'a, A, T>) {
register_describer(|module| {
let schema = A::schema::<I>(module);
module.module.reducers.push(schema);
module.reducers.push(I::INVOKE);
})
}
#[derive(Default)]
struct ModuleBuilder {
module: ModuleDef,
reducers: Vec<ReducerFn>,
type_map: BTreeMap<TypeId, AlgebraicTypeRef>,
}
impl TypespaceBuilder for ModuleBuilder {
fn add(
&mut self,
typeid: TypeId,
name: Option<&'static str>,
make_ty: impl FnOnce(&mut Self) -> AlgebraicType,
) -> AlgebraicType {
let r = match self.type_map.entry(typeid) {
btree_map::Entry::Occupied(o) => *o.get(),
btree_map::Entry::Vacant(v) => {
let slot_ref = self.module.typespace.add(AlgebraicType::unit());
v.insert(slot_ref);
if let Some(name) = name {
self.module.misc_exports.push(MiscModuleExport::TypeAlias(TypeAlias {
name: name.to_owned(),
ty: slot_ref,
}));
}
let ty = make_ty(self);
self.module.typespace[slot_ref] = ty;
slot_ref
}
};
AlgebraicType::Ref(r)
}
}
static DESCRIBERS: Mutex<Vec<fn(&mut ModuleBuilder)>> = Mutex::new(Vec::new());
pub type ReducerFn = fn(Buffer, Buffer, u64, &[u8]) -> Buffer;
static REDUCERS: OnceCell<Vec<ReducerFn>> = OnceCell::new();
#[no_mangle]
extern "C" fn __describe_module__() -> Buffer {
let mut module = ModuleBuilder::default();
for describer in &*DESCRIBERS.lock().unwrap() {
describer(&mut module)
}
let bytes = bsatn::to_vec(&module.module).expect("unable to serialize typespace");
REDUCERS.set(module.reducers).ok().unwrap();
Buffer::alloc(&bytes)
}
#[no_mangle]
extern "C" fn __call_reducer__(
id: usize,
sender: Buffer,
caller_address: Buffer,
timestamp: u64,
args: Buffer,
) -> Buffer {
let reducers = REDUCERS.get().unwrap();
let args = args.read();
reducers[id](sender, caller_address, timestamp, &args)
}