Skip to main content

pot_o_mining/
neural_path.rs

1use ai3_lib::tensor::Tensor;
2use pot_o_core::TribeResult;
3use sha2::{Digest, Sha256};
4
5/// Validates the neural inference path of a tensor computation.
6///
7/// Models the tensor operation as a small feedforward network. The "path"
8/// is the binary activation pattern at each layer (ReLU > 0 = 1, else 0).
9/// The miner must find a nonce that makes the actual path match an expected
10/// path signature (derived from the challenge) within a Hamming distance.
11pub struct NeuralPathValidator {
12    /// Layer widths for the feedforward network
13    pub layer_widths: Vec<usize>,
14}
15
16impl Default for NeuralPathValidator {
17    fn default() -> Self {
18        Self {
19            layer_widths: vec![32, 16, 8],
20        }
21    }
22}
23
24impl NeuralPathValidator {
25    /// Derive the expected path signature from the challenge hash.
26    /// Returns a bit vector (as Vec<u8> of 0/1 values) representing the expected activations.
27    pub fn expected_path_signature(&self, challenge_hash: &str) -> Vec<u8> {
28        let hash_bytes = hex::decode(challenge_hash).unwrap_or_default();
29        let total_neurons: usize = self.layer_widths.iter().sum();
30        let mut sig = Vec::with_capacity(total_neurons);
31
32        let mut hasher = Sha256::new();
33        hasher.update(&hash_bytes);
34        let mut seed = hasher.finalize().to_vec();
35
36        for &width in &self.layer_widths {
37            for i in 0..width {
38                let byte_idx = i % seed.len();
39                let bit = (seed[byte_idx] >> (i % 8)) & 1;
40                sig.push(bit);
41            }
42            // Re-hash seed for next layer so each layer has different expected bits
43            let mut h = Sha256::new();
44            h.update(&seed);
45            seed = h.finalize().to_vec();
46        }
47
48        sig
49    }
50
51    /// Compute the actual activation path for a tensor with a given nonce.
52    /// Simulates a feedforward pass: input -> (linear + ReLU) per layer.
53    /// Returns the binary activation pattern.
54    pub fn compute_actual_path(&self, tensor: &Tensor, nonce: u64) -> TribeResult<Vec<u8>> {
55        let mut activations = tensor.data.as_f32();
56        let mut path_bits = Vec::new();
57        let mut bit_idx: u32 = 0;
58
59        for &width in &self.layer_widths {
60            let mut layer_output = vec![0.0f32; width];
61
62            // Simplified linear: each output neuron sums a stride of the input
63            let stride = (activations.len() / width).max(1);
64            for (j, out) in layer_output.iter_mut().enumerate() {
65                if j >= width {
66                    break;
67                }
68                let start = j * stride;
69                let end = (start + stride).min(activations.len());
70                let sum: f32 = activations[start..end].iter().sum();
71                // ReLU
72                let relu = sum.max(0.0);
73                *out = relu;
74
75                let base_bit = if relu > 0.0 { 1u8 } else { 0u8 };
76                let shift = (bit_idx % 64) as u64;
77                let nonce_bit = ((nonce >> shift) & 1) as u8;
78                let bit = base_bit ^ nonce_bit;
79
80                path_bits.push(bit);
81                bit_idx = bit_idx.wrapping_add(1);
82            }
83
84            activations = layer_output;
85        }
86
87        Ok(path_bits)
88    }
89
90    /// Compute Hamming distance between two bit vectors.
91    pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
92        a.iter()
93            .zip(b.iter())
94            .map(|(&x, &y)| if x != y { 1u32 } else { 0u32 })
95            .sum()
96    }
97
98    /// Validate that the actual path is close enough to the expected path.
99    pub fn validate(&self, actual_path: &[u8], challenge_hash: &str, max_distance: u32) -> bool {
100        let expected = self.expected_path_signature(challenge_hash);
101        let min_len = actual_path.len().min(expected.len());
102        let distance = Self::hamming_distance(&actual_path[..min_len], &expected[..min_len]);
103        distance <= max_distance
104    }
105
106    /// Encode path bits as a compact hex string for on-chain storage.
107    pub fn path_to_hex(path: &[u8]) -> String {
108        let mut bytes = Vec::with_capacity(path.len().div_ceil(8));
109        for chunk in path.chunks(8) {
110            let mut byte = 0u8;
111            for (i, &bit) in chunk.iter().enumerate() {
112                if bit != 0 {
113                    byte |= 1 << i;
114                }
115            }
116            bytes.push(byte);
117        }
118        hex::encode(bytes)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use ai3_lib::tensor::{TensorData, TensorShape};
126
127    #[test]
128    fn test_expected_path_deterministic() {
129        let v = NeuralPathValidator::default();
130        let hash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
131        let p1 = v.expected_path_signature(hash);
132        let p2 = v.expected_path_signature(hash);
133        assert_eq!(p1, p2);
134    }
135
136    #[test]
137    fn test_hamming_distance() {
138        assert_eq!(
139            NeuralPathValidator::hamming_distance(&[0, 1, 0], &[0, 1, 0]),
140            0
141        );
142        assert_eq!(
143            NeuralPathValidator::hamming_distance(&[0, 1, 0], &[1, 0, 1]),
144            3
145        );
146        assert_eq!(
147            NeuralPathValidator::hamming_distance(&[1, 1, 1], &[0, 1, 0]),
148            2
149        );
150    }
151
152    #[test]
153    fn test_actual_path_varies_with_nonce() {
154        let v = NeuralPathValidator::default();
155        let t = Tensor::new(TensorShape::new(vec![64]), TensorData::F32(vec![0.5; 64])).unwrap();
156        let p1 = v.compute_actual_path(&t, 0).unwrap();
157        let p2 = v.compute_actual_path(&t, 999_999).unwrap();
158        // Different nonces should (usually) produce different paths
159        assert_ne!(p1, p2);
160    }
161
162    #[test]
163    fn test_path_hex_roundtrip() {
164        let path = vec![1, 0, 1, 1, 0, 0, 1, 0, 1];
165        let hex_str = NeuralPathValidator::path_to_hex(&path);
166        assert!(!hex_str.is_empty());
167    }
168}