1use anyhow::Result;
2use std::path::Path;
3use tract_onnx::prelude::tract_ndarray::Array2;
4use tract_onnx::prelude::*;
5
6pub struct EmbeddingModel {
7 plan: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
8}
9
10impl EmbeddingModel {
11 pub fn load(path: &Path) -> Result<Self> {
12 let model = tract_onnx::onnx()
13 .model_for_path(path)?
14 .into_optimized()?
15 .into_runnable()?;
16 Ok(Self { plan: model })
17 }
18
19 pub fn generate_embedding(&self, token_ids: &[i64], mask: &[i64]) -> Result<Vec<f32>> {
20 let input_ids_array = Array2::from_shape_vec((1, 512), token_ids.to_vec())
21 .map_err(|e| anyhow::anyhow!("Failed to create input_ids array: {}", e))?;
22 let input_ids: Tensor = input_ids_array.into();
23
24 let attention_mask_array = Array2::from_shape_vec((1, 512), mask.to_vec())
25 .map_err(|e| anyhow::anyhow!("Failed to create attention_mask array: {}", e))?;
26 let attention_mask: Tensor = attention_mask_array.into();
27
28 let input_count = self.plan.model().inputs.len();
29 let results = if input_count == 3 {
30 let token_type_ids_array = Array2::from_shape_vec((1, 512), vec![0i64; 512])
31 .map_err(|e| anyhow::anyhow!("Failed to create token_type_ids array: {}", e))?;
32 let token_type_ids: Tensor = token_type_ids_array.into();
33 self.plan.run(tvec![
34 input_ids.into(),
35 attention_mask.into(),
36 token_type_ids.into()
37 ])?
38 } else {
39 self.plan
40 .run(tvec![input_ids.into(), attention_mask.into()])?
41 };
42
43 let output_tensor = results[0].to_array_view::<f32>()?;
44 let shape = output_tensor.shape();
45
46 let mut raw_vec = vec![0.0f32; 384];
47
48 if shape.len() == 3 {
49 let seq_len = shape[1];
50 let dim = shape[2];
51 let target_dim = std::cmp::min(dim, 384);
52
53 let mut valid_token_count = 0.0f32;
54 for (t, &m) in mask.iter().enumerate() {
55 if t < seq_len && m > 0 {
56 let weight = m as f32;
57 valid_token_count += weight;
58 for d in 0..target_dim {
59 raw_vec[d] += output_tensor[[0, t, d]] * weight;
60 }
61 }
62 }
63
64 if valid_token_count > 0.0 {
65 for val in raw_vec.iter_mut().take(target_dim) {
66 *val /= valid_token_count;
67 }
68 }
69 } else if shape.len() == 2 {
70 let dim = shape[1];
71 let target_dim = std::cmp::min(dim, 384);
72 for d in 0..target_dim {
73 raw_vec[d] = output_tensor[[0, d]];
74 }
75 } else {
76 anyhow::bail!("Unexpected model output shape: {:?}", shape);
77 }
78
79 let norm = raw_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
81 let normalized_vec = if norm > 0.0 {
82 raw_vec.into_iter().map(|x| x / norm).collect()
83 } else {
84 raw_vec
85 };
86
87 Ok(normalized_vec)
88 }
89}