Skip to main content

aethershell/
neural.rs

1//! Neural Network Primitives for AetherShell
2//!
3//! Provides lightweight neural network building blocks for in-shell learning,
4//! particularly useful for evolving multi-agent communication protocols and
5//! consensus mechanisms.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Activation functions for neural network layers
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum Activation {
13    ReLU,
14    Sigmoid,
15    Tanh,
16    Softmax,
17    Linear,
18    LeakyReLU(f64),
19    Swish,
20}
21
22impl Activation {
23    pub fn apply(&self, x: f64) -> f64 {
24        match self {
25            Activation::ReLU => x.max(0.0),
26            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
27            Activation::Tanh => x.tanh(),
28            Activation::Softmax => x.exp(), // Normalized in layer
29            Activation::Linear => x,
30            Activation::LeakyReLU(alpha) => {
31                if x > 0.0 {
32                    x
33                } else {
34                    alpha * x
35                }
36            }
37            Activation::Swish => x * (1.0 / (1.0 + (-x).exp())),
38        }
39    }
40
41    pub fn derivative(&self, x: f64) -> f64 {
42        match self {
43            Activation::ReLU => {
44                if x > 0.0 {
45                    1.0
46                } else {
47                    0.0
48                }
49            }
50            Activation::Sigmoid => {
51                let s = self.apply(x);
52                s * (1.0 - s)
53            }
54            Activation::Tanh => 1.0 - x.tanh().powi(2),
55            Activation::Softmax => 1.0, // Handled specially in backprop
56            Activation::Linear => 1.0,
57            Activation::LeakyReLU(alpha) => {
58                if x > 0.0 {
59                    1.0
60                } else {
61                    *alpha
62                }
63            }
64            Activation::Swish => {
65                let sig = 1.0 / (1.0 + (-x).exp());
66                sig + x * sig * (1.0 - sig)
67            }
68        }
69    }
70
71    pub fn from_str(s: &str) -> Option<Self> {
72        match s.to_lowercase().as_str() {
73            "relu" => Some(Activation::ReLU),
74            "sigmoid" => Some(Activation::Sigmoid),
75            "tanh" => Some(Activation::Tanh),
76            "softmax" => Some(Activation::Softmax),
77            "linear" => Some(Activation::Linear),
78            "leaky_relu" => Some(Activation::LeakyReLU(0.01)),
79            "swish" => Some(Activation::Swish),
80            _ => None,
81        }
82    }
83}
84
85/// A dense (fully connected) layer
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct DenseLayer {
88    pub weights: Vec<Vec<f64>>,
89    pub biases: Vec<f64>,
90    pub activation: Activation,
91    pub input_size: usize,
92    pub output_size: usize,
93}
94
95impl DenseLayer {
96    pub fn new(input_size: usize, output_size: usize, activation: Activation) -> Self {
97        // Xavier initialization
98        let scale = (2.0 / (input_size + output_size) as f64).sqrt();
99        let mut weights = vec![vec![0.0; input_size]; output_size];
100        let mut biases = vec![0.0; output_size];
101
102        for i in 0..output_size {
103            for j in 0..input_size {
104                weights[i][j] = (rand_f64() * 2.0 - 1.0) * scale;
105            }
106            biases[i] = 0.0;
107        }
108
109        Self {
110            weights,
111            biases,
112            activation,
113            input_size,
114            output_size,
115        }
116    }
117
118    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
119        let mut output = vec![0.0; self.output_size];
120
121        for i in 0..self.output_size {
122            let mut sum = self.biases[i];
123            for j in 0..self.input_size {
124                sum += self.weights[i][j] * input.get(j).unwrap_or(&0.0);
125            }
126            output[i] = self.activation.apply(sum);
127        }
128
129        // Handle softmax normalization
130        if matches!(self.activation, Activation::Softmax) {
131            let max = output.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
132            let exp_sum: f64 = output.iter().map(|x| (x - max).exp()).sum();
133            for x in &mut output {
134                *x = (*x - max).exp() / exp_sum;
135            }
136        }
137
138        output
139    }
140
141    pub fn param_count(&self) -> usize {
142        self.input_size * self.output_size + self.output_size
143    }
144
145    /// Get flattened parameters (weights + biases)
146    pub fn get_params(&self) -> Vec<f64> {
147        let mut params = Vec::with_capacity(self.param_count());
148        for row in &self.weights {
149            params.extend(row);
150        }
151        params.extend(&self.biases);
152        params
153    }
154
155    /// Set parameters from flattened vector
156    pub fn set_params(&mut self, params: &[f64]) {
157        let mut idx = 0;
158        for i in 0..self.output_size {
159            for j in 0..self.input_size {
160                self.weights[i][j] = params[idx];
161                idx += 1;
162            }
163        }
164        for i in 0..self.output_size {
165            self.biases[i] = params[idx];
166            idx += 1;
167        }
168    }
169}
170
171/// A complete neural network
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct NeuralNetwork {
174    pub layers: Vec<DenseLayer>,
175    pub name: String,
176}
177
178impl NeuralNetwork {
179    pub fn new(name: &str) -> Self {
180        Self {
181            layers: Vec::new(),
182            name: name.to_string(),
183        }
184    }
185
186    /// Create a network from layer specifications
187    /// Format: [(input, output, activation), ...]
188    pub fn from_spec(name: &str, spec: &[(usize, usize, Activation)]) -> Self {
189        let mut net = Self::new(name);
190        for (input, output, activation) in spec {
191            net.layers
192                .push(DenseLayer::new(*input, *output, *activation));
193        }
194        net
195    }
196
197    /// Create a simple feedforward network
198    pub fn feedforward(
199        name: &str,
200        layer_sizes: &[usize],
201        hidden_activation: Activation,
202        output_activation: Activation,
203    ) -> Self {
204        let mut net = Self::new(name);
205        for i in 0..layer_sizes.len() - 1 {
206            let activation = if i == layer_sizes.len() - 2 {
207                output_activation
208            } else {
209                hidden_activation
210            };
211            net.layers.push(DenseLayer::new(
212                layer_sizes[i],
213                layer_sizes[i + 1],
214                activation,
215            ));
216        }
217        net
218    }
219
220    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
221        let mut current = input.to_vec();
222        for layer in &self.layers {
223            current = layer.forward(&current);
224        }
225        current
226    }
227
228    pub fn param_count(&self) -> usize {
229        self.layers.iter().map(|l| l.param_count()).sum()
230    }
231
232    /// Get all parameters as a flat vector
233    pub fn get_params(&self) -> Vec<f64> {
234        let mut params = Vec::with_capacity(self.param_count());
235        for layer in &self.layers {
236            params.extend(layer.get_params());
237        }
238        params
239    }
240
241    /// Set all parameters from a flat vector
242    pub fn set_params(&mut self, params: &[f64]) {
243        let mut idx = 0;
244        for layer in &mut self.layers {
245            let count = layer.param_count();
246            layer.set_params(&params[idx..idx + count]);
247            idx += count;
248        }
249    }
250
251    /// Clone with mutated parameters
252    pub fn mutate(&self, mutation_rate: f64, mutation_strength: f64) -> Self {
253        let mut new_net = self.clone();
254        let mut params = new_net.get_params();
255
256        for p in &mut params {
257            if rand_f64() < mutation_rate {
258                *p += (rand_f64() * 2.0 - 1.0) * mutation_strength;
259            }
260        }
261
262        new_net.set_params(&params);
263        new_net
264    }
265
266    /// Crossover with another network
267    pub fn crossover(&self, other: &NeuralNetwork) -> Self {
268        let mut new_net = self.clone();
269        let params1 = self.get_params();
270        let params2 = other.get_params();
271
272        let mut new_params = Vec::with_capacity(params1.len());
273        for (p1, p2) in params1.iter().zip(params2.iter()) {
274            new_params.push(if rand_f64() < 0.5 { *p1 } else { *p2 });
275        }
276
277        new_net.set_params(&new_params);
278        new_net
279    }
280}
281
282/// A recurrent neural network cell (simple RNN)
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct RNNCell {
285    pub input_weights: Vec<Vec<f64>>,
286    pub hidden_weights: Vec<Vec<f64>>,
287    pub biases: Vec<f64>,
288    pub hidden_size: usize,
289    pub input_size: usize,
290    pub hidden_state: Vec<f64>,
291}
292
293impl RNNCell {
294    pub fn new(input_size: usize, hidden_size: usize) -> Self {
295        let scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
296
297        let mut input_weights = vec![vec![0.0; input_size]; hidden_size];
298        let mut hidden_weights = vec![vec![0.0; hidden_size]; hidden_size];
299        let biases = vec![0.0; hidden_size];
300
301        for i in 0..hidden_size {
302            for j in 0..input_size {
303                input_weights[i][j] = (rand_f64() * 2.0 - 1.0) * scale;
304            }
305            for j in 0..hidden_size {
306                hidden_weights[i][j] = (rand_f64() * 2.0 - 1.0) * scale;
307            }
308        }
309
310        Self {
311            input_weights,
312            hidden_weights,
313            biases,
314            hidden_size,
315            input_size,
316            hidden_state: vec![0.0; hidden_size],
317        }
318    }
319
320    pub fn forward(&mut self, input: &[f64]) -> Vec<f64> {
321        let mut new_hidden = vec![0.0; self.hidden_size];
322
323        for i in 0..self.hidden_size {
324            let mut sum = self.biases[i];
325
326            // Input contribution
327            for j in 0..self.input_size {
328                sum += self.input_weights[i][j] * input.get(j).unwrap_or(&0.0);
329            }
330
331            // Hidden state contribution
332            for j in 0..self.hidden_size {
333                sum += self.hidden_weights[i][j] * self.hidden_state[j];
334            }
335
336            new_hidden[i] = sum.tanh();
337        }
338
339        self.hidden_state = new_hidden.clone();
340        new_hidden
341    }
342
343    pub fn reset(&mut self) {
344        self.hidden_state = vec![0.0; self.hidden_size];
345    }
346
347    pub fn param_count(&self) -> usize {
348        self.input_size * self.hidden_size + self.hidden_size * self.hidden_size + self.hidden_size
349    }
350
351    pub fn get_params(&self) -> Vec<f64> {
352        let mut params = Vec::with_capacity(self.param_count());
353        for row in &self.input_weights {
354            params.extend(row);
355        }
356        for row in &self.hidden_weights {
357            params.extend(row);
358        }
359        params.extend(&self.biases);
360        params
361    }
362
363    pub fn set_params(&mut self, params: &[f64]) {
364        let mut idx = 0;
365        for i in 0..self.hidden_size {
366            for j in 0..self.input_size {
367                self.input_weights[i][j] = params[idx];
368                idx += 1;
369            }
370        }
371        for i in 0..self.hidden_size {
372            for j in 0..self.hidden_size {
373                self.hidden_weights[i][j] = params[idx];
374                idx += 1;
375            }
376        }
377        for i in 0..self.hidden_size {
378            self.biases[i] = params[idx];
379            idx += 1;
380        }
381    }
382}
383
384/// Attention mechanism for agent communication
385#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct AttentionLayer {
387    pub query_weights: Vec<Vec<f64>>,
388    pub key_weights: Vec<Vec<f64>>,
389    pub value_weights: Vec<Vec<f64>>,
390    pub dim: usize,
391}
392
393impl AttentionLayer {
394    pub fn new(dim: usize) -> Self {
395        let scale = (1.0 / dim as f64).sqrt();
396
397        let init_weights = || {
398            let mut w = vec![vec![0.0; dim]; dim];
399            for i in 0..dim {
400                for j in 0..dim {
401                    w[i][j] = (rand_f64() * 2.0 - 1.0) * scale;
402                }
403            }
404            w
405        };
406
407        Self {
408            query_weights: init_weights(),
409            key_weights: init_weights(),
410            value_weights: init_weights(),
411            dim,
412        }
413    }
414
415    /// Compute attention over a set of messages
416    pub fn attend(&self, query: &[f64], keys: &[Vec<f64>], values: &[Vec<f64>]) -> Vec<f64> {
417        if keys.is_empty() {
418            return vec![0.0; self.dim];
419        }
420
421        // Transform query
422        let q = self.transform(query, &self.query_weights);
423
424        // Compute attention scores
425        let mut scores: Vec<f64> = keys
426            .iter()
427            .map(|k| {
428                let k_transformed = self.transform(k, &self.key_weights);
429                let dot: f64 = q.iter().zip(k_transformed.iter()).map(|(a, b)| a * b).sum();
430                dot / (self.dim as f64).sqrt()
431            })
432            .collect();
433
434        // Softmax
435        let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
436        let exp_sum: f64 = scores.iter().map(|s| (s - max).exp()).sum();
437        for s in &mut scores {
438            *s = (*s - max).exp() / exp_sum;
439        }
440
441        // Weighted sum of values
442        let mut output = vec![0.0; self.dim];
443        for (i, v) in values.iter().enumerate() {
444            let v_transformed = self.transform(v, &self.value_weights);
445            for (j, val) in v_transformed.iter().enumerate() {
446                output[j] += scores[i] * val;
447            }
448        }
449
450        output
451    }
452
453    fn transform(&self, input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
454        let mut output = vec![0.0; self.dim];
455        for i in 0..self.dim {
456            for j in 0..self.dim.min(input.len()) {
457                output[i] += weights[i][j] * input[j];
458            }
459        }
460        output
461    }
462}
463
464/// Message encoder/decoder for agent communication
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct MessageCodec {
467    pub encoder: NeuralNetwork,
468    pub decoder: NeuralNetwork,
469    pub latent_dim: usize,
470}
471
472impl MessageCodec {
473    pub fn new(message_dim: usize, latent_dim: usize) -> Self {
474        let encoder = NeuralNetwork::feedforward(
475            "encoder",
476            &[message_dim, (message_dim + latent_dim) / 2, latent_dim],
477            Activation::ReLU,
478            Activation::Tanh,
479        );
480
481        let decoder = NeuralNetwork::feedforward(
482            "decoder",
483            &[latent_dim, (message_dim + latent_dim) / 2, message_dim],
484            Activation::ReLU,
485            Activation::Sigmoid,
486        );
487
488        Self {
489            encoder,
490            decoder,
491            latent_dim,
492        }
493    }
494
495    pub fn encode(&self, message: &[f64]) -> Vec<f64> {
496        self.encoder.forward(message)
497    }
498
499    pub fn decode(&self, latent: &[f64]) -> Vec<f64> {
500        self.decoder.forward(latent)
501    }
502}
503
504/// Consensus network for distributed decision making
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct ConsensusNetwork {
507    pub name: String,
508    pub local_encoder: NeuralNetwork,
509    pub message_processor: NeuralNetwork,
510    pub decision_network: NeuralNetwork,
511    pub state_dim: usize,
512    pub message_dim: usize,
513    pub num_rounds: usize,
514    pub agent_networks: Vec<NeuralNetwork>,
515}
516
517impl ConsensusNetwork {
518    pub fn new(
519        name: &str,
520        agent_count: usize,
521        state_dim: usize,
522        message_dim: usize,
523        decision_dim: usize,
524    ) -> Self {
525        let local_encoder = NeuralNetwork::feedforward(
526            "local_encoder",
527            &[state_dim, state_dim * 2, message_dim],
528            Activation::ReLU,
529            Activation::Tanh,
530        );
531
532        let message_processor = NeuralNetwork::feedforward(
533            "message_processor",
534            &[message_dim * 2, message_dim * 2, message_dim],
535            Activation::ReLU,
536            Activation::Tanh,
537        );
538
539        let decision_network = NeuralNetwork::feedforward(
540            "decision_network",
541            &[
542                message_dim + state_dim,
543                (message_dim + decision_dim) / 2,
544                decision_dim,
545            ],
546            Activation::ReLU,
547            Activation::Softmax,
548        );
549
550        // Create per-agent networks for individual processing
551        let agent_networks: Vec<NeuralNetwork> = (0..agent_count)
552            .map(|i| {
553                NeuralNetwork::feedforward(
554                    &format!("agent_{}", i),
555                    &[state_dim, state_dim, message_dim],
556                    Activation::ReLU,
557                    Activation::Tanh,
558                )
559            })
560            .collect();
561
562        Self {
563            name: name.to_string(),
564            local_encoder,
565            message_processor,
566            decision_network,
567            state_dim,
568            message_dim,
569            num_rounds: 3,
570            agent_networks,
571        }
572    }
573
574    /// Run consensus across all agents
575    /// Returns (individual_decisions, consensus_output)
576    pub fn consensus(&mut self, agent_inputs: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>) {
577        let num_agents = agent_inputs.len().min(self.agent_networks.len());
578
579        // Encode each agent's state
580        let mut messages: Vec<Vec<f64>> = agent_inputs
581            .iter()
582            .take(num_agents)
583            .map(|input| self.local_encoder.forward(input))
584            .collect();
585
586        // Run message passing rounds
587        for _ in 0..self.num_rounds {
588            let mut new_messages = Vec::with_capacity(num_agents);
589            for i in 0..num_agents {
590                // Aggregate other agents' messages (simple mean)
591                let mut aggregated = vec![0.0; self.message_dim];
592                let mut count = 0;
593                for (j, msg) in messages.iter().enumerate() {
594                    if j != i {
595                        for (k, v) in msg.iter().enumerate() {
596                            if k < aggregated.len() {
597                                aggregated[k] += v;
598                            }
599                        }
600                        count += 1;
601                    }
602                }
603                if count > 0 {
604                    for v in &mut aggregated {
605                        *v /= count as f64;
606                    }
607                }
608                new_messages.push(self.process_messages(&messages[i], &aggregated));
609            }
610            messages = new_messages;
611        }
612
613        // Compute individual decisions
614        let decisions: Vec<Vec<f64>> = messages
615            .iter()
616            .zip(agent_inputs.iter().take(num_agents))
617            .map(|(msg, state)| self.decide(msg, state))
618            .collect();
619
620        // Compute consensus as mean of all decisions
621        let decision_dim = decisions.first().map(|d| d.len()).unwrap_or(0);
622        let mut consensus = vec![0.0; decision_dim];
623        for decision in &decisions {
624            for (i, v) in decision.iter().enumerate() {
625                if i < consensus.len() {
626                    consensus[i] += v;
627                }
628            }
629        }
630        let num_decisions = decisions.len() as f64;
631        if num_decisions > 0.0 {
632            for v in &mut consensus {
633                *v /= num_decisions;
634            }
635        }
636
637        (decisions, consensus)
638    }
639
640    /// Generate initial message from local state
641    pub fn encode_state(&self, state: &[f64]) -> Vec<f64> {
642        self.local_encoder.forward(state)
643    }
644
645    /// Process incoming messages and own state to produce new message
646    pub fn process_messages(&self, own_message: &[f64], aggregated_messages: &[f64]) -> Vec<f64> {
647        let mut input = own_message.to_vec();
648        input.extend(aggregated_messages);
649        self.message_processor.forward(&input)
650    }
651
652    /// Make final decision based on converged messages and local state
653    pub fn decide(&self, final_message: &[f64], local_state: &[f64]) -> Vec<f64> {
654        let mut input = final_message.to_vec();
655        input.extend(local_state);
656        self.decision_network.forward(&input)
657    }
658
659    pub fn param_count(&self) -> usize {
660        self.local_encoder.param_count()
661            + self.message_processor.param_count()
662            + self.decision_network.param_count()
663    }
664
665    pub fn get_params(&self) -> Vec<f64> {
666        let mut params = self.local_encoder.get_params();
667        params.extend(self.message_processor.get_params());
668        params.extend(self.decision_network.get_params());
669        params
670    }
671
672    pub fn set_params(&mut self, params: &[f64]) {
673        let mut idx = 0;
674
675        let enc_count = self.local_encoder.param_count();
676        self.local_encoder.set_params(&params[idx..idx + enc_count]);
677        idx += enc_count;
678
679        let proc_count = self.message_processor.param_count();
680        self.message_processor
681            .set_params(&params[idx..idx + proc_count]);
682        idx += proc_count;
683
684        let dec_count = self.decision_network.param_count();
685        self.decision_network
686            .set_params(&params[idx..idx + dec_count]);
687    }
688}
689
690/// Registry for managing neural networks in the shell
691#[derive(Debug, Default)]
692pub struct NetworkRegistry {
693    pub networks: HashMap<String, NeuralNetwork>,
694    pub consensus_networks: HashMap<String, ConsensusNetwork>,
695    pub rnn_cells: HashMap<String, RNNCell>,
696}
697
698impl NetworkRegistry {
699    pub fn new() -> Self {
700        Self::default()
701    }
702
703    pub fn register_network(&mut self, net: NeuralNetwork) {
704        self.networks.insert(net.name.clone(), net);
705    }
706
707    pub fn get_network(&self, name: &str) -> Option<&NeuralNetwork> {
708        self.networks.get(name)
709    }
710
711    pub fn get_network_mut(&mut self, name: &str) -> Option<&mut NeuralNetwork> {
712        self.networks.get_mut(name)
713    }
714}
715
716// Simple pseudo-random number generator for reproducibility
717static mut RNG_STATE: u64 = 12345;
718
719fn rand_f64() -> f64 {
720    unsafe {
721        RNG_STATE = RNG_STATE.wrapping_mul(6364136223846793005).wrapping_add(1);
722        (RNG_STATE >> 33) as f64 / (1u64 << 31) as f64
723    }
724}
725
726pub fn seed_rng(seed: u64) {
727    unsafe {
728        RNG_STATE = seed;
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn test_dense_layer_forward() {
738        seed_rng(42);
739        let layer = DenseLayer::new(3, 2, Activation::ReLU);
740        let output = layer.forward(&[1.0, 2.0, 3.0]);
741        assert_eq!(output.len(), 2);
742    }
743
744    #[test]
745    fn test_neural_network_forward() {
746        seed_rng(42);
747        let net =
748            NeuralNetwork::feedforward("test", &[4, 8, 2], Activation::ReLU, Activation::Softmax);
749        let output = net.forward(&[1.0, 0.0, 1.0, 0.0]);
750        assert_eq!(output.len(), 2);
751        let sum: f64 = output.iter().sum();
752        assert!((sum - 1.0).abs() < 1e-6); // Softmax sums to 1
753    }
754
755    #[test]
756    fn test_network_mutation() {
757        seed_rng(42);
758        let net =
759            NeuralNetwork::feedforward("test", &[2, 4, 2], Activation::ReLU, Activation::Linear);
760        let mutated = net.mutate(0.5, 0.1);
761        assert_ne!(net.get_params(), mutated.get_params());
762    }
763
764    #[test]
765    fn test_rnn_cell() {
766        seed_rng(42);
767        let mut rnn = RNNCell::new(4, 8);
768        let out1 = rnn.forward(&[1.0, 0.0, 0.0, 0.0]);
769        let out2 = rnn.forward(&[0.0, 1.0, 0.0, 0.0]);
770        assert_eq!(out1.len(), 8);
771        assert_ne!(out1, out2); // Different due to hidden state
772    }
773
774    #[test]
775    fn test_consensus_network() {
776        seed_rng(42);
777        let mut consensus = ConsensusNetwork::new("test", 3, 4, 8, 3);
778        let agent_inputs = vec![
779            vec![1.0, 0.5, 0.0, 0.5],
780            vec![0.5, 1.0, 0.5, 0.0],
781            vec![0.0, 0.5, 1.0, 0.5],
782        ];
783
784        // Test individual encoding
785        let message = consensus.encode_state(&agent_inputs[0]);
786        assert_eq!(message.len(), 8);
787
788        // Test consensus
789        let (decisions, consensus_output) = consensus.consensus(&agent_inputs);
790        assert_eq!(decisions.len(), 3);
791        for decision in &decisions {
792            assert_eq!(decision.len(), 3);
793        }
794        assert_eq!(consensus_output.len(), 3);
795    }
796
797    #[test]
798    fn test_attention_layer() {
799        seed_rng(42);
800        let attention = AttentionLayer::new(4);
801        let query = vec![1.0, 0.0, 0.0, 0.0];
802        let keys = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
803        let values = vec![vec![1.0, 1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 1.0]];
804        let output = attention.attend(&query, &keys, &values);
805        assert_eq!(output.len(), 4);
806    }
807}