use crate::llm::providers::llama_cpp::command::LlamaCommand;
use crate::llm::providers::llama_cpp::session::InferenceSession;
use crate::types::config::LlamaEngineConfig;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use log::{error, info};
use std::num::NonZeroU32;
use std::panic::{self, AssertUnwindSafe};
use std::path::Path;
use std::thread::{self, JoinHandle};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
pub(crate) fn spawn_engine_thread(
cfg: LlamaEngineConfig,
) -> crate::error::Result<(UnboundedSender<LlamaCommand>, JoinHandle<()>)> {
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<LlamaCommand>();
let handle = thread::spawn(move || {
let result = panic::catch_unwind(AssertUnwindSafe(|| {
engine_main(cfg, cmd_rx);
}));
if let Err(panic_err) = result {
let msg = if let Some(s) = panic_err.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_err.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
error!("Engine thread panicked: {}", msg);
}
});
Ok((cmd_tx, handle))
}
fn engine_main(cfg: LlamaEngineConfig, mut cmd_rx: UnboundedReceiver<LlamaCommand>) {
let backend = match LlamaBackend::init() {
Ok(b) => b,
Err(e) => {
error!("Backend init failed: {}", e);
return;
}
};
let mut model_params = LlamaModelParams::default();
if cfg.use_gpu {
model_params = model_params.with_n_gpu_layers(cfg.n_gpu_layers);
} else {
model_params = model_params.with_n_gpu_layers(0);
}
let model =
match LlamaModel::load_from_file(&backend, Path::new(&cfg.model_path), &model_params) {
Ok(m) => m,
Err(e) => {
error!("Model loading failed: {}", e);
return;
}
};
let mut ctx_params = LlamaContextParams::default();
ctx_params = ctx_params.with_n_ctx(Option::from(
NonZeroU32::new(cfg.n_ctx).expect("n_ctx must be > 0"),
));
ctx_params = ctx_params.with_n_threads(num_cpus::get() as i32);
let mut context = match model.new_context(&backend, ctx_params) {
Ok(c) => c,
Err(e) => {
error!("Failed to create context: {}", e);
return;
}
};
let mut batch = LlamaBatch::new(cfg.n_tokens, cfg.n_seq_max);
let mut session = InferenceSession::new();
info!("Llama engine thread started successfully.");
while let Some(cmd) = cmd_rx.blocking_recv() {
match cmd {
LlamaCommand::Chat { prompt, reply_tx } => {
let mut full_response = String::new();
let res = InferenceSession::run_inference(
&prompt,
&model,
&mut context,
&mut batch,
&mut session,
&cfg,
|piece| {
full_response.push_str(&piece);
true
},
);
let _ = reply_tx.send(res.map(|_| full_response));
}
LlamaCommand::ChatStream {
prompt,
chunk_tx,
done_tx,
} => {
let res = InferenceSession::run_inference(
&prompt,
&model,
&mut context,
&mut batch,
&mut session,
&cfg,
|piece| chunk_tx.blocking_send(Ok(piece)).is_ok(),
);
if let Err(e) = res {
let _ = chunk_tx.blocking_send(Err(e));
}
let _ = done_tx.send(());
}
LlamaCommand::Reset => {
session.reset();
context.clear_kv_cache();
batch.clear();
}
LlamaCommand::EvaluateEntropy { sentence, reply_tx } => {
let res = InferenceSession::evaluate_entropy(
&sentence,
&model,
&mut context,
&mut batch,
&mut session,
);
let _ = reply_tx.send(res);
}
LlamaCommand::Shutdown => {
info!("Engine thread received shutdown command. Exiting gracefully.");
break;
}
}
}
info!("Llama engine thread finished.");
}