use std::{
any::{Any, type_name},
fmt::{Debug, Display},
future::Future,
hash::Hash,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::{
channel::oneshot::{self, Canceled},
future::BoxFuture,
};
use theta_flume::SendError;
use thiserror::Error;
use tracing::{error, warn};
use crate::{
actor::{Actor, ActorId},
base::{BindingError, Hex, Ident, parse_ident},
compat,
context::BINDINGS,
message::{
Continuation, Escalation, InternalSignal, Message, MsgPack, MsgTx, RawSignal, SigTx,
WeakMsgTx, WeakSigTx,
},
monitor::monitor_local_id,
monitor::monitor_remote_id,
};
#[cfg(feature = "monitor")]
use crate::{
base::MonitorError,
monitor::{AnyUpdateTx, UpdateTx},
};
#[cfg(feature = "remote")]
use {
crate::{
prelude::RemoteError,
remote::{
base::{ActorTypeId, Tag, parse_url},
network::{RecvFrameExt, SendFrameExt},
peer::{LocalPeer, PEER, Peer},
serde::{ActorRefDto, ContinuationDto, ForwardInfo, FromTaggedBytes, MsgPackDto},
},
},
iroh::{
PublicKey,
endpoint::{RecvStream, SendStream},
},
tracing::{debug, trace},
url::Url,
};
#[cfg(wasm_browser)]
use futures::future::LocalBoxFuture;
#[cfg(all(feature = "remote", feature = "monitor"))]
use {
crate::monitor::Update, theta_flume::unbounded_anonymous,
tracing::level_filters::STATIC_MAX_LEVEL,
};
pub trait AnyActorRef: Debug + Send + Sync + Any {
fn id(&self) -> Option<ActorId>;
#[cfg(feature = "remote")]
fn send_tagged_bytes(&self, tag: Tag, bytes: Vec<u8>) -> Result<(), BytesSendError>;
fn as_any(&self) -> Option<Box<dyn Any>>;
#[cfg(feature = "remote")]
fn serialize(&self) -> Result<Vec<u8>, BindingError>;
#[cfg(feature = "remote")]
fn spawn_export_task(
&self,
peer: Peer,
in_stream: RecvStream,
reply_stream: SendStream,
) -> Result<(), BindingError>;
#[cfg(all(feature = "remote", feature = "monitor"))]
fn monitor_as_bytes(
&self,
peer: Peer,
hdl: ActorHdl,
tx_stream: SendStream,
) -> Result<(), MonitorError>;
#[cfg(feature = "remote")]
fn ty_id(&self) -> ActorTypeId;
#[cfg(feature = "remote")]
fn sender_count(&self) -> usize;
}
#[derive(Debug)]
pub struct ActorRef<A: Actor>(pub(crate) MsgTx<A>);
#[derive(Debug)]
pub struct WeakActorRef<A: Actor>(pub(crate) WeakMsgTx<A>);
#[derive(Debug, Clone)]
pub struct ActorHdl(pub(crate) SigTx);
#[derive(Debug, Clone)]
pub struct WeakActorHdl(pub(crate) WeakSigTx);
pub struct MsgRequest<'a, A, M>
where
A: Actor,
M: Send + Message<A> + 'static,
{
target: &'a ActorRef<A>,
msg: M,
}
pub struct SignalRequest<'a> {
target_hdl: &'a ActorHdl,
sig: InternalSignal,
}
pub struct Deadline<'a, R>
where
R: IntoFuture,
{
request: R,
duration: Duration,
_phantom: PhantomData<&'a ()>,
}
#[cfg(feature = "remote")]
#[derive(Debug, Error)]
pub enum BytesSendError {
#[error(transparent)]
DeserializeError(#[from] postcard::Error),
#[error(transparent)]
SendError(#[from] SendError<(Tag, Vec<u8>)>),
}
#[derive(Debug, Error)]
pub enum RequestError<T> {
#[error(transparent)]
Cancelled(#[from] Canceled),
#[error(transparent)]
SendError(#[from] SendError<T>),
#[error("receiving on a closed channel")]
RecvError,
#[error("downcast failed")]
DowncastError,
#[cfg(feature = "remote")]
#[error(transparent)]
DeserializeError(#[from] postcard::Error),
#[error("timeout")]
Timeout,
}
impl<A> ActorRef<A>
where
A: Actor,
{
pub fn id(&self) -> ActorId {
self.0.id()
}
pub fn ref_id(&self) -> usize {
self.0.ptr_id()
}
pub fn ident(&self) -> Ident {
*self.0.id().as_bytes()
}
pub fn is_nil(&self) -> bool {
self.0.id().is_nil()
}
pub fn tell<M>(&self, msg: M) -> Result<(), SendError<(A::Msg, Continuation)>>
where
M: Message<A>,
{
self.send(msg.into(), Continuation::Nil)
}
pub const fn ask<M>(&self, msg: M) -> MsgRequest<'_, A, M>
where
M: Message<A>,
{
MsgRequest { target: self, msg }
}
pub fn forward<M, B>(
&self,
msg: M,
target: ActorRef<B>,
) -> Result<(), SendError<(A::Msg, Continuation)>>
where
M: Message<A>,
B: Actor,
<M as Message<A>>::Return: Message<B>,
{
let (tx, rx) = oneshot::channel::<Box<dyn Any + Send>>();
compat::spawn(async move {
let Ok(ret) = rx.await else {
return;
};
let b_msg = match ret.downcast::<<M as Message<A>>::Return>() {
Err(ret) => {
#[cfg(not(feature = "remote"))]
{
return error!(
% target, expected = type_name::<< M as Message < A >>::Return > (),
"failed to downcast forwarded msg"
);
}
#[cfg(feature = "remote")]
{
let Ok(tx) = ret.downcast::<oneshot::Sender<ForwardInfo>>() else {
return error!(
% target, expected_type = type_name::<< M as Message < A >>::Return
> (),
"failed to downcast forwarded msg (expected type or oneshot::Sender<ForwardInfo>)"
);
};
let forward_info = ForwardInfo {
actor_id: target.id(),
tag: <<M as Message<A>>::Return as Message<B>>::TAG,
};
if tx.send(forward_info).is_err() {
return error!(% target, "failed to send forward info");
}
return debug!(% target, "delegated forwarding",);
}
}
Ok(b_msg) => b_msg,
};
if let Err(_err) = target.send((*b_msg).into(), Continuation::Nil) {
warn!(% target, "forward target terminated, message dropped");
}
});
let continuation = Continuation::forward(tx);
self.send(msg.into(), continuation)
}
pub fn send(
&self,
msg: A::Msg,
k: Continuation,
) -> Result<(), SendError<(A::Msg, Continuation)>> {
self.0.send((msg, k))
}
pub fn downgrade(&self) -> WeakActorRef<A> {
WeakActorRef(self.0.downgrade())
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
#[cfg(all(feature = "monitor", feature = "remote"))]
pub async fn monitor(&self, tx: UpdateTx<A>) -> Result<(), RemoteError> {
match LocalPeer::inst().get_import_public_key(&self.id()) {
None => self.monitor_local(tx).map_err(RemoteError::MonitorError),
Some(public_key) => self.monitor_remote(public_key, tx).await,
}
}
#[cfg(feature = "monitor")]
pub fn monitor_local(&self, tx: UpdateTx<A>) -> Result<(), MonitorError> {
monitor_local_id(self.id(), tx)
}
#[cfg(all(feature = "monitor", feature = "remote"))]
pub async fn monitor_remote(
&self,
public_key: PublicKey,
tx: UpdateTx<A>,
) -> Result<(), RemoteError> {
monitor_remote_id(self.id(), public_key, tx).await
}
#[cfg(feature = "remote")]
pub async fn lookup(ident_or_url: impl AsRef<str>) -> Result<Self, RemoteError> {
let ident_or_url = ident_or_url.as_ref();
match Url::parse(ident_or_url) {
Err(_) => Ok(Self::lookup_local(ident_or_url)?),
Ok(url) => {
let (ident, public_key) = parse_url(&url)?;
Ok(Self::lookup_remote_impl(ident, public_key).await??)
}
}
}
#[cfg(feature = "remote")]
pub async fn lookup_remote(
ident: impl AsRef<str>,
public_key: PublicKey,
) -> Result<Result<Self, BindingError>, RemoteError> {
let ident = parse_ident(ident.as_ref())?;
Self::lookup_remote_impl(ident, public_key).await
}
pub fn lookup_local(ident: impl AsRef<str>) -> Result<Self, BindingError> {
let ident = parse_ident(ident.as_ref())?;
Self::lookup_local_impl(&ident)
}
#[cfg(feature = "remote")]
pub(crate) async fn lookup_remote_impl(
ident: Ident,
public_key: PublicKey,
) -> Result<Result<Self, BindingError>, RemoteError> {
let peer = LocalPeer::inst().get_or_connect_peer(public_key);
peer.lookup(ident).await
}
pub(crate) fn lookup_local_impl(ident: &Ident) -> Result<Self, BindingError> {
let entry = BINDINGS.get(ident).ok_or(BindingError::NotFound)?;
let Some(any_actor) = entry.as_any() else {
return Err(BindingError::NotFound);
};
let actor = any_actor
.downcast_ref::<Self>()
.ok_or(BindingError::DowncastError)?;
Ok(actor.clone())
}
}
impl<A> Clone for ActorRef<A>
where
A: Actor,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A> PartialEq for ActorRef<A>
where
A: Actor,
{
fn eq(&self, other: &Self) -> bool {
self.0.id() == other.0.id()
}
}
impl<A> Eq for ActorRef<A> where A: Actor {}
impl<A> Hash for ActorRef<A>
where
A: Actor,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.id().hash(state);
}
}
impl<A> PartialOrd for ActorRef<A>
where
A: Actor,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.0.id().cmp(&other.0.id()))
}
}
impl<A> Display for ActorRef<A>
where
A: Actor,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({})", type_name::<A>(), Hex(self.0.id().as_bytes()))
}
}
impl<A: Actor + Any> AnyActorRef for ActorRef<A> {
fn id(&self) -> Option<ActorId> {
Some(self.0.id())
}
#[cfg(feature = "remote")]
fn send_tagged_bytes(&self, tag: Tag, bytes: Vec<u8>) -> Result<(), BytesSendError> {
trace!(tag, bytes = bytes.len(), target = % self, "sending tagged bytes",);
let msg = <A::Msg as FromTaggedBytes>::from(tag, &bytes)?;
self.send(msg, Continuation::Nil)
.map_err(|_| SendError((tag, bytes)))?;
Ok(())
}
fn as_any(&self) -> Option<Box<dyn Any>> {
Some(Box::new(self.clone()))
}
#[cfg(feature = "remote")]
fn serialize(&self) -> Result<Vec<u8>, BindingError> {
Ok(postcard::to_stdvec(&ActorRefDto::from(self))?)
}
#[cfg(feature = "remote")]
fn spawn_export_task(
&self,
peer: Peer,
mut rx_stream: RecvStream,
mut reply_stream: SendStream,
) -> Result<(), BindingError> {
compat::spawn({
let this = self.clone();
PEER.scope(peer.clone(), async move {
trace!(actor = % this, "listening exported",);
let mut buf = Vec::new();
loop {
buf.clear();
if let Err(err) = rx_stream.recv_frame_into(&mut buf).await {
return warn!(actor = % this, % err, "stopped exported",);
}
#[cfg(feature = "verbose")]
debug!(actor = % this, bytes = buf.len(), "received remote message",);
let (msg, k_dto) = match postcard::from_bytes::<MsgPackDto<A>>(&buf) {
Err(err) => {
error!(% err, "failed to deserialize remote message");
continue;
}
Ok(msg_k_dto) => msg_k_dto,
};
match k_dto {
ContinuationDto::Reply => {
let (reply_tx, reply_rx) = oneshot::channel::<Vec<u8>>();
let k = Continuation::BinReply {
peer: peer.clone(),
reply_tx,
};
#[cfg(feature = "verbose")]
debug!(to = % this, msg = ? msg, "dispatching ask to actor");
if let Err(err) = this.send(msg, k) {
break warn!(actor = % this, % err, "stopped exported");
}
let Ok(bytes) = reply_rx.await else {
break warn!(
actor = % this, "reply oneshot cancelled (actor dropped?)"
);
};
if let Err(err) = reply_stream.send_frame(bytes).await {
warn!(
actor = % this, % err,
"failed to send reply on bi-stream, skipping"
);
}
}
ContinuationDto::Nil => {
let k = Continuation::Nil;
#[cfg(feature = "verbose")]
debug!(to = % this, msg = ? msg, "dispatching tell to actor");
if let Err(err) = this.send(msg, k) {
break warn!(actor = % this, % err, "stopped exported");
}
}
ContinuationDto::Forward { .. } => {
let k: Continuation = k_dto.into();
if let Err(err) = this.send(msg, k) {
break warn!(actor = % this, % err, "stopped exported");
}
}
}
}
})
});
Ok(())
}
#[cfg(all(feature = "remote", feature = "monitor"))]
fn monitor_as_bytes(
&self,
peer: Peer,
hdl: ActorHdl,
mut tx_stream: SendStream,
) -> Result<(), MonitorError> {
let (tx, rx) = unbounded_anonymous::<Update<A>>();
if STATIC_MAX_LEVEL >= tracing::Level::WARN {
let (remote, actor) = (format!("{peer}"), format!("{self}"));
compat::spawn(PEER.scope(peer, async move {
trace!(% actor, % remote, "remote monitoring local",);
loop {
let Some(update) = rx.recv().await else {
return warn!(
% actor, % remote, "remote monitoring channel closed"
);
};
let bytes = match postcard::to_stdvec(&update) {
Err(err) => {
warn!(
% actor, % remote, % err, "failed to serialize update"
);
continue;
}
Ok(buf) => buf,
};
if let Err(err) = tx_stream.send_frame(bytes).await {
break warn!(
% actor, % remote, % err, "failed to send serialized update"
);
}
}
}));
} else {
compat::spawn(PEER.scope(peer, async move {
loop {
let Some(update) = rx.recv().await else {
return;
};
let Ok(bytes) = postcard::to_stdvec(&update) else {
continue;
};
if tx_stream.send_frame(bytes).await.is_err() {
break;
}
}
}));
}
hdl.monitor(Box::new(tx))
.map_err(|_| MonitorError::SigSendError)?;
Ok(())
}
#[cfg(feature = "remote")]
fn ty_id(&self) -> ActorTypeId {
A::IMPL_ID
}
#[cfg(feature = "remote")]
fn sender_count(&self) -> usize {
self.0.sender_count()
}
}
impl<A> WeakActorRef<A>
where
A: Actor,
{
pub fn ref_id(&self) -> usize {
self.0.ptr_id()
}
pub fn upgrade(&self) -> Option<ActorRef<A>> {
self.0.upgrade().map(|tx| ActorRef(tx))
}
}
impl<A> Clone for WeakActorRef<A>
where
A: Actor,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A> AnyActorRef for WeakActorRef<A>
where
A: Actor + Any,
{
fn id(&self) -> Option<ActorId> {
self.upgrade().map(|a| a.id())
}
#[cfg(feature = "remote")]
fn send_tagged_bytes(&self, tag: Tag, bytes: Vec<u8>) -> Result<(), BytesSendError> {
match self.upgrade() {
None => Err(BytesSendError::SendError(SendError((tag, bytes)))),
Some(actor) => actor.send_tagged_bytes(tag, bytes),
}
}
fn as_any(&self) -> Option<Box<dyn Any>> {
match self.upgrade() {
None => None,
Some(actor) => actor.as_any(),
}
}
#[cfg(feature = "remote")]
fn serialize(&self) -> Result<Vec<u8>, BindingError> {
match self.upgrade() {
None => Err(BindingError::NotFound),
Some(actor) => actor.serialize(),
}
}
#[cfg(feature = "remote")]
fn spawn_export_task(
&self,
peer: Peer,
rx_stream: RecvStream,
reply_stream: SendStream,
) -> Result<(), BindingError> {
match self.upgrade() {
None => Err(BindingError::NotFound),
Some(actor) => actor.spawn_export_task(peer, rx_stream, reply_stream),
}
}
#[cfg(all(feature = "remote", feature = "monitor"))]
fn monitor_as_bytes(
&self,
peer: Peer,
hdl: ActorHdl,
tx_stream: SendStream,
) -> Result<(), MonitorError> {
match self.upgrade() {
None => Err(MonitorError::SigSendError),
Some(actor) => actor.monitor_as_bytes(peer, hdl, tx_stream),
}
}
#[cfg(feature = "remote")]
fn ty_id(&self) -> ActorTypeId {
A::IMPL_ID
}
#[cfg(feature = "remote")]
fn sender_count(&self) -> usize {
match self.upgrade() {
None => 0,
Some(actor) => actor.sender_count(),
}
}
}
impl ActorHdl {
#[allow(dead_code)]
pub(crate) fn id(&self) -> ActorId {
self.0.id()
}
#[allow(dead_code)]
pub(crate) const fn ask_sig(&self, sig: InternalSignal) -> SignalRequest<'_> {
SignalRequest {
target_hdl: self,
sig,
}
}
#[allow(dead_code)]
pub(crate) fn tell_sig(&self, sig: InternalSignal) -> Result<(), SendError<RawSignal>> {
self.raw_send(sig.into_raw(None))
}
pub(crate) fn downgrade(&self) -> WeakActorHdl {
WeakActorHdl(self.0.downgrade())
}
pub(crate) fn escalate(
&self,
this_hdl: Self,
escalation: Escalation,
) -> Result<(), SendError<RawSignal>> {
self.raw_send(RawSignal::Escalation(this_hdl, escalation))
}
#[cfg(feature = "monitor")]
pub(crate) fn monitor(&self, tx: AnyUpdateTx) -> Result<(), SendError<RawSignal>> {
self.raw_send(RawSignal::Monitor(tx))
}
pub(crate) fn raw_send(&self, raw_sig: RawSignal) -> Result<(), SendError<RawSignal>> {
self.0.send(raw_sig)
}
}
impl PartialEq for ActorHdl {
fn eq(&self, other: &Self) -> bool {
self.0.same_channel(&other.0)
}
}
impl WeakActorHdl {
pub(crate) fn upgrade(&self) -> Option<ActorHdl> {
Some(ActorHdl(self.0.upgrade()?))
}
}
impl<'a, A, M> MsgRequest<'a, A, M>
where
A: Actor,
M: Message<A>,
{
pub const fn timeout(self, duration: Duration) -> Deadline<'a, Self> {
Deadline {
request: self,
duration,
_phantom: PhantomData,
}
}
}
impl<'a, A, M> IntoFuture for MsgRequest<'a, A, M>
where
A: Actor,
M: Message<A>,
{
type Output = Result<<M as Message<A>>::Return, RequestError<MsgPack<A>>>;
type IntoFuture = BoxFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let (tx, rx) = oneshot::channel();
self.target
.0
.send((self.msg.into(), Continuation::reply(tx)))?;
let ret = rx.await?;
match ret.downcast::<M::Return>() {
Err(ret) => {
#[cfg(not(feature = "remote"))]
return Err(RequestError::DowncastError);
#[cfg(feature = "remote")]
{
let Ok(remote_reply_rx) =
ret.downcast::<oneshot::Receiver<(Peer, Vec<u8>)>>()
else {
return Err(RequestError::DowncastError);
};
let (peer, bytes) = remote_reply_rx.await?;
let res =
PEER.sync_scope(peer, || postcard::from_bytes::<M::Return>(&bytes))?;
Ok(res)
}
}
Ok(res) => Ok(*res),
}
})
}
}
impl<'a> SignalRequest<'a> {
pub const fn timeout(self, duration: Duration) -> Deadline<'a, Self> {
Deadline {
request: self,
duration,
_phantom: PhantomData,
}
}
}
pub enum SignalAckFuture {
#[doc(hidden)]
Waiting(oneshot::Receiver<()>),
#[doc(hidden)]
Err(Option<RequestError<RawSignal>>),
}
impl Future for SignalAckFuture {
type Output = Result<(), RequestError<RawSignal>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
Self::Waiting(rx) => match Pin::new(rx).poll(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(RequestError::Cancelled(e))),
Poll::Pending => Poll::Pending,
},
Self::Err(err) => Poll::Ready(Err(err
.take()
.expect("SignalAckFuture polled after completion"))),
}
}
}
impl IntoFuture for SignalRequest<'_> {
type Output = Result<(), RequestError<RawSignal>>;
type IntoFuture = SignalAckFuture;
fn into_future(self) -> Self::IntoFuture {
let (tx, rx) = oneshot::channel();
match self.target_hdl.raw_send(self.sig.into_raw(Some(tx))) {
Ok(()) => SignalAckFuture::Waiting(rx),
Err(e) => SignalAckFuture::Err(Some(e.into())),
}
}
}
#[cfg(not(wasm_browser))]
impl<'a, R, T, S> IntoFuture for Deadline<'a, R>
where
R: 'a + IntoFuture<Output = Result<T, RequestError<S>>> + Send,
R::IntoFuture: Send,
{
type Output = Result<T, RequestError<S>>;
type IntoFuture = BoxFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
match compat::timeout(self.duration, self.request.into_future()).await {
Err(_) => Err(RequestError::Timeout),
Ok(result) => result,
}
})
}
}
#[cfg(wasm_browser)]
impl<'a, R, T, S> IntoFuture for Deadline<'a, R>
where
R: 'a + IntoFuture<Output = Result<T, RequestError<S>>>,
{
type Output = Result<T, RequestError<S>>;
type IntoFuture = LocalBoxFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
match compat::timeout(self.duration, self.request.into_future()).await {
Err(_) => Err(RequestError::Timeout),
Ok(result) => result,
}
})
}
}