use bytes::BytesMut;
use core::marker::PhantomData;
use serde::de::DeserializeOwned;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::codec::Decoder;
use crate::de::{ParserConfig, from_slice, from_slice_with_config};
use crate::error::{Error, Result};
pub async fn from_async_reader<R, T>(reader: &mut R) -> Result<T>
where
R: AsyncRead + Unpin,
T: DeserializeOwned + 'static,
{
from_async_reader_with_config(reader, &ParserConfig::default()).await
}
pub async fn from_async_reader_with_config<R, T>(reader: &mut R, config: &ParserConfig) -> Result<T>
where
R: AsyncRead + Unpin,
T: DeserializeOwned + 'static,
{
let buf = drain_bounded(reader, config.max_document_length).await?;
let buf = strip_bom_owned(buf);
from_slice_with_config(&buf, config)
}
pub async fn from_async_reader_multi<R, T>(reader: &mut R) -> Result<Vec<T>>
where
R: AsyncRead + Unpin,
T: DeserializeOwned + 'static,
{
from_async_reader_multi_with_config(reader, &ParserConfig::default()).await
}
pub async fn from_async_reader_multi_with_config<R, T>(
reader: &mut R,
config: &ParserConfig,
) -> Result<Vec<T>>
where
R: AsyncRead + Unpin,
T: DeserializeOwned + 'static,
{
let buf = drain_bounded(reader, config.max_document_length).await?;
let buf = strip_bom_owned(buf);
let text = core::str::from_utf8(&buf)
.map_err(|e| Error::from(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
let docs = crate::doc_boundary::split_documents(text, config.max_documents);
let mut results = Vec::with_capacity(docs.len());
for doc in docs {
results.push(crate::from_str_with_config::<T>(doc, config)?);
}
Ok(results)
}
async fn drain_bounded<R>(reader: &mut R, max_bytes: usize) -> Result<Vec<u8>>
where
R: AsyncRead + Unpin,
{
let mut buf = Vec::new();
let take = u64::try_from(max_bytes).unwrap_or(u64::MAX);
let mut limited = reader.take(take);
let _ = limited.read_to_end(&mut buf).await.map_err(Error::from)?;
Ok(buf)
}
fn strip_bom_owned(mut buf: Vec<u8>) -> Vec<u8> {
if crate::doc_boundary::strip_bom(&buf) == 3 {
let _ = buf.drain(..3);
}
buf
}
#[derive(Debug, Clone)]
pub struct YamlDecoder<T> {
config: ParserConfig,
max_frame_size: Option<usize>,
_marker: PhantomData<fn() -> T>,
}
impl<T> Default for YamlDecoder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> YamlDecoder<T> {
#[must_use]
pub fn new() -> Self {
Self {
config: ParserConfig::default(),
max_frame_size: None,
_marker: PhantomData,
}
}
#[must_use]
pub fn with_config(config: ParserConfig) -> Self {
Self {
config,
max_frame_size: None,
_marker: PhantomData,
}
}
#[must_use]
pub fn max_frame_size(mut self, max: usize) -> Self {
self.max_frame_size = Some(max);
self
}
}
impl<T> Decoder for YamlDecoder<T>
where
T: DeserializeOwned + 'static,
{
type Item = T;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> core::result::Result<Option<T>, Error> {
if let Some(max) = self.max_frame_size {
if src.len() > max {
return Err(Error::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"noyalib YamlDecoder: buffer {} > max_frame_size {}",
src.len(),
max
),
)));
}
}
loop {
let bytes: &[u8] = src.as_ref();
let Some(end) = find_doc_boundary(bytes) else {
return Ok(None);
};
let doc = src.split_to(end);
if doc.iter().all(u8::is_ascii_whitespace) {
continue;
}
let parsed = from_slice_with_config::<T>(&doc, &self.config)?;
return Ok(Some(parsed));
}
}
fn decode_eof(&mut self, src: &mut BytesMut) -> core::result::Result<Option<T>, Error> {
if src.is_empty() {
return Ok(None);
}
if let Some(v) = self.decode(src)? {
return Ok(Some(v));
}
if src.iter().all(u8::is_ascii_whitespace) {
src.clear();
return Ok(None);
}
let doc = src.split();
let parsed = from_slice::<T>(&doc)?;
Ok(Some(parsed))
}
}
fn find_doc_boundary(bytes: &[u8]) -> Option<usize> {
crate::doc_boundary::next_marker_after(bytes, 0)
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
use serde::Deserialize;
use tokio::io::BufReader;
#[derive(Debug, Deserialize, PartialEq)]
struct Pkg {
name: String,
version: String,
}
#[tokio::test]
async fn reader_parses_single_document() {
let mut r = BufReader::new(&b"name: noyalib\nversion: 0.0.6\n"[..]);
let p: Pkg = from_async_reader(&mut r).await.unwrap();
assert_eq!(
p,
Pkg {
name: "noyalib".into(),
version: "0.0.6".into(),
}
);
}
#[tokio::test]
async fn reader_multi_parses_each_document() {
let yaml = b"---\nname: a\nversion: '1'\n---\nname: b\nversion: '2'\n";
let mut r = BufReader::new(&yaml[..]);
let docs: Vec<Pkg> = from_async_reader_multi(&mut r).await.unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].name, "a");
assert_eq!(docs[1].name, "b");
}
#[test]
fn decoder_emits_first_complete_document() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf = BytesMut::from(&b"name: a\nversion: '1'\n---\nname: b\nversion: '2'\n"[..]);
let first = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(first.name, "a");
let second = decoder.decode_eof(&mut buf).unwrap().unwrap();
assert_eq!(second.name, "b");
}
#[test]
fn decoder_returns_none_on_incomplete_buffer() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf = BytesMut::from(&b"name: a\n"[..]);
assert!(decoder.decode(&mut buf).unwrap().is_none());
}
#[tokio::test]
async fn reader_with_config_respects_overrides() {
let mut r = BufReader::new(&b"name: x\nversion: '1'\n"[..]);
let cfg = ParserConfig::default();
let p: Pkg = from_async_reader_with_config(&mut r, &cfg).await.unwrap();
assert_eq!(p.name, "x");
}
#[test]
fn decoder_with_config_constructor() {
let cfg = ParserConfig::default();
let _d: YamlDecoder<Pkg> = YamlDecoder::with_config(cfg);
let _d2: YamlDecoder<Pkg> = YamlDecoder::default();
let _printed = format!("{:?}", YamlDecoder::<Pkg>::new());
}
#[test]
fn decoder_eof_on_empty_buffer_returns_none() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf = BytesMut::new();
assert!(decoder.decode_eof(&mut buf).unwrap().is_none());
}
#[test]
fn decoder_skips_whitespace_only_preamble() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf = BytesMut::from(&b"\n\n---\nname: q\nversion: '2'\n"[..]);
let p = decoder.decode_eof(&mut buf).unwrap().unwrap();
assert_eq!(p.name, "q");
}
#[test]
fn decoder_eof_drains_trailing_whitespace() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf = BytesMut::from(&b"name: r\nversion: '3'\n\n\n"[..]);
let p = decoder.decode_eof(&mut buf).unwrap().unwrap();
assert_eq!(p.name, "r");
}
#[tokio::test]
async fn reader_multi_handles_invalid_utf8() {
let mut r = BufReader::new(&[0xFFu8, 0xFE, 0xFD][..]);
let res: Result<Vec<Pkg>> = from_async_reader_multi(&mut r).await;
assert!(res.is_err());
}
#[test]
fn find_doc_boundary_handles_short_input() {
assert!(find_doc_boundary(b"").is_none());
assert!(find_doc_boundary(b"abc").is_none());
}
#[tokio::test]
async fn reader_strips_leading_bom() {
let mut r = BufReader::new(&b"\xEF\xBB\xBFname: x\nversion: '1'\n"[..]);
let p: Pkg = from_async_reader(&mut r).await.unwrap();
assert_eq!(p.name, "x");
}
#[tokio::test]
async fn reader_multi_accepts_crlf() {
let yaml = b"---\r\nname: a\r\nversion: '1'\r\n---\r\nname: b\r\nversion: '2'\r\n";
let mut r = BufReader::new(&yaml[..]);
let docs: Vec<Pkg> = from_async_reader_multi(&mut r).await.unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].name, "a");
assert_eq!(docs[1].name, "b");
}
#[tokio::test]
async fn reader_caps_at_max_document_length() {
let yaml = "a: ".to_string() + &"x".repeat(10_000);
let cfg = ParserConfig {
max_document_length: 64,
..ParserConfig::default()
};
let mut r = BufReader::new(yaml.as_bytes());
let _ = from_async_reader_with_config::<_, Pkg>(&mut r, &cfg)
.await
.expect_err("expected parse error after truncation");
}
#[test]
fn decoder_rejects_oversize_frame() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new().max_frame_size(16);
let mut buf = BytesMut::from(&b"name: long-name-no-marker-yet-need-more-bytes"[..]);
let err = decoder.decode(&mut buf).err().unwrap();
assert!(err.to_string().contains("max_frame_size"));
}
#[test]
fn decoder_accepts_crlf_boundary() {
let mut decoder: YamlDecoder<Pkg> = YamlDecoder::new();
let mut buf =
BytesMut::from(&b"name: a\r\nversion: '1'\r\n---\r\nname: b\r\nversion: '2'\r\n"[..]);
let first = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(first.name, "a");
}
#[tokio::test]
async fn reader_multi_with_config_routes_through() {
let yaml = b"---\nname: a\nversion: '1'\n---\nname: b\nversion: '2'\n";
let mut r = BufReader::new(&yaml[..]);
let cfg = ParserConfig::default();
let docs: Vec<Pkg> = from_async_reader_multi_with_config(&mut r, &cfg)
.await
.unwrap();
assert_eq!(docs.len(), 2);
}
#[test]
fn find_doc_boundary_skips_leading_marker() {
assert!(find_doc_boundary(b"---\na: 1\n").is_none());
let bs = b"---\na: 1\n---\nb: 2\n";
let at = find_doc_boundary(bs).unwrap();
assert_eq!(at, 9);
}
}