edge-transformers 0.1.2

A Rust wrapper over ONNXRuntime that implements Huggingface's Optimum pipelines for inference and generates bindings for C# and C.
Documentation
use std::ffi::CString;
use std::path::Path;

use interoptopus::patterns::slice::FFISlice;
use interoptopus::patterns::string::AsciiPointer;
use interoptopus::{ffi_service, ffi_service_ctor, ffi_service_method, ffi_type};

use crate::error::Result;
use crate::ffi::{
    error::FFIError, DeviceFFI, EnvContainer, GraphOptimizationLevelFFI, StringBatch,
};
use crate::{ClassPrediction, TaggedString, TokenClassPrediction, TokenClassificationPipeline};

#[repr(C)]
#[ffi_type]
pub struct TaggedStringFFI<'a> {
    pub input_string: AsciiPointer<'a>,
    pub tags: FFISlice<'a, TokenClassPredictionFFI<'a>>,
}

impl<'a> TaggedStringFFI<'a> {
    pub fn from_tagged_string(
        tagged_string: &'a TaggedString,
        tags_ffi_buf: &'a Vec<TokenClassPredictionFFI<'a>>,
    ) -> Self {
        let input_string = AsciiPointer::from_slice_with_nul(
            CString::new(tagged_string.input_string.as_bytes())
                .unwrap()
                .as_bytes_with_nul(),
        )
        .expect("Failed to convert input string to AsciiPointer");
        let tags = FFISlice::from_slice(tags_ffi_buf.as_slice());
        Self { input_string, tags }
    }
}

impl Default for TaggedStringFFI<'_> {
    fn default() -> Self {
        Self {
            input_string: AsciiPointer::from_slice_with_nul(b"").unwrap(),
            tags: FFISlice::from_slice(&[]),
        }
    }
}

#[repr(C)]
#[ffi_type]
pub struct TokenClassPredictionFFI<'a> {
    pub best: ClassPredictionFFI<'a>,
    pub all: FFISlice<'a, ClassPredictionFFI<'a>>,
    pub start: u32,
    pub end: u32,
}

impl<'a> TokenClassPredictionFFI<'a> {
    pub fn from_token_class_prediction(
        token_class_prediction: &'a TokenClassPrediction,
        prediction_all_ffi: &'a Vec<ClassPredictionFFI<'a>>,
    ) -> Self {
        let best = ClassPredictionFFI::from(&token_class_prediction.best);
        let all = FFISlice::from_slice(prediction_all_ffi.as_slice());
        let start = token_class_prediction.start;
        let end = token_class_prediction.end;
        Self {
            best,
            all,
            start: start as u32,
            end: end as u32,
        }
    }
}

#[repr(C)]
#[ffi_type]
pub struct ClassPredictionFFI<'a> {
    pub label: AsciiPointer<'a>,
    pub score: f32,
}

impl<'a> From<&ClassPrediction> for ClassPredictionFFI<'a> {
    fn from(class_prediction: &ClassPrediction) -> Self {
        let v = Vec::from(
            CString::new(class_prediction.label.as_str())
                .unwrap()
                .to_bytes_with_nul(),
        );
        Self {
            label: AsciiPointer::from_slice_with_nul(v.as_ref())
                .expect("Failed to convert CString to AsciiPointer"),
            score: class_prediction.score,
        }
    }
}

#[ffi_type(opaque)]
pub struct TokenClassificationPipelineFFI<'a> {
    pub model: TokenClassificationPipeline<'a>,
    output_buf: Vec<TaggedString>,
    output_buf_ffi: Vec<TaggedStringFFI<'a>>,
    tags_buf_ffi: Vec<Vec<TokenClassPredictionFFI<'a>>>,
    prediction_all_buf_ffi: Vec<Vec<Vec<ClassPredictionFFI<'a>>>>,
}

#[ffi_service(error = "FFIError", prefix = "onnx_token_classification_")]
impl<'a> TokenClassificationPipelineFFI<'a> {
    #[ffi_service_ctor]
    pub fn from_pretrained(
        env: &'a EnvContainer,
        model_id: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = TokenClassificationPipeline::from_pretrained(
            env.env.clone(),
            model_id.as_c_str().unwrap().to_string_lossy().to_string(),
            device.into(),
            optimization.into(),
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_ffi: Vec::new(),
            tags_buf_ffi: Vec::new(),
            prediction_all_buf_ffi: Vec::new(),
        })
    }

    #[ffi_service_ctor]
    pub fn create_from_files(
        env: &'a EnvContainer,
        model_path: AsciiPointer<'a>,
        tokenizer_config_path: AsciiPointer<'a>,
        special_tokens_map_path: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = TokenClassificationPipeline::new_from_files(
            env.env.clone(),
            Path::new(model_path.as_str().unwrap()).to_path_buf(),
            Path::new(tokenizer_config_path.as_str().unwrap()).to_path_buf(),
            Path::new(special_tokens_map_path.as_str().unwrap()).to_path_buf(),
            device.into(),
            optimization.into(),
            None,
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_ffi: Vec::new(),
            tags_buf_ffi: Vec::new(),
            prediction_all_buf_ffi: Vec::new(),
        })
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn tag(
        s: &'a mut TokenClassificationPipelineFFI<'a>,
        input: AsciiPointer,
    ) -> TaggedStringFFI<'a> {
        let input_str = input.as_c_str().unwrap().to_string_lossy().to_string();
        let output = s.model.tag(&input_str).unwrap();
        s.output_buf = vec![output];
        s.prediction_all_buf_ffi = vec![s.output_buf[0]
            .tags
            .iter()
            .map(|x| x.all.iter().map(|x| ClassPredictionFFI::from(x)).collect())
            .collect()];
        s.tags_buf_ffi = vec![s.output_buf[0]
            .tags
            .iter()
            .zip(s.prediction_all_buf_ffi.last().unwrap().iter())
            .map(|(x, y)| TokenClassPredictionFFI::from_token_class_prediction(x, y))
            .collect()];
        TaggedStringFFI::from_tagged_string(&s.output_buf[0], s.tags_buf_ffi.last().unwrap())
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn tag_batch(
        s: &'a mut TokenClassificationPipelineFFI<'a>,
        input: StringBatch,
    ) -> FFISlice<'a, TaggedStringFFI<'a>> {
        let output = s.model.tag_batch(input.batch).unwrap();
        s.output_buf = output;
        s.prediction_all_buf_ffi = s
            .output_buf
            .iter()
            .map(|x| {
                x.tags
                    .iter()
                    .map(|x| x.all.iter().map(|x| x.into()).collect())
                    .collect()
            })
            .collect();
        s.tags_buf_ffi = s
            .output_buf
            .iter()
            .zip(s.prediction_all_buf_ffi.iter())
            .map(|(x, y)| {
                x.tags
                    .iter()
                    .zip(y.iter())
                    .map(|(x, y)| TokenClassPredictionFFI::from_token_class_prediction(x, y))
                    .collect()
            })
            .collect();
        s.output_buf_ffi = s
            .output_buf
            .iter()
            .zip(s.tags_buf_ffi.iter())
            .map(|(x, y)| TaggedStringFFI::from_tagged_string(x, y))
            .collect();
        FFISlice::from_slice(s.output_buf_ffi.as_slice())
    }
}