oxirs_embed/
memory_nets_tests.rs1#[cfg(test)]
4mod tests {
5 use scirs2_core::ndarray_ext::{Array1, Array2};
6
7 use crate::memory_nets_controller::{
8 ControllerNetwork, DNCConfig, DifferentiableNeuralComputer, ReadHead, WriteHead,
9 };
10 use crate::memory_nets_ops::{
11 EpisodicConfig, EpisodicMemory, MemoryAugmentedNetwork, MemoryConfig, MemoryNetworks,
12 MemoryNetworksConfig, SparseAccessMemory, SparseConfig,
13 };
14
15 #[tokio::test]
16 async fn test_memory_augmented_network_creation() {
17 let config = MemoryConfig::default();
18 let network = MemoryAugmentedNetwork::new(config);
19 assert!(network.is_ok());
20 }
21
22 #[tokio::test]
23 async fn test_dnc_forward_pass() {
24 let config = DNCConfig::default();
25 let mut dnc = DifferentiableNeuralComputer::new(config);
26 let input = Array1::zeros(64);
27
28 let result = dnc.forward(&input);
29 assert!(result.is_ok());
30 }
31
32 #[tokio::test]
33 async fn test_memory_networks_store_and_query() {
34 let config = MemoryNetworksConfig::default();
35 let mut memory_net = MemoryNetworks::new(config);
36
37 let embedding = Array1::ones(128);
38 let result = memory_net.store_memory("test content".to_string(), embedding.clone());
39 assert!(result.is_ok());
40
41 let query_result = memory_net.query(&embedding);
42 assert!(query_result.is_ok());
43 }
44
45 #[tokio::test]
46 async fn test_episodic_memory() {
47 let config = EpisodicConfig::default();
48 let mut episodic = EpisodicMemory::new(config);
49
50 episodic.start_episode("test".to_string());
51
52 let state = Array1::ones(128);
53 let result = episodic.add_state(state, 1.0);
54 assert!(result.is_ok());
55
56 let end_result = episodic.end_episode(true);
57 assert!(end_result.is_ok());
58 }
59
60 #[tokio::test]
61 async fn test_sparse_memory() {
62 let config = SparseConfig::default();
63 let mut sparse = SparseAccessMemory::new(config);
64
65 let value = Array1::ones(512);
66 let store_result = sparse.store(123, value.clone());
67 assert!(store_result.is_ok());
68
69 let retrieved = sparse.retrieve(123);
70 assert!(retrieved.is_some());
71
72 let similar = sparse.find_similar(&value, 1);
73 assert_eq!(similar.len(), 1);
74 }
75
76 #[test]
77 fn test_controller_network() {
78 let mut controller = ControllerNetwork::new(100, 256, 128);
79 let input = Array1::zeros(100);
80
81 let output = controller.forward(&input);
82 assert_eq!(output.len(), 128);
83 }
84
85 #[test]
86 fn test_read_head_weighting() {
87 let read_head = ReadHead::new(64);
88 let memory = Array2::zeros((128, 64));
89 let link_matrix = Array2::zeros((128, 128));
90 let prev_weighting = Array1::zeros(128);
91
92 let weighting = read_head.generate_weighting(&memory, &link_matrix, &prev_weighting);
93 assert_eq!(weighting.len(), 128);
94
95 let sum = weighting.sum();
96 assert!((sum - 1.0).abs() < 1e-6 || sum == 0.0);
97 }
98
99 #[test]
100 fn test_write_head_operations() {
101 let write_head = WriteHead::new(64);
102 let memory = Array2::zeros((128, 64));
103 let usage_vector = Array1::zeros(128);
104
105 let weighting = write_head.generate_weighting(&memory, &usage_vector);
106 assert_eq!(weighting.len(), 128);
107 }
108}