use tokio::sync::mpsc;
use tracing::{debug, error, warn};
use oxillama_runtime::engine::InferenceEngine;
use crate::queue::{BatchRequest, UsageStats};
pub fn spawn_inference_worker(engine: InferenceEngine, rx: mpsc::Receiver<BatchRequest>) {
tokio::task::spawn_blocking(move || {
run_worker(engine, rx);
});
}
fn run_worker(mut engine: InferenceEngine, mut rx: mpsc::Receiver<BatchRequest>) {
tracing::info!("inference worker started");
while let Some(req) = rx.blocking_recv() {
debug!(req = ?req, "processing inference request");
match req {
BatchRequest::Generate {
prompt,
max_tokens,
config,
reply,
} => {
engine.reset();
let mut completion_tokens = 0usize;
let prompt_tokens = engine.tokenize(&prompt).map(|t| t.len()).unwrap_or(0);
let result = engine
.generate_with_config(&prompt, max_tokens, config, |_| {
completion_tokens += 1;
})
.map(|text| {
let usage = UsageStats {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
};
(text, usage)
})
.map_err(|e| e.to_string());
if reply.send(result).is_err() {
warn!("Generate reply channel closed before result was delivered");
}
}
BatchRequest::GenerateStream {
prompt,
max_tokens,
config,
mut callback,
reply,
} => {
engine.reset();
let mut completion_tokens = 0usize;
let prompt_tokens = engine.tokenize(&prompt).map(|t| t.len()).unwrap_or(0);
let result = engine
.generate_with_config(&prompt, max_tokens, config, |token| {
completion_tokens += 1;
callback(token);
})
.map(|_| UsageStats {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
})
.map_err(|e| e.to_string());
if reply.send(result).is_err() {
warn!("GenerateStream reply channel closed before result was delivered");
}
}
BatchRequest::Embed { text, reply } => {
engine.reset();
let result = engine.embed(&text).map_err(|e| e.to_string());
if reply.send(result).is_err() {
warn!("Embed reply channel closed before result was delivered");
}
}
}
}
error!("inference worker channel closed — no more requests can be processed");
}
#[cfg(test)]
mod tests {
use super::*;
use oxillama_runtime::{engine::EngineConfig, sampling::SamplerConfig};
use tokio::sync::oneshot;
#[tokio::test]
async fn test_worker_returns_error_for_unloaded_engine() {
let engine = InferenceEngine::new(EngineConfig::default());
let (tx, rx) = mpsc::channel::<BatchRequest>(4);
spawn_inference_worker(engine, rx);
let (reply_tx, reply_rx) =
oneshot::channel::<Result<(String, crate::queue::UsageStats), String>>();
tx.send(BatchRequest::Generate {
prompt: "test".to_string(),
max_tokens: 8,
config: SamplerConfig::default(),
reply: reply_tx,
})
.await
.expect("send should succeed");
let result = reply_rx.await.expect("reply future should resolve");
assert!(
result.is_err(),
"unloaded engine should produce an error, got: {result:?}"
);
}
#[tokio::test]
async fn test_worker_embed_error_for_unloaded_engine() {
let engine = InferenceEngine::new(EngineConfig::default());
let (tx, rx) = mpsc::channel::<BatchRequest>(4);
spawn_inference_worker(engine, rx);
let (reply_tx, reply_rx) = oneshot::channel::<Result<Vec<f32>, String>>();
tx.send(BatchRequest::Embed {
text: "hello world".to_string(),
reply: reply_tx,
})
.await
.expect("send should succeed");
let result = reply_rx.await.expect("reply future should resolve");
assert!(
result.is_err(),
"unloaded engine embed should produce an error, got: {result:?}"
);
}
}