1use std::any::Any;
28use std::collections::HashMap;
29
30use axonml_autograd::no_grad::is_grad_enabled;
31use axonml_autograd::{GradFn, GradientFunction, Variable};
32use axonml_tensor::Tensor;
33
34use crate::init::{kaiming_uniform, zeros};
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38#[derive(Debug, Clone)]
46pub struct PackedTernaryWeights {
47 data: Vec<u8>,
49 num_weights: usize,
51 scale: f32,
53}
54
55impl PackedTernaryWeights {
56 pub fn pack(ternary_values: &[i8], scale: f32) -> Self {
58 let num_weights = ternary_values.len();
59 let num_bytes = num_weights.div_ceil(4);
60 let mut data = vec![0u8; num_bytes];
61
62 for (i, &val) in ternary_values.iter().enumerate() {
63 let byte_idx = i / 4;
64 let bit_offset = (i % 4) * 2;
65 let encoded = match val {
66 0 => 0b00u8,
67 1 => 0b01u8,
68 -1 => 0b10u8,
69 _ => 0b00u8, };
71 data[byte_idx] |= encoded << bit_offset;
72 }
73
74 Self {
75 data,
76 num_weights,
77 scale,
78 }
79 }
80
81 pub fn unpack(&self) -> Vec<i8> {
83 let mut values = Vec::with_capacity(self.num_weights);
84 for i in 0..self.num_weights {
85 let byte_idx = i / 4;
86 let bit_offset = (i % 4) * 2;
87 let encoded = (self.data[byte_idx] >> bit_offset) & 0b11;
88 let val = match encoded {
89 0b00 => 0i8,
90 0b01 => 1i8,
91 0b10 => -1i8,
92 _ => 0i8,
93 };
94 values.push(val);
95 }
96 values
97 }
98
99 pub fn scale(&self) -> f32 {
101 self.scale
102 }
103
104 pub fn storage_bytes(&self) -> usize {
106 self.data.len()
107 }
108
109 pub fn num_weights(&self) -> usize {
111 self.num_weights
112 }
113
114 pub fn count_zeros(&self) -> usize {
116 let values = self.unpack();
117 values.iter().filter(|&&v| v == 0).count()
118 }
119}
120
121pub struct TernaryLinear {
151 pub shadow_weight: Parameter,
153 pub bias: Option<Parameter>,
155 packed_weights: Option<PackedTernaryWeights>,
157 in_features: usize,
159 out_features: usize,
161 inference_mode: bool,
163}
164
165impl TernaryLinear {
166 pub fn new(in_features: usize, out_features: usize) -> Self {
168 Self::with_bias(in_features, out_features, true)
169 }
170
171 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
173 let weight_data = kaiming_uniform(out_features, in_features);
174 let shadow_weight = Parameter::named("shadow_weight", weight_data, true);
175
176 let bias_param = if bias {
177 let bias_data = zeros(&[out_features]);
178 Some(Parameter::named("bias", bias_data, true))
179 } else {
180 None
181 };
182
183 Self {
184 shadow_weight,
185 bias: bias_param,
186 packed_weights: None,
187 in_features,
188 out_features,
189 inference_mode: false,
190 }
191 }
192
193 pub fn in_features(&self) -> usize {
195 self.in_features
196 }
197
198 pub fn out_features(&self) -> usize {
200 self.out_features
201 }
202
203 pub fn quantize_weights(&self) -> (Vec<i8>, f32) {
209 let w = self.shadow_weight.data();
210 let w_vec = w.to_vec();
211 let n = w_vec.len();
212
213 let abs_mean: f32 = w_vec.iter().map(|v| v.abs()).sum::<f32>() / n as f32;
215 let scale = abs_mean.max(1e-8); let ternary: Vec<i8> = w_vec
219 .iter()
220 .map(|&w| {
221 let normalized = (w.abs() / scale).round().min(1.0);
222 let sign = if w > 0.0 {
223 1i8
224 } else if w < 0.0 {
225 -1i8
226 } else {
227 0i8
228 };
229 sign * (normalized as i8)
230 })
231 .collect();
232
233 (ternary, scale)
234 }
235
236 pub fn quantize_for_inference(&mut self) {
238 let (ternary, scale) = self.quantize_weights();
239 self.packed_weights = Some(PackedTernaryWeights::pack(&ternary, scale));
240 self.inference_mode = true;
241 }
242
243 pub fn use_shadow_weights(&mut self) {
245 self.inference_mode = false;
246 }
247
248 pub fn weight_sparsity(&self) -> f32 {
250 let (ternary, _) = self.quantize_weights();
251 let zeros = ternary.iter().filter(|&&v| v == 0).count();
252 zeros as f32 / ternary.len() as f32
253 }
254
255 pub fn compression_ratio(&self) -> f32 {
257 let fp32_bytes = self.in_features * self.out_features * 4;
258 let ternary_bytes = (self.in_features * self.out_features).div_ceil(4) + 4; fp32_bytes as f32 / ternary_bytes as f32
260 }
261
262 pub fn packed_weights(&self) -> Option<&PackedTernaryWeights> {
264 self.packed_weights.as_ref()
265 }
266
267 fn ternary_matmul(
273 input: &[f32],
274 ternary: &[i8],
275 scale: f32,
276 batch_size: usize,
277 in_features: usize,
278 out_features: usize,
279 ) -> Vec<f32> {
280 let mut output = vec![0.0f32; batch_size * out_features];
281
282 for b in 0..batch_size {
283 let x_off = b * in_features;
284 let y_off = b * out_features;
285
286 for o in 0..out_features {
287 let w_off = o * in_features;
288 let mut sum_pos = 0.0f32;
289 let mut sum_neg = 0.0f32;
290
291 for j in 0..in_features {
292 let w = ternary[w_off + j];
293 let x = input[x_off + j];
294 if w == 1 {
295 sum_pos += x;
296 } else if w == -1 {
297 sum_neg += x;
298 }
299 }
301
302 output[y_off + o] = scale * (sum_pos - sum_neg);
303 }
304 }
305
306 output
307 }
308
309 fn forward_training(&self, input: &Variable) -> Variable {
311 let input_data = input.data();
312 let input_shape = input_data.shape();
313 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
314 let total_batch: usize = batch_dims.iter().product();
315
316 let (ternary, scale) = self.quantize_weights();
318
319 let input_vec = input_data.to_vec();
321
322 let output_vec = Self::ternary_matmul(
324 &input_vec,
325 &ternary,
326 scale,
327 total_batch,
328 self.in_features,
329 self.out_features,
330 );
331
332 let mut out_shape = batch_dims.clone();
334 out_shape.push(self.out_features);
335 let output_tensor = Tensor::from_vec(output_vec, &out_shape).unwrap();
336
337 let output_tensor = if let Some(ref bias) = self.bias {
339 let bias_vec = bias.data().to_vec();
340 let mut out = output_tensor.to_vec();
341 for b in 0..total_batch {
342 for o in 0..self.out_features {
343 out[b * self.out_features + o] += bias_vec[o];
344 }
345 }
346 Tensor::from_vec(out, &out_shape).unwrap()
347 } else {
348 output_tensor
349 };
350
351 let requires_grad = input.requires_grad() && is_grad_enabled();
352 if requires_grad {
353 let saved_input = input_data.clone();
357 let saved_ternary = ternary;
358 let saved_scale = scale;
359 let in_f = self.in_features;
360 let out_f = self.out_features;
361 let shadow_grad_fn = self.shadow_weight.variable().grad_fn().cloned();
362 let bias_grad_fn = self
363 .bias
364 .as_ref()
365 .and_then(|b| b.variable().grad_fn().cloned());
366
367 let mut next_fns = vec![input.grad_fn().cloned(), shadow_grad_fn];
368 if bias_grad_fn.is_some() {
369 next_fns.push(bias_grad_fn);
370 }
371
372 let grad_fn = GradFn::new(TernaryLinearBackward {
373 next_fns,
374 saved_input,
375 saved_ternary,
376 saved_scale,
377 in_features: in_f,
378 out_features: out_f,
379 has_bias: self.bias.is_some(),
380 total_batch,
381 });
382 Variable::from_operation(output_tensor, grad_fn, true)
383 } else {
384 Variable::new(output_tensor, false)
385 }
386 }
387
388 fn forward_inference(&self, input: &Variable) -> Variable {
390 let packed = self
391 .packed_weights
392 .as_ref()
393 .expect("Must call quantize_for_inference() before inference forward");
394
395 let input_data = input.data();
396 let input_shape = input_data.shape();
397 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
398 let total_batch: usize = batch_dims.iter().product();
399
400 let ternary = packed.unpack();
402 let scale = packed.scale();
403
404 let input_vec = input_data.to_vec();
405 let output_vec = Self::ternary_matmul(
406 &input_vec,
407 &ternary,
408 scale,
409 total_batch,
410 self.in_features,
411 self.out_features,
412 );
413
414 let mut out_shape = batch_dims;
415 out_shape.push(self.out_features);
416 let mut output_tensor = Tensor::from_vec(output_vec, &out_shape).unwrap();
417
418 if let Some(ref bias) = self.bias {
420 let bias_vec = bias.data().to_vec();
421 let mut out = output_tensor.to_vec();
422 for b in 0..total_batch {
423 for o in 0..self.out_features {
424 out[b * self.out_features + o] += bias_vec[o];
425 }
426 }
427 output_tensor = Tensor::from_vec(out, &out_shape).unwrap();
428 }
429
430 Variable::new(output_tensor, false)
431 }
432}
433
434impl Module for TernaryLinear {
435 fn forward(&self, input: &Variable) -> Variable {
436 if self.inference_mode {
437 self.forward_inference(input)
438 } else {
439 self.forward_training(input)
440 }
441 }
442
443 fn parameters(&self) -> Vec<Parameter> {
444 let mut params = vec![self.shadow_weight.clone()];
445 if let Some(ref bias) = self.bias {
446 params.push(bias.clone());
447 }
448 params
449 }
450
451 fn named_parameters(&self) -> HashMap<String, Parameter> {
452 let mut params = HashMap::new();
453 params.insert("shadow_weight".to_string(), self.shadow_weight.clone());
454 if let Some(ref bias) = self.bias {
455 params.insert("bias".to_string(), bias.clone());
456 }
457 params
458 }
459
460 fn name(&self) -> &'static str {
461 "TernaryLinear"
462 }
463}
464
465impl std::fmt::Debug for TernaryLinear {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 f.debug_struct("TernaryLinear")
468 .field("in_features", &self.in_features)
469 .field("out_features", &self.out_features)
470 .field("bias", &self.bias.is_some())
471 .field("inference_mode", &self.inference_mode)
472 .finish()
473 }
474}
475
476#[derive(Debug)]
491struct TernaryLinearBackward {
492 next_fns: Vec<Option<GradFn>>,
493 saved_input: Tensor<f32>,
494 saved_ternary: Vec<i8>,
495 saved_scale: f32,
496 in_features: usize,
497 out_features: usize,
498 has_bias: bool,
499 total_batch: usize,
500}
501
502impl GradientFunction for TernaryLinearBackward {
503 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
504 let g_vec = grad_output.to_vec();
505 let x_vec = self.saved_input.to_vec();
506
507 let mut grad_input = vec![0.0f32; self.total_batch * self.in_features];
511 for b in 0..self.total_batch {
512 let g_off = b * self.out_features;
513 let gi_off = b * self.in_features;
514
515 for j in 0..self.in_features {
516 let mut sum = 0.0f32;
517 for o in 0..self.out_features {
518 let w = self.saved_ternary[o * self.in_features + j];
519 if w == 1 {
520 sum += g_vec[g_off + o];
521 } else if w == -1 {
522 sum -= g_vec[g_off + o];
523 }
524 }
525 grad_input[gi_off + j] = self.saved_scale * sum;
526 }
527 }
528
529 let gi_tensor = Tensor::from_vec(grad_input, self.saved_input.shape()).unwrap();
530
531 let mut grad_weight = vec![0.0f32; self.out_features * self.in_features];
534 for b in 0..self.total_batch {
535 let g_off = b * self.out_features;
536 let x_off = b * self.in_features;
537
538 for o in 0..self.out_features {
539 let go = g_vec[g_off + o];
540 let w_off = o * self.in_features;
541 for j in 0..self.in_features {
542 grad_weight[w_off + j] += go * x_vec[x_off + j];
543 }
544 }
545 }
546 let gw_tensor =
547 Tensor::from_vec(grad_weight, &[self.out_features, self.in_features]).unwrap();
548
549 let mut results: Vec<Option<Tensor<f32>>> = vec![Some(gi_tensor), Some(gw_tensor)];
550
551 if self.has_bias {
553 let mut grad_bias = vec![0.0f32; self.out_features];
554 for b in 0..self.total_batch {
555 for o in 0..self.out_features {
556 grad_bias[o] += g_vec[b * self.out_features + o];
557 }
558 }
559 let gb_tensor = Tensor::from_vec(grad_bias, &[self.out_features]).unwrap();
560 results.push(Some(gb_tensor));
561 }
562
563 results
564 }
565
566 fn name(&self) -> &'static str {
567 "TernaryLinearBackward"
568 }
569
570 fn next_functions(&self) -> &[Option<GradFn>] {
571 &self.next_fns
572 }
573
574 fn as_any(&self) -> &dyn Any {
575 self
576 }
577}
578
579#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_ternary_linear_creation() {
589 let layer = TernaryLinear::new(64, 32);
590 assert_eq!(layer.in_features(), 64);
591 assert_eq!(layer.out_features(), 32);
592 assert!(layer.bias.is_some());
593 }
594
595 #[test]
596 fn test_ternary_linear_no_bias() {
597 let layer = TernaryLinear::with_bias(64, 32, false);
598 assert!(layer.bias.is_none());
599 }
600
601 #[test]
602 fn test_ternary_linear_forward() {
603 let layer = TernaryLinear::new(8, 4);
604 let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[2, 8]).unwrap(), false);
605 let output = layer.forward(&input);
606 assert_eq!(output.shape(), vec![2, 4]);
607 }
608
609 #[test]
610 fn test_ternary_quantization() {
611 let layer = TernaryLinear::new(16, 8);
612 let (ternary, scale) = layer.quantize_weights();
613
614 for &v in &ternary {
616 assert!(v == -1 || v == 0 || v == 1, "got {}", v);
617 }
618
619 assert!(scale > 0.0);
621
622 assert_eq!(ternary.len(), 16 * 8);
624 }
625
626 #[test]
627 fn test_packed_ternary_roundtrip() {
628 let values: Vec<i8> = vec![1, 0, -1, 1, 0, 0, -1, -1, 1, 0];
629 let packed = PackedTernaryWeights::pack(&values, 0.5);
630 let unpacked = packed.unpack();
631 assert_eq!(values, unpacked);
632 assert_eq!(packed.scale(), 0.5);
633 }
634
635 #[test]
636 fn test_packed_storage_compression() {
637 let n = 1024;
638 let values: Vec<i8> = (0..n).map(|i| ((i % 3) as i8) - 1).collect();
639 let packed = PackedTernaryWeights::pack(&values, 1.0);
640 assert_eq!(packed.storage_bytes(), 256);
642 }
643
644 #[test]
645 fn test_ternary_matmul_simple() {
646 let ternary = vec![1i8, -1, 0, 0, 1, 1];
648 let scale = 1.0;
649 let input = vec![2.0f32, 3.0, 5.0]; let output = TernaryLinear::ternary_matmul(&input, &ternary, scale, 1, 3, 2);
652
653 assert!((output[0] - (-1.0)).abs() < 1e-6);
656 assert!((output[1] - 8.0).abs() < 1e-6);
657 }
658
659 #[test]
660 fn test_ternary_linear_inference_mode() {
661 let mut layer = TernaryLinear::new(8, 4);
662
663 let input = Variable::new(Tensor::from_vec(vec![1.0; 8], &[1, 8]).unwrap(), false);
664
665 let train_out = layer.forward(&input);
667
668 layer.quantize_for_inference();
670 let infer_out = layer.forward(&input);
671
672 let train_vec = train_out.data().to_vec();
674 let infer_vec = infer_out.data().to_vec();
675 for (a, b) in train_vec.iter().zip(infer_vec.iter()) {
676 assert!((a - b).abs() < 1e-5, "Training {} vs inference {}", a, b);
677 }
678 }
679
680 #[test]
681 fn test_ternary_linear_sparsity() {
682 let layer = TernaryLinear::new(64, 32);
683 let sparsity = layer.weight_sparsity();
684 assert!(sparsity >= 0.0 && sparsity <= 1.0);
686 }
687
688 #[test]
689 fn test_ternary_linear_compression_ratio() {
690 let layer = TernaryLinear::new(512, 512);
691 let ratio = layer.compression_ratio();
692 assert!(ratio > 14.0 && ratio < 17.0, "ratio = {}", ratio);
694 }
695
696 #[test]
697 fn test_ternary_linear_parameters() {
698 let layer = TernaryLinear::new(16, 8);
699 let params = layer.parameters();
700 assert_eq!(params.len(), 2); let layer_no_bias = TernaryLinear::with_bias(16, 8, false);
703 assert_eq!(layer_no_bias.parameters().len(), 1);
704 }
705
706 #[test]
707 fn test_ternary_linear_backward() {
708 let layer = TernaryLinear::new(4, 2);
709
710 let input = Variable::new(
711 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
712 true,
713 );
714 let output = layer.forward(&input);
715 let loss = output.sum();
716 loss.backward();
717
718 assert!(input.grad().is_some());
720 }
721}