umi-memory 0.1.0

Memory library for AI agents with deterministic simulation testing
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
//! SimNetwork - Simulated Network with Fault Injection
//!
//! TigerStyle: Configurable network conditions with explicit fault injection.
//! Supports partitions, delays, packet loss, and message reordering.

use bytes::Bytes;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::RwLock;

use super::clock::SimClock;
use super::fault::{FaultInjector, FaultType};
use super::rng::DeterministicRng;
use crate::constants::{
    NETWORK_JITTER_MS_DEFAULT, NETWORK_LATENCY_MS_DEFAULT, NETWORK_LATENCY_MS_MAX,
};

/// A network message in flight.
#[derive(Debug, Clone)]
pub struct NetworkMessage {
    /// Source node ID
    pub from: String,
    /// Destination node ID
    pub to: String,
    /// Message payload
    pub payload: Bytes,
    /// Time when message should be delivered (ms)
    pub deliver_at_ms: u64,
}

/// Network errors.
#[derive(Debug, Clone, thiserror::Error)]
pub enum NetworkError {
    /// Message was dropped due to partition
    #[error("network partition between {from} and {to}")]
    Partitioned {
        /// Source node that tried to send
        from: String,
        /// Destination node that couldn't be reached
        to: String,
    },

    /// Message was dropped due to fault injection
    #[error("packet loss fault injected")]
    PacketLoss,

    /// Connection timed out
    #[error("connection timeout")]
    Timeout,

    /// Connection refused
    #[error("connection refused")]
    ConnectionRefused,
}

/// Simulated network for DST.
///
/// TigerStyle:
/// - Deterministic message delivery with configurable delays
/// - Explicit partitions with heal/partition API
/// - Fault injection at send/receive boundaries
/// - Full statistics tracking
pub struct SimNetwork {
    /// Pending messages per destination node
    messages: Arc<RwLock<HashMap<String, VecDeque<NetworkMessage>>>>,
    /// Network partitions (set of (from, to) pairs that are partitioned)
    partitions: Arc<RwLock<Vec<(String, String)>>>,
    /// Simulation clock
    clock: SimClock,
    /// Fault injector (shared)
    fault_injector: Arc<FaultInjector>,
    /// RNG for latency jitter (RefCell for interior mutability)
    rng: RefCell<DeterministicRng>,
    /// Base latency in milliseconds
    base_latency_ms: u64,
    /// Latency jitter in milliseconds
    latency_jitter_ms: u64,
}

impl SimNetwork {
    /// Create a new simulated network.
    ///
    /// TigerStyle: Takes shared fault injector for consistent fault injection.
    #[must_use]
    pub fn new(clock: SimClock, rng: DeterministicRng, fault_injector: Arc<FaultInjector>) -> Self {
        Self {
            messages: Arc::new(RwLock::new(HashMap::new())),
            partitions: Arc::new(RwLock::new(Vec::new())),
            clock,
            fault_injector,
            rng: RefCell::new(rng),
            base_latency_ms: NETWORK_LATENCY_MS_DEFAULT,
            latency_jitter_ms: NETWORK_JITTER_MS_DEFAULT,
        }
    }

    /// Set network latency parameters.
    ///
    /// # Panics
    /// Panics if base_ms exceeds NETWORK_LATENCY_MS_MAX.
    #[must_use]
    pub fn with_latency(mut self, base_ms: u64, jitter_ms: u64) -> Self {
        // Precondition
        assert!(
            base_ms <= NETWORK_LATENCY_MS_MAX,
            "base_latency_ms {} exceeds max {}",
            base_ms,
            NETWORK_LATENCY_MS_MAX
        );

        self.base_latency_ms = base_ms;
        self.latency_jitter_ms = jitter_ms;
        self
    }

