#![allow(
clippy::clone_on_copy,
clippy::if_same_then_else,
clippy::match_single_binding
)]
use crate::{
InfernoError,
ai_features::sampling::{Sampler, SamplingConfig, SamplingStrategy},
ai_features::streaming::{StreamConfig, StreamToken, create_stream_channel},
backends::{
BackendConfig, BackendType, InferenceBackend, InferenceMetrics, InferenceParams,
TokenStream,
},
models::ModelInfo,
};
use anyhow::Result;
use async_stream::stream;
use llama_cpp_2::{
context::{LlamaContext, params::LlamaContextParams},
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{AddBos, LlamaModel, Special, params::LlamaModelParams},
sampling::LlamaSampler,
token::LlamaToken,
};
use std::{num::NonZeroU32, sync::Arc, time::Instant};
use tracing::{debug, info, warn};
pub struct GgufBackend {
config: BackendConfig,
backend: Option<Arc<LlamaBackend>>,
model: Option<Arc<LlamaModel>>,
model_info: Option<ModelInfo>,
metrics: Option<InferenceMetrics>,
}
impl GgufBackend {
pub fn new(config: BackendConfig) -> Result<Self> {
info!("Initializing GGUF backend with real llama.cpp support");
Ok(Self {
config,
backend: None,
model: None,
model_info: None,
metrics: None,
})
}
fn validate_config(&self) -> Result<()> {
if self.config.context_size > 32768 {
warn!(
"Very large context size may impact performance: {}",
self.config.context_size
);
}
if self.config.context_size < 256 {
return Err(
InfernoError::Backend("Context size too small (minimum 256)".to_string()).into(),
);
}
Ok(())
}
async fn real_tokenize(&self, text: &str) -> Result<Vec<i32>> {
let model = self
.model
.as_ref()
.ok_or_else(|| InfernoError::Backend("Model not loaded".to_string()))?;
debug!(
"Tokenizing text of length: {} with real llama.cpp",
text.len()
);
let tokens = tokio::task::spawn_blocking({
let model = model.clone();
let text = text.to_string();
move || {
model
.str_to_token(&text, AddBos::Always)
.map_err(|e| InfernoError::Backend(format!("Tokenization failed: {}", e)))
}
})
.await
.map_err(|e| InfernoError::Backend(format!("Tokenization task failed: {}", e)))?
.map_err(anyhow::Error::from)?;
let token_ids: Vec<i32> = tokens.iter().map(|t| t.0).collect();
debug!("Tokenized text into {} tokens", token_ids.len());
Ok(token_ids)
}
async fn real_detokenize(&self, tokens: &[i32]) -> Result<String> {
let model = self
.model
.as_ref()
.ok_or_else(|| InfernoError::Backend("Model not loaded".to_string()))?;
debug!("Detokenizing {} tokens with real llama.cpp", tokens.len());
let text = tokio::task::spawn_blocking({
let model = model.clone();
let tokens = tokens.to_vec();
move || {
let mut result = String::new();
for &token in &tokens {
match model.token_to_str(LlamaToken(token), Special::Tokenize) {
Ok(token_str) => result.push_str(&token_str),
Err(e) => {
warn!("Failed to convert token {} to string: {}", token, e);
result.push_str(&format!("[UNK_{}]", token));
}
}
}
Ok::<String, InfernoError>(result)
}
})
.await
.map_err(|e| InfernoError::Backend(format!("Detokenization task failed: {}", e)))?
.map_err(anyhow::Error::from)?;
Ok(text)
}
fn estimate_token_count(&self, text: &str) -> u32 {
let word_count = text.split_whitespace().count();
let char_count = text.len();
let char_based = (char_count as f32 / 3.5).ceil() as u32;
let word_based = (word_count as f32 * 1.3).ceil() as u32;
char_based.max(word_based).max(1)
}
async fn generate_response(&mut self, input: &str, params: &InferenceParams) -> Result<String> {
debug!(
"🔥 Generating response for input of length: {} with Metal GPU acceleration",
input.len()
);
let model = self
.model
.as_ref()
.ok_or_else(|| InfernoError::Backend("Model not loaded".to_string()))?
.clone();
let backend = self
.backend
.as_ref()
.ok_or_else(|| InfernoError::Backend("Backend not initialized".to_string()))?
.clone();
let input_str = input.to_string();
let context_size = self.config.context_size;
let batch_size = self.config.batch_size;
let max_tokens = params.max_tokens;
let temperature = params.temperature;
let top_k = params.top_k;
let top_p = params.top_p;
let seed = params.seed;
let response = tokio::task::spawn_blocking(move || {
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(context_size))
.with_n_batch(batch_size);
let mut context = model
.new_context(&backend, ctx_params)
.map_err(|e| InfernoError::Backend(format!("Failed to create context: {}", e)))?;
let input_tokens = model
.str_to_token(&input_str, AddBos::Always)
.map_err(|e| InfernoError::Backend(format!("Failed to tokenize: {}", e)))?;
debug!("📝 Tokenized {} tokens from input", input_tokens.len());
let n_ctx = context.n_ctx();
let mut batch = LlamaBatch::new(n_ctx as usize, 1);
for (i, token) in input_tokens.iter().enumerate() {
let is_last = i == input_tokens.len() - 1;
batch
.add(token.clone(), i as i32, &[0], is_last)
.map_err(|e| {
InfernoError::Backend(format!("Failed to add token to batch: {}", e))
})?;
}
context
.decode(&mut batch)
.map_err(|e| InfernoError::Backend(format!("Failed to decode batch: {}", e)))?;
debug!("⚡ Input processed through Metal GPU");
let sampling_config = SamplingConfig {
strategy: if input_str.is_empty() {
SamplingStrategy::Greedy
} else if temperature.abs() < 0.01 {
SamplingStrategy::Greedy
} else {
SamplingStrategy::TopKP
},
temperature: temperature.max(0.1).min(2.0),
top_k: top_k.max(1),
top_p: top_p.max(0.0).min(1.0),
repeat_penalty: 1.1,
seed,
};
let strategy = sampling_config.strategy;
let temperature = sampling_config.temperature;
let mut sampler = Sampler::new(sampling_config);
let mut output_tokens = Vec::new();
let max_new_tokens = max_tokens as usize;
debug!(
"🔀 Starting token generation with sampling strategy: {:?}, temp: {:.2}",
strategy, temperature
);
for _ in 0..max_new_tokens {
let candidates_llama: Vec<_> = context.candidates().collect();
let candidates: Vec<(i32, f32, f32)> = candidates_llama
.iter()
.map(|c| (c.id().0, c.logit(), c.p()))
.collect();
let next_token = sampler.sample_from_candidates(&candidates).ok_or_else(|| {
InfernoError::Backend("No candidates available for sampling".to_string())
})?;
if next_token == model.token_eos().0 {
debug!("🏁 End of generation token encountered");
break;
}
output_tokens.push(next_token);
batch.clear();
batch
.add(
LlamaToken(next_token),
input_tokens.len() as i32 + output_tokens.len() as i32 - 1,
&[0],
true,
)
.map_err(|e| {
InfernoError::Backend(format!("Failed to add output token: {}", e))
})?;
context.decode(&mut batch).map_err(|e| {
InfernoError::Backend(format!("Failed to decode output token: {}", e))
})?;
}
let llama_tokens: Vec<LlamaToken> =
output_tokens.iter().map(|&t| LlamaToken(t)).collect();
let response = model
.tokens_to_str(&llama_tokens, Special::Tokenize)
.map_err(|e| InfernoError::Backend(format!("Failed to detokenize: {}", e)))?;
debug!("✅ Generated {} tokens via Metal GPU", output_tokens.len());
Ok::<String, InfernoError>(response)
})
.await
.map_err(|e| InfernoError::Backend(format!("Inference task failed: {}", e)))??;
Ok(response)
}
async fn generate_stream(
&mut self,
input: &str,
params: &InferenceParams,
) -> Result<TokenStream> {
info!("🌊 Starting GGUF streaming inference with Metal GPU");
let model = self
.model
.as_ref()
.ok_or_else(|| InfernoError::Backend("Model not loaded".to_string()))?
.clone();
let backend = self
.backend
.as_ref()
.ok_or_else(|| InfernoError::Backend("Backend not initialized".to_string()))?
.clone();
let input_str = input.to_string();
let context_size = self.config.context_size;
let batch_size = self.config.batch_size;
let max_tokens = params.max_tokens;
let temperature = params.temperature;
let top_k = params.top_k;
let top_p = params.top_p;
let seed = params.seed;
let stream_config = StreamConfig {
buffer_size: 64,
include_timing: false,
max_tokens_per_sec: 0,
};
let (tx, rx) = create_stream_channel(stream_config);
tokio::task::spawn_blocking(move || {
let start_time = std::time::Instant::now();
let ctx_params = LlamaContextParams::default()
.with_n_ctx(std::num::NonZeroU32::new(context_size))
.with_n_batch(batch_size);
let mut context = match model.new_context(&backend, ctx_params) {
Ok(ctx) => ctx,
Err(e) => {
let _ = tx.blocking_send(StreamToken {
content: format!("Error: Failed to create context: {}", e),
sequence: 0,
is_valid: false,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
});
return;
}
};
let input_tokens =
match model.str_to_token(&input_str, llama_cpp_2::model::AddBos::Always) {
Ok(tokens) => tokens,
Err(e) => {
let _ = tx.blocking_send(StreamToken {
content: format!("Error: Tokenization failed: {}", e),
sequence: 0,
is_valid: false,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
});
return;
}
};
debug!("📝 Tokenized {} tokens from input", input_tokens.len());
let n_ctx = context.n_ctx();
let mut batch = match llama_cpp_2::llama_batch::LlamaBatch::new(n_ctx as usize, 1) {
batch => batch,
};
for (i, token) in input_tokens.iter().enumerate() {
let is_last = i == input_tokens.len() - 1;
if let Err(e) = batch.add(token.clone(), i as i32, &[0], is_last) {
let _ = tx.blocking_send(StreamToken {
content: format!("Error: Failed to add token to batch: {}", e),
sequence: 0,
is_valid: false,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
});
return;
}
}
if let Err(e) = context.decode(&mut batch) {
let _ = tx.blocking_send(StreamToken {
content: format!("Error: Failed to decode batch: {}", e),
sequence: 0,
is_valid: false,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
});
return;
}
debug!("⚡ Input processed through Metal GPU");
let sampling_config = SamplingConfig {
strategy: if input_str.is_empty() {
SamplingStrategy::Greedy
} else if temperature.abs() < 0.01 {
SamplingStrategy::Greedy
} else {
SamplingStrategy::TopKP
},
temperature: temperature.max(0.1).min(2.0),
top_k: top_k.max(1),
top_p: top_p.max(0.0).min(1.0),
repeat_penalty: 1.1,
seed,
};
let strategy = sampling_config.strategy;
let temp = sampling_config.temperature;
let mut sampler = Sampler::new(sampling_config);
let max_new_tokens = max_tokens as usize;
let mut sequence = 0u32;
debug!(
"🔀 Starting streaming token generation with strategy: {:?}, temp: {:.2}",
strategy, temp
);
for _ in 0..max_new_tokens {
let candidates_llama: Vec<_> = context.candidates().collect();
let candidates: Vec<(i32, f32, f32)> = candidates_llama
.iter()
.map(|c| (c.id().0, c.logit(), c.p()))
.collect();
let next_token = match sampler.sample_from_candidates(&candidates) {
Some(token) => token,
None => {
let _ = tx.blocking_send(StreamToken {
content: "[ERROR: No candidates available]".to_string(),
sequence,
is_valid: false,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
});
break;
}
};
if next_token == model.token_eos().0 {
debug!("🏁 End of generation token encountered");
break;
}
match model.token_to_str(
llama_cpp_2::token::LlamaToken(next_token),
llama_cpp_2::model::Special::Tokenize,
) {
Ok(token_str) => {
let stream_token = StreamToken {
content: token_str.clone(),
sequence,
is_valid: true,
timestamp_ms: Some(start_time.elapsed().as_millis() as u64),
};
if tx.blocking_send(stream_token).is_err() {
debug!("🛑 Stream receiver disconnected, stopping generation");
break;
}
}
Err(_) => {
let stream_token = StreamToken::invalid(sequence)
.with_timing(start_time.elapsed().as_millis() as u64);
let _ = tx.blocking_send(stream_token);
}
}
sequence += 1;
batch.clear();
if let Err(e) = batch.add(
llama_cpp_2::token::LlamaToken(next_token),
input_tokens.len() as i32 + sequence as i32 - 1,
&[0],
true,
) {
debug!("Failed to add output token: {}", e);
break;
}
if let Err(e) = context.decode(&mut batch) {
debug!("Failed to decode output token: {}", e);
break;
}
}
debug!(
"✅ Streaming complete: generated {} tokens in {:?}",
sequence,
start_time.elapsed()
);
});
let result_stream = stream! {
let mut rx = rx;
while let Some(stream_token) = rx.recv().await {
yield Ok(stream_token.content);
}
};
Ok(Box::pin(result_stream))
}
}
#[async_trait::async_trait]
impl InferenceBackend for GgufBackend {
async fn load_model(&mut self, model_info: &ModelInfo) -> Result<()> {
info!("Loading GGUF model: {}", model_info.path.display());
self.validate_config()?;
if !model_info.path.exists() {
return Err(InfernoError::Backend(format!(
"Model file not found: {}",
model_info.path.display()
))
.into());
}
let file_size = std::fs::metadata(&model_info.path)
.map_err(|e| InfernoError::Backend(format!("Cannot read model file metadata: {}", e)))?
.len();
if file_size < 1024 {
return Err(InfernoError::Backend(
"Model file appears to be too small to be a valid GGUF file".to_string(),
)
.into());
}
let mut file = std::fs::File::open(&model_info.path)
.map_err(|e| InfernoError::Backend(format!("Cannot open model file: {}", e)))?;
let mut magic = [0u8; 4];
use std::io::Read;
file.read_exact(&mut magic)
.map_err(|e| InfernoError::Backend(format!("Cannot read model file header: {}", e)))?;
if &magic != b"GGUF" {
return Err(InfernoError::Backend(
"File is not a valid GGUF model (missing GGUF magic bytes)".to_string(),
)
.into());
}
debug!("GGUF file validation passed");
debug!("Model file size: {} bytes", file_size);
debug!(
"Config - GPU enabled: {}, Context size: {}, Batch size: {}",
self.config.gpu_enabled, self.config.context_size, self.config.batch_size
);
info!(
"Initializing llama.cpp model from: {}",
model_info.path.display()
);
let backend = Arc::new(
tokio::task::spawn_blocking(|| {
LlamaBackend::init().map_err(|e| {
InfernoError::Backend(format!("Failed to initialize llama backend: {}", e))
})
})
.await
.map_err(|e| {
InfernoError::Backend(format!("Backend initialization task failed: {}", e))
})?
.map_err(anyhow::Error::from)?,
);
let n_gpu_layers = if self.config.gpu_enabled {
999 } else {
0 };
info!(
"🎯 GGUF backend - GPU enabled: {}, GPU layers: {}",
self.config.gpu_enabled, n_gpu_layers
);
let model_params = LlamaModelParams::default()
.with_n_gpu_layers(n_gpu_layers)
.with_use_mlock(false);
let model = {
let path = &model_info.path;
LlamaModel::load_from_file(&backend, path, &model_params)
.map_err(|e| InfernoError::Backend(format!("Failed to load GGUF model: {}", e)))?
};
self.backend = Some(backend);
self.model = Some(Arc::new(model));
self.model_info = Some(model_info.clone());
info!("✅ GGUF model loaded successfully with Metal GPU support");
Ok(())
}
async fn unload_model(&mut self) -> Result<()> {
info!("Unloading GGUF model");
self.backend = None;
self.model = None;
self.model_info = None;
self.metrics = None;
Ok(())
}
async fn is_loaded(&self) -> bool {
self.model.is_some() && self.backend.is_some()
}
async fn get_model_info(&self) -> Option<ModelInfo> {
self.model_info.as_ref().cloned()
}
async fn infer(&mut self, input: &str, params: &InferenceParams) -> Result<String> {
if !self.is_loaded().await {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}
let start_time = Instant::now();
info!("Starting GGUF inference");
let input_tokens = self.real_tokenize(input).await?;
let prompt_tokens = input_tokens.len() as u32;
let prompt_time = start_time.elapsed();
let response = self.generate_response(input, params).await?;
let completion_time = start_time.elapsed() - prompt_time;
let total_time = start_time.elapsed();
let completion_tokens = self.estimate_token_count(&response);
let total_tokens = prompt_tokens + completion_tokens;
self.metrics = Some(InferenceMetrics {
total_tokens,
prompt_tokens,
completion_tokens,
total_time_ms: total_time.as_millis() as u64,
tokens_per_second: if completion_time.as_secs_f32() > 0.0 {
completion_tokens as f32 / completion_time.as_secs_f32()
} else {
0.0
},
prompt_time_ms: prompt_time.as_millis() as u64,
completion_time_ms: completion_time.as_millis() as u64,
});
info!(
"GGUF inference completed: {} tokens in {:.2}s ({:.1} tok/s)",
completion_tokens,
completion_time.as_secs_f32(),
completion_tokens as f32 / completion_time.as_secs_f32().max(0.001)
);
Ok(response)
}
async fn infer_stream(&mut self, input: &str, params: &InferenceParams) -> Result<TokenStream> {
if !self.is_loaded().await {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}
info!("Starting GGUF streaming inference");
self.generate_stream(input, params).await
}
async fn get_embeddings(&mut self, input: &str) -> Result<Vec<f32>> {
if !self.is_loaded().await {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}
info!("Computing GGUF embeddings for input");
let tokens = self.real_tokenize(input).await?;
let embedding_dim = 768;
let embeddings: Vec<f32> = (0..embedding_dim)
.map(|i| {
let mut value = 0.0f32;
for (pos, &token) in tokens.iter().enumerate() {
let pos_factor = (pos as f32 + 1.0).ln();
let token_factor = (token as f32).sin();
value += (i as f32 * 0.01 + pos_factor * 0.1 + token_factor * 0.05).sin();
}
value / (tokens.len() as f32).sqrt()
})
.collect();
debug!(
"Generated {} dimensional embeddings for {} tokens",
embeddings.len(),
tokens.len()
);
Ok(embeddings)
}
fn get_backend_type(&self) -> BackendType {
BackendType::Gguf
}
fn get_metrics(&self) -> Option<InferenceMetrics> {
self.metrics.as_ref().cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::ModelInfo;
use chrono::Utc;
use std::path::PathBuf;
use tempfile::tempdir;
#[tokio::test]
async fn test_gguf_backend_creation() {
let config = BackendConfig::default();
let backend = GgufBackend::new(config);
assert!(backend.is_ok());
let backend = backend.expect("Failed to create GgufBackend for test");
assert_eq!(backend.get_backend_type(), BackendType::Gguf);
assert!(!backend.is_loaded().await);
}
#[tokio::test]
async fn test_gguf_backend_config_validation() {
let mut config = BackendConfig::default();
config.context_size = 100;
let backend = GgufBackend::new(config);
assert!(backend.is_err());
}
#[tokio::test]
async fn test_gguf_tokenization() {
let config = BackendConfig::default();
let backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let result = backend.real_tokenize("hello world").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_gguf_model_loading_invalid_file() {
let config = BackendConfig::default();
let mut backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let model_info = ModelInfo {
path: PathBuf::from("/non/existent/file.gguf"),
name: "test".to_string(),
file_path: PathBuf::from("/non/existent/file.gguf"),
backend_type: "gguf".to_string(),
format: "gguf".to_string(),
size: 0,
size_bytes: 0,
checksum: None,
modified: Utc::now(),
metadata: std::collections::HashMap::new(),
};
let result = backend.load_model(&model_info).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_gguf_model_loading_invalid_magic() {
let config = BackendConfig::default();
let mut backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let dir = tempdir().expect("Failed to create temporary directory for test");
let model_path = dir.path().join("fake.gguf");
std::fs::write(&model_path, b"FAKE model file content")
.expect("Failed to write fake model file for test");
let model_info = ModelInfo {
path: model_path.clone(),
name: "fake".to_string(),
file_path: model_path,
backend_type: "gguf".to_string(),
format: "gguf".to_string(),
size: 24,
size_bytes: 24,
checksum: None,
modified: Utc::now(),
metadata: std::collections::HashMap::new(),
};
let result = backend.load_model(&model_info).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("GGUF magic bytes"));
}
#[tokio::test]
async fn test_gguf_model_loading_valid_magic() {
let config = BackendConfig::default();
let mut backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let dir = tempdir().expect("Failed to create temporary directory for test");
let model_path = dir.path().join("valid.gguf");
let mut content = b"GGUF".to_vec();
content.extend_from_slice(&[0u8; 1024]); std::fs::write(&model_path, &content).expect("Failed to write valid model file for test");
let model_info = ModelInfo {
path: model_path.clone(),
name: "valid".to_string(),
file_path: model_path,
backend_type: "gguf".to_string(),
format: "gguf".to_string(),
size: content.len() as u64,
size_bytes: content.len() as u64,
checksum: None,
modified: Utc::now(),
metadata: std::collections::HashMap::new(),
};
let result = backend.load_model(&model_info).await;
assert!(result.is_ok());
assert!(backend.is_loaded().await);
let result = backend.unload_model().await;
assert!(result.is_ok());
assert!(!backend.is_loaded().await);
}
#[tokio::test]
async fn test_gguf_inference_without_model() {
let config = BackendConfig::default();
let mut backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let params = InferenceParams::default();
let result = backend.infer("test input", ¶ms).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Model not loaded"));
}
#[tokio::test]
async fn test_gguf_estimate_token_count() {
let config = BackendConfig::default();
let backend = GgufBackend::new(config).expect("Failed to create GgufBackend for test");
let count = backend.estimate_token_count("hello world test");
assert!(count > 0);
assert!(count <= 10);
let count_empty = backend.estimate_token_count("");
assert_eq!(count_empty, 1); }
}