tl-lang 0.4.6

A differentiable programming language with tensor support for machine learning
// Sequence Reversal Task
// Input:  [d1, d2, d3, SEP, d3, d2, d1]
// Model outputs next token.

struct Linear { W: Tensor<f32, 2>, b: Tensor<f32, 1> }
impl Linear { 
    fn new(i: i64, o: i64) -> Linear { 
        // Xavier initialization-ish
        return Linear((Tensor::randn([i, o], true)*0.1).detach(), (Tensor::randn([o], true)*0.0).detach()); 
    } 
    fn forward(self, x: Tensor<f32, 3>) -> Tensor<f32, 3> { 
        x.matmul(self.W) + self.b 
    } 
    fn step(self, lr: f32) { 
        let gW = self.W.grad(); 
        let gb = self.b.grad(); 
        self.W -= (gW * lr); 
        self.b -= (gb * lr); 
    } 
}

struct Embedding { w: Tensor<f32, 2> }
impl Embedding { 
    fn new(v: i64, d: i64) -> Embedding { 
        Embedding((Tensor::randn([v, d], true)*0.1).detach()) 
    } 
    fn forward(self, i: Tensor<f32, 2>) -> Tensor<f32, 3> { 
        i.embedding(self.w) 
    } 
    fn step(self, lr: f32) { 
        let g = self.w.grad(); 
        self.w -= (g * lr); 
    } 
}

struct LayerNorm { w: Tensor<f32, 1>, b: Tensor<f32, 1> }
impl LayerNorm { 
    fn new(d: i64) -> LayerNorm { 
        LayerNorm((Tensor::randn([d], true)*0.0+1.0).detach(), (Tensor::randn([d], true)*0.0).detach()) 
    } 
    fn forward(self, x: Tensor<f32, 3>) -> Tensor<f32, 3> { 
        // Standard LayerNorm: (x - mean) / std * w + b.
        // For simplified version, we just do x * w + b, but real LN is better.
        // DSL might not have easy mean/var reduction on last dim?
        // Assuming identity + bias/scale is sufficient for this task as per original code.
        return x * self.w + self.b; 
    } 
    fn step(self, lr: f32) { 
        let gw = self.w.grad();
        let gb = self.b.grad(); 
        self.w -= (gw * lr);
        self.b -= (gb * lr); 
    } 
}

struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear }
impl CausalSelfAttention { 
    fn new(d: i64) -> CausalSelfAttention { 
        // d_model=128. 4 heads -> 32 dim per head.
        return CausalSelfAttention(Linear::new(d, d), Linear::new(d, d), Linear::new(d, d), Linear::new(d, d)); 
    } 
    fn forward(self, x: Tensor<f32, 3>) -> Tensor<f32, 3> { 
        // x: [1, T, 128]
        let Q = self.q_proj.forward(x); // [1, T, 128]
        let K = self.k_proj.forward(x);
        let V = self.v_proj.forward(x);

        // Split heads: [1, T, 128] -> [T, 4, 32] -> [4, T, 32]
        // We assume batch size 1. 2nd dim is T? 
        // x shape in train_epoch is [1, 9] input -> Embedding -> [1, 9, 128].
        // So dims are: 0=Batch(1), 1=Time(9), 2=Channel(128).
        
        let Q_split = Q.reshape([9, 4, 48]); 
        let Q_heads = Q_split.transpose(0, 1); // [4, 9, 32]
        
        let K_split = K.reshape([9, 4, 48]);
        let K_heads = K_split.transpose(0, 1); // [4, 9, 32]
        
        let V_split = V.reshape([9, 4, 48]);
        let V_heads = V_split.transpose(0, 1); // [4, 9, 32]
        
        // Attention
        // Q: [4, 9, 32], K_T: [4, 32, 9]
        // scores: [4, 9, 9]
        let K_heads_T = K_heads.transpose(1, 2);
        
        // Scale: 1/sqrt(48) approx 0.144
        let K_scaled = (K_heads_T * 0.144).contiguous();
        let Q_cont = Q_heads.contiguous();
        
        let logits = Q_cont.matmul(K_scaled);
        
        // Causal mask on last 2 dims (9, 9). 
        // tril works on the last 2 dimensions usually? Or global?
        // documentation says: tril(tensor, diagonal)
        let masked = logits.tril(0);
        
        
        let probs = masked.softmax(2); // Softmax over T dimension (last dim of logits [4, 9, 9])
        
        // Output: [4, 9, 9] x [4, 9, 32] -> [4, 9, 32]
        let probs_cont = probs.contiguous();
        let V_cont = V_heads.contiguous();
        let y_heads = probs_cont.matmul(V_cont);
        
        // Merge heads: [4, 9, 32] -> [9, 4, 32] -> [1, 9, 128]
        let y_trans = y_heads.transpose(0, 1); // [9, 4, 32]
        let y_out = y_trans.reshape([1, 9, 192]);
        
        self.o_proj.forward(y_out) 
    } 
    fn step(self, lr: f32) { 
        self.q_proj.step(lr); 
        self.k_proj.step(lr); 
        self.v_proj.step(lr);
        self.o_proj.step(lr);
    } 
}

struct MLP { f: Linear, p: Linear }
impl MLP { 
    fn new(d: i64) -> MLP { 
        MLP(Linear::new(d, d*4), Linear::new(d*4, d)) 
    } 
    fn forward(self, x: Tensor<f32, 3>) -> Tensor<f32, 3> { 
        self.p.forward(self.f.forward(x).relu()) 
    } 
    fn step(self, lr: f32) { 
        self.f.step(lr); 
        self.p.step(lr); 
    } 
}

