eryon_actors/engine/neural/
engine.rs

1/*
2    Appellation: neural <module>
3    Contrib: @FL03
4*/
5use super::{NeuralController, NeuralFeatures};
6
7use crate::engine::{ComputationalEngine, Engine, RawEngine};
8use crate::mem::TopoLedger;
9use cnc::{ReLU, Sigmoid};
10use eryon::{Direction, Head, State, Tail};
11use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand, s};
12use num_traits::{Float, FromPrimitive, NumAssign};
13
14fn _direction_to_idx(direction: Direction) -> usize {
15    match direction {
16        Direction::Left => 0,
17        Direction::Stay => 1,
18        Direction::Right => 2,
19    }
20}
21
22/// A Neural Turing Machine (NTM) is a type of recurrent neural network that can learn to
23/// perform algorithmic tasks by interacting with an external memory. The NTM consists of a
24/// controller network that interacts with a memory matrix using attention mechanisms.
25/// Internally, the controller is a shallow feed-forward neural network capable of processing
26/// encoded inputs and producing outputs that determine the next state, symbol, and direction
27/// of the machine.
28#[derive(Clone, Debug)]
29pub struct NeuralEngine<T = f32> {
30    pub(crate) alphabet: [usize; 3],
31    pub(crate) controller: NeuralController<T>,
32    pub(crate) memory: Array2<T>,
33    pub(crate) position: usize,
34    pub(crate) state: State<usize>,
35    pub(crate) tape: Vec<usize>,
36    // Read/write heads with attention weights
37    pub(crate) read_weights: Array1<T>,
38    /// Write weights
39    pub(crate) write_weights: Array1<T>,
40    // Learning parameters
41    pub(crate) learning_rate: T,
42}
43
44impl<T> NeuralEngine<T> {
45    pub const STATES: usize = 2;
46    pub const SYMBOLS: usize = 3;
47
48    pub fn new(alphabet: [usize; 3], State(initial_state): State<usize>) -> Self
49    where
50        T: Clone + Default + FromPrimitive,
51    {
52        let controller = NeuralController::new();
53        let memory_size = controller.features().dim_memory();
54        Self {
55            alphabet,
56            controller,
57            memory: Array2::default(memory_size),
58            position: 0,
59            state: State(initial_state),
60            tape: Vec::new(),
61            read_weights: Array1::default(memory_size.0),
62            write_weights: Array1::default(memory_size.0),
63            learning_rate: T::from_f64(0.01).unwrap(),
64        }
65    }
66    /// returns an immutable reference to the alphabet
67    pub const fn alphabet(&self) -> &[usize; 3] {
68        &self.alphabet
69    }
70    /// returns an immutable reference to the controller
71    pub const fn controller(&self) -> &NeuralController<T> {
72        &self.controller
73    }
74    /// returns a mutable reference to the controller
75    pub fn controller_mut(&mut self) -> &mut NeuralController<T> {
76        &mut self.controller
77    }
78    /// returns a copy of the engine controller's features
79    pub const fn features(&self) -> NeuralFeatures {
80        self.controller().features()
81    }
82    /// returns a mutable reference to the engine controller's features
83    pub fn features_mut(&mut self) -> &mut NeuralFeatures {
84        self.controller_mut().features_mut()
85    }
86    /// returns an immutable reference to the memory of the machine
87    pub const fn memory(&self) -> &Array2<T> {
88        &self.memory
89    }
90    /// returns a mutable reference to the memory of the machine
91    pub fn memory_mut(&mut self) -> &mut Array2<T> {
92        &mut self.memory
93    }
94    /// returns the current position of the head of the machine
95    pub const fn position(&self) -> usize {
96        self.position
97    }
98    /// returns a mutable reference to the position of the head of the machine
99    pub fn position_mut(&mut self) -> &mut usize {
100        &mut self.position
101    }
102    /// returns a copy of the state of the machine
103    pub const fn state(&self) -> State<usize> {
104        self.state
105    }
106    /// returns a mutable reference to the state of the machine
107    pub fn state_mut(&mut self) -> &mut State<usize> {
108        &mut self.state
109    }
110    /// returns an immutable reference to the tape of the machine
111    pub const fn tape(&self) -> &Vec<usize> {
112        &self.tape
113    }
114    /// returns a mutable reference to the tape of the machine
115    pub fn tape_mut(&mut self) -> &mut Vec<usize> {
116        &mut self.tape
117    }
118    /// set the current position of the machine
119    pub fn set_position(&mut self, position: usize) {
120        self.position = position;
121    }
122    /// set the state of the machine
123    pub fn set_state(&mut self, state: State<usize>) {
124        self.state = state;
125    }
126    /// set the tape of the machine
127    pub fn set_tape<I>(&mut self, iter: I)
128    where
129        I: IntoIterator<Item = usize>,
130    {
131        self.tape = Vec::from_iter(iter);
132    }
133    /// returns the head of the machine
134    pub fn head(&self) -> Head<usize, usize> {
135        Head::new(self.state, self.tape[self.position])
136    }
137    /// inititalize the controller and return a new instance with the randomized parameters
138    #[cfg(feature = "rand")]
139    pub fn init(self) -> Self
140    where
141        T: Float + FromPrimitive + rand_distr::uniform::SampleUniform,
142        rand_distr::StandardNormal: rand_distr::Distribution<T>,
143    {
144        Self {
145            controller: self.controller.init(),
146            ..self
147        }
148    }
149    /// clear's the contents of the tape and resets the position of the head back to 0
150    pub fn reset(&mut self) {
151        self.tape.clear();
152        self.reset_position();
153    }
154    /// reset the position of the head of the machine to `0`
155    pub fn reset_position(&mut self) {
156        self.position = 0;
157    }
158}
159
160impl<T> NeuralEngine<T>
161where
162    T: Float + FromPrimitive + ScalarOperand,
163    NeuralEngine<T>: ComputationalEngine<usize, [usize; 3], Store = Vec<usize>>,
164{
165    /// adapt the engine's weights to the target tail (pattern)
166    pub fn adapt_to_target(&mut self, tail: Tail<usize, usize>) -> crate::Result<()>
167    where
168        T: NumAssign + core::iter::Sum,
169    {
170        // Read current symbol
171        let cur_symbol = if self.position < self.tape.len() {
172            self.tape[self.position]
173        } else {
174            0
175        };
176
177        // encode the symbol into input for the controller
178        let controller_input = self.encode_input_symbol(cur_symbol);
179        // forward the input through the controller to get a prediction
180        let current_output = self.forward(&controller_input)?;
181        // convert the tail to a target tensor
182        let target_output = self.tail_to_targets(tail);
183        // compute the error
184        let error = &target_output - &current_output;
185        // initialize a buffer to store layer activations
186        let mut activations = Vec::new();
187        // forward the input through the first layer
188        let mut fwd = self.controller().input().forward(&controller_input)?.relu();
189        activations.push(fwd.clone());
190        // forward the input through the hidden layer
191        fwd = self.controller().hidden().forward(&fwd)?.relu();
192        activations.push(fwd.clone());
193        // forward the input through the output layer
194        fwd = self.controller().output().forward(&fwd)?.sigmoid();
195        activations.push(fwd.clone());
196
197        // compute the magnitude of the error (L2 Norm)
198        let error_magnitude = error.pow2().sum().sqrt();
199        // adapt the learning rate based on the magnitude of the error
200        let adaptive_lr = if error_magnitude > T::from(0.5).unwrap() {
201            self.learning_rate * T::from(1.5).unwrap() // Boost learning for large errors
202        } else {
203            self.learning_rate
204        };
205        // Calculate error gradient for hidden layer - prepare it before updating weights
206        let hidden_error = &error * self.controller().output_weights().t().relu_derivative();
207
208        // use blocks to avoid referencing issues
209        {
210            // start with the output layer
211            let delta = &error * activations.last().cloned().unwrap().sigmoid_derivative();
212            self.controller_mut()
213                .output_mut()
214                .backward(&activations[1], &delta, adaptive_lr)?;
215        }
216        // backpropagate the hidden layer(s)
217        {
218            let delta = hidden_error.dot(&activations[1].relu_derivative());
219            self.controller_mut()
220                .hidden_mut()
221                .backward(&activations[0], &delta, adaptive_lr)?;
222        }
223        // then backpropagate the input layer
224        {
225            let lr = adaptive_lr * T::from(0.5).unwrap();
226            let delta = hidden_error.dot(&activations[1].relu_derivative());
227            self.controller_mut()
228                .input_mut()
229                .backward(&controller_input, &delta, lr)?;
230        }
231
232        // if necessary, update the attention mechanism
233        if error_magnitude > T::from(0.3).unwrap() {
234            self.update_attention(controller_input);
235        }
236        Ok(())
237    }
238    /// extract a tail (next_state, next_symbol, direction) from controller output
239    pub fn decode_outputs_into_tail(&self, output: Array1<T>) -> Tail<usize, usize> {
240        let output_len = output.len();
241
242        // Split the output vector into segments for different actions
243
244        // State determination (first part of output)
245        // Use softmax-like approach to select a state
246        let num_states = 2; // Assuming 2 possible states
247        let state_end = core::cmp::min(num_states, output_len);
248        let state_probs = output.slice(s![0..state_end]).to_owned();
249
250        // Find the index with the highest activation for state
251        let next_state = state_probs
252            .iter()
253            .enumerate()
254            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
255            .map(|(i, _)| i)
256            .unwrap_or(0);
257
258        // Symbol determination (middle part of output)
259        let symbol_start = state_end;
260        let symbol_end = core::cmp::min(symbol_start + self.alphabet.len(), output_len);
261        let symbol_probs = if symbol_end > symbol_start {
262            output.slice(s![symbol_start..symbol_end]).to_owned()
263        } else {
264            Array1::<T>::zeros(1)
265        };
266
267        // Find the index with highest activation for symbol
268        let symbol_idx = symbol_probs
269            .iter()
270            .enumerate()
271            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
272            .map(|(i, _)| i)
273            .unwrap_or(0);
274
275        // Map to actual symbol in alphabet
276        let next_symbol = if symbol_idx < self.alphabet.len() {
277            self.alphabet[symbol_idx]
278        } else {
279            self.alphabet[0] // Default to first symbol
280        };
281
282        // Direction determination (last part of output)
283        // Use 3 values for Left, Stay, Right
284        let dir_start = symbol_end;
285        let dir_end = core::cmp::min(dir_start + 3, output_len);
286        let dir_probs = if dir_end > dir_start {
287            output.slice(s![dir_start..dir_end]).to_owned()
288        } else {
289            Array1::zeros(3)
290        };
291
292        // Find index with highest activation for direction
293        let dir_idx = dir_probs
294            .iter()
295            .enumerate()
296            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
297            .map(|(i, _)| i)
298            .unwrap_or(1); // Default to Stay (middle value)
299
300        let direction = match dir_idx {
301            0 => Direction::Left,
302            1 => Direction::Stay,
303            _ => Direction::Right,
304        };
305
306        Tail::new(direction, State(next_state), next_symbol)
307    }
308    /// determine the best direction toDetermine best direction to reach expected symbol based on current state and tape
309    pub fn determine_best_direction(&self, expected: usize) -> Direction {
310        // Check current position in the tape
311        let current_symbol = if self.position < self.tape.len() {
312            self.tape[self.position]
313        } else {
314            0 // Default empty symbol
315        };
316
317        // If we're already at the expected symbol, stay
318        if current_symbol == expected {
319            return Direction::Stay;
320        }
321
322        // Check left and right positions to find the expected symbol
323        let left_has_expected = self.position > 0
324            && self.position - 1 < self.tape.len()
325            && self.tape[self.position - 1] == expected;
326
327        let right_has_expected =
328            self.position + 1 < self.tape.len() && self.tape[self.position + 1] == expected;
329
330        // If expected symbol is to the left, go left
331        if left_has_expected {
332            return Direction::Left;
333        }
334
335        // If expected symbol is to the right, go right
336        if right_has_expected {
337            return Direction::Right;
338        }
339
340        // If we can't find the expected symbol, use a heuristic:
341        // - If we're closer to the left end, try moving right
342        // - If we're closer to the right end, try moving left
343        // - Otherwise, just stay
344        if self.position < self.tape.len() / 2 {
345            Direction::Right
346        } else if self.position > 0 {
347            Direction::Left
348        } else {
349            Direction::Stay
350        }
351    }
352    /// Prepares the input for the controller network by combining state and symbol information
353    pub fn encode_input_symbol(&self, symbol: usize) -> Array1<T> {
354        let mut input = Array1::<T>::zeros(self.controller.inputs());
355
356        // One-hot encode the current symbol (assuming symbol values are within alphabet size)
357        let symbol_idx = self.alphabet.iter().position(|&s| s == symbol).unwrap_or(0);
358
359        // First part of the input is the one-hot encoded symbol
360        if symbol_idx < self.controller.inputs() {
361            input[symbol_idx] = T::one();
362        }
363
364        // Encode the current state as part of the input
365        // Use bit representation of state value spread across input vector
366        let state_val = *self.state();
367        for i in 0..core::cmp::min(32, self.controller.inputs() - 3) {
368            if (state_val & (1 << i)) != 0 {
369                let idx = 3 + i;
370                if idx < self.controller.inputs() {
371                    input[idx] = T::one();
372                }
373            }
374        }
375
376        // Add memory read vector to the input
377        let read_content = self.read_memory();
378        let dim_inputs = self.controller.inputs();
379        let memory_start = core::cmp::min(dim_inputs.wrapping_sub(read_content.len()), dim_inputs);
380        for (i, &val) in read_content.iter().enumerate() {
381            let idx = memory_start + i;
382            if idx < self.controller.inputs() {
383                input[idx] = val;
384            }
385        }
386
387        input
388    }
389    /// Execute the NTM using learned rules from topological memory
390    pub fn execute_with_stored(&mut self, memory: &TopoLedger<T>) -> crate::Result<()>
391    where
392        T: NumAssign + core::iter::Sum,
393    {
394        // Get current head (state and symbol)
395        let head = self.head();
396
397        // Try to find a matching rule in the rulespace
398        if let Some(rule) = memory.find_rule_by_head(head.view()) {
399            // Found a relevant pattern in memory - use the learned rule
400            let Tail {
401                state: next_state,
402                symbol: next_symbol,
403                direction,
404            } = rule.tail().copied();
405
406            // Update tape
407            if self.position >= self.tape.len() {
408                self.tape.resize(self.position + 1, 0);
409            }
410            self.tape[self.position] = next_symbol;
411
412            // Update state
413            self.state = next_state;
414
415            // Move head according to direction
416            match direction {
417                Direction::Left => {
418                    if self.position > 0 {
419                        self.position -= 1;
420                    } else {
421                        self.tape.insert(0, 0);
422                    }
423                }
424                Direction::Right => {
425                    self.position += 1;
426                    if self.position >= self.tape.len() {
427                        self.tape.push(0);
428                    }
429                }
430                Direction::Stay => {}
431            }
432
433            // Create tail for adapting weights
434            let tail = Tail::new(direction, next_state, next_symbol);
435
436            // Update neural weights to reinforce this behavior
437            // This uses the learning_rate to gradually adapt the neural model
438            self.adapt_to_target(tail)?;
439
440            return Ok(());
441        }
442
443        // No rule found in memory, fall back to neural controller
444        self.step()
445    }
446
447    /// forward input through the controller
448    pub fn forward(&self, input: &Array1<T>) -> cnc::Result<Array1<T>> {
449        self.controller.forward(input)
450    }
451    /// train the model to learn the given dataset for a number of epochs
452    pub fn learn_sequence<U, V>(
453        &mut self,
454        inputs: &ArrayBase<U, Ix1>,
455        targets: &ArrayBase<V, Ix1>,
456    ) -> crate::Result<T>
457    where
458        U: Data<Elem = usize>,
459        V: Data<Elem = usize>,
460        T: NumAssign + core::iter::Sum,
461        NeuralEngine<T>: ComputationalEngine<usize, [usize; 3], Store = Vec<usize>>,
462    {
463        if inputs.len() != targets.len() {
464            return Err(crate::ActorError::InvalidShape(format!(
465                "The input and targets must have the same shape: {:?} != {:?}",
466                inputs.len(),
467                targets.len()
468            )));
469        }
470        // take a snapshot of the machine so that we can restore it after training
471        let snapshot = self.snapshot().cloned();
472
473        // Determine required tape length
474        let max_pos = snapshot.position() + inputs.len();
475        if self.tape.len() < max_pos {
476            self.tape.resize(max_pos, 0);
477        }
478
479        let mut total_error = T::zero();
480
481        // Process each input in the sequence
482        for (&input_symbol, &expected) in inputs.iter().zip(targets.iter()) {
483            // Write input to current position
484            if self.position >= self.tape.len() {
485                self.tape.push(input_symbol);
486            } else {
487                self.tape[self.position] = input_symbol;
488            }
489
490            // Step the machine
491            self.step()?;
492
493            // Calculate error (if output differs from expected)
494            if self.tape[self.position] != expected {
495                // Create target for training
496                let target = Tail::new(
497                    self.determine_best_direction(expected),
498                    self.state,
499                    expected,
500                );
501
502                // Adapt weights to the target pattern
503                self.adapt_to_target(target)?;
504
505                // Track error
506                total_error += T::one();
507            }
508        }
509
510        // Restore original state
511        self.restore_from_snapshot(snapshot);
512
513        Ok(total_error)
514    }
515    /// train the model to learn the given dataset for a number of epochs
516    pub fn learn_sequence_for<U, V>(
517        &mut self,
518        inputs: &ArrayBase<U, Ix1>,
519        targets: &ArrayBase<V, Ix1>,
520        epochs: usize,
521    ) -> crate::Result<T>
522    where
523        U: Data<Elem = usize>,
524        V: Data<Elem = usize>,
525        T: NumAssign + core::iter::Sum,
526    {
527        if inputs.len() != targets.len() {
528            return Err(crate::ActorError::InvalidShape(
529                "Input and output sequences must have the same length".to_string(),
530            ));
531        }
532
533        let snapshot = self.snapshot().cloned();
534
535        let mut total_error = T::zero();
536
537        // Training loop
538        for _ in 0..epochs {
539            // restore the original configuration before each epoch
540            self.restore_from_snapshot(snapshot.clone());
541
542            total_error += self.learn_sequence(inputs, targets)?;
543        }
544
545        Ok(total_error / T::from_usize(epochs).unwrap())
546    }
547    /// Predict the next n symbols in the sequence
548    pub fn predict_sequence(&mut self, n: usize) -> Vec<usize> {
549        let snapshot = self.snapshot().cloned();
550
551        let mut predictions = Vec::with_capacity(n);
552
553        // Run the machine forward to generate predictions
554        for _ in 0..n {
555            if let Ok(()) = self.step() {
556                predictions.push(self.tape[snapshot.position()]);
557            } else {
558                break;
559            }
560        }
561
562        // Restore original configuration
563        self.restore_from_snapshot(snapshot);
564
565        predictions
566    }
567    /// train the engine on a sequence of inputs with their expected targets
568    pub fn learn_sequences_for<U, V>(
569        &mut self,
570        inputs: &ArrayBase<U, Ix2>,
571        targets: &ArrayBase<V, Ix2>,
572        epochs: usize,
573    ) -> crate::Result<T>
574    where
575        U: Data<Elem = usize>,
576        V: Data<Elem = usize>,
577        T: NumAssign + core::iter::Sum,
578    {
579        if inputs.nrows() != targets.nrows() {
580            return Err(crate::ActorError::InvalidShape(format!(
581                "the number of input and target rows must be the same for batching: {:?} != {:?}",
582                inputs.nrows(),
583                targets.nrows()
584            )));
585        }
586        let batch_size = inputs.nrows();
587        let mut error = T::zero();
588        for _e in 0..epochs {
589            for (x, tgt) in inputs.rows().into_iter().zip(targets.rows()) {
590                error += self.learn_sequence(&x, &tgt)?;
591            }
592        }
593        Ok(error / T::from(epochs * batch_size).unwrap())
594    }
595    /// read from memory using attention mechanism
596    pub fn read_memory(&self) -> Array1<T>
597    where
598        for<'a> ndarray::ArrayView2<'a, T>: ndarray::linalg::Dot<Array1<T>, Output = Array1<T>>,
599    {
600        // Read from memory using attention mechanism
601        self.memory().t().dot(&self.read_weights)
602    }
603    /// Perform a single step of the Turing machine
604    pub fn step(&mut self) -> crate::Result<()> {
605        // Read current symbol
606        let symbol = if self.position < self.tape.len() {
607            self.tape[self.position]
608        } else {
609            0 // Default empty symbol
610        };
611
612        // Create controller input from state and symbol
613        let controller_input = self.encode_input_symbol(symbol);
614
615        // Process input through controller (neural network)
616        let controller_output = self.forward(&controller_input)?;
617
618        // Extract actions from controller output
619        let Tail {
620            state: next_state,
621            symbol: next_symbol,
622            direction,
623        } = self.decode_outputs_into_tail(controller_output);
624
625        // Update tape
626        if self.position >= self.tape.len() {
627            self.tape.resize(self.position + 1, 0);
628        }
629        self.tape[self.position] = next_symbol;
630
631        // Update state
632        self.state = next_state;
633
634        // Move head according to direction
635        match direction {
636            Direction::Left => {
637                if self.position > 0 {
638                    self.position -= 1;
639                } else {
640                    self.tape.insert(0, 0);
641                }
642            }
643            Direction::Right => {
644                self.position += 1;
645                if self.position >= self.tape.len() {
646                    self.tape.push(0);
647                }
648            }
649            Direction::Stay => {}
650        }
651
652        Ok(())
653    }
654    /// update the attention mechanism using controller output
655    pub fn update_attention(&mut self, inputs: Array1<T>) {
656        // Generate key vectors for content-based addressing
657        let key = self.controller().attention().dot(&inputs);
658
659        // Calculate content-based attention using cosine similarity
660        let d_memory = self.features().codex();
661        let mut similarities = Array1::<T>::zeros(d_memory);
662        let key_norm = key.dot(&key).sqrt();
663
664        if key_norm > T::zero() {
665            // Calculate similarity with each memory row
666            for i in 0..d_memory {
667                let memory_row = self.memory.row(i).to_owned();
668                let mem_norm = memory_row.dot(&memory_row).sqrt();
669
670                if mem_norm > T::zero() {
671                    let similarity = memory_row.dot(&key) / (key_norm * mem_norm);
672                    similarities[i] = similarity;
673                }
674            }
675        }
676
677        // Apply softmax to get normalized attention weights
678        let max_sim = similarities.fold(T::neg_infinity(), |m, &x| if x > m { x } else { m });
679        let mut exp_similarities = similarities.mapv(|x| (x - max_sim).exp());
680        let sum_exp = exp_similarities.sum();
681
682        if sum_exp > T::zero() {
683            exp_similarities.mapv_inplace(|x| x / sum_exp);
684        } else {
685            // If all similarities are extremely negative, use uniform weights
686            exp_similarities.fill(T::one() / T::from(d_memory).unwrap());
687        }
688
689        // Split attention for read and write weights (first half for reading, second half for writing)
690        let attention_size = inputs.len();
691        let read_size = attention_size / 2;
692
693        // Update read weights
694        self.read_weights = exp_similarities.clone();
695
696        // Get write gate from second half of input - determines how much to update write weights
697        if attention_size > read_size {
698            let write_intensity = inputs[read_size];
699            let write_gate = T::one() / (T::one() + (-write_intensity).exp()); // sigmoid
700
701            // Blend previous and new write weights
702            for i in 0..d_memory {
703                self.write_weights[i] = (T::one() - write_gate) * self.write_weights[i]
704                    + write_gate * exp_similarities[i];
705            }
706        }
707    }
708    /// write to memory using attention mechanism
709    pub fn write_memory(&mut self, erase: &Array1<T>, add: &Array1<T>) {
710        let dim_memory = self.features().dim_memory();
711        // Erase then add pattern with attention mechanism
712        for i in 0..dim_memory.0 {
713            let w = self.write_weights[i];
714            if w > T::zero() {
715                for j in 0..dim_memory.1 {
716                    self.memory[[i, j]] =
717                        self.memory[[i, j]] * (T::one() - w * erase[j]) + w * add[j];
718                }
719            }
720        }
721    }
722}
723
724impl<T> NeuralEngine<T> {
725    pub(crate) fn tail_to_targets(&self, tail: Tail<usize, usize>) -> Array1<T>
726    where
727        T: Clone + num_traits::One + num_traits::Zero,
728    {
729        let n_states = NeuralEngine::<T>::STATES;
730        // deconstruct the tail into its constituents
731        let Tail {
732            direction: tgt_direction,
733            state: tgt_state,
734            symbol: tgt_symbol,
735        } = tail;
736
737        // Create target output (what we want the network to produce)
738        let mut output = Array1::<T>::zeros(self.features().outputs());
739        // Target state (first part of output)
740        if *tgt_state < n_states && *tgt_state < output.len() {
741            output[*tgt_state] = T::one();
742        }
743
744        // Target symbol (middle part)
745        let symbol_start = n_states;
746        let symbol_idx = self
747            .alphabet
748            .iter()
749            .position(|&s| s == tgt_symbol)
750            .unwrap_or(0);
751        if symbol_start + symbol_idx < output.len() {
752            output[symbol_start + symbol_idx] = T::one();
753        }
754
755        // Target direction (last part)
756        let dir_start = symbol_start + self.alphabet.len();
757        let dir_idx = match tgt_direction {
758            Direction::Left => 0,
759            Direction::Stay => 1,
760            Direction::Right => 2,
761        };
762
763        if dir_start + dir_idx < output.len() {
764            output[dir_start + dir_idx] = T::one();
765        }
766        output
767    }
768}
769
770impl<T> Default for NeuralEngine<T>
771where
772    T: Clone + Default + FromPrimitive,
773{
774    fn default() -> Self {
775        Self::new([0, 4, 7], State(0))
776    }
777}
778
779impl<T> RawEngine for NeuralEngine<T>
780where
781    T: Send + Sync + core::fmt::Debug,
782{
783    type Store = Vec<usize>;
784
785    seal!();
786}
787
788impl<T> Engine for NeuralEngine<T>
789where
790    T: Clone + Default + Send + Sync + core::fmt::Debug + FromPrimitive,
791{
792    seal!();
793}
794
795impl<T> ComputationalEngine<usize, [usize; 3]> for NeuralEngine<T>
796where
797    T: Default + Float + FromPrimitive + ScalarOperand + Send + Sync + core::fmt::Debug,
798{
799    fn new(alphabet: [usize; 3], initial_state: State<usize>) -> Self {
800        Self::new(alphabet, initial_state)
801    }
802
803    fn alphabet(&self) -> &[usize; 3] {
804        &self.alphabet
805    }
806
807    fn alphabet_mut(&mut self) -> &mut [usize; 3] {
808        &mut self.alphabet
809    }
810
811    fn head(&self) -> Head<&usize, &usize> {
812        Head::new(self.state.view(), &self.tape[self.position])
813    }
814
815    fn head_mut(&mut self) -> Head<&mut usize, &mut usize> {
816        Head::new(self.state.view_mut(), &mut self.tape[self.position])
817    }
818
819    fn position(&self) -> usize {
820        self.position
821    }
822
823    fn position_mut(&mut self) -> &mut usize {
824        &mut self.position
825    }
826
827    fn state(&self) -> State<&usize> {
828        self.state.view()
829    }
830
831    fn state_mut(&mut self) -> State<&mut usize> {
832        self.state.view_mut()
833    }
834
835    fn tape(&self) -> &Vec<usize> {
836        &self.tape
837    }
838
839    fn tape_mut(&mut self) -> &mut Vec<usize> {
840        &mut self.tape
841    }
842
843    fn step(&mut self) -> crate::Result<()> {
844        self.step()
845    }
846
847    fn set_alphabet(&mut self, alphabet: [usize; 3]) {
848        self.alphabet = alphabet;
849    }
850
851    fn set_position(&mut self, position: usize) {
852        self.position = position;
853    }
854
855    fn set_state(&mut self, state: State<usize>) {
856        self.state = state;
857    }
858
859    fn set_tape<I>(&mut self, tape: I)
860    where
861        I: IntoIterator<Item = usize>,
862    {
863        self.tape = Vec::from_iter(tape);
864    }
865}