use anyhow::Result;
use base64::prelude::*;
use ndarray::Axis;
use num_cpus;
use ort::{
Environment, GraphOptimizationLevel, LoggingLevel, InMemorySession, SessionBuilder,
Value,
};
use tokenizers::utils::truncation::*;
use tokenizers::{Encoding, Tokenizer};
#[allow(dead_code)]
pub struct EmbeddingSession{
session: InMemorySession<'static>,
tokenizer: Tokenizer,
}
impl EmbeddingSession {
pub fn new(
session_name: &str,
model_bytes: &'static [u8],
tokenizer_bytes: &[u8],
tokenizer_max_len: usize,
threads: i16, ) -> Self {
let environment = Environment::builder()
.with_name(session_name)
.with_log_level(LoggingLevel::Error)
.build()
.unwrap()
.into_arc();
let mut session_builder = SessionBuilder::new(&environment)
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3)
.unwrap();
let threads = if threads > 0 {
std::cmp::min(threads, num_cpus::get() as i16)
} else {
num_cpus::get() as i16
};
if threads > 0 {
session_builder =
session_builder.with_intra_threads(threads).unwrap();
}
let session = session_builder.with_model_from_memory(model_bytes).unwrap();
let mut tokenizer = Tokenizer::from_bytes(tokenizer_bytes).unwrap();
let _ = tokenizer.with_truncation(Some(TruncationParams {
max_length: tokenizer_max_len as usize, direction: TruncationDirection::Left, strategy: TruncationStrategy::LongestFirst, stride: 0, }));
Self {
session: session,
tokenizer: tokenizer,
}
}
pub fn count_tokens(
&self,
sequence: &str,
) -> Result<usize> {
let tokenizer_output = self.tokenizer.encode(sequence, true).unwrap();
Ok(tokenizer_output.get_ids().len())
}
pub fn embed(
&self,
sequence: &str,
) -> Result<Vec<f32>> {
fn create_ndarray<F>(
tokenizer_output: &Encoding,
func: F,
) -> ndarray::Array2<i64>
where
F: Fn(&Encoding) -> &[u32],
{
ndarray::Array::from_shape_vec(
(1, tokenizer_output.len()),
func(tokenizer_output).iter().map(|&x| x as i64).collect(),
)
.unwrap()
}
fn create_cow_array<F>(
tokenizer_output: &Encoding,
func: F,
) -> ndarray::CowArray<'_, i64, ndarray::Dim<[usize; 2]>>
where
F: Fn(&Encoding) -> &[u32],
{
ndarray::CowArray::from(create_ndarray(tokenizer_output, func))
}
let tokenizer_output = self.tokenizer.encode(sequence, true).unwrap();
let outputs = self.session.run(vec![
Value::from_array(
self.session.allocator(),
&create_cow_array(&tokenizer_output, Encoding::get_ids)
.into_dyn(),
)
.unwrap(),
Value::from_array(
self.session.allocator(),
&create_cow_array(
&tokenizer_output,
Encoding::get_attention_mask,
)
.into_dyn(),
)
.unwrap(),
Value::from_array(
self.session.allocator(),
&create_cow_array(&tokenizer_output, Encoding::get_type_ids)
.into_dyn(),
)
.unwrap(),
])?;
let result = outputs[0]
.try_extract() .unwrap()
.view()
.mean_axis(Axis(1)) .unwrap()
.to_owned()
.as_slice()
.unwrap()
.to_vec();
Ok(result)
}
pub fn binary_quantize(
&self,
unquantized_vector: Vec<f32>,
) -> Result<Vec<i64>> {
let packed_vector: Vec<i64> = unquantized_vector
.iter()
.map(|&x| if x > 0.0 { 1 } else { 0 }) .collect::<Vec<i8>>() .chunks(64) .map(|chunk| {
chunk.iter().fold(0, |acc, &bit| (acc << 1) | (bit as i64))
}) .collect();
Ok(packed_vector)
}
pub fn display_binary(
&self,
vector: Vec<i64>,
) -> Result<String> {
let result = vector
.iter()
.map(|&num| format!("{:064b}", num))
.collect::<Vec<String>>()
.join("");
Ok(result)
}
pub fn display_base64(
&self,
vector: Vec<i64>,
) -> Result<String> {
let bytes: Vec<u8> = vector
.iter()
.flat_map(|&i| i.to_le_bytes().to_vec())
.collect();
let result = BASE64_STANDARD.encode(&bytes);
Ok(result)
}
}