ambi 0.1.7

A flexible, multi-backend, customizable AI agent framework, entirely based on Rust.
Documentation
// src/llm/providers/llama_cpp/engine/thread.rs

use crate::llm::providers::llama_cpp::command::LlamaCommand;
use crate::llm::providers::llama_cpp::session::InferenceSession;
use crate::types::config::LlamaEngineConfig;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use log::{error, info};
use std::num::NonZeroU32;
use std::panic::{self, AssertUnwindSafe};
use std::path::Path;
use std::thread::{self, JoinHandle};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};

pub(crate) fn spawn_engine_thread(
    cfg: LlamaEngineConfig,
) -> crate::error::Result<(UnboundedSender<LlamaCommand>, JoinHandle<()>)> {
    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<LlamaCommand>();

    let handle = thread::spawn(move || {
        let result = panic::catch_unwind(AssertUnwindSafe(|| {
            engine_main(cfg, cmd_rx);
        }));
        if let Err(panic_err) = result {
            let msg = if let Some(s) = panic_err.downcast_ref::<&str>() {
                s.to_string()
            } else if let Some(s) = panic_err.downcast_ref::<String>() {
                s.clone()
            } else {
                "unknown panic".to_string()
            };
            error!("Engine thread panicked: {}", msg);
        }
    });

    Ok((cmd_tx, handle))
}

fn engine_main(cfg: LlamaEngineConfig, mut cmd_rx: UnboundedReceiver<LlamaCommand>) {
    // 1. Initialise backend, model, context, batch
    let backend = match LlamaBackend::init() {
        Ok(b) => b,
        Err(e) => {
            error!("Backend init failed: {}", e);
            return;
        }
    };

    let mut model_params = LlamaModelParams::default();
    if cfg.use_gpu {
        model_params = model_params.with_n_gpu_layers(cfg.n_gpu_layers);
    } else {
        model_params = model_params.with_n_gpu_layers(0);
    }

    let model =
        match LlamaModel::load_from_file(&backend, Path::new(&cfg.model_path), &model_params) {
            Ok(m) => m,
            Err(e) => {
                error!("Model loading failed: {}", e);
                return;
            }
        };

    let mut ctx_params = LlamaContextParams::default();
    ctx_params = ctx_params.with_n_ctx(Option::from(
        NonZeroU32::new(cfg.n_ctx).expect("n_ctx must be > 0"),
    ));
    ctx_params = ctx_params.with_n_threads(num_cpus::get() as i32);

    let mut context = match model.new_context(&backend, ctx_params) {
        Ok(c) => c,
        Err(e) => {
            error!("Failed to create context: {}", e);
            return;
        }
    };

    let mut batch = LlamaBatch::new(cfg.n_tokens, cfg.n_seq_max);
    let mut session = InferenceSession::new();

    info!("Llama engine thread started successfully.");

    // 2. Command loop
    while let Some(cmd) = cmd_rx.blocking_recv() {
        match cmd {
            LlamaCommand::Chat { prompt, reply_tx } => {
                let mut full_response = String::new();
                let res = InferenceSession::run_inference(
                    &prompt,
                    &model,
                    &mut context,
                    &mut batch,
                    &mut session,
                    &cfg,
                    |piece| {
                        full_response.push_str(&piece);
                        true
                    },
                );
                let _ = reply_tx.send(res.map(|_| full_response));
            }
            LlamaCommand::ChatStream {
                prompt,
                chunk_tx,
                done_tx,
            } => {
                let res = InferenceSession::run_inference(
                    &prompt,
                    &model,
                    &mut context,
                    &mut batch,
                    &mut session,
                    &cfg,
                    |piece| chunk_tx.blocking_send(Ok(piece)).is_ok(),
                );
                if let Err(e) = res {
                    let _ = chunk_tx.blocking_send(Err(e));
                }
                let _ = done_tx.send(());
            }
            LlamaCommand::Reset => {
                session.reset();
                context.clear_kv_cache();
                batch.clear();
            }
            LlamaCommand::EvaluateEntropy { sentence, reply_tx } => {
                let res = InferenceSession::evaluate_entropy(
                    &sentence,
                    &model,
                    &mut context,
                    &mut batch,
                    &mut session,
                );
                let _ = reply_tx.send(res);
            }
            LlamaCommand::Shutdown => {
                info!("Engine thread received shutdown command. Exiting gracefully.");
                break;
            }
        }
    }
    info!("Llama engine thread finished.");
}