shimmy 1.3.4

Lightweight 5MB Ollama alternative with native SafeTensors support. No Python dependencies, 2x faster loading.
Documentation
#![allow(clippy::too_many_arguments)]
use anyhow::{anyhow, Result};
use async_trait::async_trait;

use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};

#[cfg(feature = "llama")]
use std::sync::Mutex;
#[cfg(feature = "llama")]
use tracing::info;

#[derive(Default)]
pub struct LlamaEngine;
impl LlamaEngine {
    pub fn new() -> Self {
        Self
    }
}

#[async_trait]
impl InferenceEngine for LlamaEngine {
    async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
        #[cfg(feature = "llama")]
        {
            use llama_cpp_2 as llama;
            use std::num::NonZeroU32;
            let be = llama::llama_backend::LlamaBackend::init()?;
            let model = llama::model::LlamaModel::load_from_file(
                &be,
                &spec.base_path,
                &Default::default(),
            )?;
            let ctx_params = llama::context::params::LlamaContextParams::default()
                .with_n_ctx(NonZeroU32::new(spec.ctx_len as u32))
                .with_n_batch(2048)
                .with_n_ubatch(512)
                .with_n_threads(
                    spec.n_threads.unwrap_or(
                        std::thread::available_parallelism()
                            .map(|n| n.get() as i32)
                            .unwrap_or(4),
                    ),
                )
                .with_n_threads_batch(
                    spec.n_threads.unwrap_or(
                        std::thread::available_parallelism()
                            .map(|n| n.get() as i32)
                            .unwrap_or(4),
                    ),
                );
            let ctx_tmp = model.new_context(&be, ctx_params)?;
            if let Some(ref lora) = spec.lora_path {
                // Check if it's a SafeTensors file and convert if needed
                let lora_path = if lora.extension().and_then(|s| s.to_str()) == Some("safetensors")
                {
                    // For now, provide helpful error message for SafeTensors files
                    return Err(anyhow!(
                        "SafeTensors LoRA detected: {}. Please convert to GGUF format first.",
                        lora.display()
                    ));
                } else {
                    lora.clone()
                };

                let mut adapter = model.lora_adapter_init(&lora_path)?;
                ctx_tmp
                    .lora_adapter_set(&mut adapter, 1.0)
                    .map_err(|e| anyhow!("lora set: {e:?}"))?;
                info!(adapter=%lora_path.display(), "LoRA adapter attached");
            }
            // Store both model and context together to maintain proper lifetimes
            // The context lifetime is tied to &model; storing both in the same struct ensures safety
            let ctx: llama::context::LlamaContext<'static> =
                unsafe { std::mem::transmute(ctx_tmp) };
            Ok(Box::new(LlamaLoaded {
                _be: be,
                model,
                ctx: Mutex::new(ctx),
            }))
        }
        #[cfg(not(feature = "llama"))]
        {
            let _ = spec; // silence unused warning
            Ok(Box::new(LlamaFallback))
        }
    }
}

#[cfg(feature = "llama")]
struct LlamaLoaded {
    _be: llama_cpp_2::llama_backend::LlamaBackend,
    model: llama_cpp_2::model::LlamaModel,
    ctx: Mutex<llama_cpp_2::context::LlamaContext<'static>>,
}

#[cfg(feature = "llama")]
// The llama.cpp context & model use raw pointers internally and are !Send by default.
// We wrap access in a Mutex and only perform FFI calls while holding the lock, so it's
// sound to mark the container Send + Sync for our usage (single-threaded mutable access).
unsafe impl Send for LlamaLoaded {}
#[cfg(feature = "llama")]
unsafe impl Sync for LlamaLoaded {}

