edge-transformers 0.1.2

A Rust wrapper over ONNXRuntime that implements Huggingface's Optimum pipelines for inference and generates bindings for C# and C.
Documentation
use std::ffi::CString;
use std::path::Path;

use interoptopus::patterns::option::FFIOption;
use interoptopus::patterns::slice::FFISlice;
use interoptopus::patterns::string::AsciiPointer;
use interoptopus::{ffi_service, ffi_service_ctor, ffi_service_method, ffi_type};

use crate::error::Result;
use crate::ffi::{
    error::FFIError, DeviceFFI, EnvContainer, GraphOptimizationLevelFFI, StringBatch,
    UseAsciiStringPattern,
};
use crate::sampling::{ArgmaxSampler, RandomSampler, TopKSampler};
use crate::Seq2SeqGenerationPipeline;

#[ffi_type(opaque)]
pub struct Seq2SeqGenerationPipelineFFI<'a> {
    pub model: Seq2SeqGenerationPipeline<'a>,
    pub output_buf: Vec<String>,
    pub output_buf_ffi: Vec<UseAsciiStringPattern<'a>>,
}

#[ffi_service(error = "FFIError", prefix = "onnx_seq2seq_")]
impl<'a> Seq2SeqGenerationPipelineFFI<'a> {
    #[ffi_service_ctor]
    pub fn from_pretrained(
        env: &'a EnvContainer,
        model_id: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = Seq2SeqGenerationPipeline::from_pretrained(
            env.env.clone(),
            model_id.as_c_str().unwrap().to_string_lossy().to_string(),
            device.into(),
            optimization.into(),
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_ffi: Vec::new(),
        })
    }

    #[ffi_service_ctor]
    pub fn create_from_files(
        env: &'a EnvContainer,
        model_path: AsciiPointer<'a>,
        tokenizer_config_path: AsciiPointer<'a>,
        special_tokens_map_path: AsciiPointer<'a>,
        device: DeviceFFI,
        optimization: GraphOptimizationLevelFFI,
    ) -> Result<Self> {
        let model = Seq2SeqGenerationPipeline::new_from_files(
            env.env.clone(),
            Path::new(&*model_path.as_c_str().unwrap().to_string_lossy()).to_path_buf(),
            Path::new(&*tokenizer_config_path.as_c_str().unwrap().to_string_lossy()).to_path_buf(),
            Path::new(
                &*special_tokens_map_path
                    .as_c_str()
                    .unwrap()
                    .to_string_lossy(),
            )
            .to_path_buf(),
            device.into(),
            optimization.into(),
        )?;
        Ok(Self {
            model,
            output_buf: Vec::new(),
            output_buf_ffi: Vec::new(),
        })
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_topk_sampling(
        &mut self,
        input: AsciiPointer,
        decoder_input: AsciiPointer,
        max_length: i32,
        topk: i32,
        temperature: f32,
    ) -> AsciiPointer<'a> {
        let sampler = TopKSampler::new(topk as usize, temperature);
        let decoder_input = decoder_input
            .as_c_str()
            .unwrap()
            .to_string_lossy()
            .to_string();
        let decoder_input_option = match decoder_input.as_str() {
            "" => None,
            _ => Some(decoder_input.as_str()),
        };
        let output = self
            .model
            .generate(
                &*input.as_c_str().unwrap().to_string_lossy(),
                decoder_input_option,
                max_length,
                &sampler,
            )
            .unwrap();
        AsciiPointer::from_slice_with_nul(CString::new(output).unwrap().to_bytes_with_nul())
            .expect("Failed to convert CString to AsciiPointer")
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_random_sampling(
        &mut self,
        input: AsciiPointer<'a>,
        decoder_input: AsciiPointer,
        max_length: i32,
        temperature: f32,
    ) -> AsciiPointer<'a> {
        let sampler = RandomSampler::new(temperature);
        let decoder_input = decoder_input
            .as_c_str()
            .unwrap()
            .to_string_lossy()
            .to_string();
        let decoder_input_option = match decoder_input.as_str() {
            "" => None,
            _ => Some(decoder_input.as_str()),
        };
        let output = self
            .model
            .generate(
                &*input.as_c_str().unwrap().to_string_lossy(),
                decoder_input_option,
                max_length,
                &sampler,
            )
            .unwrap();

        AsciiPointer::from_slice_with_nul(CString::new(output).unwrap().to_bytes_with_nul())
            .expect("Failed to convert CString to AsciiPointer")
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_argmax(
        &mut self,
        input: AsciiPointer<'a>,
        decoder_input: AsciiPointer,
        max_length: i32,
    ) -> AsciiPointer<'a> {
        let sampler = ArgmaxSampler::new();
        let decoder_input = decoder_input
            .as_c_str()
            .unwrap()
            .to_string_lossy()
            .to_string();
        let decoder_input_option = match decoder_input.as_str() {
            "" => None,
            _ => Some(decoder_input.as_str()),
        };
        let output = self
            .model
            .generate(
                &*input.as_c_str().unwrap().to_string_lossy(),
                decoder_input_option,
                max_length,
                &sampler,
            )
            .unwrap();
        AsciiPointer::from_slice_with_nul(CString::new(output).unwrap().to_bytes_with_nul())
            .expect("Failed to convert CString to AsciiPointer")
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_topk_sampling_batch(
        s: &'a mut Seq2SeqGenerationPipelineFFI,
        input: StringBatch,
        decoder_input: FFIOption<StringBatch>,
        max_length: i32,
        topk: i32,
        temperature: f32,
    ) -> FFISlice<'a, UseAsciiStringPattern<'a>> {
        let sampler = TopKSampler::new(topk as usize, temperature);
        s.output_buf = s
            .model
            .generate_batch(
                input.batch,
                decoder_input.into_option().map(|s| s.batch),
                max_length,
                &sampler,
            )
            .unwrap();
        s.output_buf_ffi = s
            .output_buf
            .iter()
            .map(|s| {
                AsciiPointer::from_slice_with_nul(
                    CString::new(s.as_str()).unwrap().to_bytes_with_nul(),
                )
                .expect("Failed to convert CString to AsciiPointer")
            })
            .map(|s| UseAsciiStringPattern { ascii_string: s })
            .collect();
        FFISlice::from_slice(s.output_buf_ffi.as_slice())
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_random_sampling_batch(
        s: &'a mut Seq2SeqGenerationPipelineFFI,
        input: StringBatch,
        decoder_input: FFIOption<StringBatch>,
        max_length: i32,
        temperature: f32,
    ) -> FFISlice<'a, UseAsciiStringPattern<'a>> {
        let sampler = RandomSampler::new(temperature);
        s.output_buf = s
            .model
            .generate_batch(
                input.batch,
                decoder_input.into_option().map(|s| s.batch),
                max_length,
                &sampler,
            )
            .unwrap();
        s.output_buf_ffi = s
            .output_buf
            .iter()
            .map(|s| {
                AsciiPointer::from_slice_with_nul(
                    CString::new(s.as_str()).unwrap().to_bytes_with_nul(),
                )
                .expect("Failed to convert CString to AsciiPointer")
            })
            .map(|s| UseAsciiStringPattern { ascii_string: s })
            .collect();
        FFISlice::from_slice(s.output_buf_ffi.as_slice())
    }

    #[ffi_service_method(on_panic = "return_default")]
    pub fn generate_argmax_batch(
        s: &'a mut Seq2SeqGenerationPipelineFFI,
        input: StringBatch,
        decoder_input: FFIOption<StringBatch>,
        max_length: i32,
    ) -> FFISlice<'a, UseAsciiStringPattern<'a>> {
        let sampler = ArgmaxSampler::new();
        s.output_buf = s
            .model
            .generate_batch(
                input.batch,
                decoder_input.into_option().map(|s| s.batch),
                max_length,
                &sampler,
            )
            .unwrap();
        s.output_buf_ffi = s
            .output_buf
            .iter()
            .map(|s| {
                AsciiPointer::from_slice_with_nul(
                    CString::new(s.as_str()).unwrap().to_bytes_with_nul(),
                )
                .expect("Failed to convert CString to AsciiPointer")
            })
            .map(|s| UseAsciiStringPattern { ascii_string: s })
            .collect();
        FFISlice::from_slice(s.output_buf_ffi.as_slice())
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[ignore]
    #[test]
    fn test_generate_topk_sampling() -> Result<()> {
        let e = EnvContainer::new()?;
        let mut pipeline = Seq2SeqGenerationPipelineFFI::create_from_files(
            &e,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/model.onnx")?.to_bytes_with_nul(),
            )?,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/tokenizer.json")?.to_bytes_with_nul(),
            )?,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/special_tokens_map.json")?.to_bytes_with_nul(),
            )?,
            DeviceFFI::CPU,
            GraphOptimizationLevelFFI::Level3,
        )
        .unwrap();

        let output = pipeline.generate_topk_sampling(
            AsciiPointer::from_slice_with_nul(b"translate English to French: How old are you?\0")
                .unwrap(),
            AsciiPointer::from_slice_with_nul(b"\0").unwrap(),
            32,
            5,
            1.0,
        );
        println!(
            "{}",
            output.as_c_str().unwrap().to_string_lossy().to_string()
        );
        Ok(())
    }

    #[ignore]
    #[test]
    fn test_generate_topk_sampling_batch() -> Result<()> {
        let e = EnvContainer::new()?;
        let mut pipeline = Seq2SeqGenerationPipelineFFI::create_from_files(
            &e,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/model.onnx")?.to_bytes_with_nul(),
            )?,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/tokenizer.json")?.to_bytes_with_nul(),
            )?,
            AsciiPointer::from_slice_with_nul(
                CString::new("resources/seq2seq/special_tokens_map.json")?.to_bytes_with_nul(),
            )?,
            DeviceFFI::CPU,
            GraphOptimizationLevelFFI::Level3,
        )
        .unwrap();
        let b = StringBatch {
            batch: vec![
                "translate English to French: How old are you?".to_string(),
                "translate English to French: What is your name?".to_string(),
            ],
        };
        let b_dec = StringBatch {
            batch: vec!["<pad>".to_string(), "<pad>".to_string()],
        };
        let output = Seq2SeqGenerationPipelineFFI::generate_topk_sampling_batch(
            &mut pipeline,
            b,
            FFIOption::some(b_dec),
            32,
            5,
            1.0,
        );
        println!(
            "{:?}",
            output.as_slice()[0]
                .ascii_string
                .as_c_str()
                .unwrap()
                .to_string_lossy()
                .to_string()
        );
        println!(
            "{:?}",
            output.as_slice()[1]
                .ascii_string
                .as_c_str()
                .unwrap()
                .to_string_lossy()
                .to_string()
        );
        Ok(())
    }
}