1use anyhow::{anyhow, Result};
12use scirs2_core::ndarray::concatenate as ndarray_concatenate;
13use scirs2_core::ndarray_ext::{s, Array1, Array2, Axis};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct DNCConfig {
19 pub memory_size: usize,
21 pub memory_width: usize,
23 pub num_read_heads: usize,
25 pub controller_size: usize,
27 pub output_size: usize,
29 pub memory_learning_rate: f32,
31 pub memory_decay: f32,
33}
34
35impl Default for DNCConfig {
36 fn default() -> Self {
37 Self {
38 memory_size: 256,
39 memory_width: 64,
40 num_read_heads: 4,
41 controller_size: 512,
42 output_size: 256,
43 memory_learning_rate: 0.001,
44 memory_decay: 0.95,
45 }
46 }
47}
48
49pub struct ControllerNetwork {
51 pub(crate) w_ih: Array2<f32>,
53 pub(crate) w_hh: Array2<f32>,
55 pub(crate) w_ho: Array2<f32>,
57 pub(crate) bias_h: Array1<f32>,
59 pub(crate) bias_o: Array1<f32>,
60 pub(crate) hidden_state: Array1<f32>,
62 pub(crate) cell_state: Array1<f32>,
64}
65
66impl ControllerNetwork {
67 pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
68 use scirs2_core::random::Random;
69 let mut rng = Random::default();
70
71 let w_ih =
72 Array2::from_shape_fn((hidden_size, input_size), |_| rng.random_range(-0.1..0.1));
73 let w_hh =
74 Array2::from_shape_fn((hidden_size, hidden_size), |_| rng.random_range(-0.1..0.1));
75 let w_ho =
76 Array2::from_shape_fn((output_size, hidden_size), |_| rng.random_range(-0.1..0.1));
77 let bias_h = Array1::zeros(hidden_size);
78 let bias_o = Array1::zeros(output_size);
79 let hidden_state = Array1::zeros(hidden_size);
80 let cell_state = Array1::zeros(hidden_size);
81
82 Self {
83 w_ih,
84 w_hh,
85 w_ho,
86 bias_h,
87 bias_o,
88 hidden_state,
89 cell_state,
90 }
91 }
92
93 pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
95 let input_gate = self
96 .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
97 let forget_gate = self
98 .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
99 let cell_gate =
100 self.tanh(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
101 let output_gate = self
102 .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
103
104 self.cell_state = &forget_gate * &self.cell_state + &input_gate * &cell_gate;
105 self.hidden_state = &output_gate * &self.tanh(&self.cell_state);
106
107 self.w_ho.dot(&self.hidden_state) + &self.bias_o
108 }
109
110 fn sigmoid(&self, x: &Array1<f32>) -> Array1<f32> {
111 x.map(|&v| 1.0 / (1.0 + (-v).exp()))
112 }
113
114 fn tanh(&self, x: &Array1<f32>) -> Array1<f32> {
115 x.map(|&v| v.tanh())
116 }
117}
118
119pub struct ReadHead {
121 pub(crate) key: Array1<f32>,
123 pub(crate) key_strength: f32,
125 pub(crate) free_gates: Array1<f32>,
127 pub(crate) read_modes: Array1<f32>,
129}
130
131impl ReadHead {
132 pub fn new(memory_width: usize) -> Self {
133 Self {
134 key: Array1::zeros(memory_width),
135 key_strength: 1.0,
136 free_gates: Array1::zeros(memory_width),
137 read_modes: Array1::from_vec(vec![1.0, 0.0, 0.0]),
138 }
139 }
140
141 pub fn generate_weighting(
143 &self,
144 memory: &Array2<f32>,
145 link_matrix: &Array2<f32>,
146 prev_read_weighting: &Array1<f32>,
147 ) -> Array1<f32> {
148 let content_weighting = self.content_lookup(memory);
149 let forward_weighting = link_matrix.dot(prev_read_weighting);
150 let backward_weighting = link_matrix.t().dot(prev_read_weighting);
151
152 let combined_weighting = self.read_modes[0] * &backward_weighting
153 + self.read_modes[1] * &content_weighting
154 + self.read_modes[2] * &forward_weighting;
155
156 let sum = combined_weighting.sum();
157 if sum > 0.0 {
158 combined_weighting / sum
159 } else {
160 Array1::zeros(memory.nrows())
161 }
162 }
163
164 fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
165 let mut similarities = Array1::zeros(memory.nrows());
166 for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
167 similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
168 }
169 let scaled = similarities.map(|&x| (x * self.key_strength).exp());
170 let sum = scaled.sum();
171 if sum > 0.0 {
172 scaled / sum
173 } else {
174 Array1::zeros(memory.nrows())
175 }
176 }
177}
178
179pub struct WriteHead {
181 pub(crate) key: Array1<f32>,
182 pub(crate) key_strength: f32,
183 pub(crate) erase_vector: Array1<f32>,
184 pub(crate) write_vector: Array1<f32>,
185 pub(crate) allocation_gate: f32,
186 pub(crate) write_gate: f32,
187}
188
189impl WriteHead {
190 pub fn new(memory_width: usize) -> Self {
191 Self {
192 key: Array1::zeros(memory_width),
193 key_strength: 1.0,
194 erase_vector: Array1::zeros(memory_width),
195 write_vector: Array1::zeros(memory_width),
196 allocation_gate: 0.0,
197 write_gate: 1.0,
198 }
199 }
200
201 pub fn generate_weighting(
203 &self,
204 memory: &Array2<f32>,
205 usage_vector: &Array1<f32>,
206 ) -> Array1<f32> {
207 let content_weighting = self.content_lookup(memory);
208 let allocation_weighting = self.allocation_lookup(usage_vector);
209
210 self.write_gate
211 * (self.allocation_gate * allocation_weighting
212 + (1.0 - self.allocation_gate) * content_weighting)
213 }
214
215 fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
216 let mut similarities = Array1::zeros(memory.nrows());
217 for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
218 similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
219 }
220 let scaled = similarities.map(|&x| (x * self.key_strength).exp());
221 let sum = scaled.sum();
222 if sum > 0.0 {
223 scaled / sum
224 } else {
225 Array1::zeros(memory.nrows())
226 }
227 }
228
229 fn allocation_lookup(&self, usage_vector: &Array1<f32>) -> Array1<f32> {
230 let mut indices: Vec<usize> = (0..usage_vector.len()).collect();
231 indices.sort_by(|&a, &b| {
232 usage_vector[a]
233 .partial_cmp(&usage_vector[b])
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 let mut allocation = Array1::zeros(usage_vector.len());
238 for (rank, &idx) in indices.iter().enumerate() {
239 allocation[idx] = 1.0 / (rank as f32 + 1.0);
240 }
241 let sum = allocation.sum();
242 if sum > 0.0 {
243 allocation / sum
244 } else {
245 Array1::zeros(usage_vector.len())
246 }
247 }
248
249 pub fn write_to_memory(&self, memory: &mut Array2<f32>, weighting: &Array1<f32>) {
251 for i in 0..memory.nrows() {
252 for j in 0..memory.ncols() {
253 memory[[i, j]] *= 1.0 - weighting[i] * self.erase_vector[j];
254 }
255 }
256 for i in 0..memory.nrows() {
257 for j in 0..memory.ncols() {
258 memory[[i, j]] += weighting[i] * self.write_vector[j];
259 }
260 }
261 }
262}
263
264pub struct UsageTracker {
266 pub(crate) usage: Array1<f32>,
267 pub(crate) memory_size: usize,
268}
269
270impl UsageTracker {
271 pub fn new(memory_size: usize) -> Self {
272 Self {
273 usage: Array1::zeros(memory_size),
274 memory_size,
275 }
276 }
277
278 pub fn update(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
279 for i in 0..self.memory_size {
280 self.usage[i] = (self.usage[i] + write_weighting[i] - self.usage[i] * free_gates[i])
281 .clamp(0.0, 1.0);
282 }
283 }
284
285 pub fn get_allocation_weighting(&self, _allocation_gate: f32) -> Array1<f32> {
286 let mut sorted_indices: Vec<usize> = (0..self.memory_size).collect();
287 sorted_indices.sort_by(|&a, &b| {
288 self.usage[a]
289 .partial_cmp(&self.usage[b])
290 .unwrap_or(std::cmp::Ordering::Equal)
291 });
292
293 let mut weights = Array1::zeros(self.memory_size);
294 for (rank, &idx) in sorted_indices.iter().enumerate() {
295 weights[idx] = 1.0 / (rank as f32 + 1.0);
296 }
297 let sum = weights.sum();
298 if sum > 0.0 {
299 weights / sum
300 } else {
301 Array1::zeros(self.memory_size)
302 }
303 }
304}
305
306pub struct AllocationMechanism {
308 pub(crate) usage_tracker: UsageTracker,
309}
310
311impl AllocationMechanism {
312 pub fn new(memory_size: usize) -> Self {
313 Self {
314 usage_tracker: UsageTracker::new(memory_size),
315 }
316 }
317
318 pub fn allocate(&mut self, allocation_gate: f32) -> Array1<f32> {
319 self.usage_tracker.get_allocation_weighting(allocation_gate)
320 }
321
322 pub fn update_usage(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
323 self.usage_tracker.update(write_weighting, free_gates);
324 }
325}
326
327pub struct TemporalLinkage {
329 pub(crate) link_matrix: Array2<f32>,
330 pub(crate) precedence_weighting: Array1<f32>,
331}
332
333impl TemporalLinkage {
334 pub fn new(memory_size: usize) -> Self {
335 Self {
336 link_matrix: Array2::zeros((memory_size, memory_size)),
337 precedence_weighting: Array1::zeros(memory_size),
338 }
339 }
340
341 pub fn update(&mut self, write_weighting: &Array1<f32>) {
342 let sum = write_weighting.sum();
343 if sum > 0.0 {
344 self.precedence_weighting = (1.0 - sum) * &self.precedence_weighting + write_weighting;
345 }
346 for i in 0..self.link_matrix.nrows() {
347 for j in 0..self.link_matrix.ncols() {
348 if i != j {
349 self.link_matrix[[i, j]] = (1.0 - write_weighting[i] - write_weighting[j])
350 * self.link_matrix[[i, j]]
351 + write_weighting[i] * self.precedence_weighting[j];
352 }
353 }
354 }
355 }
356
357 pub fn get_link_matrix(&self) -> &Array2<f32> {
358 &self.link_matrix
359 }
360}
361
362pub struct MemoryAddressing {
364 pub(crate) allocation_mechanism: AllocationMechanism,
365 pub(crate) temporal_linkage: TemporalLinkage,
366}
367
368pub struct DifferentiableNeuralComputer {
370 pub(crate) config: DNCConfig,
371 pub(crate) controller: ControllerNetwork,
372 pub(crate) memory_matrix: Array2<f32>,
373 pub(crate) read_heads: Vec<ReadHead>,
374 pub(crate) write_head: WriteHead,
375 pub(crate) memory_addressing: MemoryAddressing,
376 pub(crate) usage_vector: Array1<f32>,
377 pub(crate) precedence_weights: Array1<f32>,
378 pub(crate) link_matrix: Array2<f32>,
379 pub(crate) read_weightings: Array2<f32>,
380 pub(crate) write_weighting: Array1<f32>,
381}
382
383impl DifferentiableNeuralComputer {
384 pub fn new(config: DNCConfig) -> Self {
386 let memory_matrix = Array2::zeros((config.memory_size, config.memory_width));
387 let usage_vector = Array1::zeros(config.memory_size);
388 let precedence_weights = Array1::zeros(config.memory_size);
389 let link_matrix = Array2::zeros((config.memory_size, config.memory_size));
390 let read_weightings = Array2::zeros((config.num_read_heads, config.memory_size));
391 let write_weighting = Array1::zeros(config.memory_size);
392
393 let controller = ControllerNetwork::new(
394 config.memory_width + config.num_read_heads * config.memory_width,
395 config.controller_size,
396 config.output_size
397 + config.memory_width * (config.num_read_heads + 1)
398 + 3 * config.num_read_heads
399 + 5,
400 );
401
402 let read_heads = (0..config.num_read_heads)
403 .map(|_| ReadHead::new(config.memory_width))
404 .collect();
405
406 let write_head = WriteHead::new(config.memory_width);
407
408 let memory_addressing = MemoryAddressing {
409 allocation_mechanism: AllocationMechanism::new(config.memory_size),
410 temporal_linkage: TemporalLinkage::new(config.memory_size),
411 };
412
413 Self {
414 config,
415 controller,
416 memory_matrix,
417 read_heads,
418 write_head,
419 memory_addressing,
420 usage_vector,
421 precedence_weights,
422 link_matrix,
423 read_weightings,
424 write_weighting,
425 }
426 }
427
428 pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
430 let mut read_vectors = Vec::new();
431 for (i, read_head) in self.read_heads.iter().enumerate() {
432 let read_weighting = read_head.generate_weighting(
433 &self.memory_matrix,
434 &self.link_matrix,
435 &self.read_weightings.row(i).to_owned(),
436 );
437 let read_vector = self.memory_matrix.t().dot(&read_weighting);
438 read_vectors.push(read_vector);
439 }
440
441 let mut controller_input = input.clone();
442 for read_vector in &read_vectors {
443 let views: &[_] = &[controller_input.view(), read_vector.view()];
444 controller_input = ndarray_concatenate(Axis(0), views)
445 .map_err(|e| anyhow!("concatenate failed: {}", e))?;
446 }
447
448 let controller_output = self.controller.forward(&controller_input);
449 let (output, _interface_vector) = self.parse_controller_output(&controller_output)?;
450
451 let write_weighting = self
452 .write_head
453 .generate_weighting(&self.memory_matrix, &self.usage_vector);
454 self.write_head
455 .write_to_memory(&mut self.memory_matrix, &write_weighting);
456
457 let free_gates = Array1::ones(self.config.memory_size);
458 self.memory_addressing
459 .allocation_mechanism
460 .update_usage(&write_weighting, &free_gates);
461 self.memory_addressing
462 .temporal_linkage
463 .update(&write_weighting);
464
465 self.write_weighting = write_weighting;
466 self.link_matrix = self
467 .memory_addressing
468 .temporal_linkage
469 .get_link_matrix()
470 .clone();
471
472 Ok(output)
473 }
474
475 fn parse_controller_output(&self, output: &Array1<f32>) -> Result<(Array1<f32>, Array1<f32>)> {
476 if output.len() < self.config.output_size {
477 return Err(anyhow!("Controller output too short"));
478 }
479 let network_output = output.slice(s![..self.config.output_size]).to_owned();
480 let interface_vector = output.slice(s![self.config.output_size..]).to_owned();
481 Ok((network_output, interface_vector))
482 }
483
484 pub fn reset(&mut self) {
486 self.memory_matrix.fill(0.0);
487 self.usage_vector.fill(0.0);
488 self.precedence_weights.fill(0.0);
489 self.link_matrix.fill(0.0);
490 self.read_weightings.fill(0.0);
491 self.write_weighting.fill(0.0);
492 }
493
494 pub fn get_memory_utilization(&self) -> f32 {
496 self.usage_vector.sum() / self.usage_vector.len() as f32
497 }
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct NTMConfig {
503 pub memory_size: usize,
504 pub memory_width: usize,
505 pub num_heads: usize,
506 pub controller_size: usize,
507 pub shift_range: usize,
508}
509
510impl Default for NTMConfig {
511 fn default() -> Self {
512 Self {
513 memory_size: 128,
514 memory_width: 32,
515 num_heads: 2,
516 controller_size: 256,
517 shift_range: 3,
518 }
519 }
520}
521
522pub struct NTMHead {
524 pub(crate) key: Array1<f32>,
525 pub(crate) key_strength: f32,
526 pub(crate) gate: f32,
527 pub(crate) shift_weights: Array1<f32>,
528 pub(crate) gamma: f32,
529 pub(crate) prev_weighting: Array1<f32>,
530}
531
532impl NTMHead {
533 pub fn new(memory_width: usize, memory_size: usize, shift_range: usize) -> Self {
534 Self {
535 key: Array1::zeros(memory_width),
536 key_strength: 1.0,
537 gate: 0.5,
538 shift_weights: Array1::zeros(2 * shift_range + 1),
539 gamma: 1.0,
540 prev_weighting: Array1::zeros(memory_size),
541 }
542 }
543
544 pub fn address(&mut self, memory: &Array2<f32>) -> Array1<f32> {
546 let content_weights = self.content_addressing(memory);
547 let gated_weights = self.gate * &content_weights + (1.0 - self.gate) * &self.prev_weighting;
548 let shifted_weights = self.shift_addressing(&gated_weights);
549 let final_weights = self.sharpen_addressing(&shifted_weights);
550 self.prev_weighting = final_weights.clone();
551 final_weights
552 }
553
554 fn content_addressing(&self, memory: &Array2<f32>) -> Array1<f32> {
555 let mut similarities = Array1::zeros(memory.nrows());
556 for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
557 similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
558 }
559 let scaled = similarities.map(|&x| (x * self.key_strength).exp());
560 let sum = scaled.sum();
561 if sum > 0.0 {
562 scaled / sum
563 } else {
564 Array1::zeros(memory.nrows())
565 }
566 }
567
568 fn shift_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
569 let memory_size = weights.len();
570 let shift_range = (self.shift_weights.len() - 1) / 2;
571 let mut shifted = Array1::zeros(memory_size);
572
573 for i in 0..memory_size {
574 for (j, &shift_weight) in self.shift_weights.iter().enumerate() {
575 let shift = j as i32 - shift_range as i32;
576 let shifted_idx = ((i as i32 + shift) % memory_size as i32 + memory_size as i32)
577 % memory_size as i32;
578 shifted[shifted_idx as usize] += weights[i] * shift_weight;
579 }
580 }
581 shifted
582 }
583
584 fn sharpen_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
585 let sharpened = weights.map(|&x| x.powf(self.gamma));
586 let sum = sharpened.sum();
587 if sum > 0.0 {
588 sharpened / sum
589 } else {
590 Array1::zeros(weights.len())
591 }
592 }
593}
594
595pub struct NeuralTuringMachine {
597 pub(crate) config: NTMConfig,
598 pub(crate) controller: ControllerNetwork,
599 pub(crate) memory: Array2<f32>,
600 pub(crate) read_heads: Vec<NTMHead>,
601 pub(crate) write_heads: Vec<NTMHead>,
602}
603
604impl NeuralTuringMachine {
605 pub fn new(config: NTMConfig) -> Self {
606 let memory = Array2::zeros((config.memory_size, config.memory_width));
607 let controller = ControllerNetwork::new(
608 config.memory_width + config.num_heads * config.memory_width,
609 config.controller_size,
610 config.memory_width
611 + config.num_heads * (config.memory_width + 3 + 2 * config.shift_range + 1),
612 );
613
614 let read_heads = (0..config.num_heads)
615 .map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
616 .collect();
617
618 let write_heads = (0..config.num_heads)
619 .map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
620 .collect();
621
622 Self {
623 config,
624 controller,
625 memory,
626 read_heads,
627 write_heads,
628 }
629 }
630
631 pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
633 let mut read_vectors = Vec::new();
634 for read_head in &mut self.read_heads {
635 let weighting = read_head.address(&self.memory);
636 let read_vector = self.memory.t().dot(&weighting);
637 read_vectors.push(read_vector);
638 }
639
640 let mut controller_input = input.clone();
641 for read_vector in &read_vectors {
642 let views: &[_] = &[controller_input.view(), read_vector.view()];
643 controller_input = ndarray_concatenate(Axis(0), views)
644 .map_err(|e| anyhow!("concatenate failed: {}", e))?;
645 }
646
647 let controller_output = self.controller.forward(&controller_input);
648 Ok(controller_output)
649 }
650}
651
652pub(crate) fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
654 let dot_product = a.dot(b);
655 let norm_a = a.mapv(|x| x * x).sum().sqrt();
656 let norm_b = b.mapv(|x| x * x).sum().sqrt();
657 if norm_a > 0.0 && norm_b > 0.0 {
658 dot_product / (norm_a * norm_b)
659 } else {
660 0.0
661 }
662}