    /// Send a message from one node to another.
    ///
    /// Returns true if message was queued, false if dropped (partition/fault).
    pub async fn send(&self, from: &str, to: &str, payload: Bytes) -> bool {
        // Preconditions
        assert!(!from.is_empty(), "from node ID cannot be empty");
        assert!(!to.is_empty(), "to node ID cannot be empty");

        // Check for network partition
        {
            let partitions = self.partitions.read().await;
            if partitions
                .iter()
                .any(|(a, b)| (a == from && b == to) || (a == to && b == from))
            {
                tracing::debug!(from = from, to = to, "Message dropped: network partition");
                return false;
            }
        }

        // Check for packet loss fault
        if let Some(fault) = self.fault_injector.should_inject("network_send") {
            match fault {
                FaultType::NetworkTimeout
                | FaultType::NetworkConnectionRefused
                | FaultType::NetworkReset => {
                    tracing::debug!(from = from, to = to, fault = ?fault, "Message dropped: fault");
                    return false;
                }
                _ => {}
            }
        }

        // Calculate delivery time with latency
        let latency = self.calculate_latency();
        let deliver_at_ms = self.clock.now_ms() + latency;

        let message = NetworkMessage {
            from: from.to_string(),
            to: to.to_string(),
            payload,
            deliver_at_ms,
        };

        // Queue the message
        let mut messages = self.messages.write().await;
        messages
            .entry(to.to_string())
            .or_default()
            .push_back(message);

        true
    }

    /// Receive messages for a node.
    ///
    /// Returns all messages that have arrived (delivery time <= current time).
    pub async fn receive(&self, node_id: &str) -> Vec<NetworkMessage> {
        // Precondition
        assert!(!node_id.is_empty(), "node_id cannot be empty");

        let current_time = self.clock.now_ms();
        let mut messages = self.messages.write().await;

        let queue = match messages.get_mut(node_id) {
            Some(q) => q,
            None => return Vec::new(),
        };

        // Collect messages ready for delivery
        let mut ready = Vec::new();
        let mut remaining = VecDeque::new();

        while let Some(msg) = queue.pop_front() {
            if msg.deliver_at_ms <= current_time {
                ready.push(msg);
            } else {
                remaining.push_back(msg);
            }
        }

        *queue = remaining;

        // Check for message reordering fault
        if !ready.is_empty() {
            if let Some(FaultType::NetworkPartialWrite) =
                self.fault_injector.should_inject("network_receive")
            {
                self.rng.borrow_mut().shuffle(&mut ready);
                tracing::debug!(node_id = node_id, "Messages reordered by fault");
            }
        }

        ready
    }

    /// Create a network partition between two nodes.
    ///
    /// Messages between these nodes will be dropped.
    pub async fn partition(&self, node_a: &str, node_b: &str) {
        // Preconditions
        assert!(!node_a.is_empty(), "node_a cannot be empty");
        assert!(!node_b.is_empty(), "node_b cannot be empty");
        assert_ne!(node_a, node_b, "cannot partition node with itself");

        let mut partitions = self.partitions.write().await;
        partitions.push((node_a.to_string(), node_b.to_string()));

        tracing::info!(
            node_a = node_a,
            node_b = node_b,
            "Network partition created"
        );
    }

    /// Heal a network partition between two nodes.
    pub async fn heal(&self, node_a: &str, node_b: &str) {
        let mut partitions = self.partitions.write().await;
        partitions.retain(|(a, b)| !((a == node_a && b == node_b) || (a == node_b && b == node_a)));

        tracing::info!(node_a = node_a, node_b = node_b, "Network partition healed");
    }

    /// Heal all network partitions.
    pub async fn heal_all(&self) {
        let mut partitions = self.partitions.write().await;
        partitions.clear();

        tracing::info!("All network partitions healed");
    }

    /// Check if two nodes are partitioned.
    pub async fn is_partitioned(&self, node_a: &str, node_b: &str) -> bool {
        let partitions = self.partitions.read().await;
        partitions
            .iter()
            .any(|(a, b)| (a == node_a && b == node_b) || (a == node_b && b == node_a))
    }

    /// Get count of pending messages for a node.
    pub async fn pending_count(&self, node_id: &str) -> usize {
        let messages = self.messages.read().await;
        messages.get(node_id).map(|q| q.len()).unwrap_or(0)
    }

    /// Get total pending messages across all nodes.
    pub async fn total_pending(&self) -> usize {
        let messages = self.messages.read().await;
        messages.values().map(|q| q.len()).sum()
    }

    /// Clear all pending messages.
    pub async fn clear(&self) {
        let mut messages = self.messages.write().await;
        messages.clear();
    }

    /// Get the clock.
    #[must_use]
    pub fn clock(&self) -> &SimClock {
        &self.clock
    }

