Skip to main content

oxirs_embed/
memory_nets_tests.rs

1//! Tests for memory-augmented network components.
2
3#[cfg(test)]
4mod tests {
5    use scirs2_core::ndarray_ext::{Array1, Array2};
6
7    use crate::memory_nets_controller::{
8        ControllerNetwork, DNCConfig, DifferentiableNeuralComputer, ReadHead, WriteHead,
9    };
10    use crate::memory_nets_ops::{
11        EpisodicConfig, EpisodicMemory, MemoryAugmentedNetwork, MemoryConfig, MemoryNetworks,
12        MemoryNetworksConfig, SparseAccessMemory, SparseConfig,
13    };
14
15    #[tokio::test]
16    async fn test_memory_augmented_network_creation() {
17        let config = MemoryConfig::default();
18        let network = MemoryAugmentedNetwork::new(config);
19        assert!(network.is_ok());
20    }
21
22    #[tokio::test]
23    async fn test_dnc_forward_pass() {
24        let config = DNCConfig::default();
25        let mut dnc = DifferentiableNeuralComputer::new(config);
26        let input = Array1::zeros(64);
27
28        let result = dnc.forward(&input);
29        assert!(result.is_ok());
30    }
31
32    #[tokio::test]
33    async fn test_memory_networks_store_and_query() {
34        let config = MemoryNetworksConfig::default();
35        let mut memory_net = MemoryNetworks::new(config);
36
37        let embedding = Array1::ones(128);
38        let result = memory_net.store_memory("test content".to_string(), embedding.clone());
39        assert!(result.is_ok());
40
41        let query_result = memory_net.query(&embedding);
42        assert!(query_result.is_ok());
43    }
44
45    #[tokio::test]
46    async fn test_episodic_memory() {
47        let config = EpisodicConfig::default();
48        let mut episodic = EpisodicMemory::new(config);
49
50        episodic.start_episode("test".to_string());
51
52        let state = Array1::ones(128);
53        let result = episodic.add_state(state, 1.0);
54        assert!(result.is_ok());
55
56        let end_result = episodic.end_episode(true);
57        assert!(end_result.is_ok());
58    }
59
60    #[tokio::test]
61    async fn test_sparse_memory() {
62        let config = SparseConfig::default();
63        let mut sparse = SparseAccessMemory::new(config);
64
65        let value = Array1::ones(512);
66        let store_result = sparse.store(123, value.clone());
67        assert!(store_result.is_ok());
68
69        let retrieved = sparse.retrieve(123);
70        assert!(retrieved.is_some());
71
72        let similar = sparse.find_similar(&value, 1);
73        assert_eq!(similar.len(), 1);
74    }
75
76    #[test]
77    fn test_controller_network() {
78        let mut controller = ControllerNetwork::new(100, 256, 128);
79        let input = Array1::zeros(100);
80
81        let output = controller.forward(&input);
82        assert_eq!(output.len(), 128);
83    }
84
85    #[test]
86    fn test_read_head_weighting() {
87        let read_head = ReadHead::new(64);
88        let memory = Array2::zeros((128, 64));
89        let link_matrix = Array2::zeros((128, 128));
90        let prev_weighting = Array1::zeros(128);
91
92        let weighting = read_head.generate_weighting(&memory, &link_matrix, &prev_weighting);
93        assert_eq!(weighting.len(), 128);
94
95        let sum = weighting.sum();
96        assert!((sum - 1.0).abs() < 1e-6 || sum == 0.0);
97    }
98
99    #[test]
100    fn test_write_head_operations() {
101        let write_head = WriteHead::new(64);
102        let memory = Array2::zeros((128, 64));
103        let usage_vector = Array1::zeros(128);
104
105        let weighting = write_head.generate_weighting(&memory, &usage_vector);
106        assert_eq!(weighting.len(), 128);
107    }
108}