use crate::{Result, WireError};
use futures::stream::Stream;
use futures::StreamExt;
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct TypedJsonStream<T: DeserializeOwned> {
inner: Box<dyn Stream<Item = Result<Value>> + Send + Unpin>,
_phantom: PhantomData<T>,
}
impl<T: DeserializeOwned> TypedJsonStream<T> {
pub fn new(inner: Box<dyn Stream<Item = Result<Value>> + Send + Unpin>) -> Self {
Self {
inner,
_phantom: PhantomData,
}
}
fn deserialize_value(value: Value) -> Result<T> {
let type_name = std::any::type_name::<T>().to_string();
let deser_start = std::time::Instant::now();
match serde_json::from_value::<T>(value) {
Ok(result) => {
let duration_ms = deser_start.elapsed().as_millis() as u64;
crate::metrics::histograms::deserialization_duration(
"unknown",
&type_name,
duration_ms,
);
crate::metrics::counters::deserialization_success("unknown", &type_name);
Ok(result)
}
Err(e) => {
crate::metrics::counters::deserialization_failure(
"unknown",
&type_name,
"serde_error",
);
Err(WireError::Deserialization {
type_name,
details: e.to_string(),
})
}
}
}
}
impl<T: DeserializeOwned + Unpin> Stream for TypedJsonStream<T> {
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(value))) => {
Poll::Ready(Some(Self::deserialize_value(value)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] use super::*;
#[test]
fn test_typed_stream_creation() {
let _stream: TypedJsonStream<serde_json::Value> =
TypedJsonStream::new(Box::new(futures::stream::empty()));
#[derive(serde::Deserialize, Debug)]
#[allow(dead_code)] struct TestType {
id: String,
}
let _stream: TypedJsonStream<TestType> =
TypedJsonStream::new(Box::new(futures::stream::empty()));
}
#[test]
fn test_deserialize_valid_value() {
let json = serde_json::json!({
"id": "123",
"name": "Test"
});
#[derive(serde::Deserialize)]
#[allow(dead_code)] struct TestType {
id: String,
name: String,
}
let result = TypedJsonStream::<TestType>::deserialize_value(json);
let item = result.unwrap_or_else(|e| panic!("expected Ok for valid JSON, got: {e}"));
assert_eq!(item.id, "123");
assert_eq!(item.name, "Test");
}
#[test]
fn test_deserialize_missing_field() {
let json = serde_json::json!({
"id": "123"
});
#[derive(Debug, serde::Deserialize)]
#[allow(dead_code)] struct TestType {
id: String,
name: String,
}
let result = TypedJsonStream::<TestType>::deserialize_value(json);
match result {
Err(WireError::Deserialization { type_name, details }) => {
assert!(type_name.contains("TestType"));
assert!(details.contains("name"));
}
other => panic!("expected Deserialization error for missing field, got: {other:?}"),
}
}
#[test]
fn test_deserialize_type_mismatch() {
let json = serde_json::json!({
"id": "123",
"count": "not a number" });
#[derive(Debug, serde::Deserialize)]
#[allow(dead_code)] struct TestType {
id: String,
count: i32,
}
let result = TypedJsonStream::<TestType>::deserialize_value(json);
match result {
Err(WireError::Deserialization { type_name, details }) => {
assert!(type_name.contains("TestType"));
assert!(details.contains("invalid") || details.contains("type"));
}
other => panic!("expected Deserialization error for type mismatch, got: {other:?}"),
}
}
#[test]
fn test_deserialize_value_type() {
let json = serde_json::json!({
"id": "123",
"name": "Test"
});
let result = TypedJsonStream::<serde_json::Value>::deserialize_value(json.clone());
let value =
result.unwrap_or_else(|e| panic!("expected Ok for Value escape hatch, got: {e}"));
assert_eq!(value, json);
}
#[test]
fn test_phantom_data_has_no_size() {
use std::mem::size_of;
let size_without_phantom = size_of::<Box<dyn Stream<Item = Result<Value>> + Unpin>>();
let size_with_phantom = size_of::<TypedJsonStream<serde_json::Value>>();
assert!(
size_with_phantom <= size_without_phantom + 8,
"PhantomData added too much size: {} vs {}",
size_with_phantom,
size_without_phantom
);
}
}