1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use anyhow::anyhow;
use ndarray::{Array1, Array2, ArrayBase, Axis, Ix2, OwnedRepr};
use std::array::TryFromSliceError;
use std::path::Path;
use tokenizers::Tokenizer;
const MODEL_INPUT_LIMIT: usize = 512;
#[derive(Debug, thiserror::Error)]
pub enum BgeError {
#[error(
"Number of tokens in the input exceed the model limit. Limit: {}, got: {}",
MODEL_INPUT_LIMIT,
0
)]
LargeInput(usize),
#[error(transparent)]
OnnxRuntimeError(#[from] ort::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub struct Bge {
tokenizer: Tokenizer,
model: ort::Session,
}
impl Bge {
/// Creates a new instance of `Bge` by loading a tokenizer and a model from the specified file paths.
///
/// # Arguments
///
/// * `tokenizer_file_path` - A path to the file containing the tokenizer configuration.
/// * `model_file_path` - A path to the ONNX model file.
///
/// # Returns
///
/// If successful, returns an `Ok(Self)` containing a new instance of `Bge`. On failure, returns an `Err(anyhow::Error)`
/// detailing the error encountered during the loading process.
///
/// # Errors
///
/// This function can fail if:
/// - The paths provided do not point to valid files.
/// - The tokenizer or model file cannot be correctly parsed or loaded, possibly due to format issues or
/// compatibility problems.
///
/// # Examples
///
/// ```
/// let bge = bge::Bge::from_files("path/to/tokenizer.json", "path/to/model.onnx");
/// match bge {
/// Ok(instance) => println!("Bge instance created successfully."),
/// Err(e) => eprintln!("Failed to create Bge instance: {}", e),
/// }
/// ```
pub fn from_files<P>(tokenizer_file_path: P, model_file_path: P) -> anyhow::Result<Self>
where
P: AsRef<Path>,
{
let tokenizer = Tokenizer::from_file(tokenizer_file_path.as_ref().to_str().unwrap())
.map_err(|e| anyhow!(e))?;
let model = ort::Session::builder()?.commit_from_file(model_file_path)?;
Ok(Self { tokenizer, model })
}
/// Generates embeddings for a given input text using the model.
///
/// This method tokenizes the input text, performs necessary preprocessing,
/// and then runs the model to produce embeddings. The embeddings are normalized
/// before being returned.
///
/// # Arguments
///
/// * `input` - The text input for which embeddings should be generated.
///
/// # Returns
///
/// If successful, returns a `Result` containing a fixed-size array of `f32` elements representing
/// the generated embeddings. On failure, returns a `BgeError` detailing the nature of the error.
///
/// # Errors
///
/// This method can return an error in several cases:
/// - `BgeError::LargeInput` if the input text produces more tokens than the model can accept.
/// - `BgeError::OnnxRuntimeError` for errors related to running the ONNX model.
/// - `BgeError::Other` for all other errors, including issues with tokenization or tensor extraction.
///
/// # Examples
///
/// ```
/// # let bge = bge::Bge::from_files("path/to/tokenizer.json", "path/to/model.onnx").unwrap();
/// let embeddings = bge.create_embeddings("This is a sample text.");
/// match embeddings {
/// Ok(embeds) => println!("Embeddings: {:?}", embeds),
/// Err(e) => eprintln!("Error generating embeddings: {}", e),
/// }
/// ```
pub fn create_embeddings(&self, input: &str) -> Result<[f32; 384], BgeError> {
let encoding = self
.tokenizer
.encode(input, true)
.map_err(|e| BgeError::Other(anyhow!(e)))?;
let encoding_ids = encoding.get_ids();
let tokens_count = encoding_ids.len();
if tokens_count > MODEL_INPUT_LIMIT {
return Err(BgeError::LargeInput(tokens_count));
}
let input_ids =
Array1::from_vec(encoding_ids.iter().map(|v| *v as i64).collect()).insert_axis(Axis(0));
let attention_mask: ArrayBase<OwnedRepr<i64>, Ix2> = Array2::ones([1, tokens_count]);
let token_type_ids: ArrayBase<OwnedRepr<i64>, Ix2> = Array2::zeros([1, tokens_count]);
let inputs = ort::inputs! {
"input_ids" => input_ids.view(),
"attention_mask" => attention_mask.view(),
"token_type_ids" => token_type_ids.view(),
}
.map_err(BgeError::OnnxRuntimeError)?;
let outputs = self.model.run(inputs).map_err(BgeError::OnnxRuntimeError)?;
let output = outputs["last_hidden_state"]
.try_extract_tensor()
.map_err(BgeError::OnnxRuntimeError)?;
let view = output.view();
let slice = view.rows().into_iter().next().unwrap().to_slice().unwrap();
let mut res: [f32; 384] = slice
.try_into()
.map_err(|e: TryFromSliceError| BgeError::Other(e.into()))?;
normalize(&mut res);
Ok(res)
}
}
fn normalize(vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
if norm != 0.0 {
for x in vec.iter_mut() {
*x /= norm;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
mod test_data;
#[test]
fn it_works() {
let bge = Bge::from_files("assets/tokenizer.json", "assets/model.onnx").unwrap();
let res = bge.create_embeddings("Some input text to generate embeddings for.");
assert_eq!(res.unwrap(), test_data::TEST_EMBEDDING_RESULT);
}
}