1use super::{NeuralController, NeuralFeatures};
6
7use crate::engine::{ComputationalEngine, Engine, RawEngine};
8use crate::mem::TopoLedger;
9use cnc::{ReLU, Sigmoid};
10use eryon::{Direction, Head, State, Tail};
11use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand, s};
12use num_traits::{Float, FromPrimitive, NumAssign};
13
14fn _direction_to_idx(direction: Direction) -> usize {
15 match direction {
16 Direction::Left => 0,
17 Direction::Stay => 1,
18 Direction::Right => 2,
19 }
20}
21
22#[derive(Clone, Debug)]
29pub struct NeuralEngine<T = f32> {
30 pub(crate) alphabet: [usize; 3],
31 pub(crate) controller: NeuralController<T>,
32 pub(crate) memory: Array2<T>,
33 pub(crate) position: usize,
34 pub(crate) state: State<usize>,
35 pub(crate) tape: Vec<usize>,
36 pub(crate) read_weights: Array1<T>,
38 pub(crate) write_weights: Array1<T>,
40 pub(crate) learning_rate: T,
42}
43
44impl<T> NeuralEngine<T> {
45 pub const STATES: usize = 2;
46 pub const SYMBOLS: usize = 3;
47
48 pub fn new(alphabet: [usize; 3], State(initial_state): State<usize>) -> Self
49 where
50 T: Clone + Default + FromPrimitive,
51 {
52 let controller = NeuralController::new();
53 let memory_size = controller.features().dim_memory();
54 Self {
55 alphabet,
56 controller,
57 memory: Array2::default(memory_size),
58 position: 0,
59 state: State(initial_state),
60 tape: Vec::new(),
61 read_weights: Array1::default(memory_size.0),
62 write_weights: Array1::default(memory_size.0),
63 learning_rate: T::from_f64(0.01).unwrap(),
64 }
65 }
66 pub const fn alphabet(&self) -> &[usize; 3] {
68 &self.alphabet
69 }
70 pub const fn controller(&self) -> &NeuralController<T> {
72 &self.controller
73 }
74 pub fn controller_mut(&mut self) -> &mut NeuralController<T> {
76 &mut self.controller
77 }
78 pub const fn features(&self) -> NeuralFeatures {
80 self.controller().features()
81 }
82 pub fn features_mut(&mut self) -> &mut NeuralFeatures {
84 self.controller_mut().features_mut()
85 }
86 pub const fn memory(&self) -> &Array2<T> {
88 &self.memory
89 }
90 pub fn memory_mut(&mut self) -> &mut Array2<T> {
92 &mut self.memory
93 }
94 pub const fn position(&self) -> usize {
96 self.position
97 }
98 pub fn position_mut(&mut self) -> &mut usize {
100 &mut self.position
101 }
102 pub const fn state(&self) -> State<usize> {
104 self.state
105 }
106 pub fn state_mut(&mut self) -> &mut State<usize> {
108 &mut self.state
109 }
110 pub const fn tape(&self) -> &Vec<usize> {
112 &self.tape
113 }
114 pub fn tape_mut(&mut self) -> &mut Vec<usize> {
116 &mut self.tape
117 }
118 pub fn set_position(&mut self, position: usize) {
120 self.position = position;
121 }
122 pub fn set_state(&mut self, state: State<usize>) {
124 self.state = state;
125 }
126 pub fn set_tape<I>(&mut self, iter: I)
128 where
129 I: IntoIterator<Item = usize>,
130 {
131 self.tape = Vec::from_iter(iter);
132 }
133 pub fn head(&self) -> Head<usize, usize> {
135 Head::new(self.state, self.tape[self.position])
136 }
137 #[cfg(feature = "rand")]
139 pub fn init(self) -> Self
140 where
141 T: Float + FromPrimitive + rand_distr::uniform::SampleUniform,
142 rand_distr::StandardNormal: rand_distr::Distribution<T>,
143 {
144 Self {
145 controller: self.controller.init(),
146 ..self
147 }
148 }
149 pub fn reset(&mut self) {
151 self.tape.clear();
152 self.reset_position();
153 }
154 pub fn reset_position(&mut self) {
156 self.position = 0;
157 }
158}
159
160impl<T> NeuralEngine<T>
161where
162 T: Float + FromPrimitive + ScalarOperand,
163 NeuralEngine<T>: ComputationalEngine<usize, [usize; 3], Store = Vec<usize>>,
164{
165 pub fn adapt_to_target(&mut self, tail: Tail<usize, usize>) -> crate::Result<()>
167 where
168 T: NumAssign + core::iter::Sum,
169 {
170 let cur_symbol = if self.position < self.tape.len() {
172 self.tape[self.position]
173 } else {
174 0
175 };
176
177 let controller_input = self.encode_input_symbol(cur_symbol);
179 let current_output = self.forward(&controller_input)?;
181 let target_output = self.tail_to_targets(tail);
183 let error = &target_output - ¤t_output;
185 let mut activations = Vec::new();
187 let mut fwd = self.controller().input().forward(&controller_input)?.relu();
189 activations.push(fwd.clone());
190 fwd = self.controller().hidden().forward(&fwd)?.relu();
192 activations.push(fwd.clone());
193 fwd = self.controller().output().forward(&fwd)?.sigmoid();
195 activations.push(fwd.clone());
196
197 let error_magnitude = error.pow2().sum().sqrt();
199 let adaptive_lr = if error_magnitude > T::from(0.5).unwrap() {
201 self.learning_rate * T::from(1.5).unwrap() } else {
203 self.learning_rate
204 };
205 let hidden_error = &error * self.controller().output_weights().t().relu_derivative();
207
208 {
210 let delta = &error * activations.last().cloned().unwrap().sigmoid_derivative();
212 self.controller_mut()
213 .output_mut()
214 .backward(&activations[1], &delta, adaptive_lr)?;
215 }
216 {
218 let delta = hidden_error.dot(&activations[1].relu_derivative());
219 self.controller_mut()
220 .hidden_mut()
221 .backward(&activations[0], &delta, adaptive_lr)?;
222 }
223 {
225 let lr = adaptive_lr * T::from(0.5).unwrap();
226 let delta = hidden_error.dot(&activations[1].relu_derivative());
227 self.controller_mut()
228 .input_mut()
229 .backward(&controller_input, &delta, lr)?;
230 }
231
232 if error_magnitude > T::from(0.3).unwrap() {
234 self.update_attention(controller_input);
235 }
236 Ok(())
237 }
238 pub fn decode_outputs_into_tail(&self, output: Array1<T>) -> Tail<usize, usize> {
240 let output_len = output.len();
241
242 let num_states = 2; let state_end = core::cmp::min(num_states, output_len);
248 let state_probs = output.slice(s![0..state_end]).to_owned();
249
250 let next_state = state_probs
252 .iter()
253 .enumerate()
254 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
255 .map(|(i, _)| i)
256 .unwrap_or(0);
257
258 let symbol_start = state_end;
260 let symbol_end = core::cmp::min(symbol_start + self.alphabet.len(), output_len);
261 let symbol_probs = if symbol_end > symbol_start {
262 output.slice(s![symbol_start..symbol_end]).to_owned()
263 } else {
264 Array1::<T>::zeros(1)
265 };
266
267 let symbol_idx = symbol_probs
269 .iter()
270 .enumerate()
271 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
272 .map(|(i, _)| i)
273 .unwrap_or(0);
274
275 let next_symbol = if symbol_idx < self.alphabet.len() {
277 self.alphabet[symbol_idx]
278 } else {
279 self.alphabet[0] };
281
282 let dir_start = symbol_end;
285 let dir_end = core::cmp::min(dir_start + 3, output_len);
286 let dir_probs = if dir_end > dir_start {
287 output.slice(s![dir_start..dir_end]).to_owned()
288 } else {
289 Array1::zeros(3)
290 };
291
292 let dir_idx = dir_probs
294 .iter()
295 .enumerate()
296 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
297 .map(|(i, _)| i)
298 .unwrap_or(1); let direction = match dir_idx {
301 0 => Direction::Left,
302 1 => Direction::Stay,
303 _ => Direction::Right,
304 };
305
306 Tail::new(direction, State(next_state), next_symbol)
307 }
308 pub fn determine_best_direction(&self, expected: usize) -> Direction {
310 let current_symbol = if self.position < self.tape.len() {
312 self.tape[self.position]
313 } else {
314 0 };
316
317 if current_symbol == expected {
319 return Direction::Stay;
320 }
321
322 let left_has_expected = self.position > 0
324 && self.position - 1 < self.tape.len()
325 && self.tape[self.position - 1] == expected;
326
327 let right_has_expected =
328 self.position + 1 < self.tape.len() && self.tape[self.position + 1] == expected;
329
330 if left_has_expected {
332 return Direction::Left;
333 }
334
335 if right_has_expected {
337 return Direction::Right;
338 }
339
340 if self.position < self.tape.len() / 2 {
345 Direction::Right
346 } else if self.position > 0 {
347 Direction::Left
348 } else {
349 Direction::Stay
350 }
351 }
352 pub fn encode_input_symbol(&self, symbol: usize) -> Array1<T> {
354 let mut input = Array1::<T>::zeros(self.controller.inputs());
355
356 let symbol_idx = self.alphabet.iter().position(|&s| s == symbol).unwrap_or(0);
358
359 if symbol_idx < self.controller.inputs() {
361 input[symbol_idx] = T::one();
362 }
363
364 let state_val = *self.state();
367 for i in 0..core::cmp::min(32, self.controller.inputs() - 3) {
368 if (state_val & (1 << i)) != 0 {
369 let idx = 3 + i;
370 if idx < self.controller.inputs() {
371 input[idx] = T::one();
372 }
373 }
374 }
375
376 let read_content = self.read_memory();
378 let dim_inputs = self.controller.inputs();
379 let memory_start = core::cmp::min(dim_inputs.wrapping_sub(read_content.len()), dim_inputs);
380 for (i, &val) in read_content.iter().enumerate() {
381 let idx = memory_start + i;
382 if idx < self.controller.inputs() {
383 input[idx] = val;
384 }
385 }
386
387 input
388 }
389 pub fn execute_with_stored(&mut self, memory: &TopoLedger<T>) -> crate::Result<()>
391 where
392 T: NumAssign + core::iter::Sum,
393 {
394 let head = self.head();
396
397 if let Some(rule) = memory.find_rule_by_head(head.view()) {
399 let Tail {
401 state: next_state,
402 symbol: next_symbol,
403 direction,
404 } = rule.tail().copied();
405
406 if self.position >= self.tape.len() {
408 self.tape.resize(self.position + 1, 0);
409 }
410 self.tape[self.position] = next_symbol;
411
412 self.state = next_state;
414
415 match direction {
417 Direction::Left => {
418 if self.position > 0 {
419 self.position -= 1;
420 } else {
421 self.tape.insert(0, 0);
422 }
423 }
424 Direction::Right => {
425 self.position += 1;
426 if self.position >= self.tape.len() {
427 self.tape.push(0);
428 }
429 }
430 Direction::Stay => {}
431 }
432
433 let tail = Tail::new(direction, next_state, next_symbol);
435
436 self.adapt_to_target(tail)?;
439
440 return Ok(());
441 }
442
443 self.step()
445 }
446
447 pub fn forward(&self, input: &Array1<T>) -> cnc::Result<Array1<T>> {
449 self.controller.forward(input)
450 }
451 pub fn learn_sequence<U, V>(
453 &mut self,
454 inputs: &ArrayBase<U, Ix1>,
455 targets: &ArrayBase<V, Ix1>,
456 ) -> crate::Result<T>
457 where
458 U: Data<Elem = usize>,
459 V: Data<Elem = usize>,
460 T: NumAssign + core::iter::Sum,
461 NeuralEngine<T>: ComputationalEngine<usize, [usize; 3], Store = Vec<usize>>,
462 {
463 if inputs.len() != targets.len() {
464 return Err(crate::ActorError::InvalidShape(format!(
465 "The input and targets must have the same shape: {:?} != {:?}",
466 inputs.len(),
467 targets.len()
468 )));
469 }
470 let snapshot = self.snapshot().cloned();
472
473 let max_pos = snapshot.position() + inputs.len();
475 if self.tape.len() < max_pos {
476 self.tape.resize(max_pos, 0);
477 }
478
479 let mut total_error = T::zero();
480
481 for (&input_symbol, &expected) in inputs.iter().zip(targets.iter()) {
483 if self.position >= self.tape.len() {
485 self.tape.push(input_symbol);
486 } else {
487 self.tape[self.position] = input_symbol;
488 }
489
490 self.step()?;
492
493 if self.tape[self.position] != expected {
495 let target = Tail::new(
497 self.determine_best_direction(expected),
498 self.state,
499 expected,
500 );
501
502 self.adapt_to_target(target)?;
504
505 total_error += T::one();
507 }
508 }
509
510 self.restore_from_snapshot(snapshot);
512
513 Ok(total_error)
514 }
515 pub fn learn_sequence_for<U, V>(
517 &mut self,
518 inputs: &ArrayBase<U, Ix1>,
519 targets: &ArrayBase<V, Ix1>,
520 epochs: usize,
521 ) -> crate::Result<T>
522 where
523 U: Data<Elem = usize>,
524 V: Data<Elem = usize>,
525 T: NumAssign + core::iter::Sum,
526 {
527 if inputs.len() != targets.len() {
528 return Err(crate::ActorError::InvalidShape(
529 "Input and output sequences must have the same length".to_string(),
530 ));
531 }
532
533 let snapshot = self.snapshot().cloned();
534
535 let mut total_error = T::zero();
536
537 for _ in 0..epochs {
539 self.restore_from_snapshot(snapshot.clone());
541
542 total_error += self.learn_sequence(inputs, targets)?;
543 }
544
545 Ok(total_error / T::from_usize(epochs).unwrap())
546 }
547 pub fn predict_sequence(&mut self, n: usize) -> Vec<usize> {
549 let snapshot = self.snapshot().cloned();
550
551 let mut predictions = Vec::with_capacity(n);
552
553 for _ in 0..n {
555 if let Ok(()) = self.step() {
556 predictions.push(self.tape[snapshot.position()]);
557 } else {
558 break;
559 }
560 }
561
562 self.restore_from_snapshot(snapshot);
564
565 predictions
566 }
567 pub fn learn_sequences_for<U, V>(
569 &mut self,
570 inputs: &ArrayBase<U, Ix2>,
571 targets: &ArrayBase<V, Ix2>,
572 epochs: usize,
573 ) -> crate::Result<T>
574 where
575 U: Data<Elem = usize>,
576 V: Data<Elem = usize>,
577 T: NumAssign + core::iter::Sum,
578 {
579 if inputs.nrows() != targets.nrows() {
580 return Err(crate::ActorError::InvalidShape(format!(
581 "the number of input and target rows must be the same for batching: {:?} != {:?}",
582 inputs.nrows(),
583 targets.nrows()
584 )));
585 }
586 let batch_size = inputs.nrows();
587 let mut error = T::zero();
588 for _e in 0..epochs {
589 for (x, tgt) in inputs.rows().into_iter().zip(targets.rows()) {
590 error += self.learn_sequence(&x, &tgt)?;
591 }
592 }
593 Ok(error / T::from(epochs * batch_size).unwrap())
594 }
595 pub fn read_memory(&self) -> Array1<T>
597 where
598 for<'a> ndarray::ArrayView2<'a, T>: ndarray::linalg::Dot<Array1<T>, Output = Array1<T>>,
599 {
600 self.memory().t().dot(&self.read_weights)
602 }
603 pub fn step(&mut self) -> crate::Result<()> {
605 let symbol = if self.position < self.tape.len() {
607 self.tape[self.position]
608 } else {
609 0 };
611
612 let controller_input = self.encode_input_symbol(symbol);
614
615 let controller_output = self.forward(&controller_input)?;
617
618 let Tail {
620 state: next_state,
621 symbol: next_symbol,
622 direction,
623 } = self.decode_outputs_into_tail(controller_output);
624
625 if self.position >= self.tape.len() {
627 self.tape.resize(self.position + 1, 0);
628 }
629 self.tape[self.position] = next_symbol;
630
631 self.state = next_state;
633
634 match direction {
636 Direction::Left => {
637 if self.position > 0 {
638 self.position -= 1;
639 } else {
640 self.tape.insert(0, 0);
641 }
642 }
643 Direction::Right => {
644 self.position += 1;
645 if self.position >= self.tape.len() {
646 self.tape.push(0);
647 }
648 }
649 Direction::Stay => {}
650 }
651
652 Ok(())
653 }
654 pub fn update_attention(&mut self, inputs: Array1<T>) {
656 let key = self.controller().attention().dot(&inputs);
658
659 let d_memory = self.features().codex();
661 let mut similarities = Array1::<T>::zeros(d_memory);
662 let key_norm = key.dot(&key).sqrt();
663
664 if key_norm > T::zero() {
665 for i in 0..d_memory {
667 let memory_row = self.memory.row(i).to_owned();
668 let mem_norm = memory_row.dot(&memory_row).sqrt();
669
670 if mem_norm > T::zero() {
671 let similarity = memory_row.dot(&key) / (key_norm * mem_norm);
672 similarities[i] = similarity;
673 }
674 }
675 }
676
677 let max_sim = similarities.fold(T::neg_infinity(), |m, &x| if x > m { x } else { m });
679 let mut exp_similarities = similarities.mapv(|x| (x - max_sim).exp());
680 let sum_exp = exp_similarities.sum();
681
682 if sum_exp > T::zero() {
683 exp_similarities.mapv_inplace(|x| x / sum_exp);
684 } else {
685 exp_similarities.fill(T::one() / T::from(d_memory).unwrap());
687 }
688
689 let attention_size = inputs.len();
691 let read_size = attention_size / 2;
692
693 self.read_weights = exp_similarities.clone();
695
696 if attention_size > read_size {
698 let write_intensity = inputs[read_size];
699 let write_gate = T::one() / (T::one() + (-write_intensity).exp()); for i in 0..d_memory {
703 self.write_weights[i] = (T::one() - write_gate) * self.write_weights[i]
704 + write_gate * exp_similarities[i];
705 }
706 }
707 }
708 pub fn write_memory(&mut self, erase: &Array1<T>, add: &Array1<T>) {
710 let dim_memory = self.features().dim_memory();
711 for i in 0..dim_memory.0 {
713 let w = self.write_weights[i];
714 if w > T::zero() {
715 for j in 0..dim_memory.1 {
716 self.memory[[i, j]] =
717 self.memory[[i, j]] * (T::one() - w * erase[j]) + w * add[j];
718 }
719 }
720 }
721 }
722}
723
724impl<T> NeuralEngine<T> {
725 pub(crate) fn tail_to_targets(&self, tail: Tail<usize, usize>) -> Array1<T>
726 where
727 T: Clone + num_traits::One + num_traits::Zero,
728 {
729 let n_states = NeuralEngine::<T>::STATES;
730 let Tail {
732 direction: tgt_direction,
733 state: tgt_state,
734 symbol: tgt_symbol,
735 } = tail;
736
737 let mut output = Array1::<T>::zeros(self.features().outputs());
739 if *tgt_state < n_states && *tgt_state < output.len() {
741 output[*tgt_state] = T::one();
742 }
743
744 let symbol_start = n_states;
746 let symbol_idx = self
747 .alphabet
748 .iter()
749 .position(|&s| s == tgt_symbol)
750 .unwrap_or(0);
751 if symbol_start + symbol_idx < output.len() {
752 output[symbol_start + symbol_idx] = T::one();
753 }
754
755 let dir_start = symbol_start + self.alphabet.len();
757 let dir_idx = match tgt_direction {
758 Direction::Left => 0,
759 Direction::Stay => 1,
760 Direction::Right => 2,
761 };
762
763 if dir_start + dir_idx < output.len() {
764 output[dir_start + dir_idx] = T::one();
765 }
766 output
767 }
768}
769
770impl<T> Default for NeuralEngine<T>
771where
772 T: Clone + Default + FromPrimitive,
773{
774 fn default() -> Self {
775 Self::new([0, 4, 7], State(0))
776 }
777}
778
779impl<T> RawEngine for NeuralEngine<T>
780where
781 T: Send + Sync + core::fmt::Debug,
782{
783 type Store = Vec<usize>;
784
785 seal!();
786}
787
788impl<T> Engine for NeuralEngine<T>
789where
790 T: Clone + Default + Send + Sync + core::fmt::Debug + FromPrimitive,
791{
792 seal!();
793}
794
795impl<T> ComputationalEngine<usize, [usize; 3]> for NeuralEngine<T>
796where
797 T: Default + Float + FromPrimitive + ScalarOperand + Send + Sync + core::fmt::Debug,
798{
799 fn new(alphabet: [usize; 3], initial_state: State<usize>) -> Self {
800 Self::new(alphabet, initial_state)
801 }
802
803 fn alphabet(&self) -> &[usize; 3] {
804 &self.alphabet
805 }
806
807 fn alphabet_mut(&mut self) -> &mut [usize; 3] {
808 &mut self.alphabet
809 }
810
811 fn head(&self) -> Head<&usize, &usize> {
812 Head::new(self.state.view(), &self.tape[self.position])
813 }
814
815 fn head_mut(&mut self) -> Head<&mut usize, &mut usize> {
816 Head::new(self.state.view_mut(), &mut self.tape[self.position])
817 }
818
819 fn position(&self) -> usize {
820 self.position
821 }
822
823 fn position_mut(&mut self) -> &mut usize {
824 &mut self.position
825 }
826
827 fn state(&self) -> State<&usize> {
828 self.state.view()
829 }
830
831 fn state_mut(&mut self) -> State<&mut usize> {
832 self.state.view_mut()
833 }
834
835 fn tape(&self) -> &Vec<usize> {
836 &self.tape
837 }
838
839 fn tape_mut(&mut self) -> &mut Vec<usize> {
840 &mut self.tape
841 }
842
843 fn step(&mut self) -> crate::Result<()> {
844 self.step()
845 }
846
847 fn set_alphabet(&mut self, alphabet: [usize; 3]) {
848 self.alphabet = alphabet;
849 }
850
851 fn set_position(&mut self, position: usize) {
852 self.position = position;
853 }
854
855 fn set_state(&mut self, state: State<usize>) {
856 self.state = state;
857 }
858
859 fn set_tape<I>(&mut self, tape: I)
860 where
861 I: IntoIterator<Item = usize>,
862 {
863 self.tape = Vec::from_iter(tape);
864 }
865}