1use std::any::Any;
28use std::collections::HashMap;
29
30use axonml_autograd::no_grad::is_grad_enabled;
31use axonml_autograd::{GradFn, GradientFunction, Variable};
32use axonml_tensor::Tensor;
33
34use crate::init::{kaiming_uniform, zeros};
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38#[derive(Debug, Clone)]
46pub struct PackedTernaryWeights {
47 data: Vec<u8>,
49 num_weights: usize,
51 scale: f32,
53}
54
55impl PackedTernaryWeights {
56 pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
58 let num_weights = ternary_values.len();
59 let num_bytes = num_weights.div_ceil(4);
60 let mut data = vec![0u8; num_bytes];
61
62 for (i, &val) in ternary_values.iter().enumerate() {
63 let byte_idx = i / 4;
64 let bit_offset = (i % 4) * 2;
65 let encoded = match val {
66 0 => 0b00u8,
67 1 => 0b01u8,
68 -1 => 0b10u8,
69 _ => 0b00u8, };
71 data[byte_idx] |= encoded << bit_offset;
72 }
73
74 Self {
75 data,
76 num_weights,
77 scale,
78 }
79 }
80
81 pub fn unpack(&self) -> Vec<i8> {
83 let mut values = Vec::with_capacity(self.num_weights);
84 for i in 0..self.num_weights {
85 let byte_idx = i / 4;
86 let bit_offset = (i % 4) * 2;
87 let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
88 let val = match encoded {
89 0b00 => 0i8,
90 0b01 => 1i8,
91 0b10 => -1i8,
92 _ => 0i8,
93 };
94 values.push(val);
95 }
96 values
97 }
98
99 pub fn scale(&self) -> f32 {
101 self.scale
102 }
103
104 pub fn storage_bytes(&self) -> usize {
106 self.data.len()
107 }
108
109 pub fn num_weights(&self) -> usize {
111 self.num_weights
112 }
113
114 pub fn count_zeros(&self) -> usize {
116 let values = self.unpack();
117 values.iter().filter(|&&v| v == 0).count()
118 }
119}
120
121pub struct TernaryLinear {
151 pub shadow_weight: Parameter,
153 pub bias: Option<Parameter>,
155 packed_weights: Option<PackedTernaryWeights>,
157 in_features: usize,
159 out_features: usize,
161 inference_mode: bool,
163}
164
165impl TernaryLinear {
166 pub fn new(in_features: usize, out_features: usize) -> Self {
168 Self::with_bias(in_features, out_features, true)
169 }
170
171 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
173 let weight_data = kaiming_uniform(out_features, in_features);
174 let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
175
176 let bias_param = if bias {
177 let bias_data = zeros(&[out_features]);
178 Some(Parameter::named("bias", bias_data, true))
179 } else {
180 None
181 };
182
183 Self {
184 shadow_weight,
185 bias: bias_param,
186 packed_weights: None,
187 in_features,
188 out_features,
189 inference_mode: false,
190 }
191 }
192
193 pub fn in_features(&self) -> usize {
195 self.in_features
196 }
197
198 pub fn out_features(&self) -> usize {
200 self.out_features
201 }
202
203 pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
209 let w = self.shadow_weight.data();
210 let w_vec = w.to_vec();
211 let n = w_vec.len();
212
213 let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
215 let scale = abs_mean.max(1e-8); let ternary: Vec<i8> = w_vec
219 .iter()
220 .map(|&w| {
221 let normalized = (w.abs() / scale).round().min(1.0);
222 let sign = if w > 0.0 {
223 1i8
224 } else if w < 0.0 {
225 -1i8
226 } else {
227 0i8
228 };
229 sign * (normalized as i8)
230 })
231 .collect();
232
233 (ternary, scale)
234 }
235
236 pub fn quantize_for_inference(&mut self) {
238 let (ternary, scale) = self.quantize_weights();
239 self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
240 self.inference_mode = true;
241 }
242
243 pub fn use_shadow_weights(&mut self) {
245 self.inference_mode = false;
246 }
247
248 pub fn weight_sparsity(&self) -> f32 {
250 let (ternary, _) = self.quantize_weights();
251 let zeros = ternary.iter().filter(|&&v| v == 0).count();
252 zeros as f32 / ternary.len() as f32
253 }
254
255 pub fn compression_ratio(&self) -> f32 {
257 let fp32_bytes = self.in_features * self.out_features * 4;
258 let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; fp32_bytes as f32 / ternary_bytes as f32
260 }
261
262 pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
264 self.packed_weights.as_ref()
265 }
266
267 fn ternary_matmul(
273 input: &[f32],
274 ternary: &[i8],
275 scale: f32,
276 batch_size: usize,
277 in_features: usize,
278 out_features: usize,
279 ) -> Vec<f32> {
280 let mut output = vec![0.0f32; batch_size * out_features];
281
282 for b in 0..batch_size {
283 let x_off = b * in_features;
284 let y_off = b * out_features;
285
286 for o in 0..out_features {
287 let w_off = o * in_features;
288 let mut sum_pos = 0.0f32;
289 let mut sum_neg = 0.0f32;
290
291 for j in 0..in_features {
292 let w = ternary[w_off + j];
293 let x = input[x_off + j];
294 if w == 1 {
295 sum_pos += x;
296 } else if w == -1 {
297 sum_neg += x;
298 }
299 }
301
302 output[y_off + o] = scale * (sum_pos - sum_neg);
303 }
304 }
305
306 output
307 }
308
309 fn forward_training(&self, input: &Variable) -> Variable {
311 let input_data = input.data();
312 let input_shape = input_data.shape();
313 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
314 let total_batch: usize = batch_dims.iter().product();
315
316 let (ternary, scale) = self.quantize_weights();
318
319 let input_vec = input_data.to_vec();
321
322 let output_vec = Self::ternary_matmul(
324 &input_vec,
325 &ternary,
326 scale,
327 total_batch,
328 self.in_features,
329 self.out_features,
330 );
331
332 let mut out_shape = batch_dims.clone();
334 out_shape.push(self.out_features);
335 let output_tensor =
336 Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
337
338 let output_tensor = if let Some(ref bias) = self.bias {
340 let bias_vec = bias.data().to_vec();
341 let mut out = output_tensor.to_vec();
342 for b in 0..total_batch {
343 for o in 0..self.out_features {
344 out[b * self.out_features + o] += bias_vec[o];
345 }
346 }
347 Tensor::from_vec(out, &out_shape).expect("tensor creation failed")
348 } else {
349 output_tensor
350 };
351
352 let requires_grad = input.requires_grad() && is_grad_enabled();
353 if requires_grad {
354 let saved_input = input_data.clone();
358 let saved_ternary = ternary;
359 let saved_scale = scale;
360 let in_f = self.in_features;
361 let out_f = self.out_features;
362 let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
363 let bias_grad_fn = self
364 .bias
365 .as_ref()
366 .and_then(|b| b.variable().grad_fn().cloned());
367
368 let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
369 if bias_grad_fn.is_some() {
370 next_fns.push(bias_grad_fn);
371 }
372
373 let grad_fn = GradFn::new(TernaryLinearBackward {
374 next_fns,
375 saved_input,
376 saved_ternary,
377 saved_scale,
378 in_features: in_f,
379 out_features: out_f,
380 has_bias: self.bias.is_some(),
381 total_batch,
382 });
383 Variable::from_operation(output_tensor, grad_fn, true)
384 } else {
385 Variable::new(output_tensor, false)
386 }
387 }
388
389 fn forward_inference(&self, input: &Variable) -> Variable {
391 let packed = self
392 .packed_weights
393 .as_ref()
394 .expect("Must call quantize_for_inference() before inference forward");
395
396 let input_data = input.data();
397 let input_shape = input_data.shape();
398 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
399 let total_batch: usize = batch_dims.iter().product();
400
401 let ternary = packed.unpack();
403 let scale = packed.scale();
404
405 let input_vec = input_data.to_vec();
406 let output_vec = Self::ternary_matmul(
407 &input_vec,
408 &ternary,
409 scale,
410 total_batch,
411 self.in_features,
412 self.out_features,
413 );
414
415 let mut out_shape = batch_dims;
416 out_shape.push(self.out_features);
417 let mut output_tensor =
418 Tensor::from_vec(output_vec, &out_shape).expect("tensor creation failed");
419
420 if let Some(ref bias) = self.bias {
422 let bias_vec = bias.data().to_vec();
423 let mut out = output_tensor.to_vec();
424 for b in 0..total_batch {
425 for o in 0..self.out_features {
426 out[b * self.out_features + o] += bias_vec[o];
427 }
428 }
429 output_tensor = Tensor::from_vec(out, &out_shape).expect("tensor creation failed");
430 }
431
432 Variable::new(output_tensor, false)
433 }
434}
435
436impl Module for TernaryLinear {
437 fn forward(&self, input: &Variable) -> Variable {
438 if self.inference_mode {
439 self.forward_inference(input)
440 } else {
441 self.forward_training(input)
442 }
443 }
444
445 fn parameters(&self) -> Vec<Parameter> {
446 let mut params = vec![self.shadow_weight.clone()];
447 if let Some(ref bias) = self.bias {
448 params.push(bias.clone());
449 }
450 params
451 }
452
453 fn named_parameters(&self) -> HashMap<String, Parameter> {
454 let mut params = HashMap::new();
455 params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
456 if let Some(ref bias) = self.bias {
457 params.insert("bias".to_string(), bias.clone());
458 }
459 params
460 }
461
462 fn name(&self) -> &'static str {
463 "TernaryLinear"
464 }
465}
466
467impl std::fmt::Debug for TernaryLinear {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.debug_struct("TernaryLinear")
470 .field("in_features", &self.in_features)
471 .field("out_features", &self.out_features)
472 .field("bias", &self.bias.is_some())
473 .field("inference_mode", &self.inference_mode)
474 .finish()
475 }
476}
477
478#[derive(Debug)]
493struct TernaryLinearBackward {
494 next_fns: Vec<Option<GradFn>>,
495 saved_input: Tensor<f32>,
496 saved_ternary: Vec<i8>,
497 saved_scale: f32,
498 in_features: usize,
499 out_features: usize,
500 has_bias: bool,
501 total_batch: usize,
502}
503
504impl GradientFunction for TernaryLinearBackward {
505 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
506 let g_vec = grad_output.to_vec();
507 let x_vec = self.saved_input.to_vec();
508
509 let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
513 for b in 0..self.total_batch {
514 let g_off = b * self.out_features;
515 let gi_off = b * self.in_features;
516
517 for j in 0..self.in_features {
518 let mut sum = 0.0f32;
519 for o in 0..self.out_features {
520 let w = self.saved_ternary[o * self.in_features + j];
521 if w == 1 {
522 sum += g_vec[g_off + o];
523 } else if w == -1 {
524 sum -= g_vec[g_off + o];
525 }
526 }
527 grad_input[gi_off + j] = self.saved_scale * sum;
528 }
529 }
530
531 let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
532
533 let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
536 for b in 0..self.total_batch {
537 let g_off = b * self.out_features;
538 let x_off = b * self.in_features;
539
540 for o in 0..self.out_features {
541 let go = g_vec[g_off + o];
542 let w_off = o * self.in_features;
543 for j in 0..self.in_features {
544 grad_weight[w_off + j] += go * x_vec[x_off + j];
545 }
546 }
547 }
548 let gw_tensor = Tensor::from_vec(grad_weight, &[self.out_features, self.in_features])
549 .expect("tensor creation failed");
550
551 let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
552
553 if self.has_bias {
555 let mut grad_bias = vec![0.0f32; self.out_features];
556 for b in 0..self.total_batch {
557 for o in 0..self.out_features {
558 grad_bias[o] += g_vec[b * self.out_features + o];
559 }
560 }
561 let gb_tensor =
562 Tensor::from_vec(grad_bias, &[self.out_features]).expect("tensor creation failed");
563 results.push(Some(gb_tensor));
564 }
565
566 results
567 }
568
569 fn name(&self) -> &'static str {
570 "TernaryLinearBackward"
571 }
572
573 fn next_functions(&self) -> &[Option<GradFn>] {
574 &self.next_fns
575 }
576
577 fn as_any(&self) -> &dyn Any {
578 self
579 }
580}
581
582#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_ternary_linear_creation() {
592 let layer = TernaryLinear::new(64, 32);
593 assert_eq!(layer.in_features(), 64);
594 assert_eq!(layer.out_features(), 32);
595 assert!(layer.bias.is_some());
596 }
597
598 #[test]
599 fn test_ternary_linear_no_bias() {
600 let layer = TernaryLinear::with_bias(64, 32, false);
601 assert!(layer.bias.is_none());
602 }
603
604 #[test]
605 fn test_ternary_linear_forward() {
606 let layer = TernaryLinear::new(8, 4);
607 let input = Variable::new(
608 Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
609 false,
610 );
611 let output = layer.forward(&input);
612 assert_eq!(output.shape(), vec![2, 4]);
613 }
614
615 #[test]
616 fn test_ternary_quantization() {
617 let layer = TernaryLinear::new(16, 8);
618 let (ternary, scale) = layer.quantize_weights();
619
620 for &v in &ternary {
622 assert!(v == -1 || v == 0 || v == 1, "got {}", v);
623 }
624
625 assert!(scale > 0.0);
627
628 assert_eq!(ternary.len(), 16 * 8);
630 }
631
632 #[test]
633 fn test_packed_ternary_roundtrip() {
634 let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
635 let packed = PackedTernaryWeights::pack(&values, 0.5);
636 let unpacked = packed.unpack();
637 assert_eq!(values, unpacked);
638 assert_eq!(packed.scale(), 0.5);
639 }
640
641 #[test]
642 fn test_packed_storage_compression() {
643 let n = 1024;
644 let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
645 let packed = PackedTernaryWeights::pack(&values, 1.0);
646 assert_eq!(packed.storage_bytes(), 256);
648 }
649
650 #[test]
651 fn test_ternary_matmul_simple() {
652 let ternary = vec![1i8, -1, 0, 0, 1, 1];
654 let scale = 1.0;
655 let input = vec![2.0f32, 3.0, 5.0]; let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
658
659 assert!((output[0] - (-1.0)).abs() < 1e-6);
662 assert!((output[1] - 8.0).abs() < 1e-6);
663 }
664
665 #[test]
666 fn test_ternary_linear_inference_mode() {
667 let mut layer = TernaryLinear::new(8, 4);
668
669 let input = Variable::new(
670 Tensor::from_vec(vec![1.0; 8], &[1, 8]).expect("tensor creation failed"),
671 false,
672 );
673
674 let train_out = layer.forward(&input);
676
677 layer.quantize_for_inference();
679 let infer_out = layer.forward(&input);
680
681 let train_vec = train_out.data().to_vec();
683 let infer_vec = infer_out.data().to_vec();
684 for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
685 assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
686 }
687 }
688
689 #[test]
690 fn test_ternary_linear_sparsity() {
691 let layer = TernaryLinear::new(64, 32);
692 let sparsity = layer.weight_sparsity();
693 assert!(sparsity >= 0.0 && sparsity <= 1.0);
695 }
696
697 #[test]
698 fn test_ternary_linear_compression_ratio() {
699 let layer = TernaryLinear::new(512, 512);
700 let ratio = layer.compression_ratio();
701 assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
703 }
704
705 #[test]
706 fn test_ternary_linear_parameters() {
707 let layer = TernaryLinear::new(16, 8);
708 let params = layer.parameters();
709 assert_eq!(params.len(), 2); let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
712 assert_eq!(layer_no_bias.parameters().len(), 1);
713 }
714
715 #[test]
716 fn test_ternary_linear_backward() {
717 let layer = TernaryLinear::new(4, 2);
718
719 let input = Variable::new(
720 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
721 true,
722 );
723 let output = layer.forward(&input);
724 let loss = output.sum();
725 loss.backward();
726
727 assert!(input.grad().is_some());
729 }
730}