use async_trait::async_trait;
use bytes::Bytes;
use prost::Message;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct Context {
cancel_token: CancellationToken,
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl Context {
pub fn new() -> Self {
Self {
cancel_token: CancellationToken::new(),
}
}
pub fn with_cancel_token(cancel_token: CancellationToken) -> Self {
Self { cancel_token }
}
pub fn child(&self) -> Self {
Self {
cancel_token: self.cancel_token.child_token(),
}
}
pub fn cancel_token(&self) -> &CancellationToken {
&self.cancel_token
}
pub fn cancel(&self) {
self.cancel_token.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub async fn cancelled(&self) {
self.cancel_token.cancelled().await
}
pub fn cancellation(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
let token = self.cancel_token.clone();
async move {
token.cancelled().await;
}
}
}
#[async_trait]
pub trait Stream: Send + Sync {
fn context(&self) -> &Context;
async fn send_bytes(&self, data: Bytes) -> Result<()>;
async fn recv_bytes(&self) -> Result<Bytes>;
async fn close_send(&self) -> Result<()>;
async fn close(&self) -> Result<()>;
}
#[async_trait]
pub trait StreamExt: Stream {
async fn msg_send<M: Message + Send + Sync>(&self, msg: &M) -> Result<()> {
let data = msg.encode_to_vec();
self.send_bytes(Bytes::from(data)).await
}
async fn msg_recv<M: Message + Default>(&self) -> Result<M> {
let data = self.recv_bytes().await?;
M::decode(&data[..]).map_err(Error::InvalidMessage)
}
}
impl<T: Stream + ?Sized> StreamExt for T {}
#[async_trait]
impl<T: Stream + ?Sized> Stream for Arc<T> {
fn context(&self) -> &Context {
(**self).context()
}
async fn send_bytes(&self, data: Bytes) -> Result<()> {
(**self).send_bytes(data).await
}
async fn recv_bytes(&self) -> Result<Bytes> {
(**self).recv_bytes().await
}
async fn close_send(&self) -> Result<()> {
(**self).close_send().await
}
async fn close(&self) -> Result<()> {
(**self).close().await
}
}
#[async_trait]
impl<T: Stream + ?Sized> Stream for Box<T> {
fn context(&self) -> &Context {
(**self).context()
}
async fn send_bytes(&self, data: Bytes) -> Result<()> {
(**self).send_bytes(data).await
}
async fn recv_bytes(&self) -> Result<Bytes> {
(**self).recv_bytes().await
}
async fn close_send(&self) -> Result<()> {
(**self).close_send().await
}
async fn close(&self) -> Result<()> {
(**self).close().await
}
}
pub type BoxStream = Box<dyn Stream>;
pub type ArcStream = Arc<dyn Stream>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_new() {
let ctx = Context::new();
assert!(!ctx.is_cancelled());
}
#[test]
fn test_context_cancel() {
let ctx = Context::new();
ctx.cancel();
assert!(ctx.is_cancelled());
}
#[test]
fn test_context_child() {
let parent = Context::new();
let child = parent.child();
assert!(!parent.is_cancelled());
assert!(!child.is_cancelled());
parent.cancel();
assert!(parent.is_cancelled());
assert!(child.is_cancelled());
}
#[test]
fn test_context_child_independent() {
let parent = Context::new();
let child = parent.child();
child.cancel();
assert!(!parent.is_cancelled());
assert!(child.is_cancelled());
}
#[tokio::test]
async fn test_context_cancelled_future() {
let ctx = Context::new();
let ctx_clone = ctx.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
ctx_clone.cancel();
});
ctx.cancelled().await;
assert!(ctx.is_cancelled());
}
}