#![allow(clippy::cast_precision_loss)]
#![allow(unreachable_pub)]
#![allow(dead_code)]
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::RealizarError;
#[cfg(feature = "bench-http")]
use crate::http_client::{CompletionRequest, ModelHttpClient, OllamaOptions, OllamaRequest};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RuntimeType {
Realizar,
LlamaCpp,
Vllm,
Ollama,
}
impl RuntimeType {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Realizar => "realizar",
Self::LlamaCpp => "llama-cpp",
Self::Vllm => "vllm",
Self::Ollama => "ollama",
}
}
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"realizar" => Some(Self::Realizar),
"llama-cpp" | "llama.cpp" | "llamacpp" => Some(Self::LlamaCpp),
"vllm" => Some(Self::Vllm),
"ollama" => Some(Self::Ollama),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceRequest {
pub prompt: String,
pub max_tokens: usize,
pub temperature: f64,
pub stop: Vec<String>,
}
impl Default for InferenceRequest {
fn default() -> Self {
Self {
prompt: String::new(),
max_tokens: 100,
temperature: 0.7,
stop: Vec::new(),
}
}
}
impl InferenceRequest {
#[must_use]
pub fn new(prompt: &str) -> Self {
Self {
prompt: prompt.to_string(),
..Default::default()
}
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = stop;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceResponse {
pub text: String,
pub tokens_generated: usize,
pub ttft_ms: f64,
pub total_time_ms: f64,
pub itl_ms: Vec<f64>,
}
impl InferenceResponse {
#[must_use]
pub fn tokens_per_second(&self) -> f64 {
if self.total_time_ms <= 0.0 {
return 0.0;
}
(self.tokens_generated as f64) / (self.total_time_ms / 1000.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendInfo {
pub runtime_type: RuntimeType,
pub version: String,
pub supports_streaming: bool,
pub loaded_model: Option<String>,
}
pub trait RuntimeBackend: Send + Sync {
fn info(&self) -> BackendInfo;
fn inference(&self, request: &InferenceRequest) -> Result<InferenceResponse, RealizarError>;
fn load_model(&mut self, _model_path: &str) -> Result<(), RealizarError> {
Ok(()) }
}
pub struct MockBackend {
ttft_ms: f64,
tokens_per_second: f64,
}
impl MockBackend {
#[must_use]
pub fn new(ttft_ms: f64, tokens_per_second: f64) -> Self {
Self {
ttft_ms,
tokens_per_second,
}
}
}
impl RuntimeBackend for MockBackend {
fn info(&self) -> BackendInfo {
BackendInfo {
runtime_type: RuntimeType::Realizar,
version: env!("CARGO_PKG_VERSION").to_string(),
supports_streaming: true,
loaded_model: None,
}
}
fn inference(&self, request: &InferenceRequest) -> Result<InferenceResponse, RealizarError> {
let tokens = request.max_tokens.min(100);
let gen_time_ms = (tokens as f64) / self.tokens_per_second * 1000.0;
Ok(InferenceResponse {
text: "Mock response".to_string(),
tokens_generated: tokens,
ttft_ms: self.ttft_ms,
total_time_ms: self.ttft_ms + gen_time_ms,
itl_ms: vec![gen_time_ms / tokens as f64; tokens],
})
}
}
pub struct BackendRegistry {
backends: HashMap<RuntimeType, Box<dyn RuntimeBackend>>,
}
impl BackendRegistry {
#[must_use]
pub fn new() -> Self {
Self {
backends: HashMap::new(),
}
}
pub fn register(&mut self, runtime: RuntimeType, backend: Box<dyn RuntimeBackend>) {
self.backends.insert(runtime, backend);
}
#[must_use]
pub fn get(&self, runtime: RuntimeType) -> Option<&dyn RuntimeBackend> {
self.backends.get(&runtime).map(AsRef::as_ref)
}
#[must_use]
pub fn list(&self) -> Vec<RuntimeType> {
self.backends.keys().copied().collect()
}
}
impl Default for BackendRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlamaCppConfig {
pub binary_path: String,
pub model_path: Option<String>,
pub n_gpu_layers: u32,
pub ctx_size: usize,
pub threads: usize,
}
impl Default for LlamaCppConfig {
fn default() -> Self {
Self {
binary_path: "llama-cli".to_string(),
model_path: None,
n_gpu_layers: 0,
ctx_size: 2048,
threads: 4,
}
}
}
impl LlamaCppConfig {
#[must_use]
pub fn new(binary_path: &str) -> Self {
Self {
binary_path: binary_path.to_string(),
..Default::default()
}
}
#[must_use]
pub fn with_model(mut self, model_path: &str) -> Self {
self.model_path = Some(model_path.to_string());
self
}
#[must_use]
pub fn with_gpu_layers(mut self, layers: u32) -> Self {
self.n_gpu_layers = layers;
self
}
#[must_use]
pub fn with_ctx_size(mut self, ctx_size: usize) -> Self {
self.ctx_size = ctx_size;
self
}
#[must_use]
pub fn with_threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VllmConfig {
pub base_url: String,
pub api_version: String,
pub model: Option<String>,
pub api_key: Option<String>,
}
impl Default for VllmConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:8000".to_string(),
api_version: "v1".to_string(),
model: None,
api_key: None,
}
}
}
impl VllmConfig {
#[must_use]
pub fn new(base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
..Default::default()
}
}
#[must_use]
pub fn with_model(mut self, model: &str) -> Self {
self.model = Some(model.to_string());
self
}
#[must_use]
pub fn with_api_key(mut self, api_key: &str) -> Self {
self.api_key = Some(api_key.to_string());
self
}
}
pub struct LlamaCppBackend {
config: LlamaCppConfig,
}
include!("backend.rs");
include!("runtime_type.rs");