moe-llm-core 1.3.6

Part of the MoE-13 Ternary Intelligence Stack
Documentation
use candle_core::{Result, Tensor, IndexOp};
use candle_nn::{Module, VarBuilder};
use super::attention::Attention;
use super::moe::MoeBlock;
use super::mlp::Mlp;
use super::config::TransformerConfig;
use super::ternary_linear::TernaryLinear;

pub struct Block {
    attention: Attention,
    moe: Option<MoeBlock>,
    mlp: Option<Mlp>,
    ln1: candle_nn::LayerNorm,
    ln2: candle_nn::LayerNorm,
}

impl Block {
    pub fn new(config: &TransformerConfig, vb: VarBuilder, threshold: f32) -> Result<Self> {
        let attention = Attention::new(config.hidden_size, config.num_heads, vb.pp("attn"), threshold)?;
        
        let (moe, mlp) = if config.num_experts > 0 {
            let moe = MoeBlock::new(config.hidden_size, config.num_experts, vb.pp("moe"), threshold)?;
            (Some(moe), None)
        } else {
            let mlp = Mlp::new(config.hidden_size, config.hidden_size * 4, vb.pp("mlp"), threshold)?;
            (None, Some(mlp))
        };

        let ln1 = candle_nn::layer_norm(config.hidden_size, 1e-5, vb.pp("ln1"))?;
        let ln2 = candle_nn::layer_norm(config.hidden_size, 1e-5, vb.pp("ln2"))?;
        Ok(Self { attention, moe, mlp, ln1, ln2 })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let x = (x + self.attention.forward(&self.ln1.forward(x)?)?)?;
        
        if let Some(moe) = &self.moe {
            let x = (&x + moe.forward(&self.ln2.forward(&x)?)?)?;
            Ok(x)
        } else if let Some(mlp) = &self.mlp {
            let x = (&x + mlp.forward(&self.ln2.forward(&x)?)?)?;
            Ok(x)
        } else {
            Ok(x)
        }
    }
}

pub struct Transformer {
    embedding: candle_nn::Embedding,
    pos_embedding: candle_nn::Embedding,
    blocks: Vec<Block>,
    ln_f: candle_nn::LayerNorm,
    lm_head: TernaryLinear,
    config: TransformerConfig,
}

impl Transformer {
    pub fn new(config: &TransformerConfig, vb: VarBuilder) -> Result<Self> {
        let embedding = candle_nn::embedding(config.vocab_size, config.hidden_size, vb.pp("embed"))?;
        let pos_embedding = candle_nn::embedding(config.max_seq_len, config.hidden_size, vb.pp("pos_embed"))?;
        
        let mut blocks = Vec::new();
        let vb_blocks = vb.pp("blocks");
        for i in 0..config.num_layers {
            let layer_threshold = config.layer_threshold(i);
            blocks.push(Block::new(config, vb_blocks.pp(i), layer_threshold)?);
        }

        let ln_f = candle_nn::layer_norm(config.hidden_size, 1e-5, vb.pp("ln_f"))?;
        let threshold = config.threshold;
        let lm_head = TernaryLinear::new(config.hidden_size, config.vocab_size, false, threshold, vb.pp("lm_head"))?;
        
        Ok(Self { 
            embedding, 
            pos_embedding,
            blocks, 
            ln_f, 
            lm_head,
            config: config.clone(),
        })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let seq_len = x.dims()[1];
        let dev = x.device();
        
        let x_embed = self.embedding.forward(x)?;
        
        let pos = Tensor::arange(0u32, seq_len as u32, dev)?.unsqueeze(0)?.to_dtype(candle_core::DType::U32)?;
        let pos_embed = self.pos_embedding.forward(&pos)?;
        
        let mut x = x_embed.broadcast_add(&pos_embed)?;
        
        for block in &self.blocks {
            x = block.forward(&x)?;
        }
        x = self.ln_f.forward(&x)?;
        self.lm_head.forward(&x)
    }

    pub fn generate(&self, tokenizer: &super::super::tokenizer::BpeTokenizer, prompt: &str, max_new_tokens: usize) -> String {
        let mut tokens = tokenizer.encode(prompt);
        if tokens.is_empty() { 
            return "".to_string(); 
        }

        let dev = candle_core::Device::Cpu;
        let mut rng = rand::thread_rng();

        for _ in 0..max_new_tokens {
            let start_idx = tokens.len().saturating_sub(self.config.max_seq_len);
            let context = &tokens[start_idx..];
            
            let input = Tensor::new(&context[..], &dev).unwrap().unsqueeze(0).unwrap().to_dtype(candle_core::DType::U32).unwrap();
            let logits = self.forward(&input).unwrap();
            
            let dims = logits.dims();
            let seq_len = dims[1];
            
            let last_logits = logits.i((0, seq_len - 1)).unwrap();
            
            let probs = candle_nn::ops::softmax(&last_logits, 0).unwrap();
            let pr = probs.to_vec1::<f32>().unwrap();
            
            use rand::distributions::{Distribution, WeightedIndex};
            let dist = WeightedIndex::new(&pr).unwrap();
            let next_token_idx = dist.sample(&mut rng) as u32;
            
            tokens.push(next_token_idx);
        }
        
        tokenizer.decode(&tokens)
    }
}