use std::{
marker::PhantomData,
pin::Pin,
sync::{
Arc,
Mutex,
},
task::{
Context,
Poll,
},
};
enum StreamState {
Active {
receiver: tokio::sync::mpsc::UnboundedReceiver<Result<Vec<u8>, String>>,
},
Completed,
#[allow(dead_code)]
Failed(String),
}
pub struct ZigStream<T> {
state: Arc<Mutex<StreamState>>,
_phantom: PhantomData<T>,
}
impl<T> ZigStream<T> {
pub fn new(receiver: tokio::sync::mpsc::UnboundedReceiver<Result<Vec<u8>, String>>) -> Self {
Self {
state: Arc::new(Mutex::new(StreamState::Active { receiver })),
_phantom: PhantomData,
}
}
pub fn is_completed(&self) -> bool {
matches!(*self.state.lock().unwrap(), StreamState::Completed | StreamState::Failed(_))
}
pub fn error(&self) -> Option<String> {
match &*self.state.lock().unwrap() {
StreamState::Failed(e) => Some(e.clone()),
_ => None,
}
}
}
impl<T> futures::Stream for ZigStream<T>
where
T: From<Vec<u8>>,
{
type Item = Result<T, String>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut state = self.state.lock().unwrap();
match &mut *state {
StreamState::Active { receiver } => {
match receiver.poll_recv(cx) {
Poll::Ready(Some(Ok(data))) => {
Poll::Ready(Some(Ok(T::from(data))))
},
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(e)))
},
Poll::Ready(None) => {
*state = StreamState::Completed;
Poll::Ready(None)
},
Poll::Pending => {
Poll::Pending
},
}
},
StreamState::Completed => {
Poll::Ready(None)
},
StreamState::Failed(_e) => {
Poll::Ready(None)
},
}
}
}
impl<T> Drop for ZigStream<T> {
fn drop(&mut self) {
let mut state = self.state.lock().unwrap();
if let StreamState::Active { .. } = *state {
*state = StreamState::Completed;
}
}
}
pub fn create_stream<T>(
) -> (tokio::sync::mpsc::UnboundedSender<Result<Vec<u8>, String>>, ZigStream<T>) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(tx, ZigStream::new(rx))
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use super::*;
#[derive(Debug, PartialEq)]
struct TestU32(u32);
impl From<Vec<u8>> for TestU32 {
fn from(bytes: Vec<u8>) -> Self {
if bytes.len() >= 4 {
TestU32(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
} else {
TestU32(0)
}
}
}
#[tokio::test]
async fn test_empty_stream() {
let (_tx, stream) = create_stream::<TestU32>();
futures::pin_mut!(stream);
drop(_tx);
assert!(stream.next().await.is_none());
assert!(stream.is_completed());
}
#[tokio::test]
async fn test_stream_with_data() {
let (tx, stream) = create_stream::<TestU32>();
futures::pin_mut!(stream);
let value1 = 42u32;
tx.send(Ok(value1.to_le_bytes().to_vec())).unwrap();
let value2 = 100u32;
tx.send(Ok(value2.to_le_bytes().to_vec())).unwrap();
drop(tx);
let result1 = stream.next().await;
assert!(result1.is_some());
assert_eq!(result1.unwrap().unwrap(), TestU32(42));
let result2 = stream.next().await;
assert!(result2.is_some());
assert_eq!(result2.unwrap().unwrap(), TestU32(100));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_with_error() {
let (tx, stream) = create_stream::<TestU32>();
futures::pin_mut!(stream);
tx.send(Ok(42u32.to_le_bytes().to_vec())).unwrap();
tx.send(Err("Test error".to_string())).unwrap();
drop(tx);
let result1 = stream.next().await;
assert!(result1.is_some());
assert!(result1.unwrap().is_ok());
let result2 = stream.next().await;
assert!(result2.is_some());
assert!(result2.unwrap().is_err());
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_early_drop() {
let (tx, stream) = create_stream::<TestU32>();
tx.send(Ok(vec![1, 2, 3, 4])).unwrap();
drop(stream);
assert!(tx.send(Ok(vec![5, 6, 7, 8])).is_err());
}
#[tokio::test]
async fn test_multiple_consumers() {
let (tx, stream) = create_stream::<TestU32>();
let stream_ref = Arc::new(tokio::sync::Mutex::new(stream));
let stream_ref2 = stream_ref.clone();
tx.send(Ok(42u32.to_le_bytes().to_vec())).unwrap();
drop(tx);
{
let mut s = stream_ref.lock().await;
let result = s.next().await;
assert!(result.is_some());
assert_eq!(result.unwrap().unwrap(), TestU32(42));
}
{
let mut s = stream_ref2.lock().await;
assert!(s.next().await.is_none());
}
}
}