Skip to main content

proof_engine/graph/
neural_viz.rs

1use glam::{Vec2, Vec4};
2use std::collections::HashMap;
3use super::graph_core::{Graph, GraphKind, NodeId, EdgeId};
4
5#[derive(Debug, Clone)]
6pub struct NeuronNode {
7    pub layer: usize,
8    pub index: usize,
9    pub activation: f32,
10    pub bias: f32,
11}
12
13#[derive(Debug, Clone)]
14pub struct SynapseEdge {
15    pub weight: f32,
16    pub gradient: f32,
17}
18
19#[derive(Debug, Clone)]
20pub struct NeuralNetGraph {
21    pub graph: Graph<NeuronNode, SynapseEdge>,
22    /// layer_index -> Vec<NodeId> of neurons in that layer
23    pub layers: Vec<Vec<NodeId>>,
24    pub layer_count: usize,
25}
26
27impl NeuralNetGraph {
28    pub fn neuron_count(&self) -> usize {
29        self.graph.node_count()
30    }
31
32    pub fn synapse_count(&self) -> usize {
33        self.graph.edge_count()
34    }
35
36    pub fn get_neuron(&self, id: NodeId) -> Option<&NeuronNode> {
37        self.graph.get_node(id).map(|nd| &nd.data)
38    }
39
40    pub fn get_synapse(&self, id: EdgeId) -> Option<&SynapseEdge> {
41        self.graph.get_edge(id).map(|ed| &ed.data)
42    }
43
44    /// Set activations for a specific layer.
45    pub fn set_activations(&mut self, layer: usize, values: &[f32]) {
46        if layer >= self.layers.len() { return; }
47        let layer_nodes = &self.layers[layer];
48        for (i, &nid) in layer_nodes.iter().enumerate() {
49            if i < values.len() {
50                if let Some(nd) = self.graph.get_node_mut(nid) {
51                    nd.data.activation = values[i];
52                }
53            }
54        }
55    }
56
57    /// Set all weights for edges between two layers.
58    pub fn set_weights(&mut self, from_layer: usize, weights: &[Vec<f32>]) {
59        if from_layer + 1 >= self.layers.len() { return; }
60        let from_nodes = &self.layers[from_layer].clone();
61        let to_nodes = &self.layers[from_layer + 1].clone();
62
63        for (i, &from_nid) in from_nodes.iter().enumerate() {
64            if i >= weights.len() { break; }
65            for (j, &to_nid) in to_nodes.iter().enumerate() {
66                if j >= weights[i].len() { break; }
67                if let Some(eid) = self.graph.find_edge(from_nid, to_nid) {
68                    if let Some(ed) = self.graph.get_edge_mut(eid) {
69                        ed.data.weight = weights[i][j];
70                    }
71                }
72            }
73        }
74    }
75
76    /// Forward pass: compute activations using sigmoid.
77    pub fn forward(&mut self, inputs: &[f32]) {
78        self.set_activations(0, inputs);
79
80        for l in 1..self.layer_count {
81            let prev_layer = self.layers[l - 1].clone();
82            let curr_layer = self.layers[l].clone();
83
84            for &to_nid in &curr_layer {
85                let bias = self.graph.get_node(to_nid).map(|n| n.data.bias).unwrap_or(0.0);
86                let mut sum = bias;
87                for &from_nid in &prev_layer {
88                    if let Some(eid) = self.graph.find_edge(from_nid, to_nid) {
89                        let w = self.graph.get_edge(eid).map(|e| e.data.weight).unwrap_or(0.0);
90                        let a = self.graph.get_node(from_nid).map(|n| n.data.activation).unwrap_or(0.0);
91                        sum += w * a;
92                    }
93                }
94                // Sigmoid activation
95                let activation = 1.0 / (1.0 + (-sum).exp());
96                if let Some(nd) = self.graph.get_node_mut(to_nid) {
97                    nd.data.activation = activation;
98                }
99            }
100        }
101    }
102
103    /// Get output activations (last layer).
104    pub fn outputs(&self) -> Vec<f32> {
105        self.layers.last()
106            .map(|layer| {
107                layer.iter()
108                    .map(|&nid| self.graph.get_node(nid).map(|n| n.data.activation).unwrap_or(0.0))
109                    .collect()
110            })
111            .unwrap_or_default()
112    }
113
114    /// Generate layout: neurons in vertical columns per layer.
115    pub fn compute_layout(&self, bounds: Vec2) -> HashMap<NodeId, Vec2> {
116        let mut positions = HashMap::new();
117        if self.layer_count == 0 { return positions; }
118
119        let layer_spacing = bounds.x / (self.layer_count + 1) as f32;
120
121        for (l, layer_nodes) in self.layers.iter().enumerate() {
122            let n = layer_nodes.len();
123            let neuron_spacing = bounds.y / (n + 1) as f32;
124            let x = layer_spacing * (l + 1) as f32;
125            for (i, &nid) in layer_nodes.iter().enumerate() {
126                let y = neuron_spacing * (i + 1) as f32;
127                positions.insert(nid, Vec2::new(x, y));
128            }
129        }
130        positions
131    }
132
133    /// Generate rendering data: node brightness = activation, edge thickness = |weight|, edge color by sign.
134    pub fn render_data(&self) -> NeuralRenderData {
135        let mut node_data = Vec::new();
136        for layer in &self.layers {
137            for &nid in layer {
138                if let Some(nd) = self.graph.get_node(nid) {
139                    let brightness = nd.data.activation;
140                    node_data.push(NeuronRender {
141                        node_id: nid,
142                        position: nd.position,
143                        brightness,
144                        radius: 8.0 + brightness * 4.0,
145                        color: Vec4::new(brightness, brightness, 1.0, 1.0),
146                    });
147                }
148            }
149        }
150
151        let mut edge_data = Vec::new();
152        for edge in self.graph.edges() {
153            let w = edge.data.weight;
154            let thickness = w.abs().min(5.0);
155            let color = if w >= 0.0 {
156                Vec4::new(0.2, 0.8, 0.2, 0.8) // green = positive
157            } else {
158                Vec4::new(0.8, 0.2, 0.2, 0.8) // red = negative
159            };
160            edge_data.push(SynapseRender {
161                edge_id: edge.id,
162                from: edge.from,
163                to: edge.to,
164                thickness,
165                color,
166                weight: w,
167            });
168        }
169
170        NeuralRenderData { neurons: node_data, synapses: edge_data }
171    }
172}
173
174#[derive(Debug, Clone)]
175pub struct NeuronRender {
176    pub node_id: NodeId,
177    pub position: Vec2,
178    pub brightness: f32,
179    pub radius: f32,
180    pub color: Vec4,
181}
182
183#[derive(Debug, Clone)]
184pub struct SynapseRender {
185    pub edge_id: EdgeId,
186    pub from: NodeId,
187    pub to: NodeId,
188    pub thickness: f32,
189    pub color: Vec4,
190    pub weight: f32,
191}
192
193#[derive(Debug, Clone)]
194pub struct NeuralRenderData {
195    pub neurons: Vec<NeuronRender>,
196    pub synapses: Vec<SynapseRender>,
197}
198
199/// Build a feedforward neural network graph.
200/// `layer_sizes`: number of neurons in each layer, e.g. [3, 4, 2] = 3 inputs, 4 hidden, 2 outputs.
201pub fn build_feedforward(layer_sizes: &[usize]) -> NeuralNetGraph {
202    let mut graph = Graph::new(GraphKind::Directed);
203    let mut layers: Vec<Vec<NodeId>> = Vec::new();
204
205    let bounds = Vec2::new(800.0, 600.0);
206    let layer_count = layer_sizes.len();
207    let layer_spacing = bounds.x / (layer_count + 1) as f32;
208
209    for (l, &size) in layer_sizes.iter().enumerate() {
210        let mut layer_nodes = Vec::new();
211        let neuron_spacing = bounds.y / (size + 1) as f32;
212        let x = layer_spacing * (l + 1) as f32;
213
214        for i in 0..size {
215            let y = neuron_spacing * (i + 1) as f32;
216            let neuron = NeuronNode {
217                layer: l,
218                index: i,
219                activation: 0.0,
220                bias: 0.0,
221            };
222            let nid = graph.add_node_with_pos(neuron, Vec2::new(x, y));
223            layer_nodes.push(nid);
224        }
225        layers.push(layer_nodes);
226    }
227
228    // Connect consecutive layers (fully connected)
229    let mut seed_counter: u64 = 42;
230    for l in 0..(layer_count - 1) {
231        let from_layer = layers[l].clone();
232        let to_layer = layers[l + 1].clone();
233        for &from in &from_layer {
234            for &to in &to_layer {
235                // Initialize with small random weights
236                let w = (pseudo_random(seed_counter, 0) as f32 - 0.5) * 0.5;
237                seed_counter += 1;
238                let synapse = SynapseEdge { weight: w, gradient: 0.0 };
239                graph.add_edge(from, to, synapse);
240            }
241        }
242    }
243
244    NeuralNetGraph {
245        graph,
246        layers,
247        layer_count,
248    }
249}
250
251fn pseudo_random(seed: u64, i: u64) -> f64 {
252    let mut x = seed.wrapping_mul(6364136223846793005).wrapping_add(i.wrapping_mul(1442695040888963407));
253    x ^= x >> 33;
254    x = x.wrapping_mul(0xff51afd7ed558ccd);
255    x ^= x >> 33;
256    (x as f64) / (u64::MAX as f64)
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_build_feedforward_structure() {
265        let nn = build_feedforward(&[3, 4, 2]);
266        assert_eq!(nn.layer_count, 3);
267        assert_eq!(nn.layers[0].len(), 3);
268        assert_eq!(nn.layers[1].len(), 4);
269        assert_eq!(nn.layers[2].len(), 2);
270        assert_eq!(nn.neuron_count(), 9);
271        assert_eq!(nn.synapse_count(), 3 * 4 + 4 * 2); // 20
272    }
273
274    #[test]
275    fn test_set_activations() {
276        let mut nn = build_feedforward(&[2, 3, 1]);
277        nn.set_activations(0, &[0.5, 0.8]);
278        let n0 = nn.get_neuron(nn.layers[0][0]).unwrap();
279        let n1 = nn.get_neuron(nn.layers[0][1]).unwrap();
280        assert!((n0.activation - 0.5).abs() < 1e-6);
281        assert!((n1.activation - 0.8).abs() < 1e-6);
282    }
283
284    #[test]
285    fn test_forward_pass() {
286        let mut nn = build_feedforward(&[2, 2, 1]);
287        nn.forward(&[1.0, 1.0]);
288        let outputs = nn.outputs();
289        assert_eq!(outputs.len(), 1);
290        // Output should be between 0 and 1 (sigmoid)
291        assert!(outputs[0] >= 0.0 && outputs[0] <= 1.0);
292    }
293
294    #[test]
295    fn test_layout() {
296        let nn = build_feedforward(&[3, 4, 2]);
297        let bounds = Vec2::new(800.0, 600.0);
298        let layout = nn.compute_layout(bounds);
299        assert_eq!(layout.len(), 9);
300        for pos in layout.values() {
301            assert!(pos.x > 0.0 && pos.x < bounds.x);
302            assert!(pos.y > 0.0 && pos.y < bounds.y);
303        }
304    }
305
306    #[test]
307    fn test_render_data() {
308        let mut nn = build_feedforward(&[2, 3, 1]);
309        nn.forward(&[0.5, 0.5]);
310        let render = nn.render_data();
311        assert_eq!(render.neurons.len(), 6);
312        assert_eq!(render.synapses.len(), 2 * 3 + 3 * 1);
313        for n in &render.neurons {
314            assert!(n.brightness >= 0.0 && n.brightness <= 1.0);
315        }
316    }
317
318    #[test]
319    fn test_single_layer() {
320        let nn = build_feedforward(&[5]);
321        assert_eq!(nn.layer_count, 1);
322        assert_eq!(nn.neuron_count(), 5);
323        assert_eq!(nn.synapse_count(), 0);
324    }
325
326    #[test]
327    fn test_deep_network() {
328        let nn = build_feedforward(&[4, 8, 8, 4, 2]);
329        assert_eq!(nn.layer_count, 5);
330        assert_eq!(nn.neuron_count(), 26);
331        assert_eq!(nn.synapse_count(), 4 * 8 + 8 * 8 + 8 * 4 + 4 * 2);
332    }
333
334    #[test]
335    fn test_set_weights() {
336        let mut nn = build_feedforward(&[2, 2]);
337        nn.set_weights(0, &[vec![1.0, -1.0], vec![0.5, 0.5]]);
338        nn.forward(&[1.0, 0.0]);
339        let outputs = nn.outputs();
340        // neuron 0: sigmoid(1.0 * 1.0 + 0.0 * 0.5) = sigmoid(1.0) ~ 0.731
341        // neuron 1: sigmoid(1.0 * -1.0 + 0.0 * 0.5) = sigmoid(-1.0) ~ 0.269
342        assert!((outputs[0] - 0.731).abs() < 0.01);
343        assert!((outputs[1] - 0.269).abs() < 0.01);
344    }
345
346    #[test]
347    fn test_neuron_node_fields() {
348        let nn = build_feedforward(&[3, 2]);
349        let n = nn.get_neuron(nn.layers[0][1]).unwrap();
350        assert_eq!(n.layer, 0);
351        assert_eq!(n.index, 1);
352        assert_eq!(n.activation, 0.0);
353    }
354
355    #[test]
356    fn test_render_edge_colors() {
357        let mut nn = build_feedforward(&[1, 1]);
358        nn.set_weights(0, &[vec![1.0]]);
359        let render = nn.render_data();
360        // Positive weight -> green
361        assert!(render.synapses[0].color.y > render.synapses[0].color.x);
362
363        nn.set_weights(0, &[vec![-1.0]]);
364        let render = nn.render_data();
365        // Negative weight -> red
366        assert!(render.synapses[0].color.x > render.synapses[0].color.y);
367    }
368}