Skip to main content

axonml_nn/layers/
ternary.rs

1//! Ternary Linear Layer - 1.58-bit Weight Quantization (BitNet b1.58)
2//!
3//! Implements TernaryLinear: a linear layer with ternary weights {-1, 0, +1}.
4//! Weights are stored as packed 2-bit integers (4 weights per byte) for inference,
5//! with full-precision shadow weights maintained during training.
6//!
7//! Forward pass uses absmean quantization:
8//!   w_ternary = sign(w) * round(|w| / mean(|w|))
9//!
10//! The ternary matmul reduces to addition/subtraction — no multiply needed:
11//!   y[i] = scale * (sum_{w=+1} x[j] - sum_{w=-1} x[j])
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/ternary.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 19, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use std::any::Any;
28use std::collections::HashMap;
29
30use axonml_autograd::no_grad::is_grad_enabled;
31use axonml_autograd::{GradFn, GradientFunction, Variable};
32use axonml_tensor::Tensor;
33
34use crate::init::{kaiming_uniform, zeros};
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38// =============================================================================
39// Packed Ternary Weights
40// =============================================================================
41
42/// Packed ternary weight storage: 4 weights per byte using 2-bit encoding.
43///
44/// Encoding: 0b00 = 0, 0b01 = +1, 0b10 = -1 (0b11 unused)
45#[derive(Debug, Clone)]
46pub struct PackedTernaryWeights {
47    /// Packed bytes (4 weights per byte)
48    data: Vec<u8>,
49    /// Number of actual weight values
50    num_weights: usize,
51    /// Absmean scale factor
52    scale: f32,
53}
54
55impl PackedTernaryWeights {
56    /// Pack ternary values {-1, 0, +1} into 2-bit representation.
57    pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
58        let num_weights = ternary_values.len();
59        let num_bytes = num_weights.div_ceil(4);
60        let mut data = vec![0u8; num_bytes];
61
62        for (i, &val) in ternary_values.iter().enumerate() {
63            let byte_idx = i / 4;
64            let bit_offset = (i % 4) * 2;
65            let encoded = match val {
66                0 => 0b00u8,
67                1 => 0b01u8,
68                -1 => 0b10u8,
69                _ => 0b00u8, // Clamp invalid values to zero
70            };
71            data[byte_idx] |= encoded << bit_offset;
72        }
73
74        Self {
75            data,
76            num_weights,
77            scale,
78        }
79    }
80
81    /// Unpack to dense ternary values {-1, 0, +1}.
82    pub fn unpack(&self) -> Vec<i8> {
83        let mut values = Vec::with_capacity(self.num_weights);
84        for i in 0..self.num_weights {
85            let byte_idx = i / 4;
86            let bit_offset = (i % 4) * 2;
87            let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
88            let val = match encoded {
89                0b00 => 0i8,
90                0b01 => 1i8,
91                0b10 => -1i8,
92                _ => 0i8,
93            };
94            values.push(val);
95        }
96        values
97    }
98
99    /// Returns the scale factor.
100    pub fn scale(&self) -> f32 {
101        self.scale
102    }
103
104    /// Returns the packed storage size in bytes.
105    pub fn storage_bytes(&self) -> usize {
106        self.data.len()
107    }
108
109    /// Returns the number of weights.
110    pub fn num_weights(&self) -> usize {
111        self.num_weights
112    }
113
114    /// Count zeros (sparsity).
115    pub fn count_zeros(&self) -> usize {
116        let values = self.unpack();
117        values.iter().filter(|&&v| v == 0).count()
118    }
119}
120
121// =============================================================================
122// TernaryLinear
123// =============================================================================
124
125/// A linear layer with 1.58-bit ternary weights (BitNet b1.58).
126///
127/// During training, full-precision shadow weights are maintained and quantized
128/// to ternary {-1, 0, +1} on each forward pass using absmean quantization.
129/// Gradients flow through the quantization via the Straight-Through Estimator (STE).
130///
131/// During inference, pre-quantized packed weights are used for efficient
132/// addition/subtraction-only matmul.
133///
134/// # Architecture
135/// - Shadow weights: fp32 (out_features x in_features), used during training
136/// - Ternary weights: packed 2-bit (4 per byte), used during inference
137/// - Scale factor: mean(|w|), applied after ternary matmul
138/// - Bias: optional fp32
139///
140/// # Shape
141/// - Input: (*, in_features)
142/// - Output: (*, out_features)
143///
144/// # Example
145/// ```ignore
146/// let layer = TernaryLinear::new(512, 512);
147/// let input = Variable::new(Tensor::randn(&[2, 512]), true);
148/// let output = layer.forward(&input);  // Shape: [2, 512]
149/// ```
150pub struct TernaryLinear {
151    /// Shadow weight (fp32) for training — holds the latent continuous weights.
152    pub shadow_weight: Parameter,
153    /// Optional bias (fp32).
154    pub bias: Option<Parameter>,
155    /// Pre-quantized packed weights for inference.
156    packed_weights: Option<PackedTernaryWeights>,
157    /// Input features.
158    in_features: usize,
159    /// Output features.
160    out_features: usize,
161    /// Whether to use packed inference mode.
162    inference_mode: bool,
163}
164
165impl TernaryLinear {
166    /// Creates a new TernaryLinear layer with bias.
167    pub fn new(in_features: usize, out_features: usize) -> Self {
168        Self::with_bias(in_features, out_features, true)
169    }
170
171    /// Creates a new TernaryLinear layer with optional bias.
172    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
173        let weight_data = kaiming_uniform(out_features, in_features);
174        let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
175
176        let bias_param = if bias {
177            let bias_data = zeros(&[out_features]);
178            Some(Parameter::named("bias", bias_data, true))
179        } else {
180            None
181        };
182
183        Self {
184            shadow_weight,
185            bias: bias_param,
186            packed_weights: None,
187            in_features,
188            out_features,
189            inference_mode: false,
190        }
191    }
192
193    /// Returns the input feature dimension.
194    pub fn in_features(&self) -> usize {
195        self.in_features
196    }
197
198    /// Returns the output feature dimension.
199    pub fn out_features(&self) -> usize {
200        self.out_features
201    }
202
203    /// Quantize shadow weights to ternary using absmean quantization.
204    ///
205    /// w_ternary = sign(w) * round(|w| / mean(|w|))
206    ///
207    /// Returns (ternary values as i8, scale factor).
208    pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
209        let w = self.shadow_weight.data();
210        let w_vec = w.to_vec();
211        let n = w_vec.len();
212
213        // Compute absmean scale
214        let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
215        let scale = abs_mean.max(1e-8); // Avoid division by zero
216
217        // Quantize: sign(w) * round(|w| / scale)
218        let ternary: Vec<i8> = w_vec
219            .iter()
220            .map(|&w| {
221                let normalized = (w.abs() / scale).round().min(1.0);
222                let sign = if w > 0.0 {
223                    1i8
224                } else if w < 0.0 {
225                    -1i8
226                } else {
227                    0i8
228                };
229                sign * (normalized as i8)
230            })
231            .collect();
232
233        (ternary, scale)
234    }
235
236    /// Pre-quantize weights for inference (pack to 2-bit representation).
237    pub fn quantize_for_inference(&mut self) {
238        let (ternary, scale) = self.quantize_weights();
239        self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
240        self.inference_mode = true;
241    }
242
243    /// Switch back to training mode (use shadow weights).
244    pub fn use_shadow_weights(&mut self) {
245        self.inference_mode = false;
246    }
247
248    /// Get weight sparsity (fraction of zeros in ternary representation).
249    pub fn weight_sparsity(&self) -> f32 {
250        let (ternary, _) = self.quantize_weights();
251        let zeros = ternary.iter().filter(|&&v| v == 0).count();
252        zeros as f32 / ternary.len() as f32
253    }
254
255    /// Get compression ratio vs fp32.
256    pub fn compression_ratio(&self) -> f32 {
257        let fp32_bytes = self.in_features * self.out_features * 4;
258        let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; // +4 for scale
259        fp32_bytes as f32 / ternary_bytes as f32
260    }
261
262    /// Get the packed weight storage if quantized.
263    pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
264        self.packed_weights.as_ref()
265    }
266
267    /// Perform ternary matmul: y = scale * (sum_positive - sum_negative).
268    ///
269    /// For each output element, we sum input values where the ternary weight is +1,
270    /// subtract input values where the ternary weight is -1, and multiply by scale.
271    /// This is pure addition/subtraction — no floating-point multiply for the matmul itself.
272    fn ternary_matmul(
273        input: &[f32],
274        ternary: &[i8],
275        scale: f32,
276        batch_size: usize,
277        in_features: usize,
278        out_features: usize,
279    ) -> Vec<f32> {
280        let mut output = vec![0.0f32; batch_size * out_features];
281
282        for b in 0..batch_size {
283            let x_off = b * in_features;
284            let y_off = b * out_features;
285
286            for o in 0..out_features {
287                let w_off = o * in_features;
288                let mut sum_pos = 0.0f32;
289                let mut sum_neg = 0.0f32;
290
291                for j in 0..in_features {
292                    let w = ternary[w_off + j];
293                    let x = input[x_off + j];
294                    if w == 1 {
295                        sum_pos += x;
296                    } else if w == -1 {
297                        sum_neg += x;
298                    }
299                    // w == 0: skip (zero contribution)
300                }
301
302                output[y_off + o] = scale * (sum_pos - sum_neg);
303            }
304        }
305
306        output
307    }
308
309    /// Forward pass during training: quantize-on-the-fly with STE backward.
310    fn forward_training(&self, input: &Variable) -> Variable {
311        let input_data = input.data();
312        let input_shape = input_data.shape();
313        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
314        let total_batch: usize = batch_dims.iter().product();
315
316        // Quantize shadow weights to ternary
317        let (ternary, scale) = self.quantize_weights();
318
319        // Flatten input to 2D
320        let input_vec = input_data.to_vec();
321
322        // Ternary matmul
323        let output_vec = Self::ternary_matmul(
324            &input_vec,
325            &ternary,
326            scale,
327            total_batch,
328            self.in_features,
329            self.out_features,
330        );
331
332        // Build output tensor
333        let mut out_shape = batch_dims.clone();
334        out_shape.push(self.out_features);
335        let output_tensor =
336            Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
337
338        // Add bias
339        let output_tensor = if let Some(ref bias) = self.bias {
340            let bias_vec = bias.data().to_vec();
341            let mut out = output_tensor.to_vec();
342            for b in 0..total_batch {
343                for o in 0..self.out_features {
344                    out[b * self.out_features + o] += bias_vec[o];
345                }
346            }
347            Tensor::from_vec(out, &out_shape).expect("tensor creation failed")
348        } else {
349            output_tensor
350        };
351
352        let requires_grad = input.requires_grad() && is_grad_enabled();
353        if requires_grad {
354            // STE backward: gradients pass through quantization as if it were identity.
355            // The gradient w.r.t. shadow_weight is computed as if the ternary quantization
356            // were not there (straight-through estimator).
357            let saved_input = input_data.clone();
358            let saved_ternary = ternary;
359            let saved_scale = scale;
360            let in_f = self.in_features;
361            let out_f = self.out_features;
362            let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
363            let bias_grad_fn = self
364                .bias
365                .as_ref()
366                .and_then(|b| b.variable().grad_fn().cloned());
367
368            let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
369            if bias_grad_fn.is_some() {
370                next_fns.push(bias_grad_fn);
371            }
372
373            let grad_fn = GradFn::new(TernaryLinearBackward {
374                next_fns,
375                saved_input,
376                saved_ternary,
377                saved_scale,
378                in_features: in_f,
379                out_features: out_f,
380                has_bias: self.bias.is_some(),
381                total_batch,
382            });
383            Variable::from_operation(output_tensor, grad_fn, true)
384        } else {
385            Variable::new(output_tensor, false)
386        }
387    }
388
389    /// Forward pass during inference: use pre-quantized packed weights.
390    fn forward_inference(&self, input: &Variable) -> Variable {
391        let packed = self
392            .packed_weights
393            .as_ref()
394            .expect("Must call quantize_for_inference() before inference forward");
395
396        let input_data = input.data();
397        let input_shape = input_data.shape();
398        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
399        let total_batch: usize = batch_dims.iter().product();
400
401        // Unpack ternary weights
402        let ternary = packed.unpack();
403        let scale = packed.scale();
404
405        let input_vec = input_data.to_vec();
406        let output_vec = Self::ternary_matmul(
407            &input_vec,
408            &ternary,
409            scale,
410            total_batch,
411            self.in_features,
412            self.out_features,
413        );
414
415        let mut out_shape = batch_dims;
416        out_shape.push(self.out_features);
417        let mut output_tensor =
418            Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
419
420        // Add bias
421        if let Some(ref bias) = self.bias {
422            let bias_vec = bias.data().to_vec();
423            let mut out = output_tensor.to_vec();
424            for b in 0..total_batch {
425                for o in 0..self.out_features {
426                    out[b * self.out_features + o] += bias_vec[o];
427                }
428            }
429            output_tensor = Tensor::from_vec(out, &out_shape).expect("tensor creation failed");
430        }
431
432        Variable::new(output_tensor, false)
433    }
434}
435
436impl Module for TernaryLinear {
437    fn forward(&self, input: &Variable) -> Variable {
438        if self.inference_mode {
439            self.forward_inference(input)
440        } else {
441            self.forward_training(input)
442        }
443    }
444
445    fn parameters(&self) -> Vec<Parameter> {
446        let mut params = vec![self.shadow_weight.clone()];
447        if let Some(ref bias) = self.bias {
448            params.push(bias.clone());
449        }
450        params
451    }
452
453    fn named_parameters(&self) -> HashMap<String, Parameter> {
454        let mut params = HashMap::new();
455        params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
456        if let Some(ref bias) = self.bias {
457            params.insert("bias".to_string(), bias.clone());
458        }
459        params
460    }
461
462    fn name(&self) -> &'static str {
463        "TernaryLinear"
464    }
465}
466
467impl std::fmt::Debug for TernaryLinear {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        f.debug_struct("TernaryLinear")
470            .field("in_features", &self.in_features)
471            .field("out_features", &self.out_features)
472            .field("bias", &self.bias.is_some())
473            .field("inference_mode", &self.inference_mode)
474            .finish()
475    }
476}
477
478// =============================================================================
479// TernaryLinearBackward (Straight-Through Estimator)
480// =============================================================================
481
482/// Gradient function for TernaryLinear using the Straight-Through Estimator.
483///
484/// The STE passes gradients through the ternary quantization as if it were
485/// an identity function. This allows training the shadow weights with standard
486/// gradient-based optimizers.
487///
488/// For the forward y = scale * T(W) @ x where T is ternary quantization:
489/// - grad_input = scale * T(W)^T @ grad_output  (ternary transpose matmul)
490/// - grad_weight = grad_output^T @ x             (STE: treat T as identity)
491/// - grad_bias = sum(grad_output, dim=0)
492#[derive(Debug)]
493struct TernaryLinearBackward {
494    next_fns: Vec<Option<GradFn>>,
495    saved_input: Tensor<f32>,
496    saved_ternary: Vec<i8>,
497    saved_scale: f32,
498    in_features: usize,
499    out_features: usize,
500    has_bias: bool,
501    total_batch: usize,
502}
503
504impl GradientFunction for TernaryLinearBackward {
505    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
506        let g_vec = grad_output.to_vec();
507        let x_vec = self.saved_input.to_vec();
508
509        // 1. grad_input = scale * ternary_W^T @ grad_output
510        //    For each batch element and input dimension:
511        //    grad_input[b,j] = scale * sum_o(ternary[o,j] * grad_output[b,o])
512        let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
513        for b in 0..self.total_batch {
514            let g_off = b * self.out_features;
515            let gi_off = b * self.in_features;
516
517            for j in 0..self.in_features {
518                let mut sum = 0.0f32;
519                for o in 0..self.out_features {
520                    let w = self.saved_ternary[o * self.in_features + j];
521                    if w == 1 {
522                        sum += g_vec[g_off + o];
523                    } else if w == -1 {
524                        sum -= g_vec[g_off + o];
525                    }
526                }
527                grad_input[gi_off + j] = self.saved_scale * sum;
528            }
529        }
530
531        let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
532
533        // 2. grad_weight (STE): grad_output^T @ input
534        //    grad_weight[o,j] = sum_b(grad_output[b,o] * input[b,j])
535        let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
536        for b in 0..self.total_batch {
537            let g_off = b * self.out_features;
538            let x_off = b * self.in_features;
539
540            for o in 0..self.out_features {
541                let go = g_vec[g_off + o];
542                let w_off = o * self.in_features;
543                for j in 0..self.in_features {
544                    grad_weight[w_off + j] += go * x_vec[x_off + j];
545                }
546            }
547        }
548        let gw_tensor = Tensor::from_vec(grad_weight, &[self.out_features, self.in_features])
549            .expect("tensor creation failed");
550
551        let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
552
553        // 3. grad_bias = sum(grad_output, dim=0)
554        if self.has_bias {
555            let mut grad_bias = vec![0.0f32; self.out_features];
556            for b in 0..self.total_batch {
557                for o in 0..self.out_features {
558                    grad_bias[o] += g_vec[b * self.out_features + o];
559                }
560            }
561            let gb_tensor =
562                Tensor::from_vec(grad_bias, &[self.out_features]).expect("tensor creation failed");
563            results.push(Some(gb_tensor));
564        }
565
566        results
567    }
568
569    fn name(&self) -> &'static str {
570        "TernaryLinearBackward"
571    }
572
573    fn next_functions(&self) -> &[Option<GradFn>] {
574        &self.next_fns
575    }
576
577    fn as_any(&self) -> &dyn Any {
578        self
579    }
580}
581
582// =============================================================================
583// Tests
584// =============================================================================
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_ternary_linear_creation() {
592        let layer = TernaryLinear::new(64, 32);
593        assert_eq!(layer.in_features(), 64);
594        assert_eq!(layer.out_features(), 32);
595        assert!(layer.bias.is_some());
596    }
597
598    #[test]
599    fn test_ternary_linear_no_bias() {
600        let layer = TernaryLinear::with_bias(64, 32, false);
601        assert!(layer.bias.is_none());
602    }
603
604    #[test]
605    fn test_ternary_linear_forward() {
606        let layer = TernaryLinear::new(8, 4);
607        let input = Variable::new(
608            Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
609            false,
610        );
611        let output = layer.forward(&input);
612        assert_eq!(output.shape(), vec![2, 4]);
613    }
614
615    #[test]
616    fn test_ternary_quantization() {
617        let layer = TernaryLinear::new(16, 8);
618        let (ternary, scale) = layer.quantize_weights();
619
620        // All values should be in {-1, 0, +1}
621        for &v in &ternary {
622            assert!(v == -1 || v == 0 || v == 1, "got {}", v);
623        }
624
625        // Scale should be positive
626        assert!(scale > 0.0);
627
628        // Should have the right number of values
629        assert_eq!(ternary.len(), 16 * 8);
630    }
631
632    #[test]
633    fn test_packed_ternary_roundtrip() {
634        let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
635        let packed = PackedTernaryWeights::pack(&values, 0.5);
636        let unpacked = packed.unpack();
637        assert_eq!(values, unpacked);
638        assert_eq!(packed.scale(), 0.5);
639    }
640
641    #[test]
642    fn test_packed_storage_compression() {
643        let n = 1024;
644        let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
645        let packed = PackedTernaryWeights::pack(&values, 1.0);
646        // 1024 weights / 4 per byte = 256 bytes (vs 1024 * 4 = 4096 bytes fp32)
647        assert_eq!(packed.storage_bytes(), 256);
648    }
649
650    #[test]
651    fn test_ternary_matmul_simple() {
652        // 2x3 ternary weight: [[1, -1, 0], [0, 1, 1]]
653        let ternary = vec![1i8, -1, 0, 0, 1, 1];
654        let scale = 1.0;
655        let input = vec![2.0f32, 3.0, 5.0]; // 1x3
656
657        let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
658
659        // y[0] = 1.0 * (2.0 - 3.0) = -1.0
660        // y[1] = 1.0 * (3.0 + 5.0) = 8.0
661        assert!((output[0] - (-1.0)).abs() < 1e-6);
662        assert!((output[1] - 8.0).abs() < 1e-6);
663    }
664
665    #[test]
666    fn test_ternary_linear_inference_mode() {
667        let mut layer = TernaryLinear::new(8, 4);
668
669        let input = Variable::new(
670            Tensor::from_vec(vec![1.0; 8], &[1, 8]).expect("tensor creation failed"),
671            false,
672        );
673
674        // Training forward
675        let train_out = layer.forward(&input);
676
677        // Quantize and switch to inference
678        layer.quantize_for_inference();
679        let infer_out = layer.forward(&input);
680
681        // Should produce the same result
682        let train_vec = train_out.data().to_vec();
683        let infer_vec = infer_out.data().to_vec();
684        for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
685            assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
686        }
687    }
688
689    #[test]
690    fn test_ternary_linear_sparsity() {
691        let layer = TernaryLinear::new(64, 32);
692        let sparsity = layer.weight_sparsity();
693        // Sparsity should be between 0 and 1
694        assert!(sparsity >= 0.0 && sparsity <= 1.0);
695    }
696
697    #[test]
698    fn test_ternary_linear_compression_ratio() {
699        let layer = TernaryLinear::new(512, 512);
700        let ratio = layer.compression_ratio();
701        // Should be close to 16x (32 bits / 2 bits)
702        assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
703    }
704
705    #[test]
706    fn test_ternary_linear_parameters() {
707        let layer = TernaryLinear::new(16, 8);
708        let params = layer.parameters();
709        assert_eq!(params.len(), 2); // shadow_weight + bias
710
711        let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
712        assert_eq!(layer_no_bias.parameters().len(), 1);
713    }
714
715    #[test]
716    fn test_ternary_linear_backward() {
717        let layer = TernaryLinear::new(4, 2);
718
719        let input = Variable::new(
720            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
721            true,
722        );
723        let output = layer.forward(&input);
724        let loss = output.sum();
725        loss.backward();
726
727        // Gradients should exist
728        assert!(input.grad().is_some());
729    }
730}