Skip to main content

axonml_nn/layers/
ternary.rs

1//! Ternary Linear Layer - 1.58-bit Weight Quantization (BitNet b1.58)
2//!
3//! Implements TernaryLinear: a linear layer with ternary weights {-1, 0, +1}.
4//! Weights are stored as packed 2-bit integers (4 weights per byte) for inference,
5//! with full-precision shadow weights maintained during training.
6//!
7//! Forward pass uses absmean quantization:
8//!   w_ternary = sign(w) * round(|w| / mean(|w|))
9//!
10//! The ternary matmul reduces to addition/subtraction — no multiply needed:
11//!   y[i] = scale * (sum_{w=+1} x[j] - sum_{w=-1} x[j])
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/ternary.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr. — AutomataNexus LLC
18//! ORCID: 0009-0005-2158-7060
19//!
20//! # Updated
21//! April 14, 2026 11:15 PM EST
22//!
23//! # Disclaimer
24//! Use at own risk. This software is provided "as is", without warranty of any
25//! kind, express or implied. The author and AutomataNexus shall not be held
26//! liable for any damages arising from the use of this software.
27
28use std::any::Any;
29use std::collections::HashMap;
30
31use axonml_autograd::no_grad::is_grad_enabled;
32use axonml_autograd::{GradFn, GradientFunction, Variable};
33use axonml_tensor::Tensor;
34
35use crate::init::{kaiming_uniform, zeros};
36use crate::module::Module;
37use crate::parameter::Parameter;
38
39// =============================================================================
40// Packed Ternary Weights
41// =============================================================================
42
43/// Packed ternary weight storage: 4 weights per byte using 2-bit encoding.
44///
45/// Encoding: 0b00 = 0, 0b01 = +1, 0b10 = -1 (0b11 unused)
46#[derive(Debug, Clone)]
47pub struct PackedTernaryWeights {
48    /// Packed bytes (4 weights per byte)
49    data: Vec<u8>,
50    /// Number of actual weight values
51    num_weights: usize,
52    /// Absmean scale factor
53    scale: f32,
54}
55
56impl PackedTernaryWeights {
57    /// Pack ternary values {-1, 0, +1} into 2-bit representation.
58    pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
59        let num_weights = ternary_values.len();
60        let num_bytes = num_weights.div_ceil(4);
61        let mut data = vec![0u8; num_bytes];
62
63        for (i, &val) in ternary_values.iter().enumerate() {
64            let byte_idx = i / 4;
65            let bit_offset = (i % 4) * 2;
66            let encoded = match val {
67                0 => 0b00u8,
68                1 => 0b01u8,
69                -1 => 0b10u8,
70                _ => 0b00u8, // Clamp invalid values to zero
71            };
72            data[byte_idx] |= encoded << bit_offset;
73        }
74
75        Self {
76            data,
77            num_weights,
78            scale,
79        }
80    }
81
82    /// Unpack to dense ternary values {-1, 0, +1}.
83    pub fn unpack(&self) -> Vec<i8> {
84        let mut values = Vec::with_capacity(self.num_weights);
85        for i in 0..self.num_weights {
86            let byte_idx = i / 4;
87            let bit_offset = (i % 4) * 2;
88            let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
89            let val = match encoded {
90                0b00 => 0i8,
91                0b01 => 1i8,
92                0b10 => -1i8,
93                _ => 0i8,
94            };
95            values.push(val);
96        }
97        values
98    }
99
100    /// Returns the scale factor.
101    pub fn scale(&self) -> f32 {
102        self.scale
103    }
104
105    /// Returns the packed storage size in bytes.
106    pub fn storage_bytes(&self) -> usize {
107        self.data.len()
108    }
109
110    /// Returns the number of weights.
111    pub fn num_weights(&self) -> usize {
112        self.num_weights
113    }
114
115    /// Count zeros (sparsity).
116    pub fn count_zeros(&self) -> usize {
117        let values = self.unpack();
118        values.iter().filter(|&&v| v == 0).count()
119    }
120}
121
122// =============================================================================
123// TernaryLinear
124// =============================================================================
125
126/// A linear layer with 1.58-bit ternary weights (BitNet b1.58).
127///
128/// During training, full-precision shadow weights are maintained and quantized
129/// to ternary {-1, 0, +1} on each forward pass using absmean quantization.
130/// Gradients flow through the quantization via the Straight-Through Estimator (STE).
131///
132/// During inference, pre-quantized packed weights are used for efficient
133/// addition/subtraction-only matmul.
134///
135/// # Architecture
136/// - Shadow weights: fp32 (out_features x in_features), used during training
137/// - Ternary weights: packed 2-bit (4 per byte), used during inference
138/// - Scale factor: mean(|w|), applied after ternary matmul
139/// - Bias: optional fp32
140///
141/// # Shape
142/// - Input: (*, in_features)
143/// - Output: (*, out_features)
144///
145/// # Example
146/// ```ignore
147/// let layer = TernaryLinear::new(512, 512);
148/// let input = Variable::new(Tensor::randn(&[2, 512]), true);
149/// let output = layer.forward(&input);  // Shape: [2, 512]
150/// ```
151pub struct TernaryLinear {
152    /// Shadow weight (fp32) for training — holds the latent continuous weights.
153    pub shadow_weight: Parameter,
154    /// Optional bias (fp32).
155    pub bias: Option<Parameter>,
156    /// Pre-quantized packed weights for inference.
157    packed_weights: Option<PackedTernaryWeights>,
158    /// Input features.
159    in_features: usize,
160    /// Output features.
161    out_features: usize,
162    /// Whether to use packed inference mode.
163    inference_mode: bool,
164}
165
166impl TernaryLinear {
167    /// Creates a new TernaryLinear layer with bias.
168    pub fn new(in_features: usize, out_features: usize) -> Self {
169        Self::with_bias(in_features, out_features, true)
170    }
171
172    /// Creates a new TernaryLinear layer with optional bias.
173    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
174        let weight_data = kaiming_uniform(out_features, in_features);
175        let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
176
177        let bias_param = if bias {
178            let bias_data = zeros(&[out_features]);
179            Some(Parameter::named("bias", bias_data, true))
180        } else {
181            None
182        };
183
184        Self {
185            shadow_weight,
186            bias: bias_param,
187            packed_weights: None,
188            in_features,
189            out_features,
190            inference_mode: false,
191        }
192    }
193
194    /// Returns the input feature dimension.
195    pub fn in_features(&self) -> usize {
196        self.in_features
197    }
198
199    /// Returns the output feature dimension.
200    pub fn out_features(&self) -> usize {
201        self.out_features
202    }
203
204    /// Quantize shadow weights to ternary using absmean quantization.
205    ///
206    /// w_ternary = sign(w) * round(|w| / mean(|w|))
207    ///
208    /// Returns (ternary values as i8, scale factor).
209    pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
210        let w = self.shadow_weight.data();
211        let w_vec = w.to_vec();
212        let n = w_vec.len();
213
214        // Compute absmean scale
215        let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
216        let scale = abs_mean.max(1e-8); // Avoid division by zero
217
218        // Quantize: sign(w) * round(|w| / scale)
219        let ternary: Vec<i8> = w_vec
220            .iter()
221            .map(|&w| {
222                let normalized = (w.abs() / scale).round().min(1.0);
223                let sign = if w > 0.0 {
224                    1i8
225                } else if w < 0.0 {
226                    -1i8
227                } else {
228                    0i8
229                };
230                sign * (normalized as i8)
231            })
232            .collect();
233
234        (ternary, scale)
235    }
236
237    /// Pre-quantize weights for inference (pack to 2-bit representation).
238    pub fn quantize_for_inference(&mut self) {
239        let (ternary, scale) = self.quantize_weights();
240        self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
241        self.inference_mode = true;
242    }
243
244    /// Switch back to training mode (use shadow weights).
245    pub fn use_shadow_weights(&mut self) {
246        self.inference_mode = false;
247    }
248
249    /// Get weight sparsity (fraction of zeros in ternary representation).
250    pub fn weight_sparsity(&self) -> f32 {
251        let (ternary, _) = self.quantize_weights();
252        let zeros = ternary.iter().filter(|&&v| v == 0).count();
253        zeros as f32 / ternary.len() as f32
254    }
255
256    /// Get compression ratio vs fp32.
257    pub fn compression_ratio(&self) -> f32 {
258        let fp32_bytes = self.in_features * self.out_features * 4;
259        let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; // +4 for scale
260        fp32_bytes as f32 / ternary_bytes as f32
261    }
262
263    /// Get the packed weight storage if quantized.
264    pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
265        self.packed_weights.as_ref()
266    }
267
268    /// Perform ternary matmul: y = scale * (sum_positive - sum_negative).
269    ///
270    /// For each output element, we sum input values where the ternary weight is +1,
271    /// subtract input values where the ternary weight is -1, and multiply by scale.
272    /// This is pure addition/subtraction — no floating-point multiply for the matmul itself.
273    fn ternary_matmul(
274        input: &[f32],
275        ternary: &[i8],
276        scale: f32,
277        batch_size: usize,
278        in_features: usize,
279        out_features: usize,
280    ) -> Vec<f32> {
281        let mut output = vec![0.0f32; batch_size * out_features];
282
283        for b in 0..batch_size {
284            let x_off = b * in_features;
285            let y_off = b * out_features;
286
287            for o in 0..out_features {
288                let w_off = o * in_features;
289                let mut sum_pos = 0.0f32;
290                let mut sum_neg = 0.0f32;
291
292                for j in 0..in_features {
293                    let w = ternary[w_off + j];
294                    let x = input[x_off + j];
295                    if w == 1 {
296                        sum_pos += x;
297                    } else if w == -1 {
298                        sum_neg += x;
299                    }
300                    // w == 0: skip (zero contribution)
301                }
302
303                output[y_off + o] = scale * (sum_pos - sum_neg);
304            }
305        }
306
307        output
308    }
309
310    /// Forward pass during training: quantize-on-the-fly with STE backward.
311    fn forward_training(&self, input: &Variable) -> Variable {
312        let input_data = input.data();
313        let input_shape = input_data.shape();
314        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
315        let total_batch: usize = batch_dims.iter().product();
316
317        // Quantize shadow weights to ternary
318        let (ternary, scale) = self.quantize_weights();
319
320        // Flatten input to 2D
321        let input_vec = input_data.to_vec();
322
323        // Ternary matmul
324        let output_vec = Self::ternary_matmul(
325            &input_vec,
326            &ternary,
327            scale,
328            total_batch,
329            self.in_features,
330            self.out_features,
331        );
332
333        // Build output tensor
334        let mut out_shape = batch_dims.clone();
335        out_shape.push(self.out_features);
336        let output_tensor =
337            Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
338
339        // Add bias
340        let output_tensor = if let Some(ref bias) = self.bias {
341            let bias_vec = bias.data().to_vec();
342            let mut out = output_tensor.to_vec();
343            for b in 0..total_batch {
344                for o in 0..self.out_features {
345                    out[b * self.out_features + o] += bias_vec[o];
346                }
347            }
348            Tensor::from_vec(out, &out_shape).expect("tensor creation failed")
349        } else {
350            output_tensor
351        };
352
353        let requires_grad = input.requires_grad() && is_grad_enabled();
354        if requires_grad {
355            // STE backward: gradients pass through quantization as if it were identity.
356            // The gradient w.r.t. shadow_weight is computed as if the ternary quantization
357            // were not there (straight-through estimator).
358            let saved_input = input_data.clone();
359            let saved_ternary = ternary;
360            let saved_scale = scale;
361            let in_f = self.in_features;
362            let out_f = self.out_features;
363            let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
364            let bias_grad_fn = self
365                .bias
366                .as_ref()
367                .and_then(|b| b.variable().grad_fn().cloned());
368
369            let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
370            if bias_grad_fn.is_some() {
371                next_fns.push(bias_grad_fn);
372            }
373
374            let grad_fn = GradFn::new(TernaryLinearBackward {
375                next_fns,
376                saved_input,
377                saved_ternary,
378                saved_scale,
379                in_features: in_f,
380                out_features: out_f,
381                has_bias: self.bias.is_some(),
382                total_batch,
383            });
384            Variable::from_operation(output_tensor, grad_fn, true)
385        } else {
386            Variable::new(output_tensor, false)
387        }
388    }
389
390    /// Forward pass during inference: use pre-quantized packed weights.
391    fn forward_inference(&self, input: &Variable) -> Variable {
392        let packed = self
393            .packed_weights
394            .as_ref()
395            .expect("Must call quantize_for_inference() before inference forward");
396
397        let input_data = input.data();
398        let input_shape = input_data.shape();
399        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
400        let total_batch: usize = batch_dims.iter().product();
401
402        // Unpack ternary weights
403        let ternary = packed.unpack();
404        let scale = packed.scale();
405
406        let input_vec = input_data.to_vec();
407        let output_vec = Self::ternary_matmul(
408            &input_vec,
409            &ternary,
410            scale,
411            total_batch,
412            self.in_features,
413            self.out_features,
414        );
415
416        let mut out_shape = batch_dims;
417        out_shape.push(self.out_features);
418        let mut output_tensor =
419            Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
420
421        // Add bias
422        if let Some(ref bias) = self.bias {
423            let bias_vec = bias.data().to_vec();
424            let mut out = output_tensor.to_vec();
425            for b in 0..total_batch {
426                for o in 0..self.out_features {
427                    out[b * self.out_features + o] += bias_vec[o];
428                }
429            }
430            output_tensor = Tensor::from_vec(out, &out_shape).expect("tensor creation failed");
431        }
432
433        Variable::new(output_tensor, false)
434    }
435}
436
437impl Module for TernaryLinear {
438    fn forward(&self, input: &Variable) -> Variable {
439        if self.inference_mode {
440            self.forward_inference(input)
441        } else {
442            self.forward_training(input)
443        }
444    }
445
446    fn parameters(&self) -> Vec<Parameter> {
447        let mut params = vec![self.shadow_weight.clone()];
448        if let Some(ref bias) = self.bias {
449            params.push(bias.clone());
450        }
451        params
452    }
453
454    fn named_parameters(&self) -> HashMap<String, Parameter> {
455        let mut params = HashMap::new();
456        params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
457        if let Some(ref bias) = self.bias {
458            params.insert("bias".to_string(), bias.clone());
459        }
460        params
461    }
462
463    fn name(&self) -> &'static str {
464        "TernaryLinear"
465    }
466}
467
468impl std::fmt::Debug for TernaryLinear {
469    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470        f.debug_struct("TernaryLinear")
471            .field("in_features", &self.in_features)
472            .field("out_features", &self.out_features)
473            .field("bias", &self.bias.is_some())
474            .field("inference_mode", &self.inference_mode)
475            .finish()
476    }
477}
478
479// =============================================================================
480// TernaryLinearBackward (Straight-Through Estimator)
481// =============================================================================
482
483/// Gradient function for TernaryLinear using the Straight-Through Estimator.
484///
485/// The STE passes gradients through the ternary quantization as if it were
486/// an identity function. This allows training the shadow weights with standard
487/// gradient-based optimizers.
488///
489/// For the forward y = scale * T(W) @ x where T is ternary quantization:
490/// - grad_input = scale * T(W)^T @ grad_output  (ternary transpose matmul)
491/// - grad_weight = grad_output^T @ x             (STE: treat T as identity)
492/// - grad_bias = sum(grad_output, dim=0)
493#[derive(Debug)]
494struct TernaryLinearBackward {
495    next_fns: Vec<Option<GradFn>>,
496    saved_input: Tensor<f32>,
497    saved_ternary: Vec<i8>,
498    saved_scale: f32,
499    in_features: usize,
500    out_features: usize,
501    has_bias: bool,
502    total_batch: usize,
503}
504
505impl GradientFunction for TernaryLinearBackward {
506    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
507        let g_vec = grad_output.to_vec();
508        let x_vec = self.saved_input.to_vec();
509
510        // 1. grad_input = scale * ternary_W^T @ grad_output
511        //    For each batch element and input dimension:
512        //    grad_input[b,j] = scale * sum_o(ternary[o,j] * grad_output[b,o])
513        let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
514        for b in 0..self.total_batch {
515            let g_off = b * self.out_features;
516            let gi_off = b * self.in_features;
517
518            for j in 0..self.in_features {
519                let mut sum = 0.0f32;
520                for o in 0..self.out_features {
521                    let w = self.saved_ternary[o * self.in_features + j];
522                    if w == 1 {
523                        sum += g_vec[g_off + o];
524                    } else if w == -1 {
525                        sum -= g_vec[g_off + o];
526                    }
527                }
528                grad_input[gi_off + j] = self.saved_scale * sum;
529            }
530        }
531
532        let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
533
534        // 2. grad_weight (STE): grad_output^T @ input
535        //    grad_weight[o,j] = sum_b(grad_output[b,o] * input[b,j])
536        let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
537        for b in 0..self.total_batch {
538            let g_off = b * self.out_features;
539            let x_off = b * self.in_features;
540
541            for o in 0..self.out_features {
542                let go = g_vec[g_off + o];
543                let w_off = o * self.in_features;
544                for j in 0..self.in_features {
545                    grad_weight[w_off + j] += go * x_vec[x_off + j];
546                }
547            }
548        }
549        let gw_tensor = Tensor::from_vec(grad_weight, &[self.out_features, self.in_features])
550            .expect("tensor creation failed");
551
552        let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
553
554        // 3. grad_bias = sum(grad_output, dim=0)
555        if self.has_bias {
556            let mut grad_bias = vec![0.0f32; self.out_features];
557            for b in 0..self.total_batch {
558                for o in 0..self.out_features {
559                    grad_bias[o] += g_vec[b * self.out_features + o];
560                }
561            }
562            let gb_tensor =
563                Tensor::from_vec(grad_bias, &[self.out_features]).expect("tensor creation failed");
564            results.push(Some(gb_tensor));
565        }
566
567        results
568    }
569
570    fn name(&self) -> &'static str {
571        "TernaryLinearBackward"
572    }
573
574    fn next_functions(&self) -> &[Option<GradFn>] {
575        &self.next_fns
576    }
577
578    fn as_any(&self) -> &dyn Any {
579        self
580    }
581}
582
583// =============================================================================
584// Tests
585// =============================================================================
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_ternary_linear_creation() {
593        let layer = TernaryLinear::new(64, 32);
594        assert_eq!(layer.in_features(), 64);
595        assert_eq!(layer.out_features(), 32);
596        assert!(layer.bias.is_some());
597    }
598
599    #[test]
600    fn test_ternary_linear_no_bias() {
601        let layer = TernaryLinear::with_bias(64, 32, false);
602        assert!(layer.bias.is_none());
603    }
604
605    #[test]
606    fn test_ternary_linear_forward() {
607        let layer = TernaryLinear::new(8, 4);
608        let input = Variable::new(
609            Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
610            false,
611        );
612        let output = layer.forward(&input);
613        assert_eq!(output.shape(), vec![2, 4]);
614    }
615
616    #[test]
617    fn test_ternary_quantization() {
618        let layer = TernaryLinear::new(16, 8);
619        let (ternary, scale) = layer.quantize_weights();
620
621        // All values should be in {-1, 0, +1}
622        for &v in &ternary {
623            assert!(v == -1 || v == 0 || v == 1, "got {}", v);
624        }
625
626        // Scale should be positive
627        assert!(scale > 0.0);
628
629        // Should have the right number of values
630        assert_eq!(ternary.len(), 16 * 8);
631    }
632
633    #[test]
634    fn test_packed_ternary_roundtrip() {
635        let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
636        let packed = PackedTernaryWeights::pack(&values, 0.5);
637        let unpacked = packed.unpack();
638        assert_eq!(values, unpacked);
639        assert_eq!(packed.scale(), 0.5);
640    }
641
642    #[test]
643    fn test_packed_storage_compression() {
644        let n = 1024;
645        let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
646        let packed = PackedTernaryWeights::pack(&values, 1.0);
647        // 1024 weights / 4 per byte = 256 bytes (vs 1024 * 4 = 4096 bytes fp32)
648        assert_eq!(packed.storage_bytes(), 256);
649    }
650
651    #[test]
652    fn test_ternary_matmul_simple() {
653        // 2x3 ternary weight: [[1, -1, 0], [0, 1, 1]]
654        let ternary = vec![1i8, -1, 0, 0, 1, 1];
655        let scale = 1.0;
656        let input = vec![2.0f32, 3.0, 5.0]; // 1x3
657
658        let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
659
660        // y[0] = 1.0 * (2.0 - 3.0) = -1.0
661        // y[1] = 1.0 * (3.0 + 5.0) = 8.0
662        assert!((output[0] - (-1.0)).abs() < 1e-6);
663        assert!((output[1] - 8.0).abs() < 1e-6);
664    }
665
666    #[test]
667    fn test_ternary_linear_inference_mode() {
668        let mut layer = TernaryLinear::new(8, 4);
669
670        let input = Variable::new(
671            Tensor::from_vec(vec![1.0; 8], &[1, 8]).expect("tensor creation failed"),
672            false,
673        );
674
675        // Training forward
676        let train_out = layer.forward(&input);
677
678        // Quantize and switch to inference
679        layer.quantize_for_inference();
680        let infer_out = layer.forward(&input);
681
682        // Should produce the same result
683        let train_vec = train_out.data().to_vec();
684        let infer_vec = infer_out.data().to_vec();
685        for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
686            assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
687        }
688    }
689
690    #[test]
691    fn test_ternary_linear_sparsity() {
692        let layer = TernaryLinear::new(64, 32);
693        let sparsity = layer.weight_sparsity();
694        // Sparsity should be between 0 and 1
695        assert!((0.0..=1.0).contains(&sparsity));
696    }
697
698    #[test]
699    fn test_ternary_linear_compression_ratio() {
700        let layer = TernaryLinear::new(512, 512);
701        let ratio = layer.compression_ratio();
702        // Should be close to 16x (32 bits / 2 bits)
703        assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
704    }
705
706    #[test]
707    fn test_ternary_linear_parameters() {
708        let layer = TernaryLinear::new(16, 8);
709        let params = layer.parameters();
710        assert_eq!(params.len(), 2); // shadow_weight + bias
711
712        let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
713        assert_eq!(layer_no_bias.parameters().len(), 1);
714    }
715
716    #[test]
717    fn test_ternary_linear_backward() {
718        let layer = TernaryLinear::new(4, 2);
719
720        let input = Variable::new(
721            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
722            true,
723        );
724        let output = layer.forward(&input);
725        let loss = output.sum();
726        loss.backward();
727
728        // Gradients should exist
729        assert!(input.grad().is_some());
730    }
731}