use anyhow::Result;
use serde::{Deserialize, Serialize};
#[cfg(feature = "candle")]
pub mod candle_backend;
#[cfg(feature = "onnx")]
pub mod ort_backend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceResult {
pub text: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub generation_ms: f64,
pub tokens_per_second: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResult {
pub embeddings: Vec<Vec<f32>>,
pub total_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct GenerationParams {
pub max_tokens: u32,
pub temperature: f64,
pub top_p: f64,
pub repetition_penalty: f64,
pub seed: Option<u64>,
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
max_tokens: 2048,
temperature: 0.7,
top_p: 0.9,
repetition_penalty: 1.1,
seed: None,
}
}
}
pub trait LocalInferenceBackend: Send + Sync {
fn name(&self) -> &str;
fn load_model(&self, path: &str) -> Result<String>;
fn unload_model(&self, handle: &str) -> Result<()>;
fn loaded_models(&self) -> Vec<String>;
fn generate(
&self,
handle: &str,
prompt: &str,
params: &GenerationParams,
) -> Result<InferenceResult>;
fn embed(&self, handle: &str, inputs: &[String]) -> Result<EmbeddingResult>;
fn supports_format(&self, extension: &str) -> bool;
fn estimate_memory_mb(&self, path: &str) -> Result<u64>;
}
pub fn available_backends() -> Vec<Box<dyn LocalInferenceBackend>> {
#[allow(unused_mut)]
let mut backends: Vec<Box<dyn LocalInferenceBackend>> = Vec::new();
#[cfg(feature = "candle")]
{
backends.push(Box::new(candle_backend::CandleBackend::new()));
}
#[cfg(feature = "onnx")]
{
backends.push(Box::new(ort_backend::OrtBackend::new()));
}
backends
}
pub fn backend_for_model(path: &str) -> Option<Box<dyn LocalInferenceBackend>> {
let ext = std::path::Path::new(path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
for backend in available_backends() {
if backend.supports_format(&ext) {
return Some(backend);
}
}
None
}
pub fn has_local_backend() -> bool {
cfg!(feature = "candle") || cfg!(feature = "onnx")
}
pub fn backend_names() -> Vec<&'static str> {
#[allow(unused_mut)]
let mut names = Vec::new();
#[cfg(feature = "candle")]
names.push("candle");
#[cfg(feature = "onnx")]
names.push("onnx");
names
}