xz_embed/embedder/
mock.rs1use async_trait::async_trait;
2use std::fmt::Debug;
3use std::sync::Mutex;
4
5use crate::error::EmbedError;
6use crate::traits::{EmbedModelInfo, EmbedPricing, EmbeddingModel};
7
8#[derive(Debug)]
10pub struct MockEmbedder {
11 info: EmbedModelInfo,
12 expected_input: Mutex<Option<Vec<String>>>,
13 mock_output: Mutex<Vec<Vec<f32>>>,
14 should_error: Mutex<Option<EmbedError>>,
15}
16
17impl MockEmbedder {
18 pub fn new(dimensions: usize, max_batch_size: usize) -> Self {
20 Self {
21 info: EmbedModelInfo {
22 name: "mock-embedder".into(),
23 display_name: "Mock Embedder".into(),
24 supported_dimensions: None,
25 current_dimension: dimensions,
26 max_input_tokens: 1024,
27 max_batch_size,
28 pricing: EmbedPricing {
29 input_per_million: 0.0,
30 },
31 },
32 expected_input: Mutex::new(None),
33 mock_output: Mutex::new(vec![]),
34 should_error: Mutex::new(None),
35 }
36 }
37
38 pub fn expect_embed(&mut self, inputs: Vec<&str>, outputs: Vec<Vec<f32>>) -> &mut Self {
40 *self.expected_input.get_mut().unwrap() = Some(inputs.iter().map(|s| s.to_string()).collect());
41 *self.mock_output.get_mut().unwrap() = outputs;
42 self
43 }
44
45 pub fn set_error(&mut self, error: EmbedError) {
47 *self.should_error.get_mut().unwrap() = Some(error);
48 }
49
50 pub fn set_output(&mut self, vectors: Vec<Vec<f32>>) {
52 *self.mock_output.get_mut().unwrap() = vectors;
53 }
54}
55
56#[async_trait]
57impl EmbeddingModel for MockEmbedder {
58 async fn embed(&self, input: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
59 if input.is_empty() {
60 return Err(EmbedError::EmptyBatch);
61 }
62
63 if let Some(ref err) = *self.should_error.lock().unwrap() {
65 return Err(EmbedError::Model(format!("Mock error: {err}")));
66 }
67
68 if let Some(ref expected) = *self.expected_input.lock().unwrap() {
70 let actual: Vec<String> = input.iter().map(|s| s.to_string()).collect();
71 if &actual != expected {
72 return Err(EmbedError::Model(format!(
73 "输入不匹配: expected {expected:?}, got {actual:?}"
74 )));
75 }
76 }
77
78 let output = self.mock_output.lock().unwrap();
79 if !output.is_empty() {
80 if output.len() != input.len() {
81 return Err(EmbedError::Model(format!(
82 "输出数量不匹配: expected {}, got {}",
83 input.len(),
84 output.len()
85 )));
86 }
87 return Ok(output.clone());
88 }
89
90 Ok(vec![vec![0.0; self.info.current_dimension]; input.len()])
92 }
93
94 fn model_info(&self) -> &EmbedModelInfo {
95 &self.info
96 }
97
98 fn max_batch_size(&self) -> usize {
99 self.info.max_batch_size
100 }
101
102 fn dimensions(&self) -> usize {
103 self.info.current_dimension
104 }
105}