edge-transformers 0.1.2

A Rust wrapper over ONNXRuntime that implements Huggingface's Optimum pipelines for inference and generates bindings for C# and C.
Documentation
use std::ffi::CString;
use std::path::Path;

use interoptopus::patterns::slice::FFISlice;
use interoptopus::patterns::string::AsciiPointer;
use interoptopus::{ffi_service, ffi_service_ctor, ffi_service_method, ffi_type};

use crate::error::Result;
use crate::ffi::{
    error::FFIError, DeviceFFI, EnvContainer, GraphOptimizationLevelFFI, StringBatch,
};
use crate::{ClassPrediction, Prediction, SequenceClassificationPipeline};

#[repr(C)]
#[ffi_type]
pub struct PredictionFFI<'a> {
    pub best: ClassPredictionFFI<'a>,
    pub all: FFISlice<'a, ClassPredictionFFI<'a>>,
}

impl Default for PredictionFFI<'_> {
    fn default() -> Self {
        PredictionFFI {
            best: ClassPredictionFFI {
                label: AsciiPointer::default(),
                score: 0.0,
            },
            all: FFISlice::default(),
        }
    }
}

impl<'a> PredictionFFI<'a> {
    pub fn from_prediction(
        prediction: &'a Prediction,
        prediction_all_ffi: &'a Vec<ClassPredictionFFI<'a>>,
    ) -> Self {
        let best = ClassPredictionFFI::from(&prediction.best);
        let all = FFISlice::from_slice(prediction_all_ffi.as_slice());
        Self { best, all }
    }
}

#[repr(C)]
#[ffi_type]
pub struct ClassPredictionFFI<'a> {
    pub label: AsciiPointer<'a>,
    pub score: f32,
}

impl<'a> From<&ClassPrediction> for ClassPredictionFFI<'a> {
    fn from(class_prediction: &ClassPrediction) -> Self {
        let v = Vec::from(
            CString::new(class_prediction.label.as_str())
                .unwrap()
                .to_bytes_with_nul(),
        );
        Self {
            label: AsciiPointer::from_slice_with_nul(v.as_ref())
                .expect("Failed to convert CString to AsciiPointer"),
            score: class_prediction.score,
        }
    }
}

#[ffi_type(opaque)]
pub struct SequenceClassificationPipelineFFI<'a> {
    pub model: SequenceClassificationPipeline<'a>,
    output_buf: Vec<Prediction>,
    output_buf_vec: Vec<PredictionFFI<'a>>,
    class_preds_buf_vec: Vec<Vec<ClassPredictionFFI<'a>>>,
}

#[ffi_service(error = "FFIError", prefix = "onnx_classification_")]
impl<'a> SequenceClassificationPipelineFFI<'a> {
    #[ffi_service_ctor]
    pub fn from_pretrained(
        env: &'a EnvContainer,
        model_id: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = SequenceClassificationPipeline::from_pretrained(
            env.env.clone(),
            model_id.as_c_str().unwrap().to_string_lossy().to_string(),
            device.into(),
            optimization.into(),
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_vec: Vec::new(),
            class_preds_buf_vec: Vec::new(),
        })
    }

    #[ffi_service_ctor]
    pub fn create_from_files(
        env: &'a EnvContainer,
        model_path: AsciiPointer<'a>,
        tokenizer_config_path: AsciiPointer<'a>,
        special_tokens_map_path: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = SequenceClassificationPipeline::new_from_files(
            env.env.clone(),
            Path::new(&*model_path.as_c_str().unwrap().to_string_lossy()).to_path_buf(),
            Path::new(&*tokenizer_config_path.as_c_str().unwrap().to_string_lossy()).to_path_buf(),
            Path::new(
                &*special_tokens_map_path
                    .as_c_str()
                    .unwrap()
                    .to_string_lossy(),
            )
            .to_path_buf(),
            device.into(),
            optimization.into(),
            None,
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_vec: Vec::new(),
            class_preds_buf_vec: Vec::new(),
        })
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn classify(
        s: &'a mut SequenceClassificationPipelineFFI,
        input: AsciiPointer,
    ) -> PredictionFFI<'a> {
        let output = s
            .model
            .classify(&*input.as_c_str().unwrap().to_string_lossy())
            .unwrap();
        s.output_buf = vec![output];
        s.class_preds_buf_vec.push(
            s.output_buf
                .last()
                .unwrap()
                .all
                .iter()
                .map(|x| ClassPredictionFFI::from(x))
                .collect(),
        );
        let output_ffi = PredictionFFI::from_prediction(
            &s.output_buf.last().unwrap(),
            s.class_preds_buf_vec.last().unwrap(),
        );
        output_ffi
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn classify_batch(
        s: &'a mut SequenceClassificationPipelineFFI<'a>,
        input: StringBatch,
    ) -> FFISlice<'a, PredictionFFI<'a>> {
        let output = s.model.classify_batch(input.batch).unwrap();
        s.output_buf = output;
        s.class_preds_buf_vec = s
            .output_buf
            .iter()
            .map(|x| x.all.iter().map(|x| ClassPredictionFFI::from(x)).collect())
            .collect();
        s.output_buf_vec = s
            .output_buf
            .iter()
            .zip(s.class_preds_buf_vec.iter())
            .map(|(x, y)| PredictionFFI::from_prediction(x, y))
            .collect();
        let output_ffi = FFISlice::from_slice(s.output_buf_vec.as_slice());
        output_ffi
    }
}