1use std::any::Any;
29use std::collections::HashMap;
30
31use axonml_autograd::no_grad::is_grad_enabled;
32use axonml_autograd::{GradFn, GradientFunction, Variable};
33use axonml_tensor::Tensor;
34
35use crate::init::{kaiming_uniform, zeros};
36use crate::module::Module;
37use crate::parameter::Parameter;
38
39#[derive(Debug, Clone)]
47pub struct PackedTernaryWeights {
48 data: Vec<u8>,
50 num_weights: usize,
52 scale: f32,
54}
55
56impl PackedTernaryWeights {
57 pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
59 let num_weights = ternary_values.len();
60 let num_bytes = num_weights.div_ceil(4);
61 let mut data = vec![0u8; num_bytes];
62
63 for (i, &val) in ternary_values.iter().enumerate() {
64 let byte_idx = i / 4;
65 let bit_offset = (i % 4) * 2;
66 let encoded = match val {
67 0 => 0b00u8,
68 1 => 0b01u8,
69 -1 => 0b10u8,
70 _ => 0b00u8, };
72 data[byte_idx] |= encoded << bit_offset;
73 }
74
75 Self {
76 data,
77 num_weights,
78 scale,
79 }
80 }
81
82 pub fn unpack(&self) -> Vec<i8> {
84 let mut values = Vec::with_capacity(self.num_weights);
85 for i in 0..self.num_weights {
86 let byte_idx = i / 4;
87 let bit_offset = (i % 4) * 2;
88 let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
89 let val = match encoded {
90 0b00 => 0i8,
91 0b01 => 1i8,
92 0b10 => -1i8,
93 _ => 0i8,
94 };
95 values.push(val);
96 }
97 values
98 }
99
100 pub fn scale(&self) -> f32 {
102 self.scale
103 }
104
105 pub fn storage_bytes(&self) -> usize {
107 self.data.len()
108 }
109
110 pub fn num_weights(&self) -> usize {
112 self.num_weights
113 }
114
115 pub fn count_zeros(&self) -> usize {
117 let values = self.unpack();
118 values.iter().filter(|&&v| v == 0).count()
119 }
120}
121
122pub struct TernaryLinear {
152 pub shadow_weight: Parameter,
154 pub bias: Option<Parameter>,
156 packed_weights: Option<PackedTernaryWeights>,
158 in_features: usize,
160 out_features: usize,
162 inference_mode: bool,
164}
165
166impl TernaryLinear {
167 pub fn new(in_features: usize, out_features: usize) -> Self {
169 Self::with_bias(in_features, out_features, true)
170 }
171
172 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
174 let weight_data = kaiming_uniform(out_features, in_features);
175 let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
176
177 let bias_param = if bias {
178 let bias_data = zeros(&[out_features]);
179 Some(Parameter::named("bias", bias_data, true))
180 } else {
181 None
182 };
183
184 Self {
185 shadow_weight,
186 bias: bias_param,
187 packed_weights: None,
188 in_features,
189 out_features,
190 inference_mode: false,
191 }
192 }
193
194 pub fn in_features(&self) -> usize {
196 self.in_features
197 }
198
199 pub fn out_features(&self) -> usize {
201 self.out_features
202 }
203
204 pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
210 let w = self.shadow_weight.data();
211 let w_vec = w.to_vec();
212 let n = w_vec.len();
213
214 let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
216 let scale = abs_mean.max(1e-8); let ternary: Vec<i8> = w_vec
220 .iter()
221 .map(|&w| {
222 let normalized = (w.abs() / scale).round().min(1.0);
223 let sign = if w > 0.0 {
224 1i8
225 } else if w < 0.0 {
226 -1i8
227 } else {
228 0i8
229 };
230 sign * (normalized as i8)
231 })
232 .collect();
233
234 (ternary, scale)
235 }
236
237 pub fn quantize_for_inference(&mut self) {
239 let (ternary, scale) = self.quantize_weights();
240 self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
241 self.inference_mode = true;
242 }
243
244 pub fn use_shadow_weights(&mut self) {
246 self.inference_mode = false;
247 }
248
249 pub fn weight_sparsity(&self) -> f32 {
251 let (ternary, _) = self.quantize_weights();
252 let zeros = ternary.iter().filter(|&&v| v == 0).count();
253 zeros as f32 / ternary.len() as f32
254 }
255
256 pub fn compression_ratio(&self) -> f32 {
258 let fp32_bytes = self.in_features * self.out_features * 4;
259 let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; fp32_bytes as f32 / ternary_bytes as f32
261 }
262
263 pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
265 self.packed_weights.as_ref()
266 }
267
268 fn ternary_matmul(
274 input: &[f32],
275 ternary: &[i8],
276 scale: f32,
277 batch_size: usize,
278 in_features: usize,
279 out_features: usize,
280 ) -> Vec<f32> {
281 let mut output = vec![0.0f32; batch_size * out_features];
282
283 for b in 0..batch_size {
284 let x_off = b * in_features;
285 let y_off = b * out_features;
286
287 for o in 0..out_features {
288 let w_off = o * in_features;
289 let mut sum_pos = 0.0f32;
290 let mut sum_neg = 0.0f32;
291
292 for j in 0..in_features {
293 let w = ternary[w_off + j];
294 let x = input[x_off + j];
295 if w == 1 {
296 sum_pos += x;
297 } else if w == -1 {
298 sum_neg += x;
299 }
300 }
302
303 output[y_off + o] = scale * (sum_pos - sum_neg);
304 }
305 }
306
307 output
308 }
309
310 fn forward_training(&self, input: &Variable) -> Variable {
312 let input_data = input.data();
313 let input_shape = input_data.shape();
314 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
315 let total_batch: usize = batch_dims.iter().product();
316
317 let (ternary, scale) = self.quantize_weights();
319
320 let input_vec = input_data.to_vec();
322
323 let output_vec = Self::ternary_matmul(
325 &input_vec,
326 &ternary,
327 scale,
328 total_batch,
329 self.in_features,
330 self.out_features,
331 );
332
333 let mut out_shape = batch_dims.clone();
335 out_shape.push(self.out_features);
336 let output_tensor =
337 Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
338
339 let output_tensor = if let Some(ref bias) = self.bias {
341 let bias_vec = bias.data().to_vec();
342 let mut out = output_tensor.to_vec();
343 for b in 0..total_batch {
344 for o in 0..self.out_features {
345 out[b * self.out_features + o] += bias_vec[o];
346 }
347 }
348 Tensor::from_vec(out, &out_shape).expect("tensor creation failed")
349 } else {
350 output_tensor
351 };
352
353 let requires_grad = input.requires_grad() && is_grad_enabled();
354 if requires_grad {
355 let saved_input = input_data.clone();
359 let saved_ternary = ternary;
360 let saved_scale = scale;
361 let in_f = self.in_features;
362 let out_f = self.out_features;
363 let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
364 let bias_grad_fn = self
365 .bias
366 .as_ref()
367 .and_then(|b| b.variable().grad_fn().cloned());
368
369 let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
370 if bias_grad_fn.is_some() {
371 next_fns.push(bias_grad_fn);
372 }
373
374 let grad_fn = GradFn::new(TernaryLinearBackward {
375 next_fns,
376 saved_input,
377 saved_ternary,
378 saved_scale,
379 in_features: in_f,
380 out_features: out_f,
381 has_bias: self.bias.is_some(),
382 total_batch,
383 });
384 Variable::from_operation(output_tensor, grad_fn, true)
385 } else {
386 Variable::new(output_tensor, false)
387 }
388 }
389
390 fn forward_inference(&self, input: &Variable) -> Variable {
392 let packed = self
393 .packed_weights
394 .as_ref()
395 .expect("Must call quantize_for_inference() before inference forward");
396
397 let input_data = input.data();
398 let input_shape = input_data.shape();
399 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
400 let total_batch: usize = batch_dims.iter().product();
401
402 let ternary = packed.unpack();
404 let scale = packed.scale();
405
406 let input_vec = input_data.to_vec();
407 let output_vec = Self::ternary_matmul(
408 &input_vec,
409 &ternary,
410 scale,
411 total_batch,
412 self.in_features,
413 self.out_features,
414 );
415
416 let mut out_shape = batch_dims;
417 out_shape.push(self.out_features);
418 let mut output_tensor =
419 Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
420
421 if let Some(ref bias) = self.bias {
423 let bias_vec = bias.data().to_vec();
424 let mut out = output_tensor.to_vec();
425 for b in 0..total_batch {
426 for o in 0..self.out_features {
427 out[b * self.out_features + o] += bias_vec[o];
428 }
429 }
430 output_tensor = Tensor::from_vec(out, &out_shape).expect("tensor creation failed");
431 }
432
433 Variable::new(output_tensor, false)
434 }
435}
436
437impl Module for TernaryLinear {
438 fn forward(&self, input: &Variable) -> Variable {
439 if self.inference_mode {
440 self.forward_inference(input)
441 } else {
442 self.forward_training(input)
443 }
444 }
445
446 fn parameters(&self) -> Vec<Parameter> {
447 let mut params = vec![self.shadow_weight.clone()];
448 if let Some(ref bias) = self.bias {
449 params.push(bias.clone());
450 }
451 params
452 }
453
454 fn named_parameters(&self) -> HashMap<String, Parameter> {
455 let mut params = HashMap::new();
456 params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
457 if let Some(ref bias) = self.bias {
458 params.insert("bias".to_string(), bias.clone());
459 }
460 params
461 }
462
463 fn name(&self) -> &'static str {
464 "TernaryLinear"
465 }
466}
467
468impl std::fmt::Debug for TernaryLinear {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470 f.debug_struct("TernaryLinear")
471 .field("in_features", &self.in_features)
472 .field("out_features", &self.out_features)
473 .field("bias", &self.bias.is_some())
474 .field("inference_mode", &self.inference_mode)
475 .finish()
476 }
477}
478
479#[derive(Debug)]
494struct TernaryLinearBackward {
495 next_fns: Vec<Option<GradFn>>,
496 saved_input: Tensor<f32>,
497 saved_ternary: Vec<i8>,
498 saved_scale: f32,
499 in_features: usize,
500 out_features: usize,
501 has_bias: bool,
502 total_batch: usize,
503}
504
505impl GradientFunction for TernaryLinearBackward {
506 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
507 let g_vec = grad_output.to_vec();
508 let x_vec = self.saved_input.to_vec();
509
510 let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
514 for b in 0..self.total_batch {
515 let g_off = b * self.out_features;
516 let gi_off = b * self.in_features;
517
518 for j in 0..self.in_features {
519 let mut sum = 0.0f32;
520 for o in 0..self.out_features {
521 let w = self.saved_ternary[o * self.in_features + j];
522 if w == 1 {
523 sum += g_vec[g_off + o];
524 } else if w == -1 {
525 sum -= g_vec[g_off + o];
526 }
527 }
528 grad_input[gi_off + j] = self.saved_scale * sum;
529 }
530 }
531
532 let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
533
534 let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
537 for b in 0..self.total_batch {
538 let g_off = b * self.out_features;
539 let x_off = b * self.in_features;
540
541 for o in 0..self.out_features {
542 let go = g_vec[g_off + o];
543 let w_off = o * self.in_features;
544 for j in 0..self.in_features {
545 grad_weight[w_off + j] += go * x_vec[x_off + j];
546 }
547 }
548 }
549 let gw_tensor = Tensor::from_vec(grad_weight, &[self.out_features, self.in_features])
550 .expect("tensor creation failed");
551
552 let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
553
554 if self.has_bias {
556 let mut grad_bias = vec![0.0f32; self.out_features];
557 for b in 0..self.total_batch {
558 for o in 0..self.out_features {
559 grad_bias[o] += g_vec[b * self.out_features + o];
560 }
561 }
562 let gb_tensor =
563 Tensor::from_vec(grad_bias, &[self.out_features]).expect("tensor creation failed");
564 results.push(Some(gb_tensor));
565 }
566
567 results
568 }
569
570 fn name(&self) -> &'static str {
571 "TernaryLinearBackward"
572 }
573
574 fn next_functions(&self) -> &[Option<GradFn>] {
575 &self.next_fns
576 }
577
578 fn as_any(&self) -> &dyn Any {
579 self
580 }
581}
582
583#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_ternary_linear_creation() {
593 let layer = TernaryLinear::new(64, 32);
594 assert_eq!(layer.in_features(), 64);
595 assert_eq!(layer.out_features(), 32);
596 assert!(layer.bias.is_some());
597 }
598
599 #[test]
600 fn test_ternary_linear_no_bias() {
601 let layer = TernaryLinear::with_bias(64, 32, false);
602 assert!(layer.bias.is_none());
603 }
604
605 #[test]
606 fn test_ternary_linear_forward() {
607 let layer = TernaryLinear::new(8, 4);
608 let input = Variable::new(
609 Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
610 false,
611 );
612 let output = layer.forward(&input);
613 assert_eq!(output.shape(), vec![2, 4]);
614 }
615
616 #[test]
617 fn test_ternary_quantization() {
618 let layer = TernaryLinear::new(16, 8);
619 let (ternary, scale) = layer.quantize_weights();
620
621 for &v in &ternary {
623 assert!(v == -1 || v == 0 || v == 1, "got {}", v);
624 }
625
626 assert!(scale > 0.0);
628
629 assert_eq!(ternary.len(), 16 * 8);
631 }
632
633 #[test]
634 fn test_packed_ternary_roundtrip() {
635 let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
636 let packed = PackedTernaryWeights::pack(&values, 0.5);
637 let unpacked = packed.unpack();
638 assert_eq!(values, unpacked);
639 assert_eq!(packed.scale(), 0.5);
640 }
641
642 #[test]
643 fn test_packed_storage_compression() {
644 let n = 1024;
645 let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
646 let packed = PackedTernaryWeights::pack(&values, 1.0);
647 assert_eq!(packed.storage_bytes(), 256);
649 }
650
651 #[test]
652 fn test_ternary_matmul_simple() {
653 let ternary = vec![1i8, -1, 0, 0, 1, 1];
655 let scale = 1.0;
656 let input = vec![2.0f32, 3.0, 5.0]; let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
659
660 assert!((output[0] - (-1.0)).abs() < 1e-6);
663 assert!((output[1] - 8.0).abs() < 1e-6);
664 }
665
666 #[test]
667 fn test_ternary_linear_inference_mode() {
668 let mut layer = TernaryLinear::new(8, 4);
669
670 let input = Variable::new(
671 Tensor::from_vec(vec![1.0; 8], &[1, 8]).expect("tensor creation failed"),
672 false,
673 );
674
675 let train_out = layer.forward(&input);
677
678 layer.quantize_for_inference();
680 let infer_out = layer.forward(&input);
681
682 let train_vec = train_out.data().to_vec();
684 let infer_vec = infer_out.data().to_vec();
685 for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
686 assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
687 }
688 }
689
690 #[test]
691 fn test_ternary_linear_sparsity() {
692 let layer = TernaryLinear::new(64, 32);
693 let sparsity = layer.weight_sparsity();
694 assert!((0.0..=1.0).contains(&sparsity));
696 }
697
698 #[test]
699 fn test_ternary_linear_compression_ratio() {
700 let layer = TernaryLinear::new(512, 512);
701 let ratio = layer.compression_ratio();
702 assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
704 }
705
706 #[test]
707 fn test_ternary_linear_parameters() {
708 let layer = TernaryLinear::new(16, 8);
709 let params = layer.parameters();
710 assert_eq!(params.len(), 2); let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
713 assert_eq!(layer_no_bias.parameters().len(), 1);
714 }
715
716 #[test]
717 fn test_ternary_linear_backward() {
718 let layer = TernaryLinear::new(4, 2);
719
720 let input = Variable::new(
721 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
722 true,
723 );
724 let output = layer.forward(&input);
725 let loss = output.sum();
726 loss.backward();
727
728 assert!(input.grad().is_some());
730 }
731}