use std::pin::Pin;
use std::task::{Context, Poll};
use fidius_core::frame::Frame;
use fidius_core::Value;
use futures::stream::{Stream, StreamExt};
use crate::error::CallError;
use crate::executor::PluginExecutor;
pub struct ChunkStream {
inner: Pin<Box<dyn Stream<Item = Result<Value, CallError>> + Send>>,
}
impl ChunkStream {
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<Value, CallError>> + Send + 'static,
{
Self {
inner: Box::pin(stream),
}
}
pub fn from_frame_bytes<S, D>(frames: S, decode_item: D) -> Self
where
S: Stream<Item = Vec<u8>> + Send + 'static,
D: Fn(&[u8]) -> Result<Value, CallError> + Send + 'static,
{
let stream = futures::stream::unfold(
(frames.boxed(), decode_item, false),
|(mut src, decode_item, done)| async move {
if done {
return None;
}
match src.next().await {
None => Some((Err(CallError::StreamAborted), (src, decode_item, true))),
Some(bytes) => match Frame::decode(&bytes) {
Err(e) => Some((
Err(CallError::MalformedFrame(e.to_string())),
(src, decode_item, true),
)),
Ok(Frame::Item(payload)) => match decode_item(&payload) {
Ok(v) => Some((Ok(v), (src, decode_item, false))),
Err(e) => Some((Err(e), (src, decode_item, true))),
},
Ok(Frame::End) => None,
Ok(Frame::Error(pe)) => {
Some((Err(CallError::Plugin(pe)), (src, decode_item, true)))
}
},
}
},
);
Self::new(stream)
}
pub fn from_frames<D>(frames: Vec<Frame>, decode_item: D) -> Self
where
D: Fn(&[u8]) -> Result<Value, CallError> + Send + 'static,
{
let bytes: Vec<Vec<u8>> = frames
.iter()
.map(|f| f.encode().expect("frame encodes"))
.collect();
Self::from_frame_bytes(futures::stream::iter(bytes), decode_item)
}
}
impl Stream for ChunkStream {
type Item = Result<Value, CallError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[async_trait::async_trait]
pub trait StreamExecutor: PluginExecutor {
async fn call_streaming(&self, method: usize, args: Value) -> Result<ChunkStream, CallError>;
}
#[cfg(test)]
mod tests {
use super::*;
use fidius_core::error::PluginError;
use fidius_core::to_value;
fn item(v: i64) -> Frame {
Frame::Item(fidius_core::wire::serialize(&v).unwrap())
}
fn decode_i64(b: &[u8]) -> Result<Value, CallError> {
fidius_core::wire::deserialize::<i64>(b)
.map(|n| to_value(&n).unwrap())
.map_err(|e| CallError::Deserialization(e.to_string()))
}
async fn collect(mut s: ChunkStream) -> Vec<Result<Value, CallError>> {
let mut out = Vec::new();
while let Some(x) = s.next().await {
out.push(x);
}
out
}
#[tokio::test]
async fn items_then_clean_end() {
let s = ChunkStream::from_frames(vec![item(1), item(2), item(3), Frame::End], decode_i64);
let vals: Vec<i64> = collect(s)
.await
.into_iter()
.map(|r| fidius_core::from_value(r.unwrap()).unwrap())
.collect();
assert_eq!(vals, vec![1, 2, 3]);
}
#[tokio::test]
async fn native_value_stream_via_new() {
let items = vec![Ok(to_value(&"a").unwrap()), Ok(to_value(&"b").unwrap())];
let s = ChunkStream::new(futures::stream::iter(items));
let got: Vec<String> = collect(s)
.await
.into_iter()
.map(|r| fidius_core::from_value(r.unwrap()).unwrap())
.collect();
assert_eq!(got, vec!["a".to_string(), "b".to_string()]);
}
#[tokio::test]
async fn error_frame_terminates_after_one_err() {
let s = ChunkStream::from_frames(
vec![
item(1),
Frame::Error(PluginError::new("BOOM", "broke")),
item(2), ],
decode_i64,
);
let got = collect(s).await;
assert_eq!(got.len(), 2);
assert!(matches!(got[0], Ok(_)));
assert!(matches!(got[1], Err(CallError::Plugin(_))));
}
#[tokio::test]
async fn missing_terminal_is_abort() {
let s = ChunkStream::from_frames(vec![item(7)], decode_i64);
let got = collect(s).await;
assert_eq!(got.len(), 2);
assert!(matches!(got[0], Ok(_)));
assert!(matches!(got[1], Err(CallError::StreamAborted)));
}
#[tokio::test]
async fn malformed_frame_surfaces_then_stops() {
let s = ChunkStream::from_frame_bytes(
futures::stream::iter(vec![
item(1).encode().unwrap(),
vec![99, 0, 0, 0, 0], item(2).encode().unwrap(),
]),
decode_i64,
);
let got = collect(s).await;
assert_eq!(got.len(), 2);
assert!(matches!(got[0], Ok(_)));
assert!(matches!(got[1], Err(CallError::MalformedFrame(_))));
}
#[tokio::test]
async fn empty_stream_just_ends() {
let s = ChunkStream::from_frames(vec![Frame::End], decode_i64);
assert!(collect(s).await.is_empty());
}
}