Skip to main content

cmdhub_cli/
inference.rs

1use anyhow::Result;
2use std::path::Path;
3use tract_onnx::prelude::tract_ndarray::Array2;
4use tract_onnx::prelude::*;
5
6pub struct EmbeddingModel {
7    plan: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
8}
9
10impl EmbeddingModel {
11    pub fn load(path: &Path) -> Result<Self> {
12        let model = tract_onnx::onnx()
13            .model_for_path(path)?
14            .into_optimized()?
15            .into_runnable()?;
16        Ok(Self { plan: model })
17    }
18
19    pub fn generate_embedding(&self, token_ids: &[i64], mask: &[i64]) -> Result<Vec<f32>> {
20        let input_ids_array = Array2::from_shape_vec((1, 512), token_ids.to_vec())
21            .map_err(|e| anyhow::anyhow!("Failed to create input_ids array: {}", e))?;
22        let input_ids: Tensor = input_ids_array.into();
23
24        let attention_mask_array = Array2::from_shape_vec((1, 512), mask.to_vec())
25            .map_err(|e| anyhow::anyhow!("Failed to create attention_mask array: {}", e))?;
26        let attention_mask: Tensor = attention_mask_array.into();
27
28        let input_count = self.plan.model().inputs.len();
29        let results = if input_count == 3 {
30            let token_type_ids_array = Array2::from_shape_vec((1, 512), vec![0i64; 512])
31                .map_err(|e| anyhow::anyhow!("Failed to create token_type_ids array: {}", e))?;
32            let token_type_ids: Tensor = token_type_ids_array.into();
33            self.plan.run(tvec![
34                input_ids.into(),
35                attention_mask.into(),
36                token_type_ids.into()
37            ])?
38        } else {
39            self.plan
40                .run(tvec![input_ids.into(), attention_mask.into()])?
41        };
42
43        let output_tensor = results[0].to_array_view::<f32>()?;
44        let shape = output_tensor.shape();
45
46        let mut raw_vec = vec![0.0f32; 512];
47
48        if shape.len() == 3 {
49            let seq_len = shape[1];
50            let dim = shape[2];
51            let target_dim = std::cmp::min(dim, 512);
52
53            let mut valid_token_count = 0.0f32;
54            for (t, &m) in mask.iter().enumerate() {
55                if t < seq_len && m > 0 {
56                    let weight = m as f32;
57                    valid_token_count += weight;
58                    for d in 0..target_dim {
59                        raw_vec[d] += output_tensor[[0, t, d]] * weight;
60                    }
61                }
62            }
63
64            if valid_token_count > 0.0 {
65                for val in raw_vec.iter_mut().take(target_dim) {
66                    *val /= valid_token_count;
67                }
68            }
69        } else if shape.len() == 2 {
70            let dim = shape[1];
71            let target_dim = std::cmp::min(dim, 512);
72            for d in 0..target_dim {
73                raw_vec[d] = output_tensor[[0, d]];
74            }
75        } else {
76            anyhow::bail!("Unexpected model output shape: {:?}", shape);
77        }
78
79        // L2 normalization
80        let norm = raw_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
81        let normalized_vec = if norm > 0.0 {
82            raw_vec.into_iter().map(|x| x / norm).collect()
83        } else {
84            raw_vec
85        };
86
87        Ok(normalized_vec)
88    }
89}