use std::sync::Arc;
use anyhow::Context;
use burn::{
data::dataloader::batcher::Batcher,
prelude::Backend,
};
use crate::core::{
FirehoseRowBatch,
operations::executor::FirehoseBatchExecutor,
};
pub trait BatcherInputAdapter<I>: Send + Sync
where
I: Send + Sync + Clone + std::fmt::Debug + 'static,
{
fn apply(
&self,
inputs: Vec<I>,
) -> anyhow::Result<FirehoseRowBatch>;
}
pub trait BatcherOutputAdapter<B, O>: Send + Sync
where
B: Backend,
O: Send + Clone + std::fmt::Debug + 'static,
{
fn apply(
&self,
batch: &FirehoseRowBatch,
device: &B::Device,
) -> anyhow::Result<O>;
}
pub struct FirehoseExecutorBatcher<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
executor: Arc<dyn FirehoseBatchExecutor>,
input_adapter: Arc<dyn BatcherInputAdapter<I>>,
output_adapter: Arc<dyn BatcherOutputAdapter<B, O>>,
}
impl<B, I, O> FirehoseExecutorBatcher<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
pub fn new(
executor: Arc<dyn FirehoseBatchExecutor>,
input_adapter: Arc<dyn BatcherInputAdapter<I>>,
output_adapter: Arc<dyn BatcherOutputAdapter<B, O>>,
) -> Self {
Self {
executor,
input_adapter,
output_adapter,
}
}
fn batch_result(
&self,
items: Vec<I>,
device: &B::Device,
) -> anyhow::Result<O> {
let mut batch = self.input_adapter.apply(items)?;
self.executor
.execute_batch(&mut batch)
.with_context(|| "Failed to execute batch".to_string())?;
self.output_adapter.apply(&batch, device)
}
}
impl<B, I, O> Batcher<B, I, O> for FirehoseExecutorBatcher<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
fn batch(
&self,
items: Vec<I>,
device: &B::Device,
) -> O {
self.batch_result(items, device)
.expect("Failed to execute batch")
}
}