use std::sync::{Arc, OnceLock};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum EngineError {
#[error("no inference engine registered on this node")]
NoInferenceEngine,
#[error("no embedding engine registered on this node")]
NoEmbeddingEngine,
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("engine: {0}")]
Other(String),
}
pub trait InferenceEngine: Send + Sync + 'static {
fn infer(&self, model_id: &[u8; 32], prompt: &[u8]) -> Result<Vec<u8>, EngineError>;
}
pub trait EmbeddingEngine: Send + Sync + 'static {
fn embed(&self, dim: usize, text: &[u8]) -> Result<Vec<f32>, EngineError>;
}
static INFER: OnceLock<Arc<dyn InferenceEngine>> = OnceLock::new();
static EMBED: OnceLock<Arc<dyn EmbeddingEngine>> = OnceLock::new();
pub fn register_inference_engine(e: Arc<dyn InferenceEngine>) -> Result<(), EngineError> {
INFER
.set(e)
.map_err(|_| EngineError::Other("inference engine already registered".into()))
}
pub fn register_embedding_engine(e: Arc<dyn EmbeddingEngine>) -> Result<(), EngineError> {
EMBED
.set(e)
.map_err(|_| EngineError::Other("embedding engine already registered".into()))
}
pub fn infer(model_id: &[u8; 32], prompt: &[u8]) -> Result<Vec<u8>, EngineError> {
INFER
.get()
.ok_or(EngineError::NoInferenceEngine)?
.infer(model_id, prompt)
}
pub fn embed(dim: usize, text: &[u8]) -> Result<Vec<f32>, EngineError> {
EMBED
.get()
.ok_or(EngineError::NoEmbeddingEngine)?
.embed(dim, text)
}
pub fn inference_engine_registered() -> bool {
INFER.get().is_some()
}
pub fn embedding_engine_registered() -> bool {
EMBED.get().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn infer_without_engine_returns_not_registered() {
let id = [0u8; 32];
match infer(&id, b"hello") {
Err(EngineError::NoInferenceEngine) => {}
other => panic!("expected NoInferenceEngine, got {other:?}"),
}
}
#[test]
fn embed_without_engine_returns_not_registered() {
match embed(8, b"hello") {
Err(EngineError::NoEmbeddingEngine) => {}
other => panic!("expected NoEmbeddingEngine, got {other:?}"),
}
}
#[test]
fn engine_error_display_is_useful() {
assert_eq!(
EngineError::NoInferenceEngine.to_string(),
"no inference engine registered on this node"
);
assert_eq!(
EngineError::NoEmbeddingEngine.to_string(),
"no embedding engine registered on this node"
);
assert_eq!(
EngineError::ModelNotFound("abc".into()).to_string(),
"model not found: abc"
);
assert_eq!(
EngineError::Other("boom".into()).to_string(),
"engine: boom"
);
}
}