use std::{io, io::BufRead, path::PathBuf, pin::Pin};
use clap::ValueEnum;
use futures::Stream;
use s2_sdk::types::{AppendRecord, SequencedRecord};
use tokio::{
fs::{File, OpenOptions},
io::{AsyncBufReadExt, AsyncWrite, BufWriter},
sync::mpsc,
};
use tokio_stream::wrappers::{LinesStream, ReceiverStream};
use tracing::trace;
use crate::error::RecordParseError;
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
pub enum RecordFormat {
#[default]
#[clap(alias = "")]
Text,
#[clap(alias = "raw")]
Json,
#[clap(aliases = ["base64", "json-binsafe"])]
JsonBase64,
}
#[derive(Debug, Clone)]
pub enum RecordsIn {
File(PathBuf),
Stdin,
}
#[derive(Debug, Clone)]
pub enum RecordsOut {
File(PathBuf),
Stdout,
}
impl RecordsIn {
pub async fn reader(
&self,
) -> io::Result<Pin<Box<dyn Stream<Item = io::Result<String>> + Send>>> {
match self {
RecordsIn::File(path) => {
let file = File::open(path).await?;
let stream: Pin<Box<dyn Stream<Item = io::Result<String>> + Send>> =
Box::pin(LinesStream::new(tokio::io::BufReader::new(file).lines()));
Ok(stream)
}
RecordsIn::Stdin => Ok(Box::pin(stdio_lines_stream(std::io::stdin()))),
}
}
}
impl RecordsOut {
pub async fn writer(&self) -> io::Result<Box<dyn AsyncWrite + Send + Unpin>> {
match self {
RecordsOut::File(path) => {
trace!(?path, "opening file writer");
let file = OpenOptions::new()
.write(true)
.create(true)
.append(true)
.open(path)
.await?;
Ok(Box::new(BufWriter::new(file)))
}
RecordsOut::Stdout => {
trace!("stdout writer");
Ok(Box::new(BufWriter::new(tokio::io::stdout())))
}
}
}
}
fn stdio_lines_stream<F>(f: F) -> ReceiverStream<io::Result<String>>
where
F: std::io::Read + Send + 'static,
{
let lines = std::io::BufReader::new(f).lines();
let (tx, rx) = mpsc::channel(s2_sdk::types::RECORD_BATCH_MAX.count);
let _handle = std::thread::spawn(move || {
for line in lines {
if tx.blocking_send(line).is_err() {
return;
}
}
});
ReceiverStream::new(rx)
}
pub fn parse_records_input_source(s: &str) -> Result<RecordsIn, io::Error> {
match s {
"" | "-" => Ok(RecordsIn::Stdin),
_ => Ok(RecordsIn::File(PathBuf::from(s))),
}
}
pub fn parse_records_output_source(s: &str) -> Result<RecordsOut, io::Error> {
match s {
"" | "-" => Ok(RecordsOut::Stdout),
_ => Ok(RecordsOut::File(PathBuf::from(s))),
}
}
pub trait RecordParser<I>
where
I: Stream<Item = io::Result<String>> + Send + Unpin,
{
type RecordStream: Stream<Item = Result<AppendRecord, RecordParseError>> + Send + Unpin;
fn parse_records(lines: I) -> Self::RecordStream;
}
pub trait RecordWriter {
async fn write_record(
record: &SequencedRecord,
writer: &mut (impl AsyncWrite + Unpin),
) -> io::Result<()>;
}
pub use body::TextFormatter;
pub type JsonFormatter = json::Formatter<false>;
pub type JsonBase64Formatter = json::Formatter<true>;
mod body {
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use futures::{Stream, StreamExt};
use s2_sdk::types::{AppendRecord, SequencedRecord};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use super::{RecordParseError, RecordParser, RecordWriter};
pub struct TextFormatter;
impl RecordWriter for TextFormatter {
async fn write_record(
record: &SequencedRecord,
writer: &mut (impl AsyncWrite + Unpin),
) -> io::Result<()> {
let s = String::from_utf8_lossy(&record.body);
writer.write_all(s.as_bytes()).await
}
}
impl<I> RecordParser<I> for TextFormatter
where
I: Stream<Item = io::Result<String>> + Send + Unpin,
{
type RecordStream = RecordStream<I>;
fn parse_records(lines: I) -> Self::RecordStream {
RecordStream(lines)
}
}
pub struct RecordStream<S>(S);
impl<S> Stream for RecordStream<S>
where
S: Stream<Item = io::Result<String>> + Send + Unpin,
{
type Item = Result<AppendRecord, RecordParseError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.0.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Some(Ok(s))) => Poll::Ready(Some(
AppendRecord::new(s).map_err(|e| RecordParseError::Parse(e.to_string())),
)),
}
}
}
}
mod json {
use std::{
borrow::Cow,
io,
pin::Pin,
task::{Context, Poll},
};
use base64ct::{Base64, Encoding};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use s2_sdk::types::{AppendRecord, Header, SequencedRecord};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use super::{RecordParseError, RecordParser, RecordWriter};
#[derive(Debug, Clone, Default)]
struct CowStr<'a, const BIN_SAFE: bool>(Cow<'a, str>);
impl<const BIN_SAFE: bool> CowStr<'_, BIN_SAFE> {
fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
type OwnedCowStr<const BIN_SAFE: bool> = CowStr<'static, BIN_SAFE>;
impl<'a, const BIN_SAFE: bool> From<&'a [u8]> for CowStr<'a, BIN_SAFE> {
fn from(value: &'a [u8]) -> Self {
Self(if BIN_SAFE {
Base64::encode_string(value).into()
} else {
String::from_utf8_lossy(value)
})
}
}
impl<const BIN_SAFE: bool> TryFrom<OwnedCowStr<BIN_SAFE>> for Bytes {
type Error = String;
fn try_from(value: OwnedCowStr<BIN_SAFE>) -> Result<Self, Self::Error> {
let CowStr(s) = value;
Ok(if BIN_SAFE {
Base64::decode_vec(&s).map_err(|_| format!("invalid base64: {s}"))?
} else {
s.into_owned().into_bytes()
}
.into())
}
}
impl<const BIN_SAFE: bool> Serialize for CowStr<'_, BIN_SAFE> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}
impl<'de, const BIN_SAFE: bool> Deserialize<'de> for OwnedCowStr<BIN_SAFE> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
String::deserialize(deserializer).map(|s| CowStr(s.into()))
}
}
pub struct Formatter<const BIN_SAFE: bool>;
#[derive(Debug, Clone, Serialize)]
struct SerializableSequencedRecord<'a, const BIN_SAFE: bool> {
seq_num: u64,
timestamp: u64,
#[serde(skip_serializing_if = "Vec::is_empty")]
headers: Vec<(CowStr<'a, BIN_SAFE>, CowStr<'a, BIN_SAFE>)>,
#[serde(skip_serializing_if = "CowStr::is_empty")]
body: CowStr<'a, BIN_SAFE>,
}
impl<'a, const BIN_SAFE: bool> From<&'a SequencedRecord>
for SerializableSequencedRecord<'a, BIN_SAFE>
{
fn from(value: &'a SequencedRecord) -> Self {
let SequencedRecord {
timestamp,
seq_num,
headers,
body,
..
} = value;
let headers: Vec<(CowStr<BIN_SAFE>, CowStr<BIN_SAFE>)> = headers
.iter()
.map(|h| (h.name.as_ref().into(), h.value.as_ref().into()))
.collect();
let body: CowStr<BIN_SAFE> = body.as_ref().into();
SerializableSequencedRecord {
timestamp: *timestamp,
seq_num: *seq_num,
headers,
body,
}
}
}
impl<const BIN_SAFE: bool> RecordWriter for Formatter<BIN_SAFE> {
async fn write_record(
record: &SequencedRecord,
writer: &mut (impl AsyncWrite + Unpin),
) -> io::Result<()> {
let record: SerializableSequencedRecord<BIN_SAFE> = record.into();
let s = serde_json::to_string(&record).map_err(io::Error::other)?;
writer.write_all(s.as_bytes()).await
}
}
impl<const BIN_SAFE: bool, I> RecordParser<I> for Formatter<BIN_SAFE>
where
I: Stream<Item = io::Result<String>> + Send + Unpin,
{
type RecordStream = RecordStream<BIN_SAFE, I>;
fn parse_records(lines: I) -> Self::RecordStream {
RecordStream(lines)
}
}
#[derive(Debug, Clone, Deserialize)]
struct DeserializableAppendRecord<const BIN_SAFE: bool> {
timestamp: Option<u64>,
#[serde(default)]
headers: Vec<(OwnedCowStr<BIN_SAFE>, OwnedCowStr<BIN_SAFE>)>,
#[serde(default)]
body: OwnedCowStr<BIN_SAFE>,
}
impl<const BIN_SAFE: bool> TryFrom<DeserializableAppendRecord<BIN_SAFE>> for AppendRecord {
type Error = String;
fn try_from(value: DeserializableAppendRecord<BIN_SAFE>) -> Result<Self, Self::Error> {
let DeserializableAppendRecord {
timestamp,
headers,
body,
} = value;
let body_bytes: Bytes = body.try_into()?;
let mut record = AppendRecord::new(body_bytes).map_err(|e| e.to_string())?;
if !headers.is_empty() {
let parsed_headers: Vec<Header> = headers
.into_iter()
.map(|(name, value)| {
let name_bytes: Bytes = name.try_into()?;
let value_bytes: Bytes = value.try_into()?;
Ok(Header::new(name_bytes, value_bytes))
})
.collect::<Result<Vec<_>, String>>()?;
record = record
.with_headers(parsed_headers)
.map_err(|e| e.to_string())?;
}
if let Some(ts) = timestamp {
record = record.with_timestamp(ts);
}
Ok(record)
}
}
pub struct RecordStream<const BIN_SAFE: bool, S>(S);
impl<const BIN_SAFE: bool, S> Stream for RecordStream<BIN_SAFE, S>
where
S: Stream<Item = io::Result<String>> + Send + Unpin,
{
type Item = Result<AppendRecord, RecordParseError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn parse_record<const BIN_SAFE: bool>(
s: String,
) -> Result<AppendRecord, RecordParseError> {
let append_record: DeserializableAppendRecord<BIN_SAFE> =
serde_json::from_str(&s).map_err(|e| RecordParseError::Parse(e.to_string()))?;
Ok(append_record.try_into()?)
}
match self.0.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Some(Ok(s))) => Poll::Ready(Some(parse_record::<BIN_SAFE>(s))),
}
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::{StreamExt, executor::block_on};
use proptest::{prelude::*, test_runner::TestCaseResult};
use s2_sdk::types::Header;
use super::*;
fn ascii_string_strategy(max_len: usize) -> impl Strategy<Value = String> {
prop::collection::vec(0x20u8..=0x7e, 0..=max_len)
.prop_map(|bytes| String::from_utf8(bytes).unwrap())
}
fn string_strategy(max_len: usize) -> impl Strategy<Value = String> {
prop::collection::vec(any::<char>(), 0..=max_len)
.prop_map(|chars| chars.into_iter().collect())
}
fn bytes_strategy(max_len: usize) -> impl Strategy<Value = Vec<u8>> {
prop::collection::vec(any::<u8>(), 0..=max_len)
}
fn sequenced_record(
seq_num: u64,
timestamp: u64,
headers: Vec<Header>,
body: Bytes,
) -> SequencedRecord {
SequencedRecord::from_parts(seq_num, timestamp, headers, body)
}
async fn parse_text_line(line: String) -> AppendRecord {
let lines = futures::stream::iter(vec![Ok(line)]);
let mut stream = TextFormatter::parse_records(lines);
let record = stream.next().await.unwrap().unwrap();
assert!(stream.next().await.is_none());
record
}
async fn parse_json_line<const BIN_SAFE: bool>(line: String) -> AppendRecord {
let lines = futures::stream::iter(vec![Ok(line)]);
let mut stream =
<super::json::Formatter<BIN_SAFE> as RecordParser<_>>::parse_records(lines);
let record = stream.next().await.unwrap().unwrap();
assert!(stream.next().await.is_none());
record
}
fn prop_assert_headers_eq<E>(actual: &[Header], expected: &[(E, E)]) -> TestCaseResult
where
E: AsRef<[u8]>,
{
prop_assert_eq!(actual.len(), expected.len());
for (actual, (expected_name, expected_value)) in actual.iter().zip(expected.iter()) {
prop_assert_eq!(actual.name.as_ref(), expected_name.as_ref());
prop_assert_eq!(actual.value.as_ref(), expected_value.as_ref());
}
Ok(())
}
#[tokio::test]
async fn text_parse_records() {
let lines =
futures::stream::iter(vec![Ok("line one".to_string()), Ok("line two".to_string())]);
let mut stream = TextFormatter::parse_records(lines);
let r1 = stream.next().await.unwrap().unwrap();
assert_eq!(r1.body(), b"line one");
let r2 = stream.next().await.unwrap().unwrap();
assert_eq!(r2.body(), b"line two");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn json_parse_records_invalid_json() {
let lines = futures::stream::iter(vec![Ok("not json".to_string())]);
let mut stream = <JsonFormatter as RecordParser<_>>::parse_records(lines);
assert!(stream.next().await.unwrap().is_err());
}
#[tokio::test]
async fn json_parse_records_empty_body() {
let json_line = r#"{}"#.to_string();
let lines = futures::stream::iter(vec![Ok(json_line)]);
let mut stream = <JsonFormatter as RecordParser<_>>::parse_records(lines);
let r = stream.next().await.unwrap().unwrap();
assert_eq!(r.body(), b"");
assert!(r.headers().is_empty());
assert_eq!(r.timestamp(), None);
}
#[tokio::test]
async fn json_base64_parse_records_invalid_base64() {
let json_line = r#"{"body":"not-valid-base64!!!"}"#.to_string();
let lines = futures::stream::iter(vec![Ok(json_line)]);
let mut stream = <JsonBase64Formatter as RecordParser<_>>::parse_records(lines);
assert!(stream.next().await.unwrap().is_err());
}
#[tokio::test]
async fn text_parse_records_io_error() {
let lines = futures::stream::iter(vec![Err(io::Error::other("test error"))]);
let mut stream = TextFormatter::parse_records(lines);
let result = stream.next().await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn json_parse_records_io_error() {
let lines = futures::stream::iter(vec![Err(io::Error::other("io error"))]);
let mut stream = <JsonFormatter as RecordParser<_>>::parse_records(lines);
assert!(stream.next().await.unwrap().is_err());
}
proptest! {
#[test]
fn text_formatter_write_then_parse_preserves_ascii_body(body in ascii_string_strategy(256)) {
let record = sequenced_record(0, 0, vec![], Bytes::from(body.clone()));
let output = block_on(async {
let mut output = Vec::new();
TextFormatter::write_record(&record, &mut output).await.unwrap();
output
});
prop_assert_eq!(output.as_slice(), body.as_bytes());
let parsed = block_on(parse_text_line(String::from_utf8(output).unwrap()));
prop_assert_eq!(parsed.body(), body.as_bytes());
prop_assert!(parsed.headers().is_empty());
prop_assert_eq!(parsed.timestamp(), None);
}
#[test]
fn json_formatter_write_then_parse_preserves_utf8_fields(
body in string_strategy(256),
headers in prop::collection::vec((string_strategy(32), string_strategy(64)), 0..=8),
timestamp in any::<u64>(),
) {
let record = sequenced_record(
17,
timestamp,
headers
.iter()
.map(|(name, value)| Header::new(Bytes::from(name.clone()), Bytes::from(value.clone())))
.collect(),
Bytes::from(body.clone()),
);
let output = block_on(async {
let mut output = Vec::new();
JsonFormatter::write_record(&record, &mut output).await.unwrap();
output
});
let parsed = block_on(parse_json_line::<false>(String::from_utf8(output).unwrap()));
prop_assert_eq!(parsed.body(), body.as_bytes());
prop_assert_eq!(parsed.timestamp(), Some(timestamp));
prop_assert_headers_eq(parsed.headers(), &headers)?;
}
#[test]
fn json_base64_formatter_write_then_parse_preserves_binary_fields(
body in bytes_strategy(256),
headers in prop::collection::vec((bytes_strategy(32), bytes_strategy(64)), 0..=8),
timestamp in any::<u64>(),
) {
let record = sequenced_record(
17,
timestamp,
headers
.iter()
.map(|(name, value)| Header::new(Bytes::from(name.clone()), Bytes::from(value.clone())))
.collect(),
Bytes::from(body.clone()),
);
let output = block_on(async {
let mut output = Vec::new();
JsonBase64Formatter::write_record(&record, &mut output).await.unwrap();
output
});
let parsed = block_on(parse_json_line::<true>(String::from_utf8(output).unwrap()));
prop_assert_eq!(parsed.body(), body.as_slice());
prop_assert_eq!(parsed.timestamp(), Some(timestamp));
prop_assert_headers_eq(parsed.headers(), &headers)?;
}
}
}