use crate::error::Result;
use crate::llm::providers::llama_cpp::command::LlamaCommand;
use crate::llm::providers::llama_cpp::thread;
use crate::llm::LLMEngineTrait;
use crate::types::config::LlamaEngineConfig;
use crate::types::LLMRequest;
use async_trait::async_trait;
use std::thread::JoinHandle;
use tokio::sync::mpsc::{Sender, UnboundedSender};
pub struct LlamaEngine {
pub(crate) cmd_tx: UnboundedSender<LlamaCommand>,
_handle: Option<JoinHandle<()>>,
}
impl LlamaEngine {
pub fn load(cfg: LlamaEngineConfig) -> Result<Self> {
cfg.validate()?;
unsafe {
llama_cpp_sys_2::llama_log_set(
Some(super::callback::llama_log_callback),
std::ptr::null_mut(),
);
}
let (cmd_tx, handle) = thread::spawn_engine_thread(cfg)?;
Ok(Self {
cmd_tx,
_handle: Some(handle),
})
}
}
impl Drop for LlamaEngine {
fn drop(&mut self) {
let _ = self.cmd_tx.send(LlamaCommand::Shutdown);
if let Some(handle) = self._handle.take() {
let _ = handle.join();
}
}
}
#[async_trait]
impl LLMEngineTrait for LlamaEngine {
async fn chat(&mut self, request: LLMRequest) -> Result<String> {
self.chat_internal(&request.formatted_prompt).await
}
async fn chat_stream(&mut self, request: LLMRequest, tx: Sender<Result<String>>) {
self.stream_internal(&request.formatted_prompt, tx).await;
}
fn reset_context(&mut self) {
self.reset_internal();
}
async fn evaluate_sentence_entropy(&mut self, sentence: &str) -> Result<f32> {
self.evaluate_sentence_entropy_internal(sentence).await
}
}