use crate::backend::{
Backend, BackendCapabilities, EmbedError, EmbedResult, GenerateError, TokenEvent, TokenEventV2,
TokenStream, TokenStreamV2,
};
use async_trait::async_trait;
use inferd_proto::embed::{EmbedResolved, EmbedUsage};
use inferd_proto::v2::{ResolvedV2, StopReasonV2, UsageV2};
use inferd_proto::{Resolved, StopReason, Usage};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio_stream::wrappers::ReceiverStream;
#[derive(Debug, Clone, Default)]
pub struct MockConfig {
pub pre_stream_error: Option<MockError>,
pub mid_stream_drop_after: Option<usize>,
pub tokens: Vec<String>,
pub token_delay_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MockError {
NotReady,
InvalidRequest,
Unavailable,
}
impl From<MockError> for GenerateError {
fn from(e: MockError) -> Self {
match e {
MockError::NotReady => GenerateError::NotReady,
MockError::InvalidRequest => GenerateError::InvalidRequest("mock".into()),
MockError::Unavailable => GenerateError::Unavailable("mock".into()),
}
}
}
pub struct Mock {
name: &'static str,
ready: Arc<AtomicBool>,
config: MockConfig,
}
impl Mock {
pub fn new() -> Self {
Self::with_config(MockConfig {
tokens: vec!["mock-response".into()],
..Default::default()
})
}
pub fn with_config(config: MockConfig) -> Self {
Self {
name: "mock",
ready: Arc::new(AtomicBool::new(true)),
config,
}
}
pub fn set_ready(&self, ready: bool) {
self.ready.store(ready, Ordering::SeqCst);
}
}
impl Default for Mock {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Backend for Mock {
fn name(&self) -> &str {
self.name
}
fn ready(&self) -> bool {
self.ready.load(Ordering::SeqCst)
}
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities {
v2: true,
thinking: true,
embed: true,
..BackendCapabilities::default()
}
}
async fn generate(&self, _req: Resolved) -> Result<TokenStream, GenerateError> {
if let Some(err) = self.config.pre_stream_error {
return Err(err.into());
}
if !self.ready() {
return Err(GenerateError::NotReady);
}
let tokens = self.config.tokens.clone();
let drop_after = self.config.mid_stream_drop_after;
let token_delay = self
.config
.token_delay_ms
.map(std::time::Duration::from_millis);
let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move {
let mut completion_tokens: u32 = 0;
for (emitted, tok) in tokens.into_iter().enumerate() {
if let Some(n) = drop_after
&& emitted >= n
{
return;
}
if let Some(d) = token_delay {
tokio::time::sleep(d).await;
}
if tx.send(TokenEvent::Token(tok)).await.is_err() {
return; }
completion_tokens = completion_tokens.saturating_add(1);
}
let _ = tx
.send(TokenEvent::Done {
stop_reason: StopReason::End,
usage: Usage {
prompt_tokens: 0,
completion_tokens,
},
})
.await;
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn generate_v2(&self, _req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
if let Some(err) = self.config.pre_stream_error {
return Err(err.into());
}
if !self.ready() {
return Err(GenerateError::NotReady);
}
let tokens = self.config.tokens.clone();
let drop_after = self.config.mid_stream_drop_after;
let token_delay = self
.config
.token_delay_ms
.map(std::time::Duration::from_millis);
let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move {
let mut output_tokens: u32 = 0;
for (emitted, tok) in tokens.into_iter().enumerate() {
if let Some(n) = drop_after
&& emitted >= n
{
return;
}
if let Some(d) = token_delay {
tokio::time::sleep(d).await;
}
if tx.send(TokenEventV2::Text(tok)).await.is_err() {
return;
}
output_tokens = output_tokens.saturating_add(1);
}
let _ = tx
.send(TokenEventV2::Done {
stop_reason: StopReasonV2::EndTurn,
usage: UsageV2 {
input_tokens: 0,
output_tokens,
},
})
.await;
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn embed(&self, req: EmbedResolved) -> Result<EmbedResult, EmbedError> {
if let Some(err) = self.config.pre_stream_error {
return Err(match err {
MockError::NotReady => EmbedError::NotReady,
MockError::InvalidRequest => EmbedError::InvalidRequest("mock".into()),
MockError::Unavailable => EmbedError::Unavailable("mock".into()),
});
}
if !self.ready() {
return Err(EmbedError::NotReady);
}
let dimensions = req.dimensions.unwrap_or(8);
let mut input_tokens: u32 = 0;
let embeddings = req
.input
.iter()
.map(|s| {
input_tokens = input_tokens.saturating_add(s.len() as u32);
let len_f = s.len() as f32;
(0..dimensions)
.map(|i| (i as f32 + 1.0) / (len_f + 1.0))
.collect()
})
.collect();
Ok(EmbedResult {
embeddings,
dimensions,
model: "mock".into(),
usage: EmbedUsage { input_tokens },
})
}
}