use super::config::{LocalInferenceParams, LocalLlmConfig};
use anyhow::{Result, anyhow};
use brainwires_core::message::{Message, Role};
use std::sync::Arc;
#[cfg(feature = "native")]
use async_trait::async_trait;
#[cfg(feature = "native")]
use brainwires_core::message::{ChatResponse, StreamChunk, Usage};
#[cfg(feature = "native")]
use brainwires_core::provider::{ChatOptions, Provider};
#[cfg(feature = "native")]
use brainwires_core::tool::Tool;
#[cfg(feature = "native")]
use futures::stream::BoxStream;
#[cfg(feature = "llama-cpp-2")]
use llama_cpp_2::{
context::params::LlamaContextParams, llama_backend::LlamaBackend, llama_batch::LlamaBatch,
model::AddBos, model::LlamaModel, model::params::LlamaModelParams, sampling::LlamaSampler,
};
pub struct LocalLlmProvider {
config: LocalLlmConfig,
#[cfg(feature = "llama-cpp-2")]
backend: std::sync::Mutex<Option<LlamaBackend>>,
#[cfg(feature = "llama-cpp-2")]
model: std::sync::Mutex<Option<LlamaModel>>,
#[cfg(not(feature = "llama-cpp-2"))]
_placeholder: std::marker::PhantomData<()>,
}
impl LocalLlmProvider {
#[cfg(feature = "llama-cpp-2")]
pub fn new(config: LocalLlmConfig) -> Result<Self> {
config.validate().map_err(|e| anyhow!(e))?;
Ok(Self {
config,
backend: std::sync::Mutex::new(None),
model: std::sync::Mutex::new(None),
})
}
#[cfg(not(feature = "llama-cpp-2"))]
pub fn new(config: LocalLlmConfig) -> Result<Self> {
config.validate().map_err(|e| anyhow!(e))?;
Ok(Self {
config,
_placeholder: std::marker::PhantomData,
})
}
pub fn lfm2_350m(model_path: std::path::PathBuf) -> Result<Self> {
Self::new(LocalLlmConfig::lfm2_350m(model_path))
}
pub fn lfm2_1_2b(model_path: std::path::PathBuf) -> Result<Self> {
Self::new(LocalLlmConfig::lfm2_1_2b(model_path))
}
pub fn config(&self) -> &LocalLlmConfig {
&self.config
}
#[cfg(feature = "llama-cpp-2")]
pub async fn is_loaded(&self) -> bool {
self.model.lock().map(|g| g.is_some()).unwrap_or(false)
}
#[cfg(not(feature = "llama-cpp-2"))]
pub async fn is_loaded(&self) -> bool {
false
}
#[cfg(feature = "llama-cpp-2")]
pub async fn load(&self) -> Result<()> {
{
let mut backend_guard = self
.backend
.lock()
.map_err(|e| anyhow!("Lock poisoned: {}", e))?;
if backend_guard.is_none() {
let backend = LlamaBackend::init()
.map_err(|e| anyhow!("Failed to initialize llama backend: {:?}", e))?;
*backend_guard = Some(backend);
}
}
let mut model_guard = self
.model
.lock()
.map_err(|e| anyhow!("Lock poisoned: {}", e))?;
if model_guard.is_some() {
return Ok(()); }
tracing::info!("Loading local model: {}", self.config.name);
let model_params = LlamaModelParams::default().with_n_gpu_layers(self.config.gpu_layers);
let backend_guard = self
.backend
.lock()
.map_err(|e| anyhow!("Lock poisoned: {}", e))?;
let backend = backend_guard
.as_ref()
.ok_or_else(|| anyhow!("Backend not initialized"))?;
let model = LlamaModel::load_from_file(backend, &self.config.model_path, &model_params)
.map_err(|e| anyhow!("Failed to load model: {:?}", e))?;
*model_guard = Some(model);
tracing::info!("Local model loaded successfully: {}", self.config.name);
Ok(())
}
#[cfg(not(feature = "llama-cpp-2"))]
pub async fn load(&self) -> Result<()> {
Err(anyhow!(
"Local LLM support is not enabled. Build with --features llama-cpp-2"
))
}
#[cfg(feature = "llama-cpp-2")]
pub async fn unload(&self) {
if let Ok(mut model_guard) = self.model.lock() {
*model_guard = None;
}
tracing::info!("Local model unloaded: {}", self.config.name);
}
#[cfg(not(feature = "llama-cpp-2"))]
pub async fn unload(&self) {
}
#[cfg_attr(not(feature = "native"), allow(dead_code))]
fn format_prompt(&self, messages: &[Message], system: Option<&str>) -> String {
let template = self.config.model_type.chat_template();
let system_msg = system.map(String::from).or_else(|| {
messages.iter().find_map(|m| {
if m.role == Role::System {
m.text().map(String::from)
} else {
None
}
})
});
let mut prompt = String::new();
if let Some(sys) = &system_msg
&& template.contains("{system}")
{
let sys_part = template
.split("{user}")
.next()
.unwrap_or("")
.replace("{system}", sys);
prompt.push_str(&sys_part);
}
for msg in messages {
match msg.role {
Role::System => continue, Role::User => {
if let Some(text) = msg.text() {
let user_template = if template.contains("{user}") {
template
.split("{user}")
.nth(1)
.and_then(|s| s.split("{").next())
.unwrap_or("\n")
} else {
"\n"
};
prompt.push_str(text);
prompt.push_str(user_template);
}
}
Role::Assistant => {
if let Some(text) = msg.text() {
prompt.push_str(text);
prompt.push('\n');
}
}
Role::Tool => {
if let Some(text) = msg.text() {
prompt.push_str("[Tool Result]: ");
prompt.push_str(text);
prompt.push('\n');
}
}
}
}
prompt
}
#[cfg(feature = "llama-cpp-2")]
fn generate_impl_sync(&self, prompt: &str, params: &LocalInferenceParams) -> Result<String> {
let model_guard = self
.model
.lock()
.map_err(|e| anyhow!("Lock poisoned: {}", e))?;
let model = model_guard
.as_ref()
.ok_or_else(|| anyhow!("Model not loaded"))?;
let backend_guard = self
.backend
.lock()
.map_err(|e| anyhow!("Lock poisoned: {}", e))?;
let backend = backend_guard
.as_ref()
.ok_or_else(|| anyhow!("Backend not initialized"))?;
let mut ctx_params = LlamaContextParams::default();
ctx_params = ctx_params.with_n_ctx(std::num::NonZeroU32::new(self.config.context_size));
ctx_params = ctx_params.with_n_batch(self.config.batch_size);
if let Some(threads) = self.config.num_threads {
ctx_params = ctx_params.with_n_threads(threads as i32);
}
let mut ctx = model
.new_context(backend, ctx_params)
.map_err(|e| anyhow!("Failed to create context: {:?}", e))?;
let tokens = model
.str_to_token(prompt, AddBos::Always)
.map_err(|e| anyhow!("Tokenization failed: {:?}", e))?;
let mut batch = LlamaBatch::new(self.config.batch_size as usize, 1);
for (i, token) in tokens.iter().enumerate() {
let is_last = i == tokens.len() - 1;
batch
.add(*token, i as i32, &[0], is_last)
.map_err(|e| anyhow!("Failed to add token to batch: {:?}", e))?;
}
ctx.decode(&mut batch)
.map_err(|e| anyhow!("Prompt processing failed: {:?}", e))?;
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(params.temperature),
LlamaSampler::top_p(params.top_p, 1),
LlamaSampler::top_k(params.top_k as i32),
LlamaSampler::penalties(64, params.repeat_penalty, 0.0, 0.0),
LlamaSampler::dist(42),
]);
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut output = String::new();
let stop_tokens = self.config.model_type.stop_tokens();
let mut generated = 0u32;
let mut cur_pos = tokens.len() as i32;
while generated < params.max_tokens {
let token = sampler.sample(&ctx, -1);
if model.is_eog_token(token) {
break;
}
let piece = model
.token_to_piece(token, &mut decoder, false, None)
.map_err(|e| anyhow!("Token decode failed: {:?}", e))?;
let should_stop = stop_tokens.iter().any(|s| output.ends_with(s));
if should_stop {
for stop in &stop_tokens {
if output.ends_with(stop) {
output.truncate(output.len() - stop.len());
break;
}
}
break;
}
let custom_stop = params.stop_sequences.iter().any(|s| output.ends_with(s));
if custom_stop {
break;
}
output.push_str(&piece);
batch.clear();
batch
.add(token, cur_pos, &[0], true)
.map_err(|e| anyhow!("Failed to add token: {:?}", e))?;
ctx.decode(&mut batch)
.map_err(|e| anyhow!("Generation failed: {:?}", e))?;
cur_pos += 1;
generated += 1;
}
Ok(output.trim().to_string())
}
#[cfg(not(feature = "llama-cpp-2"))]
fn generate_impl_sync(&self, _prompt: &str, _params: &LocalInferenceParams) -> Result<String> {
Err(anyhow!(
"Local LLM support is not enabled. Build with --features llama-cpp-2"
))
}
pub async fn generate(&self, prompt: &str, params: &LocalInferenceParams) -> Result<String> {
if !self.is_loaded().await {
self.load().await?;
}
let prompt = prompt.to_string();
let params = params.clone();
self.generate_impl_sync(&prompt, ¶ms)
}
pub async fn route(&self, prompt: &str) -> Result<String> {
self.generate(prompt, &LocalInferenceParams::routing())
.await
}
pub async fn process(&self, prompt: &str) -> Result<String> {
self.generate(prompt, &LocalInferenceParams::factual())
.await
}
}
#[cfg(feature = "native")]
#[async_trait]
impl Provider for LocalLlmProvider {
#[allow(clippy::misnamed_getters)] fn name(&self) -> &str {
&self.config.id
}
fn max_output_tokens(&self) -> Option<u32> {
Some(self.config.max_tokens)
}
async fn chat(
&self,
messages: &[Message],
_tools: Option<&[Tool]>,
options: &ChatOptions,
) -> Result<ChatResponse> {
let prompt = self.format_prompt(messages, options.system.as_deref());
let params = LocalInferenceParams {
temperature: options.temperature.unwrap_or(0.7),
max_tokens: options.max_tokens.unwrap_or(self.config.max_tokens),
stop_sequences: options.stop.clone().unwrap_or_default(),
..Default::default()
};
let response_text = self.generate(&prompt, ¶ms).await?;
let prompt_tokens = (prompt.len() / 4) as u32;
let completion_tokens = (response_text.len() / 4) as u32;
Ok(ChatResponse {
message: Message::assistant(response_text),
usage: Usage::new(prompt_tokens, completion_tokens),
finish_reason: Some("stop".to_string()),
})
}
fn stream_chat<'a>(
&'a self,
messages: &'a [Message],
_tools: Option<&'a [Tool]>,
options: &'a ChatOptions,
) -> BoxStream<'a, Result<StreamChunk>> {
let prompt = self.format_prompt(messages, options.system.as_deref());
let params = LocalInferenceParams {
temperature: options.temperature.unwrap_or(0.7),
max_tokens: options.max_tokens.unwrap_or(self.config.max_tokens),
stop_sequences: options.stop.clone().unwrap_or_default(),
..Default::default()
};
Box::pin(async_stream::stream! {
match self.generate(&prompt, ¶ms).await {
Ok(response) => {
const CHUNK_SIZE: usize = 10; for chunk in response.chars().collect::<Vec<_>>().chunks(CHUNK_SIZE) {
let text: String = chunk.iter().collect();
yield Ok(StreamChunk::Text(text));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
let prompt_tokens = (prompt.len() / 4) as u32;
let completion_tokens = (response.len() / 4) as u32;
yield Ok(StreamChunk::Usage(Usage::new(prompt_tokens, completion_tokens)));
yield Ok(StreamChunk::Done);
}
Err(e) => {
yield Err(e);
}
}
})
}
}
impl std::fmt::Debug for LocalLlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalLlmProvider")
.field("config", &self.config)
.finish()
}
}
pub struct LocalLlmPool {
providers: Vec<Arc<LocalLlmProvider>>,
current: std::sync::atomic::AtomicUsize,
}
impl LocalLlmPool {
pub fn new(config: LocalLlmConfig, instances: usize) -> Result<Self> {
let mut providers = Vec::with_capacity(instances);
for _ in 0..instances {
providers.push(Arc::new(LocalLlmProvider::new(config.clone())?));
}
Ok(Self {
providers,
current: std::sync::atomic::AtomicUsize::new(0),
})
}
pub fn next(&self) -> Arc<LocalLlmProvider> {
let idx = self
.current
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.providers.len();
self.providers[idx].clone()
}
pub async fn load_all(&self) -> Result<()> {
for provider in &self.providers {
provider.load().await?;
}
Ok(())
}
pub async fn unload_all(&self) {
for provider in &self.providers {
provider.unload().await;
}
}
pub fn size(&self) -> usize {
self.providers.len()
}
pub fn estimated_ram_mb(&self) -> Option<u32> {
self.providers
.first()
.and_then(|p| p.config.estimated_ram_mb)
.map(|ram| ram * self.providers.len() as u32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_provider_creation() {
let config = LocalLlmConfig::lfm2_350m(PathBuf::from("/tmp/test.gguf"));
let result = LocalLlmProvider::new(config);
assert!(result.is_err());
}
#[test]
fn test_inference_params_defaults() {
let params = LocalInferenceParams::default();
assert_eq!(params.temperature, 0.7);
assert_eq!(params.max_tokens, 2048);
}
#[test]
fn test_pool_estimated_ram() {
let _config = LocalLlmConfig {
model_path: PathBuf::from("."), estimated_ram_mb: Some(220),
..Default::default()
};
let expected_ram = 220 * 4;
assert_eq!(expected_ram, 880);
}
}