use futures::stream::{self, BoxStream, StreamExt};
use serde::{Serialize, de::DeserializeOwned};
use tokio::io::AsyncReadExt;
use crate::config::{AsyncInputSpec, AsyncOutputSpec, FileExistsPolicy};
use crate::error::{AggregateError, ErrorPolicy, SingleIoError, Stage};
use crate::format::{self, AsyncFormatRegistry, FormatKind, FormatRegistry};
pub struct AsyncIoEngine {
registry: AsyncFormatRegistry,
sync_registry: Option<FormatRegistry>,
error_policy: ErrorPolicy,
inputs: Vec<AsyncInputSpec>,
outputs: Vec<AsyncOutputSpec>,
}
impl AsyncIoEngine {
pub fn new(
registry: AsyncFormatRegistry,
error_policy: ErrorPolicy,
inputs: Vec<AsyncInputSpec>,
outputs: Vec<AsyncOutputSpec>,
) -> Self {
Self {
registry,
sync_registry: None,
error_policy,
inputs,
outputs,
}
}
pub fn new_with_sync_registry(
registry: AsyncFormatRegistry,
sync_registry: FormatRegistry,
error_policy: ErrorPolicy,
inputs: Vec<AsyncInputSpec>,
outputs: Vec<AsyncOutputSpec>,
) -> Self {
Self {
registry,
sync_registry: Some(sync_registry),
error_policy,
inputs,
outputs,
}
}
pub fn registry(&self) -> &AsyncFormatRegistry {
&self.registry
}
pub fn error_policy(&self) -> ErrorPolicy {
self.error_policy
}
pub fn inputs(&self) -> &[AsyncInputSpec] {
&self.inputs
}
pub fn outputs(&self) -> &[AsyncOutputSpec] {
&self.outputs
}
pub async fn read_all<T>(&self) -> Result<Vec<T>, AggregateError>
where
T: DeserializeOwned + Send + 'static,
{
let mut results = Vec::with_capacity(self.inputs.len());
let mut errors = Vec::new();
let mut buffer = Vec::new();
for spec in &self.inputs {
match self.read_one_with_buffer::<T>(spec, &mut buffer).await {
Ok(value) => results.push(value),
Err(e) => {
errors.push(e);
if matches!(self.error_policy, ErrorPolicy::FastFail) {
return Err(AggregateError { errors });
}
}
}
}
if errors.is_empty() {
Ok(results)
} else {
Err(AggregateError { errors })
}
}
pub fn read_records_async<T>(
&self,
concurrency: usize,
) -> BoxStream<'_, Result<T, SingleIoError>>
where
T: DeserializeOwned + Send + 'static,
{
let futs = self
.inputs
.iter()
.map(|spec| self.records_stream_for_spec_async::<T>(spec));
stream::iter(futs)
.buffer_unordered(concurrency)
.flat_map(|s| s)
.boxed()
}
async fn read_one<T>(&self, spec: &AsyncInputSpec) -> Result<T, SingleIoError>
where
T: DeserializeOwned + Send + 'static,
{
let mut buffer = Vec::new();
self.read_one_with_buffer::<T>(spec, &mut buffer).await
}
async fn read_one_with_buffer<T>(
&self,
spec: &AsyncInputSpec,
buffer: &mut Vec<u8>,
) -> Result<T, SingleIoError>
where
T: DeserializeOwned + Send + 'static,
{
let mut reader = spec.provider.open().await.map_err(|e| SingleIoError {
stage: Stage::Open,
target: spec.raw.clone(),
error: Box::new(e),
})?;
buffer.clear();
reader
.read_to_end(buffer)
.await
.map_err(|e| SingleIoError {
stage: Stage::Open,
target: spec.raw.clone(),
error: Box::new(e),
})?;
if let Some(sync_registry) = &self.sync_registry {
match sync_registry.deserialize_value::<T>(
spec.explicit_format.as_ref(),
&spec.format_candidates,
buffer,
) {
Ok(value) => Ok(value),
Err(e) => {
let stage = match e {
format::FormatError::UnknownFormat(_)
| format::FormatError::NoFormatMatched
| format::FormatError::NotEnabled(_) => Stage::ResolveInput,
_ => Stage::Parse,
};
Err(SingleIoError {
stage,
target: spec.raw.clone(),
error: Box::new(e),
})
}
}
} else {
let kind = self
.registry
.resolve(spec.explicit_format.as_ref(), &spec.format_candidates)
.map_err(|e| SingleIoError {
stage: Stage::ResolveInput,
target: spec.raw.clone(),
error: Box::new(e),
})?;
format::deserialize_async::<T>(kind, buffer)
.await
.map_err(|e| SingleIoError {
stage: Stage::Parse,
target: spec.raw.clone(),
error: Box::new(e),
})
}
}
pub async fn write_all<T>(&self, values: &[T]) -> Result<(), AggregateError>
where
T: Serialize + Sync,
{
let mut errors = Vec::new();
for spec in &self.outputs {
if let Err(e) = self.write_one(spec, values).await {
errors.push(e);
if matches!(self.error_policy, ErrorPolicy::FastFail) {
return Err(AggregateError { errors });
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(AggregateError { errors })
}
}
pub async fn write_one_value<T>(&self, value: &T) -> Result<(), AggregateError>
where
T: Serialize + Sync,
{
let mut errors = Vec::new();
for spec in &self.outputs {
if let Err(e) = self.write_single(spec, value).await {
errors.push(e);
if matches!(self.error_policy, ErrorPolicy::FastFail) {
return Err(AggregateError { errors });
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(AggregateError { errors })
}
}
async fn write_one<T>(&self, spec: &AsyncOutputSpec, values: &[T]) -> Result<(), SingleIoError>
where
T: Serialize + Sync,
{
let bytes = if let Some(sync_registry) = &self.sync_registry {
match sync_registry.serialize_value(
spec.explicit_format.as_ref(),
&spec.format_candidates,
&values,
) {
Ok(bytes) => bytes,
Err(e) => {
let stage = match e {
format::FormatError::UnknownFormat(_)
| format::FormatError::NoFormatMatched
| format::FormatError::NotEnabled(_) => Stage::ResolveOutput,
_ => Stage::Serialize,
};
return Err(SingleIoError {
stage,
target: spec.raw.clone(),
error: Box::new(e),
});
}
}
} else {
let kind = self.resolve_output_kind(spec)?;
format::serialize_async(kind, &values)
.await
.map_err(|e| SingleIoError {
stage: Stage::Serialize,
target: spec.raw.clone(),
error: Box::new(e),
})?
};
let mut writer = self.open_output(spec).await?;
tokio::io::AsyncWriteExt::write_all(&mut *writer, &bytes)
.await
.map_err(|e| SingleIoError {
stage: Stage::Serialize,
target: spec.raw.clone(),
error: Box::new(e),
})
}
async fn write_single<T>(&self, spec: &AsyncOutputSpec, value: &T) -> Result<(), SingleIoError>
where
T: Serialize + Sync,
{
let bytes = if let Some(sync_registry) = &self.sync_registry {
match sync_registry.serialize_value(
spec.explicit_format.as_ref(),
&spec.format_candidates,
value,
) {
Ok(bytes) => bytes,
Err(e) => {
let stage = match e {
format::FormatError::UnknownFormat(_)
| format::FormatError::NoFormatMatched
| format::FormatError::NotEnabled(_) => Stage::ResolveOutput,
_ => Stage::Serialize,
};
return Err(SingleIoError {
stage,
target: spec.raw.clone(),
error: Box::new(e),
});
}
}
} else {
let kind = self
.registry
.resolve(spec.explicit_format.as_ref(), &spec.format_candidates)
.map_err(|e| SingleIoError {
stage: Stage::ResolveOutput,
target: spec.raw.clone(),
error: Box::new(e),
})?;
format::serialize_async(kind, value)
.await
.map_err(|e| SingleIoError {
stage: Stage::Serialize,
target: spec.raw.clone(),
error: Box::new(e),
})?
};
let mut writer = self.open_output(spec).await?;
tokio::io::AsyncWriteExt::write_all(&mut *writer, &bytes)
.await
.map_err(|e| SingleIoError {
stage: Stage::Serialize,
target: spec.raw.clone(),
error: Box::new(e),
})
}
async fn open_output(
&self,
spec: &AsyncOutputSpec,
) -> Result<Box<dyn tokio::io::AsyncWrite + Unpin + Send>, SingleIoError> {
let result = match spec.file_exists_policy {
FileExistsPolicy::Overwrite => spec.target.open_overwrite().await,
FileExistsPolicy::Append => spec.target.open_append().await,
FileExistsPolicy::Error => spec.target.open_overwrite().await,
};
result.map_err(|e| SingleIoError {
stage: Stage::Open,
target: spec.raw.clone(),
error: Box::new(e),
})
}
fn resolve_output_kind(&self, spec: &AsyncOutputSpec) -> Result<FormatKind, SingleIoError> {
self.registry
.resolve(spec.explicit_format.as_ref(), &spec.format_candidates)
.map_err(|e| SingleIoError {
stage: Stage::ResolveOutput,
target: spec.raw.clone(),
error: Box::new(e),
})
}
async fn records_stream_for_spec_async<'a, T>(
&'a self,
spec: &'a AsyncInputSpec,
) -> BoxStream<'a, Result<T, SingleIoError>>
where
T: DeserializeOwned + Send + 'static,
{
let mut reader = match spec.provider.open().await {
Ok(r) => r,
Err(e) => {
let err = SingleIoError {
stage: Stage::Open,
target: spec.raw.clone(),
error: Box::new(e),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
};
let mut buffer = Vec::new();
if let Err(e) = reader.read_to_end(&mut buffer).await {
let err = SingleIoError {
stage: Stage::Open,
target: spec.raw.clone(),
error: Box::new(e),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
let kind = if let Some(sync_registry) = &self.sync_registry {
match sync_registry.resolve(spec.explicit_format.as_ref(), &spec.format_candidates) {
Ok(k) => k,
Err(e) => {
let err = SingleIoError {
stage: Stage::Parse,
target: spec.raw.clone(),
error: Box::new(e),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
} else {
match self
.registry
.resolve(spec.explicit_format.as_ref(), &spec.format_candidates)
{
Ok(k) => k,
Err(e) => {
let err = SingleIoError {
stage: Stage::Parse,
target: spec.raw.clone(),
error: Box::new(e),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
};
let target = spec.raw.clone();
if let FormatKind::Json = kind {
#[cfg(feature = "json")]
{
let reader = std::io::Cursor::new(buffer);
let iter = crate::format::deserialize_json_stream::<T, _>(reader);
return Self::iter_to_stream(iter, target);
}
#[cfg(not(feature = "json"))]
{
let err = SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(crate::format::FormatError::NotEnabled(kind)),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
if let (Some(sync_registry), FormatKind::Custom(_)) = (&self.sync_registry, kind) {
use std::io::Cursor;
let reader: Box<dyn std::io::Read> = Box::new(Cursor::new(buffer));
let iter_result = sync_registry.stream_deserialize_into::<T>(Some(&kind), &[], reader);
let target_for_stream = target.clone();
let iter = match iter_result {
Ok(iter) => iter,
Err(e) => {
let err = SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(e),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
};
let collected: Vec<Result<T, format::FormatError>> = iter.collect();
return Self::iter_to_stream(collected.into_iter(), target_for_stream);
}
if let FormatKind::Csv = kind {
#[cfg(feature = "csv")]
{
let reader = std::io::Cursor::new(buffer);
let iter = crate::format::deserialize_csv_stream::<T, _>(reader);
return Self::iter_to_stream(iter, target);
}
#[cfg(not(feature = "csv"))]
{
let err = SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(crate::format::FormatError::NotEnabled(kind)),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
if let FormatKind::Yaml = kind {
#[cfg(feature = "yaml")]
{
let reader = std::io::Cursor::new(buffer);
let iter = crate::format::deserialize_yaml_stream::<T, _>(reader);
let collected: Vec<_> = iter.collect();
return Self::iter_to_stream(collected.into_iter(), target);
}
#[cfg(not(feature = "yaml"))]
{
let err = SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(crate::format::FormatError::NotEnabled(kind)),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
if let FormatKind::Plaintext = kind {
#[cfg(feature = "plaintext")]
{
let reader = std::io::Cursor::new(buffer);
let iter = crate::format::deserialize_plaintext_stream::<T, _>(reader);
return Self::iter_to_stream(iter, target);
}
#[cfg(not(feature = "plaintext"))]
{
let err = SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(crate::format::FormatError::NotEnabled(kind)),
};
return stream::iter(std::iter::once(Err(err))).boxed();
}
}
let value = format::deserialize_async::<T>(kind, &buffer).await;
let result = value.map_err(|e| SingleIoError {
stage: Stage::Parse,
target,
error: Box::new(e),
});
stream::iter(std::iter::once(result)).boxed()
}
fn iter_to_stream<T, I>(iter: I, target: String) -> BoxStream<'static, Result<T, SingleIoError>>
where
T: DeserializeOwned + Send + 'static,
I: Iterator<Item = Result<T, format::FormatError>> + Send + 'static,
{
let mapped = iter.map(move |res| {
res.map_err(|e| SingleIoError {
stage: Stage::Parse,
target: target.clone(),
error: Box::new(e),
})
});
stream::iter(mapped).boxed()
}
pub fn read_stream_async<T>(
&self,
concurrency: usize,
) -> BoxStream<'_, Result<T, SingleIoError>>
where
T: DeserializeOwned + Send + 'static,
{
let futs = self.inputs.iter().map(|spec| self.read_one::<T>(spec));
stream::iter(futs).buffer_unordered(concurrency).boxed()
}
}