1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use std::fmt::{Display, Formatter};
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::sync::Arc;

use ndarray::Axis;
use ort::{ExecutionProviderDispatch, GraphOptimizationLevel, LoggingLevel, SessionBuilder};

use crate::embedding::Embedding;

pub struct Semantic {
    model_ref: &'static [u8],
    tokenizer: Arc<tokenizers::Tokenizer>,
    session: Arc<ort::Session>,
}

impl Drop for Semantic {
    fn drop(&mut self) {
        unsafe {
            ManuallyDrop::drop(&mut ManuallyDrop::new(self.model_ref));
        }
    }
}

impl Semantic {
    pub async fn initialize(model: Vec<u8>, tokenizer_data: Vec<u8>) -> Result<Pin<Box<Semantic>>, SemanticError> {
        let semantic = Self::init_semantic(model, tokenizer_data)?;

        Ok(Box::pin(semantic))
    }

    pub fn init_semantic(model: Vec<u8>, tokenizer_data: Vec<u8>) -> Result<Semantic, SemanticError> {
        ort::init()
            .with_name("Encode")
            .with_log_level(LoggingLevel::Warning)
            .with_execution_providers([ExecutionProviderDispatch::CPU(Default::default())])
            .commit()
            .map_err(|e| SemanticError::InitBuildOrtEnv)?;


        let threads = if let Ok(v) = std::env::var("NUM_OMP_THREADS") {
            str::parse(&v).unwrap_or(1)
        } else {
            1
        };

        let tokenizer: Arc<tokenizers::Tokenizer> = tokenizers::Tokenizer::from_bytes(tokenizer_data)
            .map_err(|e| SemanticError::TokenizeEncodeByteError)?.into();

        let model_ref = model.leak();

        let semantic = Self {
            model_ref,
            tokenizer,
            session: SessionBuilder::new()
                .map_err(|e| SemanticError::InitSessionBuilder)?
                .with_optimization_level(GraphOptimizationLevel::Level3).map_err(|e| SemanticError::InitSessionOptimization)?
                .with_intra_threads(threads).map_err(|e| SemanticError::InitSessionThreads)?
                .with_model_from_memory(model_ref)
                .unwrap()
                .into(),
        };
        Ok(semantic)
    }

    pub fn embed(&self, sequence: &str) -> Result<Embedding, SemanticError> {
        let encoding = self.tokenizer.encode(sequence, true)
            .map_err(|_| SemanticError::TokenizeEncodeError)?;

        let input_ids = encoding.get_ids().iter().map(|item| *item as i64).collect::<Vec<_>>();
        let attention_mask = encoding.get_attention_mask().iter().map(|item| *item as i64).collect::<Vec<_>>();
        let token_type_ids = encoding.get_type_ids().iter().map(|item| *item as i64).collect::<Vec<_>>();

        // Run inference
        let sequence_length = input_ids.len();

        let input_ids = ndarray::CowArray::from(&input_ids)
            .into_shape((1, sequence_length))
            .map_err(|_| SemanticError::ShapeError)?
            .into_dyn();

        let input_ids = ndarray::CowArray::from(&input_ids)
            .into_shape((1, sequence_length))
            .map_err(|_| SemanticError::ShapeError)?
            .into_dyn();
        let input_ids = ort::Value::from_array(&input_ids).unwrap();

        let attention_mask = ndarray::CowArray::from(&attention_mask)
            .into_shape((1, sequence_length))
            .map_err(|_| SemanticError::ShapeError)?
            .into_dyn();
        let attention_mask = ort::Value::from_array(&attention_mask).unwrap();

        let token_type_ids = ndarray::CowArray::from(&token_type_ids)
            .into_shape((1, sequence_length))
            .map_err(|_| SemanticError::ShapeError)?
            .into_dyn();
        let token_type_ids = ort::Value::from_array(&token_type_ids).unwrap();

        println!("token_type_ids: {:?}", token_type_ids);

        let outputs = self.session
            .run(ort::inputs![input_ids, attention_mask, token_type_ids].unwrap())
            .unwrap();

        let output_tensor = outputs[0].extract_tensor::<f32>().unwrap();
        let sequence_embedding = &*output_tensor.view();
        let pooled = sequence_embedding.mean_axis(Axis(1)).unwrap();

        Ok(Embedding(pooled.to_owned().as_slice().unwrap().to_vec()))
    }
}

type Result<T, E = SemanticError> = std::result::Result<T, E>;

#[derive(Debug, thiserror::Error)]
pub enum SemanticError {
    TokenizeEncodeError,
    TokenizeEncodeByteError,
    ShapeError,
    InitSessionBuilder,
    InitSessionOptimization,
    InitBuildOrtEnv,
    InitSessionThreads,
    InitModelReadError,
    InitTokenizerReadError,
}

impl Display for SemanticError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            SemanticError::TokenizeEncodeError => write!(f, "TokenizeEncodeError"),
            SemanticError::ShapeError => write!(f, "ShapeError"),
            SemanticError::TokenizeEncodeByteError => write!(f, "TokenizeEncodeByteError"),
            SemanticError::InitSessionBuilder => write!(f, "InitSessionBuilder"),
            SemanticError::InitSessionOptimization => write!(f, "InitSessionOptimization"),
            SemanticError::InitSessionThreads => write!(f, "InitSessionThreads"),
            SemanticError::InitBuildOrtEnv => write!(f, "InitBuildOrtEnv"),
            SemanticError::InitModelReadError => write!(f, "InitModelReadError"),
            SemanticError::InitTokenizerReadError => write!(f, "InitTokenizerReadError"),
        }
    }
}