use crate::error::{InfluenceError, Result};
use crate::local::LocalModelConfig;
use std::path::Path;
use tracing::info;
use tokenizers::Tokenizer;
#[cfg(feature = "gguf")]
pub struct GgufBackend {
context_size: usize,
quantization: String,
gguf_path: std::path::PathBuf,
model: Option<llama_cpp::LlamaModel>,
#[allow(dead_code)]
tokenizer: Option<Tokenizer>,
}
#[cfg(not(feature = "gguf"))]
pub struct GgufBackend {
_private: (),
}
#[cfg(feature = "gguf")]
impl GgufBackend {
pub fn load(config: &LocalModelConfig, gguf_path: &Path) -> Result<Self> {
info!("Loading GGUF model from: {}", gguf_path.display());
let quantization = Self::detect_quantization(gguf_path)?;
info!("Detected quantization: {}", quantization);
let tokenizer_path = config.model_path.join("tokenizer.json");
let tokenizer = if tokenizer_path.exists() {
Some(Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InfluenceError::LocalModelError(format!("Failed to load tokenizer: {}", e)))?)
} else {
info!("No tokenizer.json found, using model's internal tokenizer");
None
};
let params = llama_cpp::LlamaModelParams {
n_ctx: config.max_seq_len as u32,
..Default::default()
};
let model = llama_cpp::LlamaModel::load_from_file(gguf_path, params)
.map_err(|e| InfluenceError::GgufError(format!("Failed to load GGUF model: {}", e)))?;
info!("GGUF model loaded successfully (quantization: {})", quantization);
Ok(Self {
gguf_path: gguf_path.to_path_buf(),
context_size: config.max_seq_len,
quantization,
model: Some(model),
tokenizer,
})
}
fn detect_quantization(path: &Path) -> Result<String> {
let filename = path.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| InfluenceError::GgufParsingError("Invalid filename".to_string()))?;
let filename_lower = filename.to_lowercase();
let quant = if filename_lower.contains("q2_k") {
"Q2_K"
} else if filename_lower.contains("q4_k_m") {
"Q4_K_M"
} else if filename_lower.contains("q4_k") {
"Q4_K"
} else if filename_lower.contains("q5_k_m") {
"Q5_K_M"
} else if filename_lower.contains("q5_k") {
"Q5_K"
} else if filename_lower.contains("q6_k") {
"Q6_K"
} else if filename_lower.contains("q8_0") {
"Q8_0"
} else if filename_lower.contains("f16") {
"F16"
} else {
"Unknown"
};
Ok(quant.to_string())
}
pub fn generate_text(
&mut self,
prompt: &str,
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
_eos_token: Option<u32>,
) -> Result<Vec<u32>> {
let model = self.model.as_ref()
.ok_or_else(|| InfluenceError::LocalModelError("GGUF model not loaded".to_string()))?;
let mut session = llama_cpp::LlamaSession::new(model);
let params = llama_cpp::LlamaPredictParams {
n_predict: max_tokens as u32,
temperature,
top_p,
top_k: top_k.unwrap_or(0) as i32,
..Default::default()
};
let mut output_tokens = Vec::new();
let mut callback = |token: u32| {
output_tokens.push(token);
if token == 0 {
false } else {
true }
};
session.advance(prompt, params, Some(&mut callback))
.map_err(|e| InfluenceError::GgufError(format!("GGUF generation failed: {}", e)))?;
Ok(output_tokens)
}
pub fn generate_text_stream<F>(
&mut self,
prompt: &str,
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
mut callback: F,
) -> Result<()>
where
F: FnMut(String) -> Result<()>,
{
let model = self.model.as_ref()
.ok_or_else(|| InfluenceError::LocalModelError("GGUF model not loaded".to_string()))?;
let mut session = llama_cpp::LlamaSession::new(model);
let params = llama_cpp::LlamaPredictParams {
n_predict: max_tokens as u32,
temperature,
top_p,
top_k: top_k.unwrap_or(0) as i32,
..Default::default()
};
let mut token_callback = |token: u32| {
if token == 0 {
false } else {
if let Some(tokenizer) = &self.tokenizer {
let decoded = tokenizer.decode(&[token], false)
.unwrap_or_else(|_| format!("<token_{}>", token));
let _ = callback(decoded);
} else {
let _ = callback(format!("<token_{}>", token));
}
true }
};
session.advance(prompt, params, Some(&mut token_callback))
.map_err(|e| InfluenceError::GgufError(format!("GGUF streaming generation failed: {}", e)))?;
Ok(())
}
pub fn embed_text(&mut self, text: &str) -> Result<Vec<f32>> {
let model = self.model.as_ref()
.ok_or_else(|| InfluenceError::LocalModelError("GGUF model not loaded".to_string()))?;
let embeddings = model.embed_text(text)
.map_err(|e| InfluenceError::GgufError(format!("GGUF embedding failed: {}", e)))?;
Ok(embeddings)
}
pub fn quantization(&self) -> &str {
&self.quantization
}
pub fn context_size(&self) -> usize {
self.context_size
}
pub fn path(&self) -> &Path {
&self.gguf_path
}
}
#[cfg(not(feature = "gguf"))]
impl GgufBackend {
pub fn load(_config: &LocalModelConfig, _gguf_path: &Path) -> Result<Self> {
Err(InfluenceError::InvalidConfig(
"GGUF support not enabled. Build with --features gguf".to_string()
))
}
pub fn quantization(&self) -> &str {
"N/A"
}
pub fn context_size(&self) -> usize {
0
}
pub fn path(&self) -> &Path {
Path::new("")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "gguf")]
fn test_detect_quantization() {
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q2_k.gguf")).unwrap(),
"Q2_K"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q4_k.gguf")).unwrap(),
"Q4_K"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q4_k_m.gguf")).unwrap(),
"Q4_K_M"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q5_k.gguf")).unwrap(),
"Q5_K"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q5_k_m.gguf")).unwrap(),
"Q5_K_M"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q6_k.gguf")).unwrap(),
"Q6_K"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-q8_0.gguf")).unwrap(),
"Q8_0"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model-f16.gguf")).unwrap(),
"F16"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("MODEL-Q2_K.GGUF")).unwrap(),
"Q2_K"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("Model-Q4_K_M.GgUF")).unwrap(),
"Q4_K_M"
);
}
#[test]
#[cfg(feature = "gguf")]
fn test_detect_quantization_unknown() {
assert_eq!(
GgufBackend::detect_quantization(Path::new("model.gguf")).unwrap(),
"Unknown"
);
assert_eq!(
GgufBackend::detect_quantization(Path::new("model.bin")).unwrap(),
"Unknown"
);
}
#[test]
#[cfg(feature = "gguf")]
fn test_detect_quantization_invalid_path() {
assert!(GgufBackend::detect_quantization(Path::new("")).is_err());
assert!(GgufBackend::detect_quantization(Path::new("/")).is_err());
}
#[test]
#[cfg(not(feature = "gguf"))]
fn test_gguf_disabled() {
let config = LocalModelConfig::default();
let result = GgufBackend::load(&config, Path::new("test.gguf"));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("GGUF support not enabled"));
}
}