burn-central-inference 0.6.0

Inference contracts and adapters for Burn Central SDK.
Documentation
use crate::writer::InferenceWriterError;
use crate::{Inference, InferenceWriter, InferenceWriterChannel};
use serde::{Serialize, de::DeserializeOwned};
use std::error::Error;
use std::fmt;
use std::marker::PhantomData;
use std::time::Duration;

pub trait ErasedInference: Send + Sync {
    fn infer_bytes(
        &self,
        input: &[u8],
        writer: Box<dyn ErasedInferenceWriter>,
    ) -> Result<(), String>;
}

pub trait ErasedInferenceWriter: Send + Sync {
    fn write_bytes(&self, output: Vec<u8>) -> Result<(), String>;
    fn error(&self, error: String) -> Result<(), String>;
    fn finish(&self, duration: Duration);
}

pub struct JsonInference<T, I, O> {
    inner: T,
    _types: PhantomData<fn(I, O)>,
}

impl<T, I, O> JsonInference<T, I, O> {
    pub fn new(inner: T) -> Self {
        Self {
            inner,
            _types: PhantomData,
        }
    }
}

impl<T, I, O> ErasedInference for JsonInference<T, I, O>
where
    T: Inference<Input = I, Output = O> + Send + Sync + 'static,
    I: DeserializeOwned + Send + 'static,
    O: Serialize + Send + 'static,
{
    fn infer_bytes(
        &self,
        input: &[u8],
        writer: Box<dyn ErasedInferenceWriter>,
    ) -> Result<(), String> {
        let input: I = serde_json::from_slice(input).map_err(|e| e.to_string())?;
        let channel = JsonInferenceWriterChannel::<O> {
            writer,
            _types: PhantomData,
        };
        let writer = InferenceWriter::new(Box::new(channel));
        self.inner.infer(input, writer);
        Ok(())
    }
}

struct JsonInferenceWriterChannel<O> {
    writer: Box<dyn ErasedInferenceWriter>,
    _types: PhantomData<fn(O)>,
}

#[derive(Debug)]
struct StringError(String);

impl fmt::Display for StringError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

impl Error for StringError {}

impl<O> InferenceWriterChannel<O> for JsonInferenceWriterChannel<O>
where
    O: Serialize,
{
    fn write(&self, output: O) -> Result<(), InferenceWriterError> {
        let bytes =
            serde_json::to_vec(&output).map_err(|e| InferenceWriterError::Unknown(Box::new(e)))?;
        self.writer
            .write_bytes(bytes)
            .map_err(|err| InferenceWriterError::Unknown(Box::new(StringError(err))))
    }

    fn error(&self, error: Box<dyn Error + Send + Sync>) -> Result<(), InferenceWriterError> {
        self.writer
            .error(error.to_string())
            .map_err(|err| InferenceWriterError::Unknown(Box::new(StringError(err))))
    }

    fn finish(&self, duration: Duration) {
        self.writer.finish(duration);
    }
}