Skip to main content

oxirs_embed/
memory_nets_controller.rs

1//! Controller networks for Memory-Augmented Networks (DNC and NTM).
2//!
3//! This module contains:
4//! - DNC configuration and implementation
5//! - NTM configuration and implementation
6//! - Controller network (LSTM-style)
7//! - Read and write heads for DNC
8//! - Memory addressing sub-systems (allocation, temporal linkage, usage tracker)
9//! - NTM heads with content/shift/sharpen addressing
10
11use anyhow::{anyhow, Result};
12use scirs2_core::ndarray::concatenate as ndarray_concatenate;
13use scirs2_core::ndarray_ext::{s, Array1, Array2, Axis};
14use serde::{Deserialize, Serialize};
15
16/// DNC Configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct DNCConfig {
19    /// Number of memory slots
20    pub memory_size: usize,
21    /// Size of each memory slot
22    pub memory_width: usize,
23    /// Number of read heads
24    pub num_read_heads: usize,
25    /// Controller network size
26    pub controller_size: usize,
27    /// Output size
28    pub output_size: usize,
29    /// Learning rate for memory operations
30    pub memory_learning_rate: f32,
31    /// Memory decay factor
32    pub memory_decay: f32,
33}
34
35impl Default for DNCConfig {
36    fn default() -> Self {
37        Self {
38            memory_size: 256,
39            memory_width: 64,
40            num_read_heads: 4,
41            controller_size: 512,
42            output_size: 256,
43            memory_learning_rate: 0.001,
44            memory_decay: 0.95,
45        }
46    }
47}
48
49/// Controller network for DNC (LSTM-style)
50pub struct ControllerNetwork {
51    /// Input to hidden weights
52    pub(crate) w_ih: Array2<f32>,
53    /// Hidden to hidden weights
54    pub(crate) w_hh: Array2<f32>,
55    /// Hidden to output weights
56    pub(crate) w_ho: Array2<f32>,
57    /// Bias vectors
58    pub(crate) bias_h: Array1<f32>,
59    pub(crate) bias_o: Array1<f32>,
60    /// Hidden state
61    pub(crate) hidden_state: Array1<f32>,
62    /// Cell state (for LSTM)
63    pub(crate) cell_state: Array1<f32>,
64}
65
66impl ControllerNetwork {
67    pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
68        use scirs2_core::random::Random;
69        let mut rng = Random::default();
70
71        let w_ih =
72            Array2::from_shape_fn((hidden_size, input_size), |_| rng.random_range(-0.1..0.1));
73        let w_hh =
74            Array2::from_shape_fn((hidden_size, hidden_size), |_| rng.random_range(-0.1..0.1));
75        let w_ho =
76            Array2::from_shape_fn((output_size, hidden_size), |_| rng.random_range(-0.1..0.1));
77        let bias_h = Array1::zeros(hidden_size);
78        let bias_o = Array1::zeros(output_size);
79        let hidden_state = Array1::zeros(hidden_size);
80        let cell_state = Array1::zeros(hidden_size);
81
82        Self {
83            w_ih,
84            w_hh,
85            w_ho,
86            bias_h,
87            bias_o,
88            hidden_state,
89            cell_state,
90        }
91    }
92
93    /// Forward pass through controller (LSTM-style computation)
94    pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
95        let input_gate = self
96            .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
97        let forget_gate = self
98            .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
99        let cell_gate =
100            self.tanh(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
101        let output_gate = self
102            .sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
103
104        self.cell_state = &forget_gate * &self.cell_state + &input_gate * &cell_gate;
105        self.hidden_state = &output_gate * &self.tanh(&self.cell_state);
106
107        self.w_ho.dot(&self.hidden_state) + &self.bias_o
108    }
109
110    fn sigmoid(&self, x: &Array1<f32>) -> Array1<f32> {
111        x.map(|&v| 1.0 / (1.0 + (-v).exp()))
112    }
113
114    fn tanh(&self, x: &Array1<f32>) -> Array1<f32> {
115        x.map(|&v| v.tanh())
116    }
117}
118
119/// Read head for DNC
120pub struct ReadHead {
121    /// Key vector for content-based addressing
122    pub(crate) key: Array1<f32>,
123    /// Key strength
124    pub(crate) key_strength: f32,
125    /// Free gates for memory deallocation
126    pub(crate) free_gates: Array1<f32>,
127    /// Read modes (backward, forward, content lookup)
128    pub(crate) read_modes: Array1<f32>,
129}
130
131impl ReadHead {
132    pub fn new(memory_width: usize) -> Self {
133        Self {
134            key: Array1::zeros(memory_width),
135            key_strength: 1.0,
136            free_gates: Array1::zeros(memory_width),
137            read_modes: Array1::from_vec(vec![1.0, 0.0, 0.0]),
138        }
139    }
140
141    /// Generate read weighting using content-based + temporal addressing
142    pub fn generate_weighting(
143        &self,
144        memory: &Array2<f32>,
145        link_matrix: &Array2<f32>,
146        prev_read_weighting: &Array1<f32>,
147    ) -> Array1<f32> {
148        let content_weighting = self.content_lookup(memory);
149        let forward_weighting = link_matrix.dot(prev_read_weighting);
150        let backward_weighting = link_matrix.t().dot(prev_read_weighting);
151
152        let combined_weighting = self.read_modes[0] * &backward_weighting
153            + self.read_modes[1] * &content_weighting
154            + self.read_modes[2] * &forward_weighting;
155
156        let sum = combined_weighting.sum();
157        if sum > 0.0 {
158            combined_weighting / sum
159        } else {
160            Array1::zeros(memory.nrows())
161        }
162    }
163
164    fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
165        let mut similarities = Array1::zeros(memory.nrows());
166        for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
167            similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
168        }
169        let scaled = similarities.map(|&x| (x * self.key_strength).exp());
170        let sum = scaled.sum();
171        if sum > 0.0 {
172            scaled / sum
173        } else {
174            Array1::zeros(memory.nrows())
175        }
176    }
177}
178
179/// Write head for DNC
180pub struct WriteHead {
181    pub(crate) key: Array1<f32>,
182    pub(crate) key_strength: f32,
183    pub(crate) erase_vector: Array1<f32>,
184    pub(crate) write_vector: Array1<f32>,
185    pub(crate) allocation_gate: f32,
186    pub(crate) write_gate: f32,
187}
188
189impl WriteHead {
190    pub fn new(memory_width: usize) -> Self {
191        Self {
192            key: Array1::zeros(memory_width),
193            key_strength: 1.0,
194            erase_vector: Array1::zeros(memory_width),
195            write_vector: Array1::zeros(memory_width),
196            allocation_gate: 0.0,
197            write_gate: 1.0,
198        }
199    }
200
201    /// Generate write weighting combining content-based and allocation
202    pub fn generate_weighting(
203        &self,
204        memory: &Array2<f32>,
205        usage_vector: &Array1<f32>,
206    ) -> Array1<f32> {
207        let content_weighting = self.content_lookup(memory);
208        let allocation_weighting = self.allocation_lookup(usage_vector);
209
210        self.write_gate
211            * (self.allocation_gate * allocation_weighting
212                + (1.0 - self.allocation_gate) * content_weighting)
213    }
214
215    fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
216        let mut similarities = Array1::zeros(memory.nrows());
217        for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
218            similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
219        }
220        let scaled = similarities.map(|&x| (x * self.key_strength).exp());
221        let sum = scaled.sum();
222        if sum > 0.0 {
223            scaled / sum
224        } else {
225            Array1::zeros(memory.nrows())
226        }
227    }
228
229    fn allocation_lookup(&self, usage_vector: &Array1<f32>) -> Array1<f32> {
230        let mut indices: Vec<usize> = (0..usage_vector.len()).collect();
231        indices.sort_by(|&a, &b| {
232            usage_vector[a]
233                .partial_cmp(&usage_vector[b])
234                .unwrap_or(std::cmp::Ordering::Equal)
235        });
236
237        let mut allocation = Array1::zeros(usage_vector.len());
238        for (rank, &idx) in indices.iter().enumerate() {
239            allocation[idx] = 1.0 / (rank as f32 + 1.0);
240        }
241        let sum = allocation.sum();
242        if sum > 0.0 {
243            allocation / sum
244        } else {
245            Array1::zeros(usage_vector.len())
246        }
247    }
248
249    /// Erase-then-write to memory
250    pub fn write_to_memory(&self, memory: &mut Array2<f32>, weighting: &Array1<f32>) {
251        for i in 0..memory.nrows() {
252            for j in 0..memory.ncols() {
253                memory[[i, j]] *= 1.0 - weighting[i] * self.erase_vector[j];
254            }
255        }
256        for i in 0..memory.nrows() {
257            for j in 0..memory.ncols() {
258                memory[[i, j]] += weighting[i] * self.write_vector[j];
259            }
260        }
261    }
262}
263
264/// Usage tracking for memory allocation
265pub struct UsageTracker {
266    pub(crate) usage: Array1<f32>,
267    pub(crate) memory_size: usize,
268}
269
270impl UsageTracker {
271    pub fn new(memory_size: usize) -> Self {
272        Self {
273            usage: Array1::zeros(memory_size),
274            memory_size,
275        }
276    }
277
278    pub fn update(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
279        for i in 0..self.memory_size {
280            self.usage[i] = (self.usage[i] + write_weighting[i] - self.usage[i] * free_gates[i])
281                .clamp(0.0, 1.0);
282        }
283    }
284
285    pub fn get_allocation_weighting(&self, _allocation_gate: f32) -> Array1<f32> {
286        let mut sorted_indices: Vec<usize> = (0..self.memory_size).collect();
287        sorted_indices.sort_by(|&a, &b| {
288            self.usage[a]
289                .partial_cmp(&self.usage[b])
290                .unwrap_or(std::cmp::Ordering::Equal)
291        });
292
293        let mut weights = Array1::zeros(self.memory_size);
294        for (rank, &idx) in sorted_indices.iter().enumerate() {
295            weights[idx] = 1.0 / (rank as f32 + 1.0);
296        }
297        let sum = weights.sum();
298        if sum > 0.0 {
299            weights / sum
300        } else {
301            Array1::zeros(self.memory_size)
302        }
303    }
304}
305
306/// Allocation mechanism for finding free memory
307pub struct AllocationMechanism {
308    pub(crate) usage_tracker: UsageTracker,
309}
310
311impl AllocationMechanism {
312    pub fn new(memory_size: usize) -> Self {
313        Self {
314            usage_tracker: UsageTracker::new(memory_size),
315        }
316    }
317
318    pub fn allocate(&mut self, allocation_gate: f32) -> Array1<f32> {
319        self.usage_tracker.get_allocation_weighting(allocation_gate)
320    }
321
322    pub fn update_usage(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
323        self.usage_tracker.update(write_weighting, free_gates);
324    }
325}
326
327/// Temporal linkage for sequential memory access
328pub struct TemporalLinkage {
329    pub(crate) link_matrix: Array2<f32>,
330    pub(crate) precedence_weighting: Array1<f32>,
331}
332
333impl TemporalLinkage {
334    pub fn new(memory_size: usize) -> Self {
335        Self {
336            link_matrix: Array2::zeros((memory_size, memory_size)),
337            precedence_weighting: Array1::zeros(memory_size),
338        }
339    }
340
341    pub fn update(&mut self, write_weighting: &Array1<f32>) {
342        let sum = write_weighting.sum();
343        if sum > 0.0 {
344            self.precedence_weighting = (1.0 - sum) * &self.precedence_weighting + write_weighting;
345        }
346        for i in 0..self.link_matrix.nrows() {
347            for j in 0..self.link_matrix.ncols() {
348                if i != j {
349                    self.link_matrix[[i, j]] = (1.0 - write_weighting[i] - write_weighting[j])
350                        * self.link_matrix[[i, j]]
351                        + write_weighting[i] * self.precedence_weighting[j];
352                }
353            }
354        }
355    }
356
357    pub fn get_link_matrix(&self) -> &Array2<f32> {
358        &self.link_matrix
359    }
360}
361
362/// Memory addressing system
363pub struct MemoryAddressing {
364    pub(crate) allocation_mechanism: AllocationMechanism,
365    pub(crate) temporal_linkage: TemporalLinkage,
366}
367
368/// Differentiable Neural Computer implementation
369pub struct DifferentiableNeuralComputer {
370    pub(crate) config: DNCConfig,
371    pub(crate) controller: ControllerNetwork,
372    pub(crate) memory_matrix: Array2<f32>,
373    pub(crate) read_heads: Vec<ReadHead>,
374    pub(crate) write_head: WriteHead,
375    pub(crate) memory_addressing: MemoryAddressing,
376    pub(crate) usage_vector: Array1<f32>,
377    pub(crate) precedence_weights: Array1<f32>,
378    pub(crate) link_matrix: Array2<f32>,
379    pub(crate) read_weightings: Array2<f32>,
380    pub(crate) write_weighting: Array1<f32>,
381}
382
383impl DifferentiableNeuralComputer {
384    /// Create new DNC
385    pub fn new(config: DNCConfig) -> Self {
386        let memory_matrix = Array2::zeros((config.memory_size, config.memory_width));
387        let usage_vector = Array1::zeros(config.memory_size);
388        let precedence_weights = Array1::zeros(config.memory_size);
389        let link_matrix = Array2::zeros((config.memory_size, config.memory_size));
390        let read_weightings = Array2::zeros((config.num_read_heads, config.memory_size));
391        let write_weighting = Array1::zeros(config.memory_size);
392
393        let controller = ControllerNetwork::new(
394            config.memory_width + config.num_read_heads * config.memory_width,
395            config.controller_size,
396            config.output_size
397                + config.memory_width * (config.num_read_heads + 1)
398                + 3 * config.num_read_heads
399                + 5,
400        );
401
402        let read_heads = (0..config.num_read_heads)
403            .map(|_| ReadHead::new(config.memory_width))
404            .collect();
405
406        let write_head = WriteHead::new(config.memory_width);
407
408        let memory_addressing = MemoryAddressing {
409            allocation_mechanism: AllocationMechanism::new(config.memory_size),
410            temporal_linkage: TemporalLinkage::new(config.memory_size),
411        };
412
413        Self {
414            config,
415            controller,
416            memory_matrix,
417            read_heads,
418            write_head,
419            memory_addressing,
420            usage_vector,
421            precedence_weights,
422            link_matrix,
423            read_weightings,
424            write_weighting,
425        }
426    }
427
428    /// Forward pass through DNC
429    pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
430        let mut read_vectors = Vec::new();
431        for (i, read_head) in self.read_heads.iter().enumerate() {
432            let read_weighting = read_head.generate_weighting(
433                &self.memory_matrix,
434                &self.link_matrix,
435                &self.read_weightings.row(i).to_owned(),
436            );
437            let read_vector = self.memory_matrix.t().dot(&read_weighting);
438            read_vectors.push(read_vector);
439        }
440
441        let mut controller_input = input.clone();
442        for read_vector in &read_vectors {
443            let views: &[_] = &[controller_input.view(), read_vector.view()];
444            controller_input = ndarray_concatenate(Axis(0), views)
445                .map_err(|e| anyhow!("concatenate failed: {}", e))?;
446        }
447
448        let controller_output = self.controller.forward(&controller_input);
449        let (output, _interface_vector) = self.parse_controller_output(&controller_output)?;
450
451        let write_weighting = self
452            .write_head
453            .generate_weighting(&self.memory_matrix, &self.usage_vector);
454        self.write_head
455            .write_to_memory(&mut self.memory_matrix, &write_weighting);
456
457        let free_gates = Array1::ones(self.config.memory_size);
458        self.memory_addressing
459            .allocation_mechanism
460            .update_usage(&write_weighting, &free_gates);
461        self.memory_addressing
462            .temporal_linkage
463            .update(&write_weighting);
464
465        self.write_weighting = write_weighting;
466        self.link_matrix = self
467            .memory_addressing
468            .temporal_linkage
469            .get_link_matrix()
470            .clone();
471
472        Ok(output)
473    }
474
475    fn parse_controller_output(&self, output: &Array1<f32>) -> Result<(Array1<f32>, Array1<f32>)> {
476        if output.len() < self.config.output_size {
477            return Err(anyhow!("Controller output too short"));
478        }
479        let network_output = output.slice(s![..self.config.output_size]).to_owned();
480        let interface_vector = output.slice(s![self.config.output_size..]).to_owned();
481        Ok((network_output, interface_vector))
482    }
483
484    /// Reset memory state
485    pub fn reset(&mut self) {
486        self.memory_matrix.fill(0.0);
487        self.usage_vector.fill(0.0);
488        self.precedence_weights.fill(0.0);
489        self.link_matrix.fill(0.0);
490        self.read_weightings.fill(0.0);
491        self.write_weighting.fill(0.0);
492    }
493
494    /// Get memory utilization
495    pub fn get_memory_utilization(&self) -> f32 {
496        self.usage_vector.sum() / self.usage_vector.len() as f32
497    }
498}
499
500/// Neural Turing Machine configuration
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct NTMConfig {
503    pub memory_size: usize,
504    pub memory_width: usize,
505    pub num_heads: usize,
506    pub controller_size: usize,
507    pub shift_range: usize,
508}
509
510impl Default for NTMConfig {
511    fn default() -> Self {
512        Self {
513            memory_size: 128,
514            memory_width: 32,
515            num_heads: 2,
516            controller_size: 256,
517            shift_range: 3,
518        }
519    }
520}
521
522/// NTM Head (read or write)
523pub struct NTMHead {
524    pub(crate) key: Array1<f32>,
525    pub(crate) key_strength: f32,
526    pub(crate) gate: f32,
527    pub(crate) shift_weights: Array1<f32>,
528    pub(crate) gamma: f32,
529    pub(crate) prev_weighting: Array1<f32>,
530}
531
532impl NTMHead {
533    pub fn new(memory_width: usize, memory_size: usize, shift_range: usize) -> Self {
534        Self {
535            key: Array1::zeros(memory_width),
536            key_strength: 1.0,
537            gate: 0.5,
538            shift_weights: Array1::zeros(2 * shift_range + 1),
539            gamma: 1.0,
540            prev_weighting: Array1::zeros(memory_size),
541        }
542    }
543
544    /// Generate addressing weighting via content + gate + shift + sharpen
545    pub fn address(&mut self, memory: &Array2<f32>) -> Array1<f32> {
546        let content_weights = self.content_addressing(memory);
547        let gated_weights = self.gate * &content_weights + (1.0 - self.gate) * &self.prev_weighting;
548        let shifted_weights = self.shift_addressing(&gated_weights);
549        let final_weights = self.sharpen_addressing(&shifted_weights);
550        self.prev_weighting = final_weights.clone();
551        final_weights
552    }
553
554    fn content_addressing(&self, memory: &Array2<f32>) -> Array1<f32> {
555        let mut similarities = Array1::zeros(memory.nrows());
556        for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
557            similarities[i] = cosine_similarity(&self.key, &memory_row.to_owned());
558        }
559        let scaled = similarities.map(|&x| (x * self.key_strength).exp());
560        let sum = scaled.sum();
561        if sum > 0.0 {
562            scaled / sum
563        } else {
564            Array1::zeros(memory.nrows())
565        }
566    }
567
568    fn shift_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
569        let memory_size = weights.len();
570        let shift_range = (self.shift_weights.len() - 1) / 2;
571        let mut shifted = Array1::zeros(memory_size);
572
573        for i in 0..memory_size {
574            for (j, &shift_weight) in self.shift_weights.iter().enumerate() {
575                let shift = j as i32 - shift_range as i32;
576                let shifted_idx = ((i as i32 + shift) % memory_size as i32 + memory_size as i32)
577                    % memory_size as i32;
578                shifted[shifted_idx as usize] += weights[i] * shift_weight;
579            }
580        }
581        shifted
582    }
583
584    fn sharpen_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
585        let sharpened = weights.map(|&x| x.powf(self.gamma));
586        let sum = sharpened.sum();
587        if sum > 0.0 {
588            sharpened / sum
589        } else {
590            Array1::zeros(weights.len())
591        }
592    }
593}
594
595/// Neural Turing Machine implementation
596pub struct NeuralTuringMachine {
597    pub(crate) config: NTMConfig,
598    pub(crate) controller: ControllerNetwork,
599    pub(crate) memory: Array2<f32>,
600    pub(crate) read_heads: Vec<NTMHead>,
601    pub(crate) write_heads: Vec<NTMHead>,
602}
603
604impl NeuralTuringMachine {
605    pub fn new(config: NTMConfig) -> Self {
606        let memory = Array2::zeros((config.memory_size, config.memory_width));
607        let controller = ControllerNetwork::new(
608            config.memory_width + config.num_heads * config.memory_width,
609            config.controller_size,
610            config.memory_width
611                + config.num_heads * (config.memory_width + 3 + 2 * config.shift_range + 1),
612        );
613
614        let read_heads = (0..config.num_heads)
615            .map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
616            .collect();
617
618        let write_heads = (0..config.num_heads)
619            .map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
620            .collect();
621
622        Self {
623            config,
624            controller,
625            memory,
626            read_heads,
627            write_heads,
628        }
629    }
630
631    /// Forward pass through NTM
632    pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
633        let mut read_vectors = Vec::new();
634        for read_head in &mut self.read_heads {
635            let weighting = read_head.address(&self.memory);
636            let read_vector = self.memory.t().dot(&weighting);
637            read_vectors.push(read_vector);
638        }
639
640        let mut controller_input = input.clone();
641        for read_vector in &read_vectors {
642            let views: &[_] = &[controller_input.view(), read_vector.view()];
643            controller_input = ndarray_concatenate(Axis(0), views)
644                .map_err(|e| anyhow!("concatenate failed: {}", e))?;
645        }
646
647        let controller_output = self.controller.forward(&controller_input);
648        Ok(controller_output)
649    }
650}
651
652/// Shared cosine similarity helper
653pub(crate) fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
654    let dot_product = a.dot(b);
655    let norm_a = a.mapv(|x| x * x).sum().sqrt();
656    let norm_b = b.mapv(|x| x * x).sum().sqrt();
657    if norm_a > 0.0 && norm_b > 0.0 {
658        dot_product / (norm_a * norm_b)
659    } else {
660        0.0
661    }
662}