kyro 0.1.1

A high-performance ML inference engine
#![allow(dead_code)]

use crate::distributed::DistributedContext;
use crate::model::config::LlamaConfig;
use crate::model::llama::LlamaModel;
use crate::model::pipeline::PipelineContext;
use crate::model::quantized::QuantizedLlama;
use candle_core::Device;
use std::path::Path;

pub enum LoadedModel {
    Standard(LlamaModel),
    Quantized(QuantizedLlama),
}

pub struct ModelLoader {
    pub config: Option<LlamaConfig>,
    pub model_path: std::path::PathBuf,
    pub is_gguf: bool,
}

impl ModelLoader {
    pub fn new<P: AsRef<Path>>(model_path: P) -> anyhow::Result<Self> {
        let model_path = model_path.as_ref().to_path_buf();

        let is_gguf = model_path.extension().is_some_and(|e| e == "gguf")
            || model_path.to_string_lossy().contains(".gguf");

        if is_gguf {
            return Ok(Self {
                config: None,
                model_path,
                is_gguf: true,
            });
        }

        let config_path = model_path.join("config.json");
        let config = LlamaConfig::from_file(config_path)?;
        Ok(Self {
            config: Some(config),
            model_path,
            is_gguf: false,
        })
    }

    pub fn load(
        &self,
        device: &Device,
        dist: std::sync::Arc<DistributedContext>,
    ) -> anyhow::Result<LoadedModel> {
        if self.is_gguf {
            let q_model = QuantizedLlama::load_gguf(&self.model_path, device)?;
            return Ok(LoadedModel::Quantized(q_model));
        }

        let config = self.config.as_ref().unwrap();

        let mut tensors_files = Vec::new();
        let read_dir = std::fs::read_dir(&self.model_path)?;
        for entry in read_dir {
            let entry = entry?;
            let path = entry.path();
            if path.extension().is_some_and(|ext| ext == "safetensors") {
                tensors_files.push(path);
            }
        }

        if tensors_files.is_empty() {
            return Err(anyhow::anyhow!("No .safetensors files found"));
        }

        let vb = unsafe {
            candle_nn::VarBuilder::from_mmaped_safetensors(
                &tensors_files,
                candle_core::DType::F16,
                device,
            )?
        };

        let pipeline_ctx = PipelineContext::new(
            dist.rank as usize,
            dist.world_size as usize,
            config.num_hidden_layers,
        );

        let model = LlamaModel::new(config, vb, device, dist, pipeline_ctx)?;
        Ok(LoadedModel::Standard(model))
    }
}