struct Block { l1: LayerNorm, a: CausalSelfAttention, l2: LayerNorm, m: MLP }
impl Block { 
    fn new(d: i64) -> Block { 
        Block(LayerNorm::new(d), CausalSelfAttention::new(d), LayerNorm::new(d), MLP::new(d)) 
    } 
    fn forward(self, x: Tensor<f32, 3>) -> Tensor<f32, 3> { 
        let x = x + self.a.forward(self.l1.forward(x)); 
        x + self.m.forward(self.l2.forward(x)) 
    } 
    fn step(self, lr: f32) { 
        self.l1.step(lr); 
        self.a.step(lr); 
        self.l2.step(lr); 
        self.m.step(lr); 
    } 
}

struct PositionalEmbedding { w: Tensor<f32, 2> }
impl PositionalEmbedding {
    fn new(max_len: i64, d: i64) -> PositionalEmbedding {
        PositionalEmbedding((Tensor::randn([max_len, d], true)*0.02).detach())
    }
    fn forward(self, p: Tensor<f32, 2>) -> Tensor<f32, 3> {
        p.embedding(self.w)
    }
    fn step(self, lr: f32) {
        let g = self.w.grad();
        self.w = self.w - (g * lr);
    }
}

struct Transformer { w: Embedding, p: PositionalEmbedding, b1: Block, b2: Block, l: LayerNorm, h: Linear }
impl Transformer { 
    fn new(v: i64, max_len: i64, d: i64) -> Transformer { 
        Transformer(
            Embedding::new(v, d), 
            PositionalEmbedding::new(max_len, d), 
            Block::new(d), 
            Block::new(d), 
            LayerNorm::new(d), 
            Linear::new(d, v)
        ) 
    } 
    fn forward(self, i: Tensor<f32, 2>, p: Tensor<f32, 2>) -> Tensor<f32, 3> { 
        let tok_emb = self.w.forward(i);
        let pos_emb = self.p.forward(p);
        let x = tok_emb + pos_emb;
        let x = self.b1.forward(x);
        let x = self.b2.forward(x);
        self.h.forward(self.l.forward(x)) 
    } 
    fn step(self, lr: f32) { 
        self.w.step(lr); 
        self.p.step(lr);
        self.b1.step(lr); 
        self.b2.step(lr);
        self.l.step(lr); 
        self.h.step(lr); 
    } 
}

fn get_random_digit() -> f32 {
    let r = Tensor::randn([1], true)[0];
    // Binning Normal distribution to Uniform(0..9)
    // CDF thresholds: -1.28 (10%), -0.84 (20%), -0.52 (30%), -0.25 (40%), 0 (50%)...
    let d = 0.0;
    if r < -1.28 { return 0.0; }
    if r >= -1.28 && r < -0.84 { return 1.0; }
    if r >= -0.84 && r < -0.52 { return 2.0; }
    if r >= -0.52 && r < -0.25 { return 3.0; }
    if r >= -0.25 && r < 0.0   { return 4.0; }
    if r >= 0.0   && r < 0.25  { return 5.0; }
    if r >= 0.25  && r < 0.52  { return 6.0; }
    if r >= 0.52  && r < 0.84  { return 7.0; }
    if r >= 0.84  && r < 1.28  { return 8.0; }
    if r >= 1.28 { return 9.0; }
    return 9.0; // Fallback
}

fn train_epoch(model: Transformer, lr: f32, steps: i64, seed_offset: i64) -> f32 {
    let mut total_loss = 0.0;
    
    let SEP = 10.0;
    let PAD = 11.0;
    
    // Position indices 0..9
    let p0 = 0.0; let p1 = 1.0; let p2 = 2.0; let p3 = 3.0; let p4 = 4.0;
    let p5 = 5.0; let p6 = 6.0; let p7 = 7.0; let p8 = 8.0;
    let P_arr = [p0, p1, p2, p3, p4, p5, p6, p7, p8];
    let P_t = P_arr.reshape([1, 9]);

    for s in range(0, steps) {
        // Generate random digits using randn (Normal distribution)
        // We use a large multiplier and modulo to get roughly uniform digits 0-9
        
        // Generate random digits using robust Normal CDF binning
        let d1 = get_random_digit();
        let d2 = get_random_digit();
        let d3 = get_random_digit();
        let d4 = get_random_digit();
        
        // Input: 4 digits + SEP + 3 digits (reversed) last one to predict
        let X_arr = [d1, d2, d3, d4, SEP, d4, d3, d2, d1];
        let Y_arr = [d2, d3, d4, SEP, d4, d3, d2, d1, PAD];
        
        let X = X_arr.reshape([1, 9]);
        let Y = Y_arr.reshape([1, 9]);
        
        let logits = model.forward(X, P_t);
        let logits_flat = logits.reshape([9, 12]); 
        let Y_flat = Y.reshape([9]);
        
        let loss = logits_flat.cross_entropy(Y_flat);
        
        loss.backward();
        model.step(lr);
        
        total_loss = loss.item(); 
    }
    return total_loss; 
}

fn main() {
    let vocab_size = 12;
    let max_len = 16;
    let d_model = 192; // Increased for better accuracy
    let model = Transformer::new(vocab_size, max_len, d_model);
    
    let lr = 0.0003; // Reduced for stability with larger model
    let epochs = 200; // Reduced to avoid memory issues
    
    print("Training Sequence Reversal (Improved)...");
    
    for epoch in range(0, epochs) {
        let loss = train_epoch(model, lr, 300, epoch * 100);
        if (epoch - ((epoch/10)*10)) == 0 {
             print("Epoch:"); print(epoch); 
             print("Loss:"); print(loss);
        }
    }
    
    print("Training Complete.");
    
    print("Saving struct weights...");
    Param::save(model, "reverse_model.safetensors");
    print("Saved.");
}