use anyhow::Result;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::TensorRef;
use std::path::Path;
use tokenizers::Tokenizer;
use tracing::{debug, info};
pub struct EmbeddingModel {
session: Session,
tokenizer: Tokenizer,
dimension: usize,
max_length: usize,
}
impl EmbeddingModel {
pub fn load(model_path: &Path, tokenizer_path: &Path) -> Result<Self> {
info!("加载嵌入模型: {:?}", model_path);
info!("加载分词器: {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| anyhow::anyhow!("加载分词器失败: {}", e))?;
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("创建 Session Builder 失败: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("设置优化级别失败: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("设置线程数失败: {:?}", e))?
.commit_from_file(model_path)
.map_err(|e| anyhow::anyhow!("加载 ONNX 模型失败: {:?}", e))?;
debug!("模型加载成功");
Ok(Self {
session,
tokenizer,
dimension: 384, max_length: 512,
})
}
pub fn encode(&mut self, text: &str) -> Result<Vec<f32>> {
let encoding =
self.tokenizer.encode(text, true).map_err(|e| anyhow::anyhow!("分词失败: {}", e))?;
let ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let seq_len = ids.len().min(self.max_length);
let mut input_ids = vec![0i64; self.max_length];
let mut input_mask = vec![0i64; self.max_length];
let token_type_ids = vec![0i64; self.max_length];
for i in 0..seq_len {
input_ids[i] = ids[i] as i64;
input_mask[i] = attention_mask[i] as i64;
}
let shape = vec![1_usize, self.max_length];
let input_ids_tensor = TensorRef::from_array_view((shape.clone(), input_ids.as_slice()))
.map_err(|e| anyhow::anyhow!("创建 input_ids tensor 失败: {:?}", e))?;
let attention_mask_tensor =
TensorRef::from_array_view((shape.clone(), input_mask.as_slice()))
.map_err(|e| anyhow::anyhow!("创建 attention_mask tensor 失败: {:?}", e))?;
let token_type_ids_tensor = TensorRef::from_array_view((shape, token_type_ids.as_slice()))
.map_err(|e| anyhow::anyhow!("创建 token_type_ids tensor 失败: {:?}", e))?;
let outputs = self
.session
.run(ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor,
])
.map_err(|e| anyhow::anyhow!("ONNX 推理失败: {:?}", e))?;
let output_tensor = &outputs[0];
let output_array = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("提取输出张量失败: {:?}", e))?;
#[allow(clippy::needless_borrow)]
let embeddings =
Self::mean_pooling(&output_array.1, &input_mask, self.max_length, self.dimension)?;
let normalized = Self::normalize(&embeddings);
Ok(normalized)
}
pub fn encode_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.encode(text)).collect()
}
fn mean_pooling(
hidden_states: &[f32],
attention_mask: &[i64],
max_length: usize,
dimension: usize,
) -> Result<Vec<f32>> {
let seq_len = max_length;
let hidden_size = dimension;
let mut pooled = vec![0.0f32; hidden_size];
let mut count = 0;
for (i, mask) in attention_mask.iter().enumerate().take(seq_len) {
if *mask == 1 {
for (j, pooled_val) in pooled.iter_mut().enumerate() {
let idx = i * hidden_size + j;
if idx < hidden_states.len() {
*pooled_val += hidden_states[idx];
}
}
count += 1;
}
}
if count > 0 {
for val in pooled.iter_mut() {
*val /= count as f32;
}
}
Ok(pooled)
}
fn normalize(vec: &[f32]) -> Vec<f32> {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
vec.iter().map(|x| x / norm).collect()
} else {
vec.to_vec()
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
pub fn dimension(&self) -> usize {
self.dimension
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let vec = vec![3.0, 4.0];
let normalized = EmbeddingModel::normalize(&vec);
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let similarity = EmbeddingModel::cosine_similarity(&a, &b);
assert!((similarity - 1.0).abs() < 1e-6);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
let similarity2 = EmbeddingModel::cosine_similarity(&c, &d);
assert!((similarity2 - 0.0).abs() < 1e-6);
}
}