Skip to main content

axonml_nn/layers/
ternary.rs

1//! Ternary Linear Layer - 1.58-bit Weight Quantization (BitNet b1.58)
2//!
3//! Implements TernaryLinear: a linear layer with ternary weights {-1, 0, +1}.
4//! Weights are stored as packed 2-bit integers (4 weights per byte) for inference,
5//! with full-precision shadow weights maintained during training.
6//!
7//! Forward pass uses absmean quantization:
8//!   w_ternary = sign(w) * round(|w| / mean(|w|))
9//!
10//! The ternary matmul reduces to addition/subtraction — no multiply needed:
11//!   y[i] = scale * (sum_{w=+1} x[j] - sum_{w=-1} x[j])
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/ternary.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 19, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use std::any::Any;
28use std::collections::HashMap;
29
30use axonml_autograd::no_grad::is_grad_enabled;
31use axonml_autograd::{GradFn, GradientFunction, Variable};
32use axonml_tensor::Tensor;
33
34use crate::init::{kaiming_uniform, zeros};
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38// =============================================================================
39// Packed Ternary Weights
40// =============================================================================
41
42/// Packed ternary weight storage: 4 weights per byte using 2-bit encoding.
43///
44/// Encoding: 0b00 = 0, 0b01 = +1, 0b10 = -1 (0b11 unused)
45#[derive(Debug, Clone)]
46pub struct PackedTernaryWeights {
47    /// Packed bytes (4 weights per byte)
48    data: Vec<u8>,
49    /// Number of actual weight values
50    num_weights: usize,
51    /// Absmean scale factor
52    scale: f32,
53}
54
55impl PackedTernaryWeights {
56    /// Pack ternary values {-1, 0, +1} into 2-bit representation.
57    pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
58        let num_weights = ternary_values.len();
59        let num_bytes = num_weights.div_ceil(4);
60        let mut data = vec![0u8; num_bytes];
61
62        for (i, &val) in ternary_values.iter().enumerate() {
63            let byte_idx = i / 4;
64            let bit_offset = (i % 4) * 2;
65            let encoded = match val {
66                0 => 0b00u8,
67                1 => 0b01u8,
68                -1 => 0b10u8,
69                _ => 0b00u8, // Clamp invalid values to zero
70            };
71            data[byte_idx] |= encoded << bit_offset;
72        }
73
74        Self {
75            data,
76            num_weights,
77            scale,
78        }
79    }
80
81    /// Unpack to dense ternary values {-1, 0, +1}.
82    pub fn unpack(&self) -> Vec<i8> {
83        let mut values = Vec::with_capacity(self.num_weights);
84        for i in 0..self.num_weights {
85            let byte_idx = i / 4;
86            let bit_offset = (i % 4) * 2;
87            let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
88            let val = match encoded {
89                0b00 => 0i8,
90                0b01 => 1i8,
91                0b10 => -1i8,
92                _ => 0i8,
93            };
94            values.push(val);
95        }
96        values
97    }
98
99    /// Returns the scale factor.
100    pub fn scale(&self) -> f32 {
101        self.scale
102    }
103
104    /// Returns the packed storage size in bytes.
105    pub fn storage_bytes(&self) -> usize {
106        self.data.len()
107    }
108
109    /// Returns the number of weights.
110    pub fn num_weights(&self) -> usize {
111        self.num_weights
112    }
113
114    /// Count zeros (sparsity).
115    pub fn count_zeros(&self) -> usize {
116        let values = self.unpack();
117        values.iter().filter(|&&v| v == 0).count()
118    }
119}
120
121// =============================================================================
122// TernaryLinear
123// =============================================================================
124
125/// A linear layer with 1.58-bit ternary weights (BitNet b1.58).
126///
127/// During training, full-precision shadow weights are maintained and quantized
128/// to ternary {-1, 0, +1} on each forward pass using absmean quantization.
129/// Gradients flow through the quantization via the Straight-Through Estimator (STE).
130///
131/// During inference, pre-quantized packed weights are used for efficient
132/// addition/subtraction-only matmul.
133///
134/// # Architecture
135/// - Shadow weights: fp32 (out_features x in_features), used during training
136/// - Ternary weights: packed 2-bit (4 per byte), used during inference
137/// - Scale factor: mean(|w|), applied after ternary matmul
138/// - Bias: optional fp32
139///
140/// # Shape
141/// - Input: (*, in_features)
142/// - Output: (*, out_features)
143///
144/// # Example
145/// ```ignore
146/// let layer = TernaryLinear::new(512, 512);
147/// let input = Variable::new(Tensor::randn(&[2, 512]), true);
148/// let output = layer.forward(&input);  // Shape: [2, 512]
149/// ```
150pub struct TernaryLinear {
151    /// Shadow weight (fp32) for training — holds the latent continuous weights.
152    pub shadow_weight: Parameter,
153    /// Optional bias (fp32).
154    pub bias: Option<Parameter>,
155    /// Pre-quantized packed weights for inference.
156    packed_weights: Option<PackedTernaryWeights>,
157    /// Input features.
158    in_features: usize,
159    /// Output features.
160    out_features: usize,
161    /// Whether to use packed inference mode.
162    inference_mode: bool,
163}
164
165impl TernaryLinear {
166    /// Creates a new TernaryLinear layer with bias.
167    pub fn new(in_features: usize, out_features: usize) -> Self {
168        Self::with_bias(in_features, out_features, true)
169    }
170
171    /// Creates a new TernaryLinear layer with optional bias.
172    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
173        let weight_data = kaiming_uniform(out_features, in_features);
174        let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
175
176        let bias_param = if bias {
177            let bias_data = zeros(&[out_features]);
178            Some(Parameter::named("bias", bias_data, true))
179        } else {
180            None
181        };
182
183        Self {
184            shadow_weight,
185            bias: bias_param,
186            packed_weights: None,
187            in_features,
188            out_features,
189            inference_mode: false,
190        }
191    }
192
193    /// Returns the input feature dimension.
194    pub fn in_features(&self) -> usize {
195        self.in_features
196    }
197
198    /// Returns the output feature dimension.
199    pub fn out_features(&self) -> usize {
200        self.out_features
201    }
202
203    /// Quantize shadow weights to ternary using absmean quantization.
204    ///
205    /// w_ternary = sign(w) * round(|w| / mean(|w|))
206    ///
207    /// Returns (ternary values as i8, scale factor).
208    pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
209        let w = self.shadow_weight.data();
210        let w_vec = w.to_vec();
211        let n = w_vec.len();
212
213        // Compute absmean scale
214        let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
215        let scale = abs_mean.max(1e-8); // Avoid division by zero
216
217        // Quantize: sign(w) * round(|w| / scale)
218        let ternary: Vec<i8> = w_vec
219            .iter()
220            .map(|&w| {
221                let normalized = (w.abs() / scale).round().min(1.0);
222                let sign = if w > 0.0 {
223                    1i8
224                } else if w < 0.0 {
225                    -1i8
226                } else {
227                    0i8
228                };
229                sign * (normalized as i8)
230            })
231            .collect();
232
233        (ternary, scale)
234    }
235
236    /// Pre-quantize weights for inference (pack to 2-bit representation).
237    pub fn quantize_for_inference(&mut self) {
238        let (ternary, scale) = self.quantize_weights();
239        self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
240        self.inference_mode = true;
241    }
242
243    /// Switch back to training mode (use shadow weights).
244    pub fn use_shadow_weights(&mut self) {
245        self.inference_mode = false;
246    }
247
248    /// Get weight sparsity (fraction of zeros in ternary representation).
249    pub fn weight_sparsity(&self) -> f32 {
250        let (ternary, _) = self.quantize_weights();
251        let zeros = ternary.iter().filter(|&&v| v == 0).count();
252        zeros as f32 / ternary.len() as f32
253    }
254
255    /// Get compression ratio vs fp32.
256    pub fn compression_ratio(&self) -> f32 {
257        let fp32_bytes = self.in_features * self.out_features * 4;
258        let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; // +4 for scale
259        fp32_bytes as f32 / ternary_bytes as f32
260    }
261
262    /// Get the packed weight storage if quantized.
263    pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
264        self.packed_weights.as_ref()
265    }
266
267    /// Perform ternary matmul: y = scale * (sum_positive - sum_negative).
268    ///
269    /// For each output element, we sum input values where the ternary weight is +1,
270    /// subtract input values where the ternary weight is -1, and multiply by scale.
271    /// This is pure addition/subtraction — no floating-point multiply for the matmul itself.
272    fn ternary_matmul(
273        input: &[f32],
274        ternary: &[i8],
275        scale: f32,
276        batch_size: usize,
277        in_features: usize,
278        out_features: usize,
279    ) -> Vec<f32> {
280        let mut output = vec![0.0f32; batch_size * out_features];
281
282        for b in 0..batch_size {
283            let x_off = b * in_features;
284            let y_off = b * out_features;
285
286            for o in 0..out_features {
287                let w_off = o * in_features;
288                let mut sum_pos = 0.0f32;
289                let mut sum_neg = 0.0f32;
290
291                for j in 0..in_features {
292                    let w = ternary[w_off + j];
293                    let x = input[x_off + j];
294                    if w == 1 {
295                        sum_pos += x;
296                    } else if w == -1 {
297                        sum_neg += x;
298                    }
299                    // w == 0: skip (zero contribution)
300                }
301
302                output[y_off + o] = scale * (sum_pos - sum_neg);
303            }
304        }
305
306        output
307    }
308
309    /// Forward pass during training: quantize-on-the-fly with STE backward.
310    fn forward_training(&self, input: &Variable) -> Variable {
311        let input_data = input.data();
312        let input_shape = input_data.shape();
313        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
314        let total_batch: usize = batch_dims.iter().product();
315
316        // Quantize shadow weights to ternary
317        let (ternary, scale) = self.quantize_weights();
318
319        // Flatten input to 2D
320        let input_vec = input_data.to_vec();
321
322        // Ternary matmul
323        let output_vec = Self::ternary_matmul(
324            &input_vec,
325            &ternary,
326            scale,
327            total_batch,
328            self.in_features,
329            self.out_features,
330        );
331
332        // Build output tensor
333        let mut out_shape = batch_dims.clone();
334        out_shape.push(self.out_features);
335        let output_tensor = Tensor::from_vec(output_vec, &out_shape).unwrap();
336
337        // Add bias
338        let output_tensor = if let Some(ref bias) = self.bias {
339            let bias_vec = bias.data().to_vec();
340            let mut out = output_tensor.to_vec();
341            for b in 0..total_batch {
342                for o in 0..self.out_features {
343                    out[b * self.out_features + o] += bias_vec[o];
344                }
345            }
346            Tensor::from_vec(out, &out_shape).unwrap()
347        } else {
348            output_tensor
349        };
350
351        let requires_grad = input.requires_grad() && is_grad_enabled();
352        if requires_grad {
353            // STE backward: gradients pass through quantization as if it were identity.
354            // The gradient w.r.t. shadow_weight is computed as if the ternary quantization
355            // were not there (straight-through estimator).
356            let saved_input = input_data.clone();
357            let saved_ternary = ternary;
358            let saved_scale = scale;
359            let in_f = self.in_features;
360            let out_f = self.out_features;
361            let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
362            let bias_grad_fn = self
363                .bias
364                .as_ref()
365                .and_then(|b| b.variable().grad_fn().cloned());
366
367            let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
368            if bias_grad_fn.is_some() {
369                next_fns.push(bias_grad_fn);
370            }
371
372            let grad_fn = GradFn::new(TernaryLinearBackward {
373                next_fns,
374                saved_input,
375                saved_ternary,
376                saved_scale,
377                in_features: in_f,
378                out_features: out_f,
379                has_bias: self.bias.is_some(),
380                total_batch,
381            });
382            Variable::from_operation(output_tensor, grad_fn, true)
383        } else {
384            Variable::new(output_tensor, false)
385        }
386    }
387
388    /// Forward pass during inference: use pre-quantized packed weights.
389    fn forward_inference(&self, input: &Variable) -> Variable {
390        let packed = self
391            .packed_weights
392            .as_ref()
393            .expect("Must call quantize_for_inference() before inference forward");
394
395        let input_data = input.data();
396        let input_shape = input_data.shape();
397        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
398        let total_batch: usize = batch_dims.iter().product();
399
400        // Unpack ternary weights
401        let ternary = packed.unpack();
402        let scale = packed.scale();
403
404        let input_vec = input_data.to_vec();
405        let output_vec = Self::ternary_matmul(
406            &input_vec,
407            &ternary,
408            scale,
409            total_batch,
410            self.in_features,
411            self.out_features,
412        );
413
414        let mut out_shape = batch_dims;
415        out_shape.push(self.out_features);
416        let mut output_tensor = Tensor::from_vec(output_vec, &out_shape).unwrap();
417
418        // Add bias
419        if let Some(ref bias) = self.bias {
420            let bias_vec = bias.data().to_vec();
421            let mut out = output_tensor.to_vec();
422            for b in 0..total_batch {
423                for o in 0..self.out_features {
424                    out[b * self.out_features + o] += bias_vec[o];
425                }
426            }
427            output_tensor = Tensor::from_vec(out, &out_shape).unwrap();
428        }
429
430        Variable::new(output_tensor, false)
431    }
432}
433
434impl Module for TernaryLinear {
435    fn forward(&self, input: &Variable) -> Variable {
436        if self.inference_mode {
437            self.forward_inference(input)
438        } else {
439            self.forward_training(input)
440        }
441    }
442
443    fn parameters(&self) -> Vec<Parameter> {
444        let mut params = vec![self.shadow_weight.clone()];
445        if let Some(ref bias) = self.bias {
446            params.push(bias.clone());
447        }
448        params
449    }
450
451    fn named_parameters(&self) -> HashMap<String, Parameter> {
452        let mut params = HashMap::new();
453        params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
454        if let Some(ref bias) = self.bias {
455            params.insert("bias".to_string(), bias.clone());
456        }
457        params
458    }
459
460    fn name(&self) -> &'static str {
461        "TernaryLinear"
462    }
463}
464
465impl std::fmt::Debug for TernaryLinear {
466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        f.debug_struct("TernaryLinear")
468            .field("in_features", &self.in_features)
469            .field("out_features", &self.out_features)
470            .field("bias", &self.bias.is_some())
471            .field("inference_mode", &self.inference_mode)
472            .finish()
473    }
474}
475
476// =============================================================================
477// TernaryLinearBackward (Straight-Through Estimator)
478// =============================================================================
479
480/// Gradient function for TernaryLinear using the Straight-Through Estimator.
481///
482/// The STE passes gradients through the ternary quantization as if it were
483/// an identity function. This allows training the shadow weights with standard
484/// gradient-based optimizers.
485///
486/// For the forward y = scale * T(W) @ x where T is ternary quantization:
487/// - grad_input = scale * T(W)^T @ grad_output  (ternary transpose matmul)
488/// - grad_weight = grad_output^T @ x             (STE: treat T as identity)
489/// - grad_bias = sum(grad_output, dim=0)
490#[derive(Debug)]
491struct TernaryLinearBackward {
492    next_fns: Vec<Option<GradFn>>,
493    saved_input: Tensor<f32>,
494    saved_ternary: Vec<i8>,
495    saved_scale: f32,
496    in_features: usize,
497    out_features: usize,
498    has_bias: bool,
499    total_batch: usize,
500}
501
502impl GradientFunction for TernaryLinearBackward {
503    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
504        let g_vec = grad_output.to_vec();
505        let x_vec = self.saved_input.to_vec();
506
507        // 1. grad_input = scale * ternary_W^T @ grad_output
508        //    For each batch element and input dimension:
509        //    grad_input[b,j] = scale * sum_o(ternary[o,j] * grad_output[b,o])
510        let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
511        for b in 0..self.total_batch {
512            let g_off = b * self.out_features;
513            let gi_off = b * self.in_features;
514
515            for j in 0..self.in_features {
516                let mut sum = 0.0f32;
517                for o in 0..self.out_features {
518                    let w = self.saved_ternary[o * self.in_features + j];
519                    if w == 1 {
520                        sum += g_vec[g_off + o];
521                    } else if w == -1 {
522                        sum -= g_vec[g_off + o];
523                    }
524                }
525                grad_input[gi_off + j] = self.saved_scale * sum;
526            }
527        }
528
529        let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
530
531        // 2. grad_weight (STE): grad_output^T @ input
532        //    grad_weight[o,j] = sum_b(grad_output[b,o] * input[b,j])
533        let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
534        for b in 0..self.total_batch {
535            let g_off = b * self.out_features;
536            let x_off = b * self.in_features;
537
538            for o in 0..self.out_features {
539                let go = g_vec[g_off + o];
540                let w_off = o * self.in_features;
541                for j in 0..self.in_features {
542                    grad_weight[w_off + j] += go * x_vec[x_off + j];
543                }
544            }
545        }
546        let gw_tensor =
547            Tensor::from_vec(grad_weight, &[self.out_features, self.in_features]).unwrap();
548
549        let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
550
551        // 3. grad_bias = sum(grad_output, dim=0)
552        if self.has_bias {
553            let mut grad_bias = vec![0.0f32; self.out_features];
554            for b in 0..self.total_batch {
555                for o in 0..self.out_features {
556                    grad_bias[o] += g_vec[b * self.out_features + o];
557                }
558            }
559            let gb_tensor = Tensor::from_vec(grad_bias, &[self.out_features]).unwrap();
560            results.push(Some(gb_tensor));
561        }
562
563        results
564    }
565
566    fn name(&self) -> &'static str {
567        "TernaryLinearBackward"
568    }
569
570    fn next_functions(&self) -> &[Option<GradFn>] {
571        &self.next_fns
572    }
573
574    fn as_any(&self) -> &dyn Any {
575        self
576    }
577}
578
579// =============================================================================
580// Tests
581// =============================================================================
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_ternary_linear_creation() {
589        let layer = TernaryLinear::new(64, 32);
590        assert_eq!(layer.in_features(), 64);
591        assert_eq!(layer.out_features(), 32);
592        assert!(layer.bias.is_some());
593    }
594
595    #[test]
596    fn test_ternary_linear_no_bias() {
597        let layer = TernaryLinear::with_bias(64, 32, false);
598        assert!(layer.bias.is_none());
599    }
600
601    #[test]
602    fn test_ternary_linear_forward() {
603        let layer = TernaryLinear::new(8, 4);
604        let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[2, 8]).unwrap(), false);
605        let output = layer.forward(&input);
606        assert_eq!(output.shape(), vec![2, 4]);
607    }
608
609    #[test]
610    fn test_ternary_quantization() {
611        let layer = TernaryLinear::new(16, 8);
612        let (ternary, scale) = layer.quantize_weights();
613
614        // All values should be in {-1, 0, +1}
615        for &v in &ternary {
616            assert!(v == -1 || v == 0 || v == 1, "got {}", v);
617        }
618
619        // Scale should be positive
620        assert!(scale > 0.0);
621
622        // Should have the right number of values
623        assert_eq!(ternary.len(), 16 * 8);
624    }
625
626    #[test]
627    fn test_packed_ternary_roundtrip() {
628        let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
629        let packed = PackedTernaryWeights::pack(&values, 0.5);
630        let unpacked = packed.unpack();
631        assert_eq!(values, unpacked);
632        assert_eq!(packed.scale(), 0.5);
633    }
634
635    #[test]
636    fn test_packed_storage_compression() {
637        let n = 1024;
638        let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
639        let packed = PackedTernaryWeights::pack(&values, 1.0);
640        // 1024 weights / 4 per byte = 256 bytes (vs 1024 * 4 = 4096 bytes fp32)
641        assert_eq!(packed.storage_bytes(), 256);
642    }
643
644    #[test]
645    fn test_ternary_matmul_simple() {
646        // 2x3 ternary weight: [[1, -1, 0], [0, 1, 1]]
647        let ternary = vec![1i8, -1, 0, 0, 1, 1];
648        let scale = 1.0;
649        let input = vec![2.0f32, 3.0, 5.0]; // 1x3
650
651        let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
652
653        // y[0] = 1.0 * (2.0 - 3.0) = -1.0
654        // y[1] = 1.0 * (3.0 + 5.0) = 8.0
655        assert!((output[0] - (-1.0)).abs() < 1e-6);
656        assert!((output[1] - 8.0).abs() < 1e-6);
657    }
658
659    #[test]
660    fn test_ternary_linear_inference_mode() {
661        let mut layer = TernaryLinear::new(8, 4);
662
663        let input = Variable::new(Tensor::from_vec(vec![1.0; 8], &[1, 8]).unwrap(), false);
664
665        // Training forward
666        let train_out = layer.forward(&input);
667
668        // Quantize and switch to inference
669        layer.quantize_for_inference();
670        let infer_out = layer.forward(&input);
671
672        // Should produce the same result
673        let train_vec = train_out.data().to_vec();
674        let infer_vec = infer_out.data().to_vec();
675        for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
676            assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
677        }
678    }
679
680    #[test]
681    fn test_ternary_linear_sparsity() {
682        let layer = TernaryLinear::new(64, 32);
683        let sparsity = layer.weight_sparsity();
684        // Sparsity should be between 0 and 1
685        assert!(sparsity >= 0.0 && sparsity <= 1.0);
686    }
687
688    #[test]
689    fn test_ternary_linear_compression_ratio() {
690        let layer = TernaryLinear::new(512, 512);
691        let ratio = layer.compression_ratio();
692        // Should be close to 16x (32 bits / 2 bits)
693        assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
694    }
695
696    #[test]
697    fn test_ternary_linear_parameters() {
698        let layer = TernaryLinear::new(16, 8);
699        let params = layer.parameters();
700        assert_eq!(params.len(), 2); // shadow_weight + bias
701
702        let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
703        assert_eq!(layer_no_bias.parameters().len(), 1);
704    }
705
706    #[test]
707    fn test_ternary_linear_backward() {
708        let layer = TernaryLinear::new(4, 2);
709
710        let input = Variable::new(
711            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
712            true,
713        );
714        let output = layer.forward(&input);
715        let loss = output.sum();
716        loss.backward();
717
718        // Gradients should exist
719        assert!(input.grad().is_some());
720    }
721}