use actr_protocol::{
ActorResult, ActrError, ActrId, ActrType, DataStream, PayloadType, RpcRequest,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::future::BoxFuture;
use prost::Message as ProstMessage;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use crate::{Context, Dest, MediaSample};
use super::context_helpers::{
actr_id_from_wit, actr_id_to_wit, actr_type_to_wit, data_stream_from_wit, data_stream_to_wit,
dest_to_wit, payload_type_to_wit,
};
use super::generated::actr::workload::host as wit_host;
use super::generated::actr::workload::types as wit_types;
pub(crate) fn wit_actr_error_to_proto(e: wit_types::ActrError) -> ActrError {
match e {
wit_types::ActrError::Unavailable(msg) => ActrError::Unavailable(msg),
wit_types::ActrError::TimedOut => ActrError::TimedOut,
wit_types::ActrError::NotFound(msg) => ActrError::NotFound(msg),
wit_types::ActrError::PermissionDenied(msg) => ActrError::PermissionDenied(msg),
wit_types::ActrError::InvalidArgument(msg) => ActrError::InvalidArgument(msg),
wit_types::ActrError::UnknownRoute(msg) => ActrError::UnknownRoute(msg),
wit_types::ActrError::DependencyNotFound(p) => ActrError::DependencyNotFound {
service_name: p.service_name,
message: p.message,
},
wit_types::ActrError::DecodeFailure(msg) => ActrError::DecodeFailure(msg),
wit_types::ActrError::NotImplemented(msg) => ActrError::NotImplemented(msg),
wit_types::ActrError::Internal(msg) => ActrError::Internal(msg),
}
}
pub(crate) fn proto_actr_error_to_wit(e: ActrError) -> wit_types::ActrError {
match e {
ActrError::Unavailable(msg) => wit_types::ActrError::Unavailable(msg),
ActrError::TimedOut => wit_types::ActrError::TimedOut,
ActrError::NotFound(msg) => wit_types::ActrError::NotFound(msg),
ActrError::PermissionDenied(msg) => wit_types::ActrError::PermissionDenied(msg),
ActrError::InvalidArgument(msg) => wit_types::ActrError::InvalidArgument(msg),
ActrError::UnknownRoute(msg) => wit_types::ActrError::UnknownRoute(msg),
ActrError::DependencyNotFound {
service_name,
message,
} => wit_types::ActrError::DependencyNotFound(wit_types::DependencyNotFoundPayload {
service_name,
message,
}),
ActrError::DecodeFailure(msg) => wit_types::ActrError::DecodeFailure(msg),
ActrError::NotImplemented(msg) => wit_types::ActrError::NotImplemented(msg),
ActrError::Internal(msg) => wit_types::ActrError::Internal(msg),
}
}
#[derive(Clone)]
pub(crate) struct WasmContext {
self_id: ActrId,
caller_id: Option<ActrId>,
request_id: String,
}
type StreamCallback =
Arc<dyn Fn(DataStream, ActrId) -> BoxFuture<'static, ActorResult<()>> + Send + Sync>;
fn stream_callbacks() -> &'static Mutex<HashMap<String, StreamCallback>> {
static CALLBACKS: OnceLock<Mutex<HashMap<String, StreamCallback>>> = OnceLock::new();
CALLBACKS.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) async fn dispatch_registered_stream(
chunk: wit_types::DataStream,
sender: wit_types::ActrId,
) -> ActorResult<()> {
let chunk = data_stream_from_wit(chunk);
let sender = actr_id_from_wit(&sender);
let callback = {
let callbacks = stream_callbacks()
.lock()
.map_err(|_| ActrError::Internal("stream callback registry poisoned".into()))?;
callbacks.get(&chunk.stream_id).cloned()
};
match callback {
Some(callback) => callback(chunk, sender).await,
None => Err(ActrError::NotFound(format!(
"no stream callback registered for '{}'",
chunk.stream_id
))),
}
}
impl WasmContext {
pub(crate) async fn from_host() -> Self {
let self_id = actr_id_from_wit(&wit_host::get_self_id().await);
let caller_id = wit_host::get_caller_id()
.await
.map(|id| actr_id_from_wit(&id));
let request_id = wit_host::get_request_id().await;
Self {
self_id,
caller_id,
request_id,
}
}
#[allow(dead_code)]
pub(crate) fn lifecycle_placeholder() -> Self {
Self {
self_id: ActrId::default(),
caller_id: None,
request_id: String::new(),
}
}
}
#[async_trait(?Send)]
impl Context for WasmContext {
fn self_id(&self) -> &ActrId {
&self.self_id
}
fn caller_id(&self) -> Option<&ActrId> {
self.caller_id.as_ref()
}
fn request_id(&self) -> &str {
&self.request_id
}
async fn call<R: RpcRequest>(&self, target: &Dest, request: R) -> ActorResult<R::Response> {
let payload = request.encode_to_vec();
let result = wit_host::call(dest_to_wit(target), R::route_key().to_string(), payload).await;
match result {
Ok(bytes) => R::Response::decode(bytes.as_slice())
.map_err(|e| ActrError::DecodeFailure(format!("response decode failed: {e}"))),
Err(e) => Err(wit_actr_error_to_proto(e)),
}
}
async fn tell<R: RpcRequest>(&self, target: &Dest, message: R) -> ActorResult<()> {
let payload = message.encode_to_vec();
wit_host::tell(dest_to_wit(target), R::route_key().to_string(), payload)
.await
.map_err(wit_actr_error_to_proto)
}
async fn register_stream<F>(&self, stream_id: String, callback: F) -> ActorResult<()>
where
F: Fn(DataStream, ActrId) -> BoxFuture<'static, ActorResult<()>> + Send + Sync + 'static,
{
stream_callbacks()
.lock()
.map_err(|_| ActrError::Internal("stream callback registry poisoned".into()))?
.insert(stream_id.clone(), Arc::new(callback));
wit_host::register_stream(stream_id)
.await
.map_err(wit_actr_error_to_proto)
}
async fn unregister_stream(&self, stream_id: &str) -> ActorResult<()> {
stream_callbacks()
.lock()
.map_err(|_| ActrError::Internal("stream callback registry poisoned".into()))?
.remove(stream_id);
wit_host::unregister_stream(stream_id.to_string())
.await
.map_err(wit_actr_error_to_proto)
}
async fn send_data_stream(
&self,
target: &Dest,
chunk: DataStream,
payload_type: PayloadType,
) -> ActorResult<()> {
wit_host::send_data_stream(
dest_to_wit(target),
data_stream_to_wit(chunk),
payload_type_to_wit(payload_type),
)
.await
.map_err(wit_actr_error_to_proto)
}
async fn discover_route_candidate(&self, target_type: &ActrType) -> ActorResult<ActrId> {
wit_host::discover(actr_type_to_wit(target_type))
.await
.map(|id| actr_id_from_wit(&id))
.map_err(wit_actr_error_to_proto)
}
async fn call_raw(
&self,
target: &ActrId,
route_key: &str,
payload: Bytes,
) -> ActorResult<Bytes> {
wit_host::call_raw(
actr_id_to_wit(target),
route_key.to_string(),
payload.to_vec(),
)
.await
.map(Bytes::from)
.map_err(wit_actr_error_to_proto)
}
async fn register_media_track<F>(&self, _track_id: String, _callback: F) -> ActorResult<()>
where
F: Fn(MediaSample, ActrId) -> BoxFuture<'static, ActorResult<()>> + Send + Sync + 'static,
{
Err(ActrError::NotImplemented(
"WebRTC media tracks are not supported in WASM environment".into(),
))
}
async fn unregister_media_track(&self, _track_id: &str) -> ActorResult<()> {
Err(ActrError::NotImplemented(
"WebRTC media tracks are not supported in WASM environment".into(),
))
}
async fn send_media_sample(
&self,
_target: &Dest,
_track_id: &str,
_sample: MediaSample,
) -> ActorResult<()> {
Err(ActrError::NotImplemented(
"WebRTC media tracks are not supported in WASM environment".into(),
))
}
async fn add_media_track(
&self,
_target: &Dest,
_track_id: &str,
_codec: &str,
_media_type: &str,
) -> ActorResult<()> {
Err(ActrError::NotImplemented(
"WebRTC media tracks are not supported in WASM environment".into(),
))
}
async fn remove_media_track(&self, _target: &Dest, _track_id: &str) -> ActorResult<()> {
Err(ActrError::NotImplemented(
"WebRTC media tracks are not supported in WASM environment".into(),
))
}
}