use std::future::Future;
use std::io::Cursor;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{Context, Result};
use ciborium::Value;
use serde::de::{self, DeserializeOwned, Deserializer};
use serde::{Deserialize, Serialize};
use crate::{actor::Actor, context::Ctx};
pub const TUPLE_ARITY_MAX: usize = 16;
pub(crate) type BoxActionFuture = Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>;
pub trait Action: serde::Serialize + DeserializeOwned + Send + Sync + 'static {
type Output: serde::Serialize + DeserializeOwned + Send + 'static;
const NAME: &'static str;
}
pub fn encode_positional<T: Serialize>(value: &T) -> Result<Vec<u8>> {
let mut encoded = Vec::new();
ciborium::into_writer(value, &mut encoded).context("encode action args as cbor")?;
let value: Value = ciborium::from_reader(Cursor::new(&encoded))
.context("decode action args into cbor value")?;
let value = positional_value(value);
encode_value(&value)
}
pub fn decode_positional<T: DeserializeOwned>(args: &[u8]) -> Result<T> {
let value = if args.is_empty() {
Value::Array(Vec::new())
} else {
ciborium::from_reader(Cursor::new(args)).context("decode action args from cbor")?
};
let value = match value {
Value::Null => Value::Array(Vec::new()),
value => value,
};
match decode_value::<T>(&value) {
Ok(value) => Ok(value),
Err(first_error) => match &value {
Value::Array(values) if values.is_empty() => decode_value(&Value::Null)
.or_else(|_| Err(first_error).context("decode positional action args as unit")),
Value::Array(values) if values.len() == 1 => decode_value(&values[0])
.or_else(|_| Err(first_error).context("decode positional action args as newtype")),
_ => Err(first_error).context("decode positional action args"),
},
}
}
fn positional_value(value: Value) -> Value {
match value {
Value::Map(entries) => Value::Array(entries.into_iter().map(|(_, value)| value).collect()),
Value::Array(values) => Value::Array(values),
Value::Null => Value::Array(Vec::new()),
value => Value::Array(vec![value]),
}
}
fn encode_value(value: &Value) -> Result<Vec<u8>> {
let mut encoded = Vec::new();
ciborium::into_writer(value, &mut encoded).context("encode positional action args as cbor")?;
Ok(encoded)
}
fn decode_value<T: DeserializeOwned>(value: &Value) -> Result<T> {
crate::event::deserialize_cbor_value(value.clone())
.map_err(|error| anyhow::anyhow!(error.to_string()))
.context("decode positional action args from cbor")
}
pub trait Handles<A: Action>: Actor + Sized {
type Future: Future<Output = Result<A::Output>> + Send + 'static;
fn handle(self: Arc<Self>, ctx: Ctx<Self>, action: A) -> Self::Future;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ActionEntry<A: Actor> {
pub name: &'static str,
_p: PhantomData<fn() -> A>,
}
impl<A: Actor> ActionEntry<A> {
pub const fn new(name: &'static str) -> Self {
Self {
name,
_p: PhantomData,
}
}
}
pub trait ActionSet<A: Actor>: Send + Sync + 'static {
fn entries() -> Vec<ActionEntry<A>>;
fn dispatch(actor: Arc<A>, ctx: Ctx<A>, name: &str, args: &[u8]) -> Option<BoxActionFuture>;
}
impl<A: Actor> ActionSet<A> for () {
fn entries() -> Vec<ActionEntry<A>> {
Vec::new()
}
fn dispatch(
_actor: Arc<A>,
_ctx: Ctx<A>,
_name: &str,
_args: &[u8],
) -> Option<BoxActionFuture> {
None
}
}
macro_rules! impl_action_set {
($($action:ident),+) => {
impl<Act, $($action),+> ActionSet<Act> for ($($action,)+)
where
Act: Actor + $(Handles<$action> +)+,
$($action: Action,)+
{
fn entries() -> Vec<ActionEntry<Act>> {
vec![$(ActionEntry::new(<$action as Action>::NAME)),+]
}
fn dispatch(
actor: Arc<Act>,
ctx: Ctx<Act>,
name: &str,
args: &[u8],
) -> Option<BoxActionFuture> {
$(
if name == <$action as Action>::NAME {
let args = args.to_vec();
return Some(Box::pin(async move {
let action = decode_positional::<$action>(&args).with_context(|| {
format!("decode action '{}' args", <$action as Action>::NAME)
})?;
let output = <Act as Handles<$action>>::handle(actor, ctx, action).await?;
encode_cbor(&output, "action response")
}));
}
)+
None
}
}
};
}
impl_action_set!(A0);
impl_action_set!(A0, A1);
impl_action_set!(A0, A1, A2);
impl_action_set!(A0, A1, A2, A3);
impl_action_set!(A0, A1, A2, A3, A4);
impl_action_set!(A0, A1, A2, A3, A4, A5);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12);
impl_action_set!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13);
impl_action_set!(
A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14
);
impl_action_set!(
A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15
);
fn encode_cbor<T: Serialize>(value: &T, label: &str) -> Result<Vec<u8>> {
let mut encoded = Vec::new();
ciborium::into_writer(value, &mut encoded)
.with_context(|| format!("encode {label} as cbor"))?;
Ok(encoded)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Raw;
impl<'de> Deserialize<'de> for Raw {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let _ = de::IgnoredAny::deserialize(deserializer)?;
Err(de::Error::custom(
"rivetkit::action::Raw cannot be deserialized; use Action::raw_args() or Action::decode_as(...) instead",
))
}
}
#[cfg(test)]
mod tests {
use std::future::{Ready, ready};
use std::sync::Arc;
use anyhow::Result;
use serde::Deserialize;
use serde::Serialize;
use serde::de::value::{Error as ValueError, UnitDeserializer};
use super::{Action, ActionSet, Handles, Raw, decode_positional, encode_positional};
use crate::{actor::Actor, context::Ctx};
#[test]
fn raw_deserialize_fails_with_guidance() {
let err = Raw::deserialize(UnitDeserializer::<ValueError>::new())
.expect_err("Raw should refuse serde decoding");
let message = err.to_string();
assert!(message.contains("Action::raw_args()"));
assert!(message.contains("Action::decode_as"));
}
struct TestActor;
impl Actor for TestActor {
type State = ();
type Input = ();
type Actions = ();
type Events = ();
type Queue = ();
type ConnParams = ();
type ConnState = ();
type Action = Raw;
}
#[derive(Debug, Serialize, Deserialize)]
struct First;
impl Action for First {
type Output = ();
const NAME: &'static str = "first";
}
#[derive(Debug, Serialize, Deserialize)]
struct Second;
impl Action for Second {
type Output = ();
const NAME: &'static str = "second";
}
impl Handles<First> for TestActor {
type Future = Ready<Result<()>>;
fn handle(self: Arc<Self>, _ctx: Ctx<Self>, _action: First) -> Self::Future {
ready(Ok(()))
}
}
impl Handles<Second> for TestActor {
type Future = Ready<Result<()>>;
fn handle(self: Arc<Self>, _ctx: Ctx<Self>, _action: Second) -> Self::Future {
ready(Ok(()))
}
}
#[test]
fn action_set_unit_registers_nothing() {
assert!(<() as ActionSet<TestActor>>::entries().is_empty());
}
#[test]
fn action_set_tuple_registers_names_in_order() {
let entries = <(First, Second) as ActionSet<TestActor>>::entries();
assert_eq!(
entries.iter().map(|entry| entry.name).collect::<Vec<_>>(),
["first", "second",]
);
}
#[test]
fn action_set_tuple_supports_one_and_max_arity() {
assert_eq!(
<(First,) as ActionSet<TestActor>>::entries()
.iter()
.map(|entry| entry.name)
.collect::<Vec<_>>(),
["first"]
);
type MaxActions = (
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
First,
);
let entries = <MaxActions as ActionSet<TestActor>>::entries();
assert_eq!(entries.len(), super::TUPLE_ARITY_MAX);
assert!(entries.iter().all(|entry| entry.name == "first"));
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct NamedArgs {
first: String,
second: String,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct TupleArgs(String, String);
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct NewtypeArg(u32);
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct UnitArg;
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Nested {
value: u32,
label: String,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct WithNested {
nested: Nested,
flag: bool,
}
#[test]
fn positional_encode_has_ts_byte_parity() {
assert_eq!(
encode_positional(&NamedArgs {
first: "a".into(),
second: "b".into(),
})
.expect("encode named args"),
vec![0x82, 0x61, b'a', 0x61, b'b']
);
assert_eq!(
encode_positional(&NewtypeArg(5)).expect("encode newtype arg"),
vec![0x81, 0x05]
);
assert_eq!(
encode_positional(&UnitArg).expect("encode unit arg"),
vec![0x80]
);
}
#[test]
fn positional_round_trips_arg_shapes() {
let named = NamedArgs {
first: "a".into(),
second: "b".into(),
};
assert_eq!(
decode_positional::<NamedArgs>(&encode_positional(&named).expect("encode named args"))
.expect("decode named args"),
named
);
let tuple = TupleArgs("a".into(), "b".into());
assert_eq!(
decode_positional::<TupleArgs>(&encode_positional(&tuple).expect("encode tuple args"))
.expect("decode tuple args"),
tuple
);
assert_eq!(
decode_positional::<NewtypeArg>(
&encode_positional(&NewtypeArg(5)).expect("encode newtype arg")
)
.expect("decode newtype arg"),
NewtypeArg(5)
);
assert_eq!(
decode_positional::<UnitArg>(&encode_positional(&UnitArg).expect("encode unit arg"))
.expect("decode unit arg"),
UnitArg
);
}
#[test]
fn positional_decode_accepts_named_struct_seq_and_map() {
let from_seq = decode_positional::<NamedArgs>(&cbor(&vec!["a", "b"]))
.expect("decode named args from positional seq");
assert_eq!(
from_seq,
NamedArgs {
first: "a".into(),
second: "b".into(),
}
);
let from_map = decode_positional::<NamedArgs>(&cbor(&NamedArgs {
first: "a".into(),
second: "b".into(),
}))
.expect("decode named args from map");
assert_eq!(from_map, from_seq);
}
#[test]
fn positional_decode_uses_field_order() {
let decoded = decode_positional::<NamedArgs>(&cbor(&vec!["first", "second"]))
.expect("decode ordered fields");
assert_eq!(decoded.first, "first");
assert_eq!(decoded.second, "second");
let err = decode_positional::<NamedArgs>(&cbor(&vec![7, 8]))
.expect_err("wrong positional field types should fail");
assert!(err.to_string().contains("decode positional action args"));
}
#[test]
fn positional_encode_leaves_nested_struct_as_map() {
let bytes = encode_positional(&WithNested {
nested: Nested {
value: 7,
label: "inside".into(),
},
flag: true,
})
.expect("encode nested args");
let value: ciborium::Value =
ciborium::from_reader(std::io::Cursor::new(bytes)).expect("decode cbor value");
let ciborium::Value::Array(values) = value else {
panic!("top-level args should be an array");
};
assert_eq!(values.len(), 2);
assert!(matches!(values[0], ciborium::Value::Map(_)));
assert!(matches!(values[1], ciborium::Value::Bool(true)));
}
fn cbor<T: Serialize>(value: &T) -> Vec<u8> {
let mut encoded = Vec::new();
ciborium::into_writer(value, &mut encoded).expect("encode test value as cbor");
encoded
}
}