#[cfg(feature = "llama")]
#[async_trait]
impl LoadedModel for LlamaLoaded {
    async fn generate(
        &self,
        prompt: &str,
        opts: GenOptions,
        mut on_token: Option<Box<dyn FnMut(String) + Send>>,
    ) -> Result<String> {
        use llama_cpp_2::{
            llama_batch::LlamaBatch,
            model::{AddBos, Special},
            sampling::LlamaSampler,
        };
        let mut ctx = self.ctx.lock().unwrap();
        let tokens = self.model.str_to_token(prompt, AddBos::Always)?;

        // Create batch with explicit logits configuration
        let mut batch = LlamaBatch::new(tokens.len(), 1);
        for (i, &token) in tokens.iter().enumerate() {
            // Only request logits for the last token in the initial batch
            let logits = i == tokens.len() - 1;
            batch.add(token, i as i32, &[0], logits)?;
        }
        ctx.decode(&mut batch)?;

        let mut sampler = LlamaSampler::chain_simple([
            LlamaSampler::temp(opts.temperature),
            LlamaSampler::top_p(opts.top_p, 1),
            LlamaSampler::top_k(opts.top_k),
            // API changed order: (repeat_last_n, freq_penalty, presence_penalty, penalty)
            LlamaSampler::penalties(64, 0.0, 0.0, opts.repeat_penalty),
            LlamaSampler::greedy(),
        ])
        .with_tokens(tokens.iter().copied());

        let mut out = String::new();
        let mut all_tokens = tokens;
        for _ in 0..opts.max_tokens {
            // Sample from the last (and only) position with logits
            let token = sampler.sample(&ctx, -1);
            if self.model.is_eog_token(token) {
                break;
            }
            // Use Plaintext to avoid re-tokenizing control tokens into special forms
            let piece = self.model.token_to_str(token, Special::Plaintext)?;
            let start = out.len();
            out.push_str(&piece);
            if let Some(cb) = on_token.as_mut() {
                cb(out[start..].to_string());
            }

            let mut step = LlamaBatch::new(1, 1);
            step.add(token, all_tokens.len() as i32, &[0], true)?;
            ctx.decode(&mut step)?;
            all_tokens.push(token);
        }
        Ok(out)
    }
}

/// Fallback implementation when llama.cpp feature is not enabled
/// Returns informative message directing users to enable the feature
#[cfg(not(feature = "llama"))]
struct LlamaFallback;

#[cfg(not(feature = "llama"))]
#[async_trait]
impl LoadedModel for LlamaFallback {
    async fn generate(
        &self,
        prompt: &str,
        _opts: GenOptions,
        mut on_token: Option<Box<dyn FnMut(String) + Send>>,
    ) -> Result<String> {
        let fallback_msg =
            "Llama.cpp support not enabled. Build with --features llama for full functionality.";
        if let Some(cb) = on_token.as_mut() {
            cb(fallback_msg.to_string());
        }
        Ok(format!("[INFO] {} Input: {}", fallback_msg, prompt))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_llama_engine_initialization() {
        let engine = LlamaEngine::new();
        // LlamaEngine is a unit struct, just test creation
        assert_eq!(std::mem::size_of_val(&engine), 0);
    }

    #[tokio::test]
    async fn test_model_loading_validation() {
        let _engine = LlamaEngine::new();
        let _spec = ModelSpec {
            name: "test".to_string(),
            base_path: "/nonexistent".into(),
            lora_path: None,
            template: Some("chatml".to_string()),
            ctx_len: 2048,
            n_threads: None,
        };

        // let result = engine.load(&spec).await; // Commented to avoid test file dependencies
        // assert!(result.is_err()); // Test spec structure instead
    }

    #[test]
    fn test_model_spec_validation() {
        let spec = ModelSpec {
            name: "valid".to_string(),
            base_path: "test.gguf".into(),
            lora_path: None,
            template: Some("chatml".to_string()),
            ctx_len: 4096,
            n_threads: Some(4),
        };

        assert_eq!(spec.name, "valid");
        assert_eq!(spec.ctx_len, 4096);
        assert!(spec.template.is_some());
    }
}