use anyhow::Result;
use std::path::Path;
use std::time::Instant;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InferenceMetrics {
pub tokens_generated: usize,
pub time_to_first_token_ms: f64,
pub tokens_per_second: f64,
pub total_time_ms: f64,
pub peak_memory_mb: f64,
}
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub repeat_penalty: f32,
pub seed: Option<u64>,
pub n_threads: Option<usize>,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
max_tokens: 512,
temperature: 0.7,
top_p: 0.9,
repeat_penalty: 1.1,
seed: None,
n_threads: None,
}
}
}
#[async_trait::async_trait]
pub trait InferenceBackend: Send + Sync {
async fn load_model(&mut self, model_path: &Path) -> Result<()>;
fn is_loaded(&self) -> bool;
async fn generate(&self, prompt: &str, config: &InferenceConfig) -> Result<String>;
async fn generate_multimodal(&self, prompt: &str, media_data: Option<&[u8]>, media_type: Option<&str>, config: &InferenceConfig) -> Result<String>;
fn as_any(&self) -> &dyn std::any::Any;
async fn generate_with_metrics(&self, prompt: &str, config: &InferenceConfig) -> Result<(String, InferenceMetrics)> {
let start = Instant::now();
let first_token_time = None;
let initial_memory = self.get_memory_usage_mb();
let result = self.generate(prompt, config).await?;
let total_time = start.elapsed();
let tokens = result.split_whitespace().count();
let metrics = InferenceMetrics {
tokens_generated: tokens,
time_to_first_token_ms: first_token_time.unwrap_or(50.0), tokens_per_second: tokens as f64 / total_time.as_secs_f64(),
total_time_ms: total_time.as_millis() as f64,
peak_memory_mb: self.get_memory_usage_mb() - initial_memory,
};
Ok((result, metrics))
}
fn name(&self) -> &str;
fn get_memory_usage_mb(&self) -> f64 {
0.0
}
fn is_available() -> bool where Self: Sized;
}
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BackendType {
Ollama,
}
impl Default for BackendType {
fn default() -> Self {
BackendType::Ollama
}
}