use std::{future::Future, pin::Pin, sync::Arc};
use crate::{
ConnectionId, MaybeSend, MaybeSendFuture, MaybeSync, Metadata, MethodId, RequestCall,
RequestId, RequestResponse, SchemaRecvTracker, SelfRef, VoxError,
};
pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
pub trait Call<'wire, T, E>: MaybeSend
where
T: facet::Facet<'wire> + MaybeSend,
E: facet::Facet<'wire> + MaybeSend,
{
fn reply(self, result: Result<T, E>) -> impl Future<Output = ()> + MaybeSend;
fn ok(self, value: T) -> impl Future<Output = ()> + MaybeSend
where
Self: Sized,
{
self.reply(Ok(value))
}
fn err(self, error: E) -> impl Future<Output = ()> + MaybeSend
where
Self: Sized,
{
self.reply(Err(error))
}
}
pub trait ReplySink: MaybeSend + MaybeSync + 'static {
fn send_reply(self, response: RequestResponse<'_>) -> impl Future<Output = ()> + MaybeSend;
fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
self,
error: VoxError<E>,
) -> impl Future<Output = ()> + MaybeSend
where
Self: Sized,
{
use crate::{Payload, RequestResponse};
async move {
let wire: Result<(), VoxError<E>> = Err(error);
self.send_reply(RequestResponse {
ret: Payload::outgoing(&wire),
metadata: Default::default(),
schemas: Default::default(),
})
.await;
}
}
fn send_typed_error<'wire, T, E>(
self,
error: VoxError<E>,
) -> impl Future<Output = ()> + MaybeSend
where
Self: Sized,
T: facet::Facet<'wire> + MaybeSend,
E: facet::Facet<'wire> + MaybeSend,
{
use crate::{Payload, RequestResponse};
async move {
let wire: Result<T, VoxError<E>> = Err(error);
let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
self.send_reply(RequestResponse {
ret,
metadata: Default::default(),
schemas: Default::default(),
})
.await;
}
}
fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
None
}
fn request_id(&self) -> Option<RequestId> {
None
}
fn connection_id(&self) -> Option<ConnectionId> {
None
}
}
pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
fn retry_policy(&self, _method_id: MethodId) -> crate::RetryPolicy {
crate::RetryPolicy::VOLATILE
}
fn args_have_channels(&self, _method_id: MethodId) -> bool {
false
}
fn response_wire_shape(&self, _method_id: MethodId) -> Option<&'static facet::Shape> {
None
}
fn handle(
&self,
call: SelfRef<RequestCall<'static>>,
reply: R,
schemas: Arc<SchemaRecvTracker>,
) -> impl Future<Output = ()> + MaybeSend + '_;
}
impl<R: ReplySink> Handler<R> for () {
async fn handle(
&self,
_call: SelfRef<RequestCall<'static>>,
_reply: R,
_schemas: Arc<SchemaRecvTracker>,
) {
}
}
pub struct ResponseParts<'a, T> {
pub ret: T,
pub metadata: Metadata<'a>,
}
impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.ret
}
}
pub struct SinkCall<R: ReplySink> {
reply: R,
}
impl<R: ReplySink> SinkCall<R> {
pub fn new(reply: R) -> Self {
Self { reply }
}
}
impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
where
T: facet::Facet<'wire> + MaybeSend,
E: facet::Facet<'wire> + MaybeSend,
R: ReplySink,
{
async fn reply(self, result: Result<T, E>) {
use crate::{Payload, RequestResponse};
let wire: Result<T, VoxError<E>> = result.map_err(VoxError::User);
let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
self.reply
.send_reply(RequestResponse {
ret,
metadata: Default::default(),
schemas: Default::default(),
})
.await;
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
use super::{Call, Handler, ReplySink, ResponseParts};
struct RecordingCall<T, E> {
observed: Arc<Mutex<Option<Result<T, E>>>>,
}
impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
where
T: facet::Facet<'wire> + MaybeSend + Send + 'static,
E: facet::Facet<'wire> + MaybeSend + Send + 'static,
{
async fn reply(self, result: Result<T, E>) {
let mut guard = self.observed.lock().expect("recording mutex poisoned");
*guard = Some(result);
}
}
struct RecordingReplySink {
saw_send_reply: Arc<Mutex<bool>>,
saw_outgoing_payload: Arc<Mutex<bool>>,
}
impl ReplySink for RecordingReplySink {
async fn send_reply(self, response: RequestResponse<'_>) {
let mut saw_send_reply = self
.saw_send_reply
.lock()
.expect("send-reply mutex poisoned");
*saw_send_reply = true;
let mut saw_outgoing = self
.saw_outgoing_payload
.lock()
.expect("payload-kind mutex poisoned");
*saw_outgoing = matches!(response.ret, Payload::Value { .. });
}
}
#[tokio::test]
async fn call_ok_and_err_route_through_reply() {
let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
RecordingCall {
observed: Arc::clone(&observed_ok),
}
.ok(7)
.await;
assert!(matches!(
*observed_ok.lock().expect("ok mutex poisoned"),
Some(Ok(7))
));
let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
Arc::new(Mutex::new(None));
RecordingCall {
observed: Arc::clone(&observed_err),
}
.err("boom")
.await;
assert!(matches!(
*observed_err.lock().expect("err mutex poisoned"),
Some(Err("boom"))
));
}
#[tokio::test]
async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
let saw_send_reply = Arc::new(Mutex::new(false));
let saw_outgoing_payload = Arc::new(Mutex::new(false));
let sink = RecordingReplySink {
saw_send_reply: Arc::clone(&saw_send_reply),
saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
};
sink.send_error(crate::VoxError::<String>::Cancelled).await;
assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
assert!(
*saw_outgoing_payload
.lock()
.expect("payload-kind mutex poisoned")
);
}
#[tokio::test]
async fn reply_sink_send_typed_error_preserves_ok_shape() {
use crate::{
SchemaKind, TypeRef, VariantPayload, VoxError, build_registry, extract_schemas,
};
struct ShapeReplySink {
observed_root: Arc<Mutex<Option<TypeRef>>>,
}
impl ReplySink for ShapeReplySink {
async fn send_reply(self, response: RequestResponse<'_>) {
let Payload::Value { shape, .. } = response.ret else {
panic!("typed error should use outgoing payload");
};
let extracted = extract_schemas(shape).expect("response shape should extract");
*self
.observed_root
.lock()
.expect("observed-root mutex poisoned") = Some(extracted.root.clone());
}
}
let observed_root = Arc::new(Mutex::new(None));
ShapeReplySink {
observed_root: Arc::clone(&observed_root),
}
.send_typed_error::<(String, i32), String>(VoxError::Cancelled)
.await;
let root = observed_root
.lock()
.expect("observed-root mutex poisoned")
.clone()
.expect("typed error should record a root");
let extracted =
extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
.expect("expected result shape should extract");
let registry = build_registry(&extracted.schemas);
let root_kind = root.resolve_kind(®istry).expect("root should resolve");
let SchemaKind::Enum { variants, .. } = root_kind else {
panic!("expected result enum root");
};
let ok_variant = variants
.iter()
.find(|variant| variant.name == "Ok")
.expect("Result should have Ok variant");
let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
panic!("Ok variant should be newtype");
};
match type_ref
.resolve_kind(®istry)
.expect("Ok payload should resolve")
{
SchemaKind::Tuple { elements } => {
assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
}
other => panic!("expected Ok payload to be tuple, got {other:?}"),
}
}
#[tokio::test]
async fn unit_handler_is_noop() {
let req = crate::SelfRef::owning(
crate::Backing::Boxed(Box::<[u8]>::default()),
RequestCall {
method_id: crate::MethodId(1),
metadata: Metadata::default(),
args: Payload::PostcardBytes(&[]),
schemas: Default::default(),
},
);
().handle(
req,
RecordingReplySink {
saw_send_reply: Arc::new(Mutex::new(false)),
saw_outgoing_payload: Arc::new(Mutex::new(false)),
},
Arc::new(crate::SchemaRecvTracker::new()),
)
.await;
}
#[test]
fn response_parts_deref_exposes_ret() {
let parts = ResponseParts {
ret: 42_u32,
metadata: Metadata::default(),
};
assert_eq!(*parts, 42);
}
#[test]
fn default_channel_binder_accessor_for_reply_sink_returns_none() {
let sink = RecordingReplySink {
saw_send_reply: Arc::new(Mutex::new(false)),
saw_outgoing_payload: Arc::new(Mutex::new(false)),
};
assert!(sink.channel_binder().is_none());
}
}