use std::fmt::Debug;
use std::hash::Hash;
use std::panic::RefUnwindSafe;
use std::sync::Arc;
use bon::Builder;
use tracing::Span;
use super::FactoryMessage;
use crate::concurrency::Duration;
use crate::concurrency::SystemTime;
#[cfg(feature = "cluster")]
use crate::message::BoxedDowncastErr;
use crate::ActorRef;
#[cfg(feature = "cluster")]
use crate::BytesConvertable;
use crate::Message;
use crate::RpcReplyPort;
#[cfg(feature = "cluster")]
pub trait JobKey:
Debug + Hash + Send + Sync + Clone + Eq + PartialEq + BytesConvertable + 'static
{
}
#[cfg(feature = "cluster")]
impl<T: Debug + Hash + Send + Sync + Clone + Eq + PartialEq + BytesConvertable + 'static> JobKey
for T
{
}
#[cfg(not(feature = "cluster"))]
pub trait JobKey: Debug + Hash + Send + Sync + Clone + Eq + PartialEq + 'static {}
#[cfg(not(feature = "cluster"))]
impl<T: Debug + Hash + Send + Sync + Clone + Eq + PartialEq + 'static> JobKey for T {}
#[derive(Debug, PartialEq, Clone)]
pub struct JobOptions {
submit_time: SystemTime,
factory_time: SystemTime,
worker_time: SystemTime,
ttl: Option<Duration>,
span: Option<Span>,
}
impl JobOptions {
pub fn new(ttl: Option<Duration>) -> Self {
let span = {
#[cfg(feature = "message_span_propogation")]
{
Some(Span::current())
}
#[cfg(not(feature = "message_span_propogation"))]
{
None
}
};
Self {
submit_time: SystemTime::now(),
factory_time: SystemTime::now(),
worker_time: SystemTime::now(),
ttl,
span,
}
}
pub fn ttl(&self) -> Option<Duration> {
self.ttl
}
pub fn set_ttl(&mut self, ttl: Option<Duration>) {
self.ttl = ttl;
}
pub fn submit_time(&self) -> SystemTime {
self.submit_time
}
pub fn worker_time(&self) -> SystemTime {
self.worker_time
}
pub fn factory_time(&self) -> SystemTime {
self.factory_time
}
pub fn span(&self) -> Option<Span> {
self.span.clone()
}
pub(crate) fn take_span(&mut self) -> Option<Span> {
self.span.take()
}
}
impl Default for JobOptions {
fn default() -> Self {
Self::new(None)
}
}
#[cfg(feature = "cluster")]
impl BytesConvertable for JobOptions {
fn into_bytes(self) -> Vec<u8> {
let submit_time = (self
.submit_time
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_nanos() as u64)
.to_be_bytes();
let ttl = self
.ttl
.map(|t| t.as_nanos() as u64)
.unwrap_or(0)
.to_be_bytes();
let mut data = vec![0u8; 16];
data[0..8].copy_from_slice(&submit_time);
data[8..16].copy_from_slice(&ttl);
data
}
fn from_bytes(mut data: Vec<u8>) -> Self {
if data.len() != 16 {
Self {
span: None,
..Default::default()
}
} else {
let ttl_bytes = data.split_off(8);
let submit_time = u64::from_be_bytes(data.try_into().unwrap()); let ttl = u64::from_be_bytes(ttl_bytes.try_into().unwrap());
Self {
submit_time: std::time::UNIX_EPOCH + Duration::from_nanos(submit_time),
ttl: if ttl > 0 {
Some(Duration::from_nanos(ttl))
} else {
None
},
span: None,
..Default::default()
}
}
}
}
#[derive(Builder)]
pub struct Job<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
pub key: TKey,
pub msg: TMsg,
#[builder(default = JobOptions::default())]
pub options: JobOptions,
pub accepted: Option<RpcReplyPort<Option<Self>>>,
}
impl<TKey, TMsg> Debug for Job<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Job")
.field("options", &self.options)
.field("has_accepted", &self.accepted.is_some())
.finish()
}
}
#[cfg(feature = "cluster")]
impl<TKey, TMsg> Job<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
fn serialize_meta(self) -> (Vec<u8>, TMsg) {
let options_bytes = self.options.into_bytes();
let key_bytes = self.key.into_bytes();
let mut meta = vec![0u8; 16 + key_bytes.len()];
meta[0..16].copy_from_slice(&options_bytes);
meta[16..].copy_from_slice(&key_bytes);
(meta, self.msg)
}
fn deserialize_meta(
maybe_bytes: Option<Vec<u8>>,
) -> Result<(TKey, JobOptions), BoxedDowncastErr> {
if let Some(mut meta_bytes) = maybe_bytes {
let key_bytes = meta_bytes.split_off(16);
Ok((
TKey::from_bytes(key_bytes),
JobOptions::from_bytes(meta_bytes),
))
} else {
Err(BoxedDowncastErr)
}
}
}
#[cfg(feature = "cluster")]
impl<TKey, TMsg> Message for Job<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
fn serializable() -> bool {
TMsg::serializable()
}
fn serialize(self) -> Result<crate::message::SerializedMessage, BoxedDowncastErr> {
let (meta_bytes, msg) = self.serialize_meta();
let inner_message = msg.serialize()?;
match inner_message {
crate::message::SerializedMessage::CallReply(_, _) => Err(BoxedDowncastErr),
crate::message::SerializedMessage::Call {
variant,
args,
reply,
..
} => Ok(crate::message::SerializedMessage::Call {
variant,
args,
reply,
metadata: Some(meta_bytes),
}),
crate::message::SerializedMessage::Cast { variant, args, .. } => {
Ok(crate::message::SerializedMessage::Cast {
variant,
args,
metadata: Some(meta_bytes),
})
}
}
}
fn deserialize(bytes: crate::message::SerializedMessage) -> Result<Self, BoxedDowncastErr> {
match bytes {
crate::message::SerializedMessage::CallReply(_, _) => Err(BoxedDowncastErr),
crate::message::SerializedMessage::Cast {
variant,
args,
metadata,
} => {
let (key, options) = Self::deserialize_meta(metadata)?;
let msg = TMsg::deserialize(crate::message::SerializedMessage::Cast {
variant,
args,
metadata: None,
})?;
Ok(Self {
msg,
key,
options,
accepted: None,
})
}
crate::message::SerializedMessage::Call {
variant,
args,
reply,
metadata,
} => {
let (key, options) = Self::deserialize_meta(metadata)?;
let msg = TMsg::deserialize(crate::message::SerializedMessage::Call {
variant,
args,
reply,
metadata: None,
})?;
Ok(Self {
msg,
key,
options,
accepted: None,
})
}
}
}
}
impl<TKey, TMsg> Job<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.options.ttl {
self.options.submit_time.elapsed().unwrap() > ttl
} else {
false
}
}
pub(crate) fn set_factory_time(&mut self) {
self.options.factory_time = SystemTime::now();
}
pub(crate) fn set_worker_time(&mut self) {
self.options.worker_time = SystemTime::now();
}
pub(crate) fn accept(&mut self) {
if let Some(port) = self.accepted.take() {
let _ = port.send(None);
}
}
pub(crate) fn reject(mut self) {
if let Some(port) = self.accepted.take() {
let _ = port.send(Some(self));
}
}
}
#[derive(Debug)]
pub enum MessageRetryStrategy {
RetryForever,
Count(usize),
NoRetry,
}
impl MessageRetryStrategy {
fn has_retries(&self) -> bool {
match self {
Self::RetryForever => true,
Self::Count(n) if *n > 0 => true,
_ => false,
}
}
fn decrement(&self) -> Self {
match self {
Self::Count(n) if *n > 1 => Self::Count(*n - 1),
Self::RetryForever => Self::RetryForever,
_ => Self::NoRetry,
}
}
}
pub struct RetriableMessage<TKey: JobKey, TMessage: Message> {
pub key: TKey,
pub message: Option<TMessage>,
pub strategy: MessageRetryStrategy,
#[allow(clippy::type_complexity)]
pub retry_hook: Option<Arc<dyn Fn(&TKey) + 'static + Send + Sync + RefUnwindSafe>>,
retry_state: Option<(JobOptions, ActorRef<FactoryMessage<TKey, Self>>)>,
}
impl<TKey, TMsg> Debug for RetriableMessage<TKey, TMsg>
where
TKey: JobKey,
TMsg: Message,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetriableMessage")
.field("key", &self.key)
.field("strategy", &self.strategy)
.field("message", &self.message.is_some())
.field("retry_hook", &self.retry_hook.is_some())
.field("retry_state", &self.retry_state.is_some())
.finish()
}
}
#[cfg(feature = "cluster")]
impl<TKey: JobKey, TMessage: Message> Message for RetriableMessage<TKey, TMessage> {}
impl<TKey: JobKey, TMessage: Message> Drop for RetriableMessage<TKey, TMessage> {
fn drop(&mut self) {
tracing::trace!("Drop handler for retriable message executing {self:?}");
if !self.strategy.has_retries() || self.message.is_none() {
return;
}
let Some((options, factory)) = self.retry_state.as_ref() else {
return;
};
let msg = Self {
key: self.key.clone(),
message: self.message.take(),
strategy: self.strategy.decrement(),
retry_state: Some((options.clone(), factory.clone())),
retry_hook: self.retry_hook.take(),
};
let job = Job {
accepted: None, key: self.key.clone(),
msg,
options: options.clone(),
};
if let Some(handler) = job.msg.retry_hook.as_ref() {
let key = std::panic::AssertUnwindSafe(&job.key);
let f = handler.clone();
_ = std::panic::catch_unwind(move || (f)(*key));
}
tracing::trace!(
"A retriable job is being resubmitted to the factory. Number of retries left {:?}",
self.strategy
);
_ = factory.cast(FactoryMessage::Dispatch(job));
}
}
impl<TKey, TMsg> ActorRef<FactoryMessage<TKey, RetriableMessage<TKey, TMsg>>>
where
TKey: JobKey,
TMsg: Message,
{
#[allow(clippy::type_complexity)]
pub fn submit_retriable_job(
&self,
job: Job<TKey, TMsg>,
strategy: MessageRetryStrategy,
) -> Result<(), Box<crate::MessagingErr<FactoryMessage<TKey, RetriableMessage<TKey, TMsg>>>>>
{
let job = RetriableMessage::from_job(job, strategy, self.clone());
Ok(self.cast(FactoryMessage::Dispatch(job))?)
}
}
impl<TKey: JobKey, TMessage: Message> RetriableMessage<TKey, TMessage> {
pub fn new(key: TKey, message: TMessage, strategy: MessageRetryStrategy) -> Self {
Self {
key,
message: Some(message),
strategy,
retry_state: None,
retry_hook: None,
}
}
pub fn set_retry_hook(&mut self, f: impl Fn(&TKey) + 'static + Send + Sync + RefUnwindSafe) {
self.retry_hook = Some(Arc::new(f));
}
pub fn from_job(
Job {
key, msg, options, ..
}: Job<TKey, TMessage>,
strategy: MessageRetryStrategy,
factory: ActorRef<FactoryMessage<TKey, Self>>,
) -> Job<TKey, Self> {
let mut retriable = RetriableMessage::new(key.clone(), msg, strategy);
retriable.capture_retry_state(&options, factory);
Job::<TKey, Self> {
accepted: None,
key,
msg: retriable,
options,
}
}
pub fn capture_retry_state(
&mut self,
options: &JobOptions,
factory: ActorRef<FactoryMessage<TKey, Self>>,
) {
self.retry_state = Some((options.clone(), factory));
}
pub fn completed(&mut self) {
self.strategy = MessageRetryStrategy::NoRetry;
self.message = None;
}
}
#[cfg(feature = "cluster")]
#[cfg(test)]
mod tests {
use super::super::FactoryMessage;
use super::Job;
use crate::concurrency::Duration;
use crate::factory::JobOptions;
use crate::message::SerializedMessage;
use crate::serialization::BytesConvertable;
use crate::Message;
use crate::RpcReplyPort;
#[derive(Eq, Hash, PartialEq, Clone, Debug)]
struct TestKey {
item: u64,
}
impl crate::BytesConvertable for TestKey {
fn from_bytes(bytes: Vec<u8>) -> Self {
Self {
item: u64::from_bytes(bytes),
}
}
fn into_bytes(self) -> Vec<u8> {
self.item.into_bytes()
}
}
#[derive(Debug)]
enum TestMessage {
#[allow(dead_code)]
A(String),
#[allow(dead_code)]
B(String, RpcReplyPort<u128>),
}
impl crate::Message for TestMessage {
fn serializable() -> bool {
true
}
fn serialize(
self,
) -> Result<crate::message::SerializedMessage, crate::message::BoxedDowncastErr> {
match self {
Self::A(args) => Ok(crate::message::SerializedMessage::Cast {
variant: "A".to_string(),
args: <String as BytesConvertable>::into_bytes(args),
metadata: None,
}),
Self::B(args, _reply) => {
let (tx, _rx) = crate::concurrency::oneshot();
Ok(crate::message::SerializedMessage::Call {
variant: "B".to_string(),
args: <String as BytesConvertable>::into_bytes(args),
reply: tx.into(),
metadata: None,
})
}
}
}
fn deserialize(
bytes: crate::message::SerializedMessage,
) -> Result<Self, crate::message::BoxedDowncastErr> {
match bytes {
crate::message::SerializedMessage::Cast { variant, args, .. } => {
match variant.as_str() {
"A" => Ok(Self::A(<String as BytesConvertable>::from_bytes(args))),
_ => Err(crate::message::BoxedDowncastErr),
}
}
crate::message::SerializedMessage::Call { variant, args, .. } => {
match variant.as_str() {
"B" => {
let (tx, _rx) = crate::concurrency::oneshot();
Ok(Self::B(
<String as BytesConvertable>::from_bytes(args),
tx.into(),
))
}
_ => Err(crate::message::BoxedDowncastErr),
}
}
_ => Err(crate::message::BoxedDowncastErr),
}
}
}
type TheJob = Job<TestKey, TestMessage>;
#[test]
#[cfg_attr(
not(all(target_arch = "wasm32", target_os = "unknown")),
tracing_test::traced_test
)]
fn test_job_serialization() {
let job_a = TheJob {
key: TestKey { item: 123 },
msg: TestMessage::A("Hello".to_string()),
options: JobOptions::default(),
accepted: None,
};
let expected_a = TheJob {
key: TestKey { item: 123 },
msg: TestMessage::A("Hello".to_string()),
options: job_a.options.clone(),
accepted: None,
};
let serialized_a = job_a.serialize().expect("Failed to serialize job A");
let deserialized_a =
TheJob::deserialize(serialized_a).expect("Failed to deserialize job A");
assert_eq!(expected_a.key, deserialized_a.key);
assert_eq!(
expected_a.options.submit_time,
deserialized_a.options.submit_time
);
assert_eq!(expected_a.options.ttl, deserialized_a.options.ttl);
if let TestMessage::A(the_msg) = deserialized_a.msg {
assert_eq!("Hello".to_string(), the_msg);
} else {
panic!("Failed to deserialize the message payload");
}
let job_b = TheJob {
key: TestKey { item: 456 },
msg: TestMessage::B("Hi".to_string(), crate::concurrency::oneshot().0.into()),
options: JobOptions {
ttl: Some(Duration::from_millis(1000)),
..Default::default()
},
accepted: None,
};
let expected_b = TheJob {
key: TestKey { item: 456 },
msg: TestMessage::B("Hi".to_string(), crate::concurrency::oneshot().0.into()),
options: job_b.options.clone(),
accepted: None,
};
let serialized_b = job_b.serialize().expect("Failed to serialize job B");
let deserialized_b =
TheJob::deserialize(serialized_b).expect("Failed to deserialize job A");
assert_eq!(expected_b.key, deserialized_b.key);
assert_eq!(
expected_b.options.submit_time,
deserialized_b.options.submit_time
);
assert_eq!(expected_b.options.ttl, deserialized_b.options.ttl);
if let TestMessage::B(the_msg, _) = deserialized_b.msg {
assert_eq!("Hi".to_string(), the_msg);
} else {
panic!("Failed to deserialize the message payload");
}
}
#[test]
#[cfg_attr(
not(all(target_arch = "wasm32", target_os = "unknown")),
tracing_test::traced_test
)]
fn test_factory_message_serialization() {
let job_a = TheJob {
key: TestKey { item: 123 },
msg: TestMessage::A("Hello".to_string()),
options: JobOptions::default(),
accepted: None,
};
let expected_a = TheJob {
key: TestKey { item: 123 },
msg: TestMessage::A("Hello".to_string()),
options: job_a.options.clone(),
accepted: None,
};
let msg = FactoryMessage::Dispatch(job_a);
let serialized_a = msg.serialize().expect("Failed to serialize");
let serialized_a_prime = expected_a.serialize().expect("Failed to serialize");
if let (
SerializedMessage::Cast {
variant: variant1,
args: args1,
metadata: metadata1,
},
SerializedMessage::Cast {
variant: variant2,
args: args2,
metadata: metadata2,
},
) = (serialized_a, serialized_a_prime)
{
assert_eq!(variant1, variant2);
assert_eq!(args1, args2);
assert_eq!(metadata1, metadata2);
} else {
panic!("Non-cast serialization")
}
}
}