1use crate::backend::{
13 Backend, BackendCapabilities, EmbedError, EmbedResult, GenerateError, TokenEvent, TokenEventV2,
14 TokenStream, TokenStreamV2,
15};
16use async_trait::async_trait;
17use inferd_proto::embed::{EmbedResolved, EmbedUsage};
18use inferd_proto::v2::{ResolvedV2, StopReasonV2, UsageV2};
19use inferd_proto::{Resolved, StopReason, Usage};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicBool, Ordering};
22use tokio_stream::wrappers::ReceiverStream;
23
24#[derive(Debug, Clone, Default)]
26pub struct MockConfig {
27 pub pre_stream_error: Option<MockError>,
30 pub mid_stream_drop_after: Option<usize>,
33 pub tokens: Vec<String>,
37 pub token_delay_ms: Option<u64>,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum MockError {
47 NotReady,
49 InvalidRequest,
51 Unavailable,
53}
54
55impl From<MockError> for GenerateError {
56 fn from(e: MockError) -> Self {
57 match e {
58 MockError::NotReady => GenerateError::NotReady,
59 MockError::InvalidRequest => GenerateError::InvalidRequest("mock".into()),
60 MockError::Unavailable => GenerateError::Unavailable("mock".into()),
61 }
62 }
63}
64
65pub struct Mock {
67 name: &'static str,
68 ready: Arc<AtomicBool>,
69 config: MockConfig,
70}
71
72impl Mock {
73 pub fn new() -> Self {
76 Self::with_config(MockConfig {
77 tokens: vec!["mock-response".into()],
78 ..Default::default()
79 })
80 }
81
82 pub fn with_config(config: MockConfig) -> Self {
84 Self {
85 name: "mock",
86 ready: Arc::new(AtomicBool::new(true)),
87 config,
88 }
89 }
90
91 pub fn set_ready(&self, ready: bool) {
94 self.ready.store(ready, Ordering::SeqCst);
95 }
96}
97
98impl Default for Mock {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[async_trait]
105impl Backend for Mock {
106 fn name(&self) -> &str {
107 self.name
108 }
109
110 fn ready(&self) -> bool {
111 self.ready.load(Ordering::SeqCst)
112 }
113
114 fn capabilities(&self) -> BackendCapabilities {
119 BackendCapabilities {
120 v2: true,
121 thinking: true,
122 embed: true,
123 ..BackendCapabilities::default()
124 }
125 }
126
127 async fn generate(&self, _req: Resolved) -> Result<TokenStream, GenerateError> {
128 if let Some(err) = self.config.pre_stream_error {
129 return Err(err.into());
130 }
131 if !self.ready() {
132 return Err(GenerateError::NotReady);
133 }
134
135 let tokens = self.config.tokens.clone();
136 let drop_after = self.config.mid_stream_drop_after;
137 let token_delay = self
138 .config
139 .token_delay_ms
140 .map(std::time::Duration::from_millis);
141 let (tx, rx) = tokio::sync::mpsc::channel(8);
142
143 tokio::spawn(async move {
146 let mut completion_tokens: u32 = 0;
147 for (emitted, tok) in tokens.into_iter().enumerate() {
148 if let Some(n) = drop_after
149 && emitted >= n
150 {
151 return;
153 }
154 if let Some(d) = token_delay {
155 tokio::time::sleep(d).await;
156 }
157 if tx.send(TokenEvent::Token(tok)).await.is_err() {
158 return; }
160 completion_tokens = completion_tokens.saturating_add(1);
161 }
162 let _ = tx
163 .send(TokenEvent::Done {
164 stop_reason: StopReason::End,
165 usage: Usage {
166 prompt_tokens: 0,
167 completion_tokens,
168 },
169 })
170 .await;
171 });
172
173 Ok(Box::pin(ReceiverStream::new(rx)))
174 }
175
176 async fn generate_v2(&self, _req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
181 if let Some(err) = self.config.pre_stream_error {
182 return Err(err.into());
183 }
184 if !self.ready() {
185 return Err(GenerateError::NotReady);
186 }
187
188 let tokens = self.config.tokens.clone();
189 let drop_after = self.config.mid_stream_drop_after;
190 let token_delay = self
191 .config
192 .token_delay_ms
193 .map(std::time::Duration::from_millis);
194 let (tx, rx) = tokio::sync::mpsc::channel(8);
195
196 tokio::spawn(async move {
197 let mut output_tokens: u32 = 0;
198 for (emitted, tok) in tokens.into_iter().enumerate() {
199 if let Some(n) = drop_after
200 && emitted >= n
201 {
202 return;
203 }
204 if let Some(d) = token_delay {
205 tokio::time::sleep(d).await;
206 }
207 if tx.send(TokenEventV2::Text(tok)).await.is_err() {
208 return;
209 }
210 output_tokens = output_tokens.saturating_add(1);
211 }
212 let _ = tx
213 .send(TokenEventV2::Done {
214 stop_reason: StopReasonV2::EndTurn,
215 usage: UsageV2 {
216 input_tokens: 0,
217 output_tokens,
218 },
219 })
220 .await;
221 });
222
223 Ok(Box::pin(ReceiverStream::new(rx)))
224 }
225
226 async fn embed(&self, req: EmbedResolved) -> Result<EmbedResult, EmbedError> {
236 if let Some(err) = self.config.pre_stream_error {
237 return Err(match err {
238 MockError::NotReady => EmbedError::NotReady,
239 MockError::InvalidRequest => EmbedError::InvalidRequest("mock".into()),
240 MockError::Unavailable => EmbedError::Unavailable("mock".into()),
241 });
242 }
243 if !self.ready() {
244 return Err(EmbedError::NotReady);
245 }
246
247 let dimensions = req.dimensions.unwrap_or(8);
248 let mut input_tokens: u32 = 0;
249 let embeddings = req
250 .input
251 .iter()
252 .map(|s| {
253 input_tokens = input_tokens.saturating_add(s.len() as u32);
254 let len_f = s.len() as f32;
255 (0..dimensions)
256 .map(|i| (i as f32 + 1.0) / (len_f + 1.0))
257 .collect()
258 })
259 .collect();
260
261 Ok(EmbedResult {
262 embeddings,
263 dimensions,
264 model: "mock".into(),
265 usage: EmbedUsage { input_tokens },
266 })
267 }
268}