1use ternlang_core::trit::Trit;
19use serde::{Serialize, Deserialize};
20
21pub mod spectra_compat {
24 use super::*;
25
26 pub fn import_spectra_weights(raw_data: &[f32], rows: usize, cols: usize) -> TritMatrix {
29 println!("ternlang-ml: Annexing Spectra-1.1 weights (Scale: 1.2T tokens)...");
30 TritMatrix::from_f32(rows, cols, raw_data, 0.5)
32 }
33}
34
35pub mod coherence;
36pub mod qat;
37pub mod perplexity;
38
39pub fn quantize(weights: &[f32], threshold: f32) -> Vec<Trit> {
50 weights.iter().map(|&w| {
51 if w > threshold {
52 Trit::Affirm
53 } else if w < -threshold {
54 Trit::Reject
55 } else {
56 Trit::Tend
57 }
58 }).collect()
59}
60
61pub fn bitnet_threshold(weights: &[f32]) -> f32 {
63 let mean_abs = weights.iter().map(|w| w.abs()).sum::<f32>() / weights.len() as f32;
64 0.5 * mean_abs
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TritMatrix {
72 pub rows: usize,
73 pub cols: usize,
74 pub data: Vec<Trit>,
75}
76
77impl TritMatrix {
78 pub fn new(rows: usize, cols: usize) -> Self {
79 Self { rows, cols, data: vec![Trit::Tend; rows * cols] }
80 }
81
82 pub fn from_trits(rows: usize, cols: usize, data: Vec<Trit>) -> Self {
83 assert_eq!(data.len(), rows * cols);
84 Self { rows, cols, data }
85 }
86
87 pub fn from_f32(rows: usize, cols: usize, weights: &[f32], threshold: f32) -> Self {
88 Self::from_trits(rows, cols, quantize(weights, threshold))
89 }
90
91 #[inline]
92 pub fn get(&self, row: usize, col: usize) -> Trit {
93 self.data[row * self.cols + col]
94 }
95
96 #[inline]
97 pub fn set(&mut self, row: usize, col: usize, val: Trit) {
98 self.data[row * self.cols + col] = val;
99 }
100
101 pub fn sparsity(&self) -> f64 {
103 let zeros = self.data.iter().filter(|&&t| t == Trit::Tend).count();
104 zeros as f64 / self.data.len() as f64
105 }
106
107 pub fn nnz(&self) -> usize {
109 self.data.iter().filter(|&&t| t != Trit::Tend).count()
110 }
111
112 pub fn to_i8_vec(&self) -> Vec<i8> {
114 self.data.iter().map(|&t| match t {
115 Trit::Affirm => 1,
116 Trit::Reject => -1,
117 Trit::Tend => 0,
118 }).collect()
119 }
120}
121
122pub fn dense_matmul(a: &TritMatrix, b: &TritMatrix) -> TritMatrix {
128 assert_eq!(a.cols, b.rows, "matmul dimension mismatch: a.cols must equal b.rows");
129 let mut c = TritMatrix::new(a.rows, b.cols);
130 for row in 0..a.rows {
131 for col in 0..b.cols {
132 let mut acc = Trit::Tend;
133 for k in 0..a.cols {
134 let prod = a.get(row, k) * b.get(k, col);
135 let (sum, _carry) = acc + prod;
136 acc = sum;
137 }
138 c.set(row, col, acc);
139 }
140 }
141 c
142}
143
144pub fn sparse_matmul(a: &TritMatrix, b: &TritMatrix) -> (TritMatrix, usize) {
164 use rayon::prelude::*;
165
166 assert_eq!(a.cols, b.rows, "matmul dimension mismatch");
167
168 #[inline(always)]
169 fn t2i(t: Trit) -> i8 {
170 match t { Trit::Reject => -1, Trit::Tend => 0, Trit::Affirm => 1 }
171 }
172
173 let a_flat: Vec<i8> = a.data.iter().map(|&t| t2i(t)).collect();
175 let a_cols = a.cols;
176
177 let mut csc_offsets = vec![0usize; b.cols + 1];
182 for k in 0..b.rows {
184 for j in 0..b.cols {
185 if t2i(b.data[k * b.cols + j]) != 0 {
186 csc_offsets[j + 1] += 1;
187 }
188 }
189 }
190 for j in 0..b.cols {
192 csc_offsets[j + 1] += csc_offsets[j];
193 }
194 let nnz = csc_offsets[b.cols];
195 let mut csc_idx = vec![0u32; nnz];
196 let mut csc_val = vec![0i8; nnz];
197 let mut col_cursor = csc_offsets[..b.cols].to_vec(); for k in 0..b.rows {
199 for j in 0..b.cols {
200 let w = t2i(b.data[k * b.cols + j]);
201 if w != 0 {
202 let pos = col_cursor[j];
203 csc_idx[pos] = k as u32;
204 csc_val[pos] = w;
205 col_cursor[j] += 1;
206 }
207 }
208 }
209
210 let dense_ops = a.rows * b.cols * a.cols;
211 let active_ops = nnz * a.rows;
212 let skipped = dense_ops.saturating_sub(active_ops);
213
214 let mut out_flat = vec![0i8; a.rows * b.cols];
217
218 out_flat
219 .par_chunks_mut(b.cols)
220 .enumerate()
221 .for_each(|(row, row_out)| {
222 let a_row = &a_flat[row * a_cols..(row + 1) * a_cols];
223 for col in 0..b.cols {
224 let start = csc_offsets[col];
225 let end = csc_offsets[col + 1];
226 let mut acc: i32 = 0;
227 for i in start..end {
230 let k = unsafe { *csc_idx.get_unchecked(i) } as usize;
231 let w = unsafe { *csc_val.get_unchecked(i) } as i32;
232 let av = unsafe { *a_row.get_unchecked(k) } as i32;
233 acc += av * w;
234 }
235 row_out[col] = if acc > 0 { 1 } else if acc < 0 { -1 } else { 0 };
236 }
237 });
238
239 let c_data: Vec<Trit> = out_flat.into_iter().map(|v| Trit::from(v)).collect();
241 let c = TritMatrix { rows: a.rows, cols: b.cols, data: c_data };
242
243 (c, skipped)
244}
245
246pub fn linear(input: &TritMatrix, weights: &TritMatrix) -> (TritMatrix, usize) {
254 sparse_matmul(input, weights)
255}
256
257pub struct BenchmarkResult {
261 pub dense_ops: usize,
262 pub sparse_ops: usize,
263 pub skipped_ops: usize,
264 pub skip_rate: f64,
265 pub weight_sparsity: f64,
266}
267
268impl BenchmarkResult {
269 pub fn print_summary(&self) {
270 println!("=== Ternary Sparse Matmul Benchmark ===");
271 println!(" Weight sparsity: {:.1}% zeros", self.weight_sparsity * 100.0);
272 println!(" Dense ops: {}", self.dense_ops);
273 println!(" Sparse ops: {}", self.sparse_ops);
274 println!(" Skipped ops: {}", self.skipped_ops);
275 println!(" Skip rate: {:.1}%", self.skip_rate * 100.0);
276 println!(" Ops saved: {:.1}x fewer multiplies", self.dense_ops as f64 / self.sparse_ops.max(1) as f64);
277 }
278}
279
280pub fn benchmark(a: &TritMatrix, b: &TritMatrix) -> BenchmarkResult {
281 let dense_ops = a.rows * a.cols * b.cols;
282 let (_result, skipped) = sparse_matmul(a, b);
283 let sparse_ops = dense_ops - skipped;
284 BenchmarkResult {
285 dense_ops,
286 sparse_ops,
287 skipped_ops: skipped,
288 skip_rate: skipped as f64 / dense_ops as f64,
289 weight_sparsity: b.sparsity(),
290 }
291}
292
293pub fn trit_activation(t: Trit) -> Trit { t }
299
300pub fn majority(trits: &[Trit]) -> Trit {
303 let sum: i32 = trits.iter().map(|&t| match t {
304 Trit::Affirm => 1,
305 Trit::Reject => -1,
306 Trit::Tend => 0,
307 }).sum();
308 match sum.signum() {
309 1 => Trit::Affirm,
310 -1 => Trit::Reject,
311 _ => Trit::Tend,
312 }
313}
314
315pub struct TernaryMLP {
325 pub w1: TritMatrix, pub w2: TritMatrix, pub in_features: usize,
328 pub hidden_size: usize,
329 pub out_features: usize,
330}
331
332impl TernaryMLP {
333 pub fn new(w1: TritMatrix, w2: TritMatrix) -> Self {
335 let in_features = w1.rows;
336 let hidden_size = w1.cols;
337 let out_features = w2.cols;
338 assert_eq!(w2.rows, hidden_size, "w1.cols must equal w2.rows");
339 Self { w1, w2, in_features, hidden_size, out_features }
340 }
341
342 pub fn from_f32(
344 in_features: usize, hidden_size: usize, out_features: usize,
345 w1_f32: &[f32], w2_f32: &[f32],
346 ) -> Self {
347 let tau1 = bitnet_threshold(w1_f32);
348 let tau2 = bitnet_threshold(w2_f32);
349 let w1 = TritMatrix::from_f32(in_features, hidden_size, w1_f32, tau1);
350 let w2 = TritMatrix::from_f32(hidden_size, out_features, w2_f32, tau2);
351 Self::new(w1, w2)
352 }
353
354 pub fn forward(&self, input: &TritMatrix) -> (TritMatrix, usize, usize) {
358 assert_eq!(input.cols, self.in_features,
359 "input width must match in_features");
360
361 let (hidden, skip1) = sparse_matmul(input, &self.w1);
363
364 let hidden_act = TritMatrix::from_trits(
366 hidden.rows, hidden.cols,
367 hidden.data.iter().map(|&t| trit_activation(t)).collect(),
368 );
369
370 let (output, skip2) = sparse_matmul(&hidden_act, &self.w2);
372
373 (output, skip1, skip2)
374 }
375
376 pub fn predict(&self, input: &TritMatrix) -> usize {
379 let (output, _, _) = self.forward(input);
380 let row = 0;
381 let mut best_col = 0;
382 let mut best_val: i8 = -2;
383 for col in 0..self.out_features {
384 let v = match output.get(row, col) {
385 Trit::Affirm => 1,
386 Trit::Tend => 0,
387 Trit::Reject => -1,
388 };
389 if v > best_val { best_val = v; best_col = col; }
390 }
391 best_col
392 }
393
394 pub fn layer1_sparsity(&self) -> f64 { self.w1.sparsity() }
395 pub fn layer2_sparsity(&self) -> f64 { self.w2.sparsity() }
396
397 pub fn forward_logits(&self, input: &[f32]) -> Vec<f32> {
404 assert_eq!(input.len(), self.in_features);
405 let (inf, hs, outf) = (self.in_features, self.hidden_size, self.out_features);
406
407 let w1_f: Vec<f32> = self.w1.to_i8_vec().iter().map(|&v| v as f32).collect();
409 let w2_f: Vec<f32> = self.w2.to_i8_vec().iter().map(|&v| v as f32).collect();
410
411 let mut hidden = vec![0.0f32; hs];
413 for j in 0..hs {
414 for i in 0..inf {
415 hidden[j] += input[i] * w1_f[i * hs + j];
416 }
417 }
418
419 let hidden_act: Vec<f32> = hidden.iter().map(|&h| {
421 if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
422 }).collect();
423
424 let mut output = vec![0.0f32; outf];
426 for j in 0..outf {
427 for i in 0..hs {
428 output[j] += hidden_act[i] * w2_f[i * outf + j];
429 }
430 }
431 output
432 }
433}
434
435#[derive(Debug)]
439pub struct TimedResult {
440 pub size: usize, pub dense_ops: usize,
442 pub sparse_ops: usize,
443 pub skipped_ops: usize,
444 pub weight_sparsity: f64,
445 pub skip_rate: f64,
446 pub speedup: f64,
447 pub dense_us: u64, pub sparse_us: u64, }
450
451pub fn timed_benchmark(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
456 use std::time::Instant;
457
458 fn lcg_weights(n: usize, seed: u64) -> Vec<f32> {
460 let mut state = seed;
461 (0..n).map(|_| {
462 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
463 let f = ((state >> 33) as f32) / (u32::MAX as f32) * 3.0 - 1.5;
466 f
467 }).collect()
468 }
469
470 fn median_us(mut times: Vec<u64>) -> u64 {
471 times.sort_unstable();
472 times[times.len() / 2]
473 }
474
475 sizes.iter().map(|&n| {
476 let weights_a = lcg_weights(n * n, 0xdeadbeef);
477 let weights_b = lcg_weights(n * n, 0xc0ffee42);
478 let tau_a = bitnet_threshold(&weights_a);
479 let tau_b = bitnet_threshold(&weights_b);
480 let a = TritMatrix::from_f32(n, n, &weights_a, tau_a);
481
482 let b = TritMatrix::from_f32(n, n, &weights_b, tau_b);
483
484 let sparsity = b.sparsity();
485 let dense_ops = n * n * n;
486 let (_, skipped) = sparse_matmul(&a, &b); let sparse_ops = dense_ops - skipped;
488
489 let dense_times: Vec<u64> = (0..reps).map(|_| {
491 let t = Instant::now();
492 let _ = dense_matmul(&a, &b);
493 t.elapsed().as_micros() as u64
494 }).collect();
495
496 let sparse_times: Vec<u64> = (0..reps).map(|_| {
498 let t = Instant::now();
499 let _ = sparse_matmul(&a, &b);
500 t.elapsed().as_micros() as u64
501 }).collect();
502
503 let dense_us = median_us(dense_times);
504 let sparse_us = median_us(sparse_times);
505 let speedup = if sparse_us > 0 {
506 dense_us as f64 / sparse_us as f64
507 } else { dense_ops as f64 / sparse_ops.max(1) as f64 };
508
509 TimedResult {
510 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
511 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
512 speedup, dense_us, sparse_us,
513 }
514 }).collect()
515}
516
517pub fn print_benchmark_table(results: &[TimedResult]) {
519 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
520 println!( "║ Ternlang Sparse Matmul Benchmark — RFI-IRFOS TIS ║");
521 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
522 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
523 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
524 for r in results {
525 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
526 r.size,
527 r.weight_sparsity * 100.0,
528 r.dense_us,
529 r.sparse_us,
530 r.speedup,
531 r.skip_rate * 100.0,
532 );
533 }
534 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
535}
536
537pub fn bitnet_matrix(rows: usize, cols: usize, seed: u64, target_sparsity: f64) -> TritMatrix {
543 let mut state = seed;
544 let n = rows * cols;
545 let mut data = Vec::with_capacity(n);
546 for _ in 0..n {
547 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
548 let prob = (state >> 32) as f64 / (u32::MAX as f64 + 1.0);
549 if prob < target_sparsity {
550 data.push(Trit::Tend);
551 } else if (state & 1) == 0 {
552 data.push(Trit::Affirm);
553 } else {
554 data.push(Trit::Reject);
555 }
556 }
557 TritMatrix { rows, cols, data }
558}
559
560pub fn timed_benchmark_bitnet(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
564 timed_benchmark_at_sparsity(0.60, sizes, reps)
565}
566
567pub fn timed_benchmark_at_sparsity(target_sparsity: f64, sizes: &[usize], reps: usize) -> Vec<TimedResult> {
569 use std::time::Instant;
570
571 let bitnet_sparsity: f64 = target_sparsity;
572
573 fn median_us(mut v: Vec<u64>) -> u64 {
574 v.sort_unstable();
575 v[v.len() / 2]
576 }
577
578 sizes.iter().map(|&n| {
579 let a = bitnet_matrix(n, n, 0xdeadbeef, bitnet_sparsity);
580 let b = bitnet_matrix(n, n, 0xc0ffee42, bitnet_sparsity);
581
582 let sparsity = b.sparsity();
583 let dense_ops = n * n * n;
584 let (_, skipped) = sparse_matmul(&a, &b);
585 let sparse_ops = dense_ops - skipped;
586 let speedup_ops = dense_ops as f64 / sparse_ops.max(1) as f64;
587
588 let dense_times: Vec<u64> = (0..reps).map(|_| {
589 let t = Instant::now();
590 let _ = dense_matmul(&a, &b);
591 t.elapsed().as_micros() as u64
592 }).collect();
593
594 let sparse_times: Vec<u64> = (0..reps).map(|_| {
595 let t = Instant::now();
596 let _ = sparse_matmul(&a, &b);
597 t.elapsed().as_micros() as u64
598 }).collect();
599
600 let dense_us = median_us(dense_times);
601 let sparse_us = median_us(sparse_times);
602 let speedup = if sparse_us > 0 {
603 dense_us as f64 / sparse_us as f64
604 } else { speedup_ops };
605
606 TimedResult {
607 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
608 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
609 speedup, dense_us, sparse_us,
610 }
611 }).collect()
612}
613
614pub fn xor_dataset() -> Vec<(TritMatrix, usize)> {
619 let inputs = vec![
620 (vec![Trit::Reject, Trit::Reject], 0usize), (vec![Trit::Reject, Trit::Affirm], 1usize), (vec![Trit::Affirm, Trit::Reject], 1usize), (vec![Trit::Affirm, Trit::Affirm], 0usize), ];
625 inputs.into_iter().map(|(row, label)| {
626 (TritMatrix::from_trits(1, 2, row), label)
627 }).collect()
628}
629
630pub fn parity_dataset() -> Vec<(TritMatrix, usize)> {
632 (0u8..8).map(|i| {
633 let bits = vec![
634 if i & 4 != 0 { Trit::Affirm } else { Trit::Reject },
635 if i & 2 != 0 { Trit::Affirm } else { Trit::Reject },
636 if i & 1 != 0 { Trit::Affirm } else { Trit::Reject },
637 ];
638 let parity = (i.count_ones() % 2) as usize;
639 (TritMatrix::from_trits(1, 3, bits), parity)
640 }).collect()
641}
642
643pub fn evaluate(mlp: &TernaryMLP, dataset: &[(TritMatrix, usize)]) -> (usize, usize, f64) {
646 let total = dataset.len();
647 let correct = dataset.iter()
648 .filter(|(input, label)| mlp.predict(input) == *label)
649 .count();
650 let accuracy = correct as f64 / total as f64;
651 (correct, total, accuracy)
652}
653
654pub const TEND_BOUNDARY: f32 = 1.0 / 3.0;
669
670#[derive(Debug, Clone)]
672pub struct TritScalar(pub f32);
673
674impl TritScalar {
675 pub fn new(v: f32) -> Self { TritScalar(v.clamp(-1.0, 1.0)) }
677
678 pub fn trit(&self) -> Trit {
680 if self.0 > TEND_BOUNDARY { Trit::Affirm }
681 else if self.0 < -TEND_BOUNDARY { Trit::Reject }
682 else { Trit::Tend }
683 }
684
685 pub fn label(&self) -> &'static str {
687 match self.trit() {
688 Trit::Affirm => "affirm",
689 Trit::Reject => "reject",
690 Trit::Tend => "tend",
691 }
692 }
693
694 pub fn confidence(&self) -> f32 {
699 let v = self.0.abs();
700 if v > TEND_BOUNDARY {
701 (v - TEND_BOUNDARY) / (1.0 - TEND_BOUNDARY)
702 } else {
703 1.0 - v / TEND_BOUNDARY
704 }
705 }
706
707 pub fn is_actionable(&self, min_confidence: f32) -> bool {
710 self.trit() != Trit::Tend && self.confidence() >= min_confidence
711 }
712
713 pub fn raw(&self) -> f32 { self.0 }
715
716 pub fn trit_i8(&self) -> i8 {
718 match self.trit() { Trit::Affirm => 1, Trit::Reject => -1, Trit::Tend => 0 }
719 }
720}
721
722pub struct TritEvidenceVec {
736 pub dimensions: Vec<String>,
737 pub values: Vec<f32>, pub weights: Vec<f32>, }
740
741impl TritEvidenceVec {
742 pub fn new(dimensions: Vec<String>, values: Vec<f32>, weights: Vec<f32>) -> Self {
743 assert_eq!(dimensions.len(), values.len(), "dimensions and values must match");
744 assert_eq!(dimensions.len(), weights.len(), "dimensions and weights must match");
745 let values = values.iter().map(|&v| v.clamp(-1.0, 1.0)).collect();
746 TritEvidenceVec { dimensions, values, weights }
747 }
748
749 pub fn aggregate(&self) -> TritScalar {
751 let total_weight: f32 = self.weights.iter().sum();
752 if total_weight == 0.0 { return TritScalar::new(0.0); }
753 let weighted_sum: f32 = self.values.iter()
754 .zip(self.weights.iter())
755 .map(|(v, w)| v * w)
756 .sum();
757 TritScalar::new(weighted_sum / total_weight)
758 }
759
760 pub fn scalars(&self) -> Vec<TritScalar> {
762 self.values.iter().map(|&v| TritScalar::new(v)).collect()
763 }
764
765 pub fn dominant(&self) -> Option<(&str, TritScalar)> {
767 self.values.iter()
768 .enumerate()
769 .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap_or(std::cmp::Ordering::Equal))
770 .map(|(i, &v)| (self.dimensions[i].as_str(), TritScalar::new(v)))
771 }
772}
773
774#[cfg(test)]
777mod tests {
778 use super::*;
779
780 #[test]
781 fn test_quantize_basic() {
782 let weights = vec![-0.9f32, -0.2, 0.0, 0.3, 0.8];
783 let threshold = 0.5;
784 let trits = quantize(&weights, threshold);
785 assert_eq!(trits, vec![Trit::Reject, Trit::Tend, Trit::Tend, Trit::Tend, Trit::Affirm]);
786 }
787
788 #[test]
789 fn test_bitnet_threshold() {
790 let weights = vec![1.0f32, -1.0, 0.5, -0.5];
791 let tau = bitnet_threshold(&weights);
792 assert!((tau - 0.375).abs() < 1e-6);
794 }
795 #[test]
796 fn test_dense_matmul_identity() {
797 let mut id = TritMatrix::new(2, 2);
799 id.set(0, 0, Trit::Affirm);
800 id.set(1, 1, Trit::Affirm);
801
802 let result = dense_matmul(&id, &id);
803 assert_eq!(result.get(0, 0), Trit::Affirm);
804 assert_eq!(result.get(0, 1), Trit::Tend);
805 assert_eq!(result.get(1, 0), Trit::Tend);
806 assert_eq!(result.get(1, 1), Trit::Affirm);
807 }
808
809 #[test]
810 fn test_sparse_matmul_matches_dense() {
811 let weights = vec![0.9f32, -0.1, 0.05, -0.8, 0.0, 0.7, -0.6, 0.2, 0.0];
813 let threshold = 0.5;
814 let w = TritMatrix::from_f32(3, 3, &weights, threshold);
815 let mut input = TritMatrix::new(3, 3);
816 input.set(0, 0, Trit::Affirm);
817 input.set(1, 1, Trit::Reject);
818 input.set(2, 2, Trit::Affirm);
819
820 let dense = dense_matmul(&input, &w);
821 let (sparse, skipped) = sparse_matmul(&input, &w);
822
823 for r in 0..3 {
825 for c in 0..3 {
826 assert_eq!(dense.get(r, c), sparse.get(r, c),
827 "mismatch at ({}, {})", r, c);
828 }
829 }
830 assert!(skipped > 0, "expected skips for a sparse weight matrix");
832 }
833
834 #[test]
835 fn test_sparsity_measurement() {
836 let weights = vec![0.9f32, 0.1, -0.9]; let threshold = 0.5;
838 let m = TritMatrix::from_f32(1, 3, &weights, threshold);
839 assert!((m.sparsity() - 1.0/3.0).abs() < 1e-9);
841 assert_eq!(m.nnz(), 2);
842 }
843
844 #[test]
845 fn test_majority_vote() {
846 assert_eq!(majority(&[Trit::Affirm, Trit::Affirm, Trit::Reject]), Trit::Affirm);
847 assert_eq!(majority(&[Trit::Reject, Trit::Reject, Trit::Affirm]), Trit::Reject);
848 assert_eq!(majority(&[Trit::Affirm, Trit::Reject]), Trit::Tend);
849 assert_eq!(majority(&[Trit::Tend, Trit::Tend]), Trit::Tend);
850 }
851
852 #[test]
853 fn test_mlp_forward_runs() {
854 let w1_f32: Vec<f32> = vec![
856 0.9, -0.8, 0.7, -0.6,
857 -0.7, 0.9, -0.5, 0.8,
858 ];
859 let w2_f32: Vec<f32> = vec![
860 0.9, -0.9,
861 -0.8, 0.8,
862 0.7, -0.7,
863 -0.6, 0.6,
864 ];
865 let mlp = TernaryMLP::from_f32(2, 4, 2, &w1_f32, &w2_f32);
866 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
867 let (out, s1, s2) = mlp.forward(&input);
868 assert_eq!(out.rows, 1);
869 assert_eq!(out.cols, 2);
870 let _ = (s1, s2);
872 }
873
874 #[test]
875 fn test_mlp_predict_returns_valid_class() {
876 let w1_f32: Vec<f32> = vec![0.9, -0.8, -0.7, 0.9];
877 let w2_f32: Vec<f32> = vec![0.9, -0.9, -0.8, 0.8];
878 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
879 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
880 let pred = mlp.predict(&input);
881 assert!(pred < 2, "prediction must be a valid class index");
882 }
883
884 #[test]
885 fn test_xor_dataset_shape() {
886 let ds = xor_dataset();
887 assert_eq!(ds.len(), 4);
888 for (input, label) in &ds {
889 assert_eq!(input.rows, 1);
890 assert_eq!(input.cols, 2);
891 assert!(*label < 2);
892 }
893 }
894
895 #[test]
896 fn test_parity_dataset_shape() {
897 let ds = parity_dataset();
898 assert_eq!(ds.len(), 8);
899 for (input, label) in &ds {
900 assert_eq!(input.cols, 3);
901 assert!(*label < 2);
902 }
903 }
904
905 #[test]
906 fn test_xor_mlp_with_known_weights() {
907 let w1_f32 = vec![
913 1.0, -1.0,
914 -1.0, 1.0,
915 ];
916 let w2_f32 = vec![
919 -1.0, 1.0,
920 -1.0, 1.0,
921 ];
922 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
923 let ds = xor_dataset();
924 let (correct, total, acc) = evaluate(&mlp, &ds);
925 println!("XOR MLP: {}/{} = {:.0}%", correct, total, acc * 100.0);
926 assert!(correct >= 2, "MLP should get at least half of XOR correct");
929 }
930
931 #[test]
932 fn test_timed_benchmark_small() {
933 let results = timed_benchmark(&[8, 16], 3);
934 assert_eq!(results.len(), 2);
935 for r in &results {
936 assert!(r.dense_ops > 0);
937 assert!(r.weight_sparsity >= 0.0 && r.weight_sparsity <= 1.0);
938 assert!(r.skip_rate >= 0.0 && r.skip_rate <= 1.0);
939 }
940 print_benchmark_table(&results);
941 }
942
943 #[test]
944 fn test_benchmark_reports_skips() {
945 let weights: Vec<f32> = vec![
947 0.9, 0.1, -0.9, 0.0,
948 0.1, 0.8, 0.0, -0.7,
949 0.0, 0.1, 0.6, 0.2,
950 -0.8, 0.0, 0.1, 0.9,
951 ];
952 let threshold = 0.5;
953 let w = TritMatrix::from_f32(4, 4, &weights, threshold);
954 let input = TritMatrix::new(4, 4); let result = benchmark(&input, &w);
956 assert!(result.skipped_ops > 0);
957 assert!(result.skip_rate > 0.0 && result.skip_rate <= 1.0);
958 result.print_summary();
959 }
960
961 #[test]
962 fn test_full_benchmark() {
963 let results = timed_benchmark(&[32, 64, 128, 256, 512], 5);
964 assert_eq!(results.len(), 5);
965 print_benchmark_table(&results);
966 }
967
968 #[test]
971 fn test_bitnet_benchmark() {
972 let results = timed_benchmark_bitnet(&[32, 64, 128, 256, 512], 5);
973 assert_eq!(results.len(), 5);
974 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
975 println!( "║ BitNet b1.58 Realistic Benchmark — 60% Sparsity — RFI-IRFOS TIS ║");
976 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
977 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
978 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
979 for r in &results {
980 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
981 r.size,
982 r.weight_sparsity * 100.0,
983 r.dense_us,
984 r.sparse_us,
985 r.speedup,
986 r.skip_rate * 100.0,
987 );
988 }
989 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
990 for r in &results {
991 assert!(r.skip_rate >= 0.50, "Expected ≥50% skip rate at 60% sparsity, got {:.1}%", r.skip_rate * 100.0);
992 }
993 }
994
995 #[test]
997 fn test_extreme_sparsity_99() {
998 let results = timed_benchmark_at_sparsity(0.99, &[32, 64, 128, 256, 512], 5);
999 assert_eq!(results.len(), 5);
1000 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
1001 println!( "║ EXTREME SPARSITY — 99% Zeros — What Happens? ║");
1002 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
1003 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
1004 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
1005 for r in &results {
1006 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>6.1}× ║ {:>6.1}% ║",
1007 r.size,
1008 r.weight_sparsity * 100.0,
1009 r.dense_us,
1010 r.sparse_us,
1011 r.speedup,
1012 r.skip_rate * 100.0,
1013 );
1014 }
1015 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
1016 for r in &results {
1017 assert!(r.skip_rate >= 0.95, "Expected ≥95% skip rate at 99% sparsity");
1018 }
1019 }
1020
1021 #[test]
1024 fn test_sparsity_sweep() {
1025 let sparsities: &[f64] = &[0.25, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.99];
1026 let sizes: &[usize] = &[32, 64, 128, 256, 512];
1027
1028 let mut grid: Vec<Vec<f64>> = Vec::new();
1030 for &sp in sparsities {
1031 let row: Vec<f64> = timed_benchmark_at_sparsity(sp, sizes, 3)
1032 .into_iter().map(|r| r.speedup).collect();
1033 grid.push(row);
1034 }
1035
1036 println!();
1038 println!("╔══════════════ SPARSITY GOLDILOCKS SWEEP ══════════════════════════╗");
1039 println!("║ Speedup (sparse / dense) across sparsity × matrix size ║");
1040 println!("╠══════════╦═══════╦═══════╦════════╦════════╦════════╣");
1041 print!( "║ Sparsity ║");
1042 for &n in sizes { print!(" {:>4}² ║", n); }
1043 println!();
1044 println!("╠══════════╬═══════╬═══════╬════════╬════════╬════════╣");
1045
1046 let mut peak_speedup = 0f64;
1047 let mut peak_sp = 0f64;
1048 let mut peak_n = 0usize;
1049
1050 for (i, &sp) in sparsities.iter().enumerate() {
1051 print!("║ {:>5.1}% ║", sp * 100.0);
1052 for (j, &speedup) in grid[i].iter().enumerate() {
1053 if speedup > peak_speedup {
1054 peak_speedup = speedup;
1055 peak_sp = sp;
1056 peak_n = sizes[j];
1057 }
1058 print!(" {:>5.1}× ║", speedup);
1059 }
1060 println!();
1061 }
1062
1063 println!("╚══════════╩═══════╩═══════╩════════╩════════╩════════╝");
1064 println!();
1065 println!(" ★ Peak: {:.1}× at {:.0}% sparsity, {}×{} matrix", peak_speedup, peak_sp * 100.0, peak_n, peak_n);
1066
1067 let avg_speedups: Vec<(f64, f64)> = sparsities.iter().zip(grid.iter())
1069 .map(|(&sp, row)| (sp, row.iter().sum::<f64>() / row.len() as f64))
1070 .collect();
1071 let (best_sp, best_avg) = avg_speedups.iter()
1072 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
1073 .copied().unwrap();
1074 println!(" ◆ Goldilocks zone: {:.0}% sparsity → {:.1}× average across all sizes", best_sp * 100.0, best_avg);
1075 println!();
1076
1077 for row in &grid {
1080 for &s in &row[1..] { assert!(s >= 1.0, "Speedup dropped below 1× — something is wrong");
1082 }
1083 }
1084 }
1085
1086 #[test]
1089 fn test_trit_scalar_zones() {
1090 assert_eq!(TritScalar::new(0.9).label(), "affirm");
1091 assert_eq!(TritScalar::new(-0.9).label(), "reject");
1092 assert_eq!(TritScalar::new(0.0).label(), "tend");
1093 assert_eq!(TritScalar::new(0.33).label(), "tend"); assert_eq!(TritScalar::new(0.34).label(), "affirm"); }
1096
1097 #[test]
1098 fn test_trit_scalar_confidence() {
1099 let s = TritScalar::new(0.0);
1101 assert_eq!(s.label(), "tend");
1102 assert!((s.confidence() - 1.0).abs() < 0.01);
1103
1104 let s = TritScalar::new(1.0);
1106 assert_eq!(s.label(), "affirm");
1107 assert!((s.confidence() - 1.0).abs() < 0.01);
1108
1109 let s = TritScalar::new(TEND_BOUNDARY + 0.001);
1111 assert_eq!(s.label(), "affirm");
1112 assert!(s.confidence() < 0.01);
1113 }
1114
1115 #[test]
1116 fn test_trit_scalar_actionable() {
1117 assert!(TritScalar::new(0.9).is_actionable(0.5));
1119 assert!(!TritScalar::new(0.35).is_actionable(0.8));
1121 assert!(!TritScalar::new(0.0).is_actionable(0.0));
1123 }
1124
1125 #[test]
1126 fn test_trit_scalar_clamp() {
1127 assert!((TritScalar::new(5.0).raw() - 1.0).abs() < 0.001);
1128 assert!((TritScalar::new(-5.0).raw() + 1.0).abs() < 0.001);
1129 }
1130
1131 #[test]
1134 fn test_evidence_vec_aggregate_uniform() {
1135 let ev = TritEvidenceVec::new(
1137 vec!["a".into(), "b".into(), "c".into()],
1138 vec![0.8, 0.9, 0.7],
1139 vec![1.0, 1.0, 1.0],
1140 );
1141 let agg = ev.aggregate();
1142 assert_eq!(agg.label(), "affirm");
1143 assert!(agg.confidence() > 0.5);
1144 }
1145
1146 #[test]
1147 fn test_evidence_vec_mixed_signals() {
1148 let ev = TritEvidenceVec::new(
1150 vec!["strong_reject".into(), "weak_affirm".into()],
1151 vec![-0.9, 0.1],
1152 vec![1.0, 1.0],
1153 );
1154 let agg = ev.aggregate();
1155 assert_eq!(agg.label(), "reject");
1157 }
1158
1159 #[test]
1160 fn test_evidence_vec_weighted_override() {
1161 let ev = TritEvidenceVec::new(
1163 vec!["weak_reject".into(), "strong_affirm".into()],
1164 vec![-0.4, 0.9],
1165 vec![10.0, 1.0], );
1167 let agg = ev.aggregate();
1168 assert_eq!(agg.label(), "tend");
1170 }
1171
1172 #[test]
1173 fn test_evidence_vec_dominant() {
1174 let ev = TritEvidenceVec::new(
1175 vec!["low".into(), "high".into(), "mid".into()],
1176 vec![0.2, -0.95, 0.5],
1177 vec![1.0, 1.0, 1.0],
1178 );
1179 let (label, scalar) = ev.dominant().unwrap();
1180 assert_eq!(label, "high");
1181 assert_eq!(scalar.label(), "reject");
1182 }
1183}
1184
1185#[derive(Debug, Clone)]
1203pub struct DeliberationRound {
1204 pub round: usize,
1205 pub new_evidence: Vec<f32>, pub cumulative_mean: f32, pub scalar: TritScalar,
1208 pub converged: bool, }
1210
1211#[derive(Debug, Clone)]
1213pub struct DeliberationResult {
1214 pub final_trit: i8,
1215 pub final_label: String,
1216 pub final_confidence: f32,
1217 pub converged: bool,
1218 pub rounds_used: usize,
1219 pub trace: Vec<DeliberationRound>,
1220 pub convergence_reason: String,
1221}
1222
1223pub struct DeliberationEngine {
1232 pub target_confidence: f32,
1234 pub max_rounds: usize,
1236 pub alpha: f32,
1238}
1239
1240impl DeliberationEngine {
1241 pub fn new(target_confidence: f32, max_rounds: usize) -> Self {
1242 Self { target_confidence, max_rounds, alpha: 0.4 }
1243 }
1244
1245 pub fn with_alpha(mut self, alpha: f32) -> Self { self.alpha = alpha.clamp(0.01, 1.0); self }
1246
1247 pub fn run(&self, rounds_evidence: Vec<Vec<f32>>) -> DeliberationResult {
1250 let mut ema: f32 = 0.0; let mut initialized = false;
1252 let mut trace = Vec::new();
1253
1254 let rounds_to_run = self.max_rounds.min(
1255 if rounds_evidence.is_empty() { self.max_rounds } else { rounds_evidence.len() }
1256 );
1257
1258 for round in 0..rounds_to_run {
1259 let new_ev: Vec<f32> = rounds_evidence.get(round).cloned().unwrap_or_default();
1260
1261 if !new_ev.is_empty() {
1263 let round_mean = new_ev.iter().sum::<f32>() / new_ev.len() as f32;
1264 ema = if !initialized {
1265 initialized = true;
1266 round_mean
1267 } else {
1268 self.alpha * round_mean + (1.0 - self.alpha) * ema
1269 };
1270 }
1271
1272 let scalar = TritScalar::new(ema);
1273 let converged = scalar.confidence() >= self.target_confidence;
1274
1275 trace.push(DeliberationRound {
1276 round,
1277 new_evidence: new_ev,
1278 cumulative_mean: ema,
1279 scalar: scalar.clone(),
1280 converged,
1281 });
1282
1283 if converged { break; }
1284 }
1285
1286 let last = trace.last().cloned().unwrap_or_else(|| DeliberationRound {
1287 round: 0, new_evidence: vec![], cumulative_mean: 0.0,
1288 scalar: TritScalar::new(0.0), converged: false,
1289 });
1290
1291 let convergence_reason = if last.converged {
1292 format!("confidence {:.1}% ≥ target {:.1}% after {} round(s)",
1293 last.scalar.confidence() * 100.0,
1294 self.target_confidence * 100.0,
1295 last.round + 1)
1296 } else {
1297 format!("max rounds ({}) reached — confidence {:.1}% below target {:.1}%",
1298 self.max_rounds,
1299 last.scalar.confidence() * 100.0,
1300 self.target_confidence * 100.0)
1301 };
1302
1303 DeliberationResult {
1304 final_trit: last.scalar.trit_i8(),
1305 final_label: last.scalar.label().to_string(),
1306 final_confidence: last.scalar.confidence(),
1307 converged: last.converged,
1308 rounds_used: last.round + 1,
1309 trace,
1310 convergence_reason,
1311 }
1312 }
1313}
1314
1315#[derive(Debug, Clone)]
1319pub struct CoalitionMember {
1320 pub label: String,
1321 pub trit: i8, pub confidence: f32, pub weight: f32, }
1325
1326impl CoalitionMember {
1327 pub fn new(label: impl Into<String>, trit: i8, confidence: f32, weight: f32) -> Self {
1328 Self {
1329 label: label.into(),
1330 trit: trit.clamp(-1, 1),
1331 confidence: confidence.clamp(0.0, 1.0),
1332 weight: weight.max(0.0),
1333 }
1334 }
1335}
1336
1337#[derive(Debug, Clone)]
1339pub struct CoalitionResult {
1340 pub trit: i8,
1341 pub label: String,
1342 pub aggregate_score: f32, pub quorum: f32, pub dissent_rate: f32, pub abstain_rate: f32, pub member_count: usize,
1347 pub effective_weight: f32, pub breakdown: Vec<(String, i8, f32)>, }
1350
1351pub fn coalition_vote(members: &[CoalitionMember]) -> CoalitionResult {
1356 if members.is_empty() {
1357 return CoalitionResult {
1358 trit: 0, label: "tend".into(), aggregate_score: 0.0,
1359 quorum: 0.0, dissent_rate: 0.0, abstain_rate: 1.0,
1360 member_count: 0, effective_weight: 0.0, breakdown: vec![],
1361 };
1362 }
1363
1364 let total_weight: f32 = members.iter().map(|m| m.weight).sum();
1365 let total_weight = if total_weight == 0.0 { 1.0 } else { total_weight };
1366
1367 let mut weighted_sum: f32 = 0.0;
1368 let mut non_zero_weight: f32 = 0.0;
1369 let mut breakdown = Vec::new();
1370
1371 for m in members {
1372 let contribution = (m.trit as f32) * m.confidence * m.weight;
1373 weighted_sum += contribution;
1374 if m.trit != 0 { non_zero_weight += m.weight; }
1375 breakdown.push((m.label.clone(), m.trit, contribution / total_weight));
1376 }
1377
1378 let aggregate_score = weighted_sum / total_weight;
1379 let scalar = TritScalar::new(aggregate_score);
1380 let result_trit: i8 = scalar.trit_i8();
1381
1382 let quorum = non_zero_weight / total_weight;
1383 let abstain_rate = 1.0 - quorum;
1384 let dissent_rate = members.iter()
1385 .filter(|m| m.trit != 0 && m.trit.signum() != result_trit.signum())
1386 .map(|m| m.weight)
1387 .sum::<f32>() / total_weight;
1388
1389 CoalitionResult {
1390 trit: result_trit,
1391 label: scalar.label().to_string(),
1392 aggregate_score,
1393 quorum,
1394 dissent_rate,
1395 abstain_rate,
1396 member_count: members.len(),
1397 effective_weight: non_zero_weight,
1398 breakdown,
1399 }
1400}
1401
1402#[derive(Debug, Clone)]
1406pub struct GateDimension {
1407 pub name: String,
1408 pub evidence: f32, pub weight: f32, pub hard_block: bool,
1413}
1414
1415impl GateDimension {
1416 pub fn new(name: impl Into<String>, evidence: f32, weight: f32) -> Self {
1417 Self { name: name.into(), evidence, weight, hard_block: false }
1418 }
1419 pub fn hard(mut self) -> Self { self.hard_block = true; self }
1420}
1421
1422#[derive(Debug, Clone, PartialEq, Eq)]
1424pub enum GateVerdict {
1425 Proceed,
1427 Hold,
1429 Block,
1431}
1432
1433impl GateVerdict {
1434 pub fn label(&self) -> &'static str {
1435 match self {
1436 GateVerdict::Proceed => "proceed",
1437 GateVerdict::Hold => "hold",
1438 GateVerdict::Block => "block",
1439 }
1440 }
1441}
1442
1443#[derive(Debug, Clone)]
1445pub struct GateResult {
1446 pub verdict: GateVerdict,
1447 pub aggregate: TritScalar,
1448 pub hard_blocked_by: Vec<String>, pub dim_results: Vec<(String, TritScalar, bool)>, pub explanation: String,
1451}
1452
1453pub fn action_gate(dimensions: &[GateDimension]) -> GateResult {
1460 let mut hard_blocked_by = Vec::new();
1461 let mut dim_results = Vec::new();
1462 let mut weighted_sum = 0.0f32;
1463 let mut total_weight = 0.0f32;
1464
1465 for dim in dimensions {
1466 let scalar = TritScalar::new(dim.evidence);
1467 let is_neg = matches!(scalar.trit(), Trit::Reject);
1468
1469 if dim.hard_block && is_neg {
1470 hard_blocked_by.push(dim.name.clone());
1471 }
1472
1473 weighted_sum += dim.evidence * dim.weight;
1474 total_weight += dim.weight;
1475 dim_results.push((dim.name.clone(), scalar, dim.hard_block));
1476 }
1477
1478 if !hard_blocked_by.is_empty() {
1480 let explanation = format!(
1481 "BLOCKED — hard constraint(s) violated: {}",
1482 hard_blocked_by.join(", ")
1483 );
1484 return GateResult {
1485 verdict: GateVerdict::Block,
1486 aggregate: TritScalar::new(-1.0),
1487 hard_blocked_by,
1488 dim_results,
1489 explanation,
1490 };
1491 }
1492
1493 let agg_score = if total_weight > 0.0 { weighted_sum / total_weight } else { 0.0 };
1494 let aggregate = TritScalar::new(agg_score);
1495
1496 let verdict = match aggregate.trit() {
1497 Trit::Affirm => GateVerdict::Proceed,
1498 Trit::Tend => GateVerdict::Hold,
1499 Trit::Reject => GateVerdict::Block,
1500 };
1501
1502 let explanation = match &verdict {
1503 GateVerdict::Proceed => format!(
1504 "PROCEED — all dimensions pass (aggregate confidence {:.0}%)",
1505 aggregate.confidence() * 100.0
1506 ),
1507 GateVerdict::Hold => format!(
1508 "HOLD — insufficient evidence (aggregate {:.3} within deliberation zone)",
1509 aggregate.raw()
1510 ),
1511 GateVerdict::Block => format!(
1512 "BLOCK — weighted aggregate {:.3} below threshold (confidence {:.0}%)",
1513 aggregate.raw(), aggregate.confidence() * 100.0
1514 ),
1515 };
1516
1517 GateResult { verdict, aggregate, hard_blocked_by, dim_results, explanation }
1518}
1519
1520#[derive(Debug, Clone)]
1535pub struct ScalarTemperature {
1536 pub trit: i8,
1537 pub confidence: f32,
1538 pub temperature: f32,
1539 pub reasoning: String,
1540 pub prompt_hint: String,
1542}
1543
1544pub fn scalar_temperature(scalar: &TritScalar) -> ScalarTemperature {
1545 let t = scalar.trit();
1546 let c = scalar.confidence(); let (temp, reasoning, prompt_hint) = match t {
1549 Trit::Affirm => {
1550 let temp = 0.3 - (c * 0.25); (
1553 temp.max(0.05),
1554 format!("Affirm (confidence {:.0}%) — execute precisely, minimal exploration", c * 100.0),
1555 "Be concise and direct. Evidence is clear. Do not hedge.".to_string(),
1556 )
1557 }
1558 Trit::Reject => {
1559 let temp = 0.15 - (c * 0.10); (
1562 temp.max(0.05),
1563 format!("Reject (confidence {:.0}%) — decline firmly, minimal hedging", c * 100.0),
1564 "Decline clearly. Do not offer alternatives unless explicitly asked. Evidence is against.".to_string(),
1565 )
1566 }
1567 Trit::Tend => {
1568 let temp = 0.7 + ((1.0 - c) * 0.3); (
1571 temp.min(1.0),
1572 format!("Tend (confidence {:.0}%) — evidence is conflicted, explore broadly", c * 100.0),
1573 "You are in deliberation. Present multiple perspectives. Ask clarifying questions. Do not commit.".to_string(),
1574 )
1575 }
1576 };
1577
1578 ScalarTemperature {
1579 trit: scalar.trit_i8(),
1580 confidence: c,
1581 temperature: (temp * 1000.0).round() / 1000.0,
1582 reasoning,
1583 prompt_hint,
1584 }
1585}
1586
1587#[derive(Debug, Clone)]
1599pub struct HallucinationScore {
1600 pub trust_trit: i8,
1601 pub trust_label: String,
1602 pub mean: f32, pub variance: f32, pub consistency: f32, pub signal_count: usize,
1606 pub explanation: String,
1607}
1608
1609pub fn hallucination_score(signals: &[f32]) -> HallucinationScore {
1610 if signals.is_empty() {
1611 return HallucinationScore {
1612 trust_trit: 0, trust_label: "tend".into(), mean: 0.0,
1613 variance: 0.0, consistency: 0.0, signal_count: 0,
1614 explanation: "No signals provided — cannot assess consistency.".into(),
1615 };
1616 }
1617
1618 let n = signals.len() as f32;
1619 let mean = signals.iter().sum::<f32>() / n;
1620 let variance = signals.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / n;
1621
1622 let norm_variance = variance.min(1.0);
1624 let consistency = 1.0 - norm_variance;
1625
1626 let trust_evidence = (consistency * 2.0 - 1.0) * mean.abs(); let trust = TritScalar::new(trust_evidence);
1631
1632 let explanation = if trust.trit() == Trit::Affirm {
1633 format!(
1634 "Consistent signals (variance {:.3}, consistency {:.0}%) — evidence coheres around {:.3}",
1635 variance, consistency * 100.0, mean
1636 )
1637 } else if trust.trit() == Trit::Reject {
1638 format!(
1639 "HIGH VARIANCE (variance {:.3}) — signals are internally contradictory. Possible hallucination or conflated sources.",
1640 variance
1641 )
1642 } else {
1643 format!(
1644 "Mixed consistency (variance {:.3}, mean {:.3}) — gather more evidence before relying on this claim.",
1645 variance, mean
1646 )
1647 };
1648
1649 HallucinationScore {
1650 trust_trit: trust.trit_i8(),
1651 trust_label: trust.label().to_string(),
1652 mean,
1653 variance,
1654 consistency,
1655 signal_count: signals.len(),
1656 explanation,
1657 }
1658}
1659
1660#[cfg(test)]
1663mod reasoning_tests {
1664 use super::*;
1665
1666 #[test]
1669 fn test_deliberation_converges_on_strong_evidence() {
1670 let engine = DeliberationEngine::new(0.7, 10).with_alpha(0.7);
1672 let rounds = vec![
1673 vec![0.85, 0.9], vec![0.9, 0.95], vec![0.92, 0.95, 0.98], ];
1677 let result = engine.run(rounds);
1678 assert!(result.converged, "should converge on strong positive evidence (got confidence {:.2})", result.final_confidence);
1679 assert_eq!(result.final_trit, 1, "should be +1 (affirm)");
1680 assert!(result.rounds_used <= 3);
1681 }
1682
1683 #[test]
1684 fn test_deliberation_holds_on_weak_evidence() {
1685 let engine = DeliberationEngine::new(0.95, 3);
1686 let rounds = vec![
1687 vec![0.1f32],
1688 vec![-0.05],
1689 vec![0.15],
1690 ];
1691 let result = engine.run(rounds);
1692 assert!(!result.converged, "should not converge on weak conflicting evidence");
1693 assert_eq!(result.final_trit, 0, "should stay at hold/tend");
1694 assert_eq!(result.rounds_used, 3);
1695 }
1696
1697 #[test]
1698 fn test_deliberation_negative_convergence() {
1699 let engine = DeliberationEngine::new(0.8, 10);
1700 let rounds = vec![
1701 vec![-0.9f32, -0.85],
1702 vec![-0.95, -0.99],
1703 ];
1704 let result = engine.run(rounds);
1705 assert!(result.converged);
1706 assert_eq!(result.final_trit, -1);
1707 }
1708
1709 #[test]
1712 fn test_coalition_unanimous_affirm() {
1713 let members = vec![
1714 CoalitionMember::new("safety", 1, 0.9, 3.0),
1715 CoalitionMember::new("utility", 1, 0.8, 1.0),
1716 CoalitionMember::new("alignment", 1, 0.95, 2.0),
1717 ];
1718 let result = coalition_vote(&members);
1719 assert_eq!(result.trit, 1);
1720 assert_eq!(result.label, "affirm");
1721 assert!(result.quorum > 0.99, "all voted");
1722 assert!(result.dissent_rate < 0.01);
1723 }
1724
1725 #[test]
1726 fn test_coalition_split_vote_tends_to_hold() {
1727 let members = vec![
1728 CoalitionMember::new("agent_a", 1, 0.8, 1.0),
1729 CoalitionMember::new("agent_b", -1, 0.8, 1.0),
1730 CoalitionMember::new("agent_c", 0, 0.5, 1.0),
1731 ];
1732 let result = coalition_vote(&members);
1733 assert_eq!(result.trit, 0);
1735 assert!(result.dissent_rate > 0.0, "there is dissent");
1736 }
1737
1738 #[test]
1739 fn test_coalition_high_weight_overrides() {
1740 let members = vec![
1741 CoalitionMember::new("expert", 1, 0.95, 10.0), CoalitionMember::new("novice_a", -1, 0.5, 1.0),
1743 CoalitionMember::new("novice_b", -1, 0.5, 1.0),
1744 ];
1745 let result = coalition_vote(&members);
1746 assert_eq!(result.trit, 1, "high-weight expert should dominate");
1748 }
1749
1750 #[test]
1753 fn test_gate_all_positive_proceeds() {
1754 let dims = vec![
1755 GateDimension::new("safety", 0.8, 3.0),
1756 GateDimension::new("utility", 0.7, 1.0),
1757 GateDimension::new("legality", 0.9, 2.0),
1758 ];
1759 let result = action_gate(&dims);
1760 assert_eq!(result.verdict, GateVerdict::Proceed);
1761 }
1762
1763 #[test]
1764 fn test_gate_hard_block_fires() {
1765 let dims = vec![
1766 GateDimension::new("utility", 0.9, 1.0),
1767 GateDimension::new("safety", -0.8, 3.0).hard(), GateDimension::new("legality", 0.7, 1.0),
1769 ];
1770 let result = action_gate(&dims);
1771 assert_eq!(result.verdict, GateVerdict::Block);
1772 assert!(result.hard_blocked_by.contains(&"safety".to_string()));
1773 }
1774
1775 #[test]
1776 fn test_gate_mixed_soft_dims_holds() {
1777 let dims = vec![
1778 GateDimension::new("utility", 0.8, 1.0),
1779 GateDimension::new("risk", -0.7, 1.0), ];
1781 let result = action_gate(&dims);
1783 assert_ne!(result.verdict, GateVerdict::Block); }
1786
1787 #[test]
1790 fn test_temperature_affirm_is_low() {
1791 let sc = TritScalar::new(0.9);
1792 let temp = scalar_temperature(&sc);
1793 assert_eq!(temp.trit, 1);
1794 assert!(temp.temperature < 0.3, "affirm → low temperature");
1795 }
1796
1797 #[test]
1798 fn test_temperature_tend_is_high() {
1799 let sc = TritScalar::new(0.05); let temp = scalar_temperature(&sc);
1801 assert_eq!(temp.trit, 0);
1802 assert!(temp.temperature >= 0.7, "tend → high temperature for exploration");
1803 }
1804
1805 #[test]
1806 fn test_temperature_reject_is_low() {
1807 let sc = TritScalar::new(-0.9);
1808 let temp = scalar_temperature(&sc);
1809 assert_eq!(temp.trit, -1);
1810 assert!(temp.temperature < 0.15, "reject → low temperature, firm");
1811 }
1812
1813 #[test]
1816 fn test_hallucination_consistent_signals_trusted() {
1817 let signals = vec![0.8, 0.82, 0.79, 0.81, 0.83];
1819 let score = hallucination_score(&signals);
1820 assert_eq!(score.trust_trit, 1, "consistent signals should be trusted");
1821 assert!(score.variance < 0.01);
1822 assert!(score.consistency > 0.99);
1823 }
1824
1825 #[test]
1826 fn test_hallucination_chaotic_signals_flagged() {
1827 let signals = vec![0.9, -0.9, 0.8, -0.8, 0.95, -0.7];
1829 let score = hallucination_score(&signals);
1830 assert!(score.variance > 0.5, "should have high variance");
1832 assert!(score.trust_trit <= 0, "chaotic signals should not be trusted");
1833 }
1834
1835 #[test]
1836 fn test_hallucination_empty_returns_hold() {
1837 let score = hallucination_score(&[]);
1838 assert_eq!(score.trust_trit, 0);
1839 assert_eq!(score.signal_count, 0);
1840 }
1841}
1842
1843use std::collections::HashMap;
1858use crate::coherence::ModelCoherence;
1859
1860pub struct TritTransformerConfig {
1861 pub dim: usize,
1862 pub n_layers: usize,
1863 pub n_heads: usize,
1864 pub n_kv_heads: usize,
1865 pub vocab_size: usize,
1866 pub multiple_of: usize,
1867 pub ffn_dim_multiplier: Option<f64>,
1868 pub norm_eps: f32,
1869 pub max_seq_len: usize,
1870}
1871
1872impl Default for TritTransformerConfig {
1873 fn default() -> Self {
1874 Self {
1875 dim: 2048,
1876 n_layers: 16,
1877 n_heads: 32,
1878 n_kv_heads: 8,
1879 vocab_size: 128256, multiple_of: 256,
1881 ffn_dim_multiplier: None,
1882 norm_eps: 1e-5,
1883 max_seq_len: 2048,
1884 }
1885 }
1886}
1887
1888pub struct TritBlock {
1890 pub wq: TritMatrix,
1891 pub wk: TritMatrix,
1892 pub wv: TritMatrix,
1893 pub wo: TritMatrix,
1894 pub w1: TritMatrix,
1895 pub w2: TritMatrix,
1896 pub w3: TritMatrix,
1897 pub attention_norm: Vec<f32>, pub ffn_norm: Vec<f32>,
1899}
1900
1901pub struct TritTransformer {
1903 pub config: TritTransformerConfig,
1904 pub tok_embeddings: TritMatrix,
1905 pub layers: Vec<TritBlock>,
1906 pub norm: Vec<f32>,
1907 pub output: TritMatrix,
1908 pub freq_cis: Vec<(f32, f32)>, }
1910
1911impl TritTransformer {
1912 pub fn from_coherence(coherence: ModelCoherence, config: TritTransformerConfig) -> Self {
1914 println!("ternlang-ml: Building TritTransformer (Layers: {})...", config.n_layers);
1915
1916 let mut layers = Vec::with_capacity(config.n_layers);
1917 let mut layer_map: HashMap<String, TritMatrix> = HashMap::new();
1918
1919 for layer in coherence.layers {
1920 layer_map.insert(layer.name.clone(), layer.to_trit_matrix());
1921 }
1922
1923 let mut get = |name: &str| {
1925 layer_map.remove(name).unwrap_or_else(|| panic!("Missing layer: {}", name))
1926 };
1927
1928 let tok_embeddings = get("token_embd.weight");
1929 let output = get("output.weight");
1930
1931 let norm = vec![1.0; config.dim];
1936
1937 for i in 0..config.n_layers {
1938 layers.push(TritBlock {
1939 wq: get(&format!("layers.{}.attention.wq.weight", i)),
1940 wk: get(&format!("layers.{}.attention.wk.weight", i)),
1941 wv: get(&format!("layers.{}.attention.wv.weight", i)),
1942 wo: get(&format!("layers.{}.attention.wo.weight", i)),
1943 w1: get(&format!("layers.{}.feed_forward.w1.weight", i)),
1944 w2: get(&format!("layers.{}.feed_forward.w2.weight", i)),
1945 w3: get(&format!("layers.{}.feed_forward.w3.weight", i)),
1946 attention_norm: vec![1.0; config.dim],
1947 ffn_norm: vec![1.0; config.dim],
1948 });
1949 }
1950
1951 let freq_cis = precompute_freqs_cis(config.dim / config.n_heads, config.max_seq_len);
1953
1954 Self {
1955 config,
1956 tok_embeddings,
1957 layers,
1958 norm,
1959 output,
1960 freq_cis,
1961 }
1962 }
1963
1964 pub fn forward(&self, token: usize, pos: usize) -> Vec<f32> {
1967 let mut h = self.get_embedding(token);
1968
1969 for layer in &self.layers {
1970 let h_norm = rms_norm(&h, &layer.attention_norm, self.config.norm_eps);
1972 let attn_out = self.attention(layer, &h_norm, pos);
1973 for i in 0..h.len() { h[i] += attn_out[i]; }
1974
1975 let h_norm = rms_norm(&h, &layer.ffn_norm, self.config.norm_eps);
1977 let ffn_out = self.feed_forward(layer, &h_norm);
1978 for i in 0..h.len() { h[i] += ffn_out[i]; }
1979 }
1980
1981 let h = rms_norm(&h, &self.norm, self.config.norm_eps);
1982 self.project_output(&h)
1983 }
1984
1985 fn get_embedding(&self, token: usize) -> Vec<f32> {
1986 let start = token * self.config.dim;
1987 let mut embd = Vec::with_capacity(self.config.dim);
1988 for i in 0..self.config.dim {
1989 embd.push(trit_to_f32(self.tok_embeddings.data[start + i]));
1990 }
1991 embd
1992 }
1993
1994 fn attention(&self, layer: &TritBlock, x: &[f32], pos: usize) -> Vec<f32> {
1995 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
1998
1999 let (q_trit, _) = sparse_matmul(&x_trit, &layer.wq);
2000 let (k_trit, _) = sparse_matmul(&x_trit, &layer.wk);
2001 let (v_trit, _) = sparse_matmul(&x_trit, &layer.wv);
2002
2003 let mut q = q_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2004 let mut k = k_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2005 let v = v_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2006
2007 apply_rope(&mut q, pos, &self.freq_cis, self.config.n_heads);
2009 apply_rope(&mut k, pos, &self.freq_cis, self.config.n_heads);
2010
2011 let v_trit = TritMatrix::from_trits(1, v.len(), v.iter().map(|&val| trit_from_f32_approx(val)).collect());
2016 let (out, _) = sparse_matmul(&v_trit, &layer.wo);
2017 out.data.iter().map(|&t| trit_to_f32(t)).collect()
2018 }
2019
2020 fn feed_forward(&self, layer: &TritBlock, x: &[f32]) -> Vec<f32> {
2021 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2022
2023 let (w1_x, _) = sparse_matmul(&x_trit, &layer.w1);
2025 let (w3_x, _) = sparse_matmul(&x_trit, &layer.w3);
2026
2027 let mut hidden = Vec::with_capacity(w1_x.data.len());
2028 for i in 0..w1_x.data.len() {
2029 let v1 = trit_to_f32(w1_x.data[i]);
2030 let v3 = trit_to_f32(w3_x.data[i]);
2031 let silu_v3 = v3 / (1.0 + (-v3).exp());
2033 hidden.push(v1 * silu_v3);
2034 }
2035
2036 let hidden_trit = TritMatrix::from_trits(1, hidden.len(), hidden.iter().map(|&v| trit_from_f32_approx(v)).collect());
2037 let (out, _) = sparse_matmul(&hidden_trit, &layer.w2);
2038 out.data.iter().map(|&t| trit_to_f32(t)).collect()
2039 }
2040
2041 fn project_output(&self, x: &[f32]) -> Vec<f32> {
2042 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2043 let (logits, _) = sparse_matmul(&x_trit, &self.output);
2044 logits.data.iter().map(|&t| trit_to_f32(t)).collect()
2045 }
2046}
2047
2048fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
2051 let sum_sq = x.iter().map(|&v| v * v).sum::<f32>();
2052 let inv_rms = 1.0 / (sum_sq / x.len() as f32 + eps).sqrt();
2053 x.iter().zip(weight.iter()).map(|(&v, &w)| v * inv_rms * w).collect()
2054}
2055
2056fn precompute_freqs_cis(dim: usize, end: usize) -> Vec<(f32, f32)> {
2057 let mut freqs_cis = Vec::with_capacity(end * (dim / 2));
2058 for pos in 0..end {
2059 for i in 0..(dim / 2) {
2060 let freq = 1.0 / 10000.0f32.powf((i * 2) as f32 / dim as f32);
2061 let val = pos as f32 * freq;
2062 freqs_cis.push((val.cos(), val.sin()));
2063 }
2064 }
2065 freqs_cis
2066}
2067
2068fn apply_rope(x: &mut [f32], pos: usize, freq_cis: &[(f32, f32)], n_heads: usize) {
2069 let head_dim = x.len() / n_heads;
2070 for h in 0..n_heads {
2071 let start = h * head_dim;
2072 for i in 0..(head_dim / 2) {
2073 let (cos, sin) = freq_cis[pos * (head_dim / 2) + i];
2074 let x0 = x[start + i];
2075 let x1 = x[start + i + head_dim / 2];
2076 x[start + i] = x0 * cos - x1 * sin;
2077 x[start + i + head_dim / 2] = x0 * sin + x1 * cos;
2078 }
2079 }
2080}
2081
2082pub fn trit_to_f32(t: Trit) -> f32 {
2083 match t {
2084 Trit::Affirm => 1.0,
2085 Trit::Reject => -1.0,
2086 Trit::Tend => 0.0,
2087 }
2088}
2089
2090pub fn trit_from_f32_approx(v: f32) -> Trit {
2091 if v > 0.5 { Trit::Affirm }
2092 else if v < -0.5 { Trit::Reject }
2093 else { Trit::Tend }
2094}