    /// Calculate latency with jitter.
    fn calculate_latency(&self) -> u64 {
        let jitter = if self.latency_jitter_ms > 0 {
            self.rng
                .borrow_mut()
                .next_usize(0, self.latency_jitter_ms as usize) as u64
        } else {
            0
        };
        self.base_latency_ms + jitter
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::dst::fault::FaultInjectorBuilder;

    fn create_network() -> SimNetwork {
        let clock = SimClock::new();
        let mut rng = DeterministicRng::new(42);
        let fault_injector = Arc::new(FaultInjectorBuilder::new(rng.fork()).build());
        SimNetwork::new(clock, rng, fault_injector).with_latency(0, 0)
    }

    #[tokio::test]
    async fn test_send_and_receive() {
        let network = create_network();

        // Send message
        let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
        assert!(sent);

        // Receive message
        let messages = network.receive("node-2").await;
        assert_eq!(messages.len(), 1);
        assert_eq!(messages[0].payload, Bytes::from("hello"));
        assert_eq!(messages[0].from, "node-1");
        assert_eq!(messages[0].to, "node-2");
    }

    #[tokio::test]
    async fn test_partition() {
        let network = create_network();

        // Create partition
        network.partition("node-1", "node-2").await;
        assert!(network.is_partitioned("node-1", "node-2").await);
        assert!(network.is_partitioned("node-2", "node-1").await); // Symmetric

        // Message should be dropped
        let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
        assert!(!sent);

        // Heal partition
        network.heal("node-1", "node-2").await;
        assert!(!network.is_partitioned("node-1", "node-2").await);

        // Message should go through
        let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
        assert!(sent);
    }

    #[tokio::test]
    async fn test_latency() {
        let clock = SimClock::new();
        let mut rng = DeterministicRng::new(42);
        let fault_injector = Arc::new(FaultInjectorBuilder::new(rng.fork()).build());
        let network = SimNetwork::new(clock.clone(), rng, fault_injector).with_latency(100, 0);

        // Send message
        network.send("node-1", "node-2", Bytes::from("hello")).await;

        // Should not be delivered yet
        let messages = network.receive("node-2").await;
        assert!(messages.is_empty());

        // Advance time
        clock.advance_ms(100);

        // Now should be delivered
        let messages = network.receive("node-2").await;
        assert_eq!(messages.len(), 1);
    }

    #[tokio::test]
    async fn test_multiple_messages() {
        let network = create_network();

        // Send multiple messages
        network.send("node-1", "node-2", Bytes::from("msg1")).await;
        network.send("node-1", "node-2", Bytes::from("msg2")).await;
        network.send("node-3", "node-2", Bytes::from("msg3")).await;

        assert_eq!(network.pending_count("node-2").await, 3);
        assert_eq!(network.total_pending().await, 3);

        // Receive all
        let messages = network.receive("node-2").await;
        assert_eq!(messages.len(), 3);
        assert_eq!(network.pending_count("node-2").await, 0);
    }

    #[tokio::test]
    async fn test_heal_all() {
        let network = create_network();

        // Create multiple partitions
        network.partition("node-1", "node-2").await;
        network.partition("node-2", "node-3").await;
        network.partition("node-1", "node-3").await;

        assert!(network.is_partitioned("node-1", "node-2").await);
        assert!(network.is_partitioned("node-2", "node-3").await);

        // Heal all
        network.heal_all().await;

        assert!(!network.is_partitioned("node-1", "node-2").await);
        assert!(!network.is_partitioned("node-2", "node-3").await);
        assert!(!network.is_partitioned("node-1", "node-3").await);
    }

    #[tokio::test]
    async fn test_clear() {
        let network = create_network();

        network.send("node-1", "node-2", Bytes::from("msg1")).await;
        network.send("node-1", "node-2", Bytes::from("msg2")).await;

        assert_eq!(network.total_pending().await, 2);

        network.clear().await;

        assert_eq!(network.total_pending().await, 0);
    }

    #[test]
    #[should_panic(expected = "from node ID cannot be empty")]
    fn test_send_empty_from() {
        let network = create_network();
        let _ = tokio_test::block_on(network.send("", "node-2", Bytes::from("hello")));
    }

    #[test]
    #[should_panic(expected = "cannot partition node with itself")]
    fn test_partition_self() {
        let network = create_network();
        let _ = tokio_test::block_on(network.partition("node-1", "node-1"));
    }
}