sapient-models 0.2.1

Pre-built LLM architecture graph builders for SAPIENT — Llama, Mistral, Phi, Gemma, GPT-2, BERT, Qwen
Documentation
//! Real transformer forward passes for text generation.

pub mod backend;
pub mod common;
mod llama;
mod phi;

use std::path::{Path, PathBuf};

use anyhow::{bail, Context, Result};
use sapient_hub::model_info::{ArchType, ModelInfo};
use sapient_hub::resolver::WeightFormat;

use crate::gguf_weights::load_gguf_hf_weights;

pub use backend::{mac_gpu_support, LlmBackendKind, MacGpuSupport};
pub use llama::LlamaForward;
pub use phi::PhiForward;

/// Architecture-specific inference engine with KV-cache support.
pub enum ForwardEngine {
    Llama(LlamaForward),
    Phi(PhiForward),
}

fn weight_format_from_paths(weight_paths: &[PathBuf]) -> WeightFormat {
    match weight_paths
        .first()
        .and_then(|p| p.extension())
        .and_then(|e| e.to_str())
    {
        Some("gguf") => WeightFormat::Gguf,
        Some("safetensors") => WeightFormat::Safetensors,
        Some("bin") => WeightFormat::PyTorchBin,
        _ => WeightFormat::Unknown,
    }
}

impl ForwardEngine {
    pub fn from_pretrained(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
        Self::from_weight_paths(info, weight_paths)
    }

    pub fn from_weight_paths(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
        Self::from_weight_paths_with_backend(info, weight_paths, LlmBackendKind::Auto)
    }

    pub fn from_weight_paths_with_backend(
        info: ModelInfo,
        weight_paths: &[PathBuf],
        backend: LlmBackendKind,
    ) -> Result<Self> {
        match weight_format_from_paths(weight_paths) {
            WeightFormat::Gguf => {
                let path = weight_paths
                    .first()
                    .context("GGUF model has no weight path")?;
                Self::from_gguf_with_backend(info, path, backend)
            }
            WeightFormat::Safetensors | WeightFormat::PyTorchBin => match info.arch {
                ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
                    Ok(Self::Llama(LlamaForward::from_files_with_backend(
                        info,
                        weight_paths,
                        backend,
                    )?))
                }
                ArchType::Phi => Ok(Self::Phi(PhiForward::from_files_with_backend(
                    info,
                    weight_paths,
                    backend,
                )?)),
                other => bail!(
                    "architecture {other:?} does not yet have a native forward engine — \
                     use safetensors weights for Llama, Phi, or Qwen models"
                ),
            },
            WeightFormat::Unknown => bail!("unknown or missing weight file format"),
        }
    }

    pub fn from_gguf(info: ModelInfo, path: &Path) -> Result<Self> {
        Self::from_gguf_with_backend(info, path, LlmBackendKind::Auto)
    }

    pub fn from_gguf_with_backend(
        info: ModelInfo,
        path: &Path,
        backend: LlmBackendKind,
    ) -> Result<Self> {
        let weights = load_gguf_hf_weights(path)?;
        match info.arch {
            ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
                Ok(Self::Llama(LlamaForward::from_weights_with_backend(
                    info, weights, backend,
                )?))
            }
            ArchType::Phi => {
                bail!("GGUF Phi models are not yet supported — use safetensors weights")
            }
            other => bail!(
                "architecture {other:?} does not yet support GGUF loading — \
                 try a Llama-family GGUF model or use safetensors weights"
            ),
        }
    }

    pub fn reset_cache(&mut self) {
        match self {
            Self::Llama(f) => f.reset_cache(),
            Self::Phi(f) => f.reset_cache(),
        }
    }

    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
        match self {
            Self::Llama(f) => f.forward_logits(input_ids, use_cache),
            Self::Phi(f) => f.forward_logits(input_ids, use_cache),
        }
    }

    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
        match self {
            Self::Llama(f) => f.embed(input_ids),
            Self::Phi(f) => f.embed(input_ids),
        }
    }
}