Skip to main content

hyperi_rustlib/transport/memory/
mod.rs

1// Project:   hyperi-rustlib
2// File:      src/transport/memory/mod.rs
3// Purpose:   In-memory transport using tokio channels
4// Language:  Rust
5//
6// License:   BUSL-1.1
7// Copyright: (c) 2026 HYPERI PTY LIMITED
8
9//! # Memory Transport
10//!
11//! In-memory transport using tokio channels for unit testing.
12//! No persistence, same-process only.
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use hyperi_rustlib::transport::{MemoryTransport, MemoryConfig, Transport};
18//!
19//! let config = MemoryConfig::default();
20//! let transport = MemoryTransport::new(&config).expect("memory transport with valid config must construct");
21//!
22//! // In tests, you can also get a sender handle
23//! let sender = transport.sender();
24//! sender.send(b"test payload".to_vec()).await?;
25//!
26//! let messages = transport.recv(10).await?.messages;
27//! assert_eq!(messages.len(), 1);
28//! ```
29
30mod token;
31
32pub use token::MemoryToken;
33
34use super::error::{TransportError, TransportResult};
35use super::traits::{RecvBatch, TransportBase, TransportReceiver, TransportSender};
36use super::types::{Message, PayloadFormat, SendResult};
37use serde::{Deserialize, Serialize};
38use std::sync::Arc;
39use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
40use tokio::sync::mpsc;
41
42/// Configuration for memory transport.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct MemoryConfig {
45    /// Channel buffer size.
46    #[serde(default = "default_buffer_size")]
47    pub buffer_size: usize,
48
49    /// Receive timeout in milliseconds (0 = no wait, return immediately).
50    #[serde(default)]
51    pub recv_timeout_ms: u64,
52
53    /// Inbound message filters (applied on recv before caller sees messages).
54    #[serde(default)]
55    pub filters_in: Vec<super::filter::FilterRule>,
56
57    /// Outbound message filters (applied on send before transport dispatches).
58    #[serde(default)]
59    pub filters_out: Vec<super::filter::FilterRule>,
60}
61
62fn default_buffer_size() -> usize {
63    1000
64}
65
66impl Default for MemoryConfig {
67    fn default() -> Self {
68        Self {
69            buffer_size: default_buffer_size(),
70            recv_timeout_ms: 0,
71            filters_in: Vec::new(),
72            filters_out: Vec::new(),
73        }
74    }
75}
76
77/// Internal message type for the channel.
78struct InternalMessage {
79    key: Option<Arc<str>>,
80    payload: Vec<u8>,
81    seq: u64,
82    timestamp_ms: i64,
83}
84
85/// In-memory transport using tokio channels.
86///
87/// Primarily for unit testing - no persistence, same-process only.
88pub struct MemoryTransport {
89    sender: mpsc::Sender<InternalMessage>,
90    receiver: tokio::sync::Mutex<mpsc::Receiver<InternalMessage>>,
91    sequence: AtomicU64,
92    committed_seq: AtomicU64,
93    closed: AtomicBool,
94    recv_timeout_ms: u64,
95    filter_engine: super::filter::TransportFilterEngine,
96}
97
98impl MemoryTransport {
99    /// Create a new memory transport.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`TransportError`] when any inbound/outbound filter rule
104    /// fails to compile. Previously this produced a `tracing::warn!` and
105    /// silently substituted an empty filter engine; that fail-open
106    /// behaviour hid real misconfiguration (a filter that should have
107    /// blocked traffic would instead let every message through), so the
108    /// constructor now propagates the error to the caller.
109    pub fn new(config: &MemoryConfig) -> super::error::TransportResult<Self> {
110        let (sender, receiver) = mpsc::channel(config.buffer_size);
111        let filter_engine = super::filter::TransportFilterEngine::new(
112            &config.filters_in,
113            &config.filters_out,
114            &crate::transport::filter::TransportFilterTierConfig::from_cascade(),
115        )?;
116        Ok(Self {
117            sender,
118            receiver: tokio::sync::Mutex::new(receiver),
119            sequence: AtomicU64::new(0),
120            committed_seq: AtomicU64::new(0),
121            closed: AtomicBool::new(false),
122            recv_timeout_ms: config.recv_timeout_ms,
123            filter_engine,
124        })
125    }
126
127    /// Get a sender handle for injecting test messages.
128    ///
129    /// This is useful in tests to send messages without going through
130    /// the Transport trait.
131    #[must_use]
132    pub fn sender(&self) -> MemorySender<'_> {
133        MemorySender {
134            sender: self.sender.clone(),
135            sequence: &self.sequence,
136        }
137    }
138
139    /// Send a message directly (bypasses Transport trait).
140    ///
141    /// # Errors
142    ///
143    /// Returns error if the channel is full or closed.
144    pub async fn inject(&self, key: Option<&str>, payload: Vec<u8>) -> TransportResult<()> {
145        if self.closed.load(Ordering::Relaxed) {
146            return Err(TransportError::Closed);
147        }
148
149        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
150        let timestamp_ms = chrono::Utc::now().timestamp_millis();
151
152        let msg = InternalMessage {
153            key: key.map(Arc::from),
154            payload,
155            seq,
156            timestamp_ms,
157        };
158
159        self.sender
160            .send(msg)
161            .await
162            .map_err(|_| TransportError::Send("channel closed".into()))
163    }
164
165    /// Get the current committed sequence number.
166    #[must_use]
167    pub fn committed_sequence(&self) -> u64 {
168        self.committed_seq.load(Ordering::Relaxed)
169    }
170}
171
172/// Sender handle for injecting test messages.
173pub struct MemorySender<'a> {
174    sender: mpsc::Sender<InternalMessage>,
175    sequence: &'a AtomicU64,
176}
177
178impl MemorySender<'_> {
179    /// Send a payload with optional key.
180    ///
181    /// # Errors
182    ///
183    /// Returns error if the channel is full or closed.
184    pub async fn send(&self, key: Option<&str>, payload: Vec<u8>) -> TransportResult<()> {
185        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
186        let timestamp_ms = chrono::Utc::now().timestamp_millis();
187
188        let msg = InternalMessage {
189            key: key.map(Arc::from),
190            payload,
191            seq,
192            timestamp_ms,
193        };
194
195        self.sender
196            .send(msg)
197            .await
198            .map_err(|_| TransportError::Send("channel closed".into()))
199    }
200}
201
202impl TransportBase for MemoryTransport {
203    async fn close(&self) -> TransportResult<()> {
204        self.closed.store(true, Ordering::Relaxed);
205        Ok(())
206    }
207
208    fn is_healthy(&self) -> bool {
209        !self.closed.load(Ordering::Relaxed)
210    }
211
212    fn name(&self) -> &'static str {
213        "memory"
214    }
215}
216
217impl TransportSender for MemoryTransport {
218    async fn send(&self, key: &str, payload: bytes::Bytes) -> SendResult {
219        if self.closed.load(Ordering::Relaxed) {
220            return SendResult::Fatal(TransportError::Closed);
221        }
222
223        // Outbound filter check
224        if self.filter_engine.has_outbound_filters() {
225            match self.filter_engine.apply_outbound(&payload) {
226                super::filter::FilterDisposition::Pass => {}
227                super::filter::FilterDisposition::Drop => return SendResult::Ok,
228                super::filter::FilterDisposition::Dlq => return SendResult::FilteredDlq,
229            }
230        }
231
232        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
233        let timestamp_ms = chrono::Utc::now().timestamp_millis();
234
235        let msg = InternalMessage {
236            key: Some(Arc::from(key)),
237            payload: payload.to_vec(),
238            seq,
239            timestamp_ms,
240        };
241
242        match self.sender.try_send(msg) {
243            Ok(()) => SendResult::Ok,
244            Err(mpsc::error::TrySendError::Full(_)) => SendResult::Backpressured,
245            Err(mpsc::error::TrySendError::Closed(_)) => SendResult::Fatal(TransportError::Closed),
246        }
247    }
248}
249
250impl TransportReceiver for MemoryTransport {
251    type Token = MemoryToken;
252
253    async fn recv(&self, max: usize) -> TransportResult<RecvBatch<Self::Token>> {
254        if self.closed.load(Ordering::Relaxed) {
255            return Err(TransportError::Closed);
256        }
257
258        let mut receiver = self.receiver.lock().await;
259        let mut messages = Vec::with_capacity(max.min(100));
260
261        for _ in 0..max {
262            let result = if self.recv_timeout_ms == 0 {
263                match receiver.try_recv() {
264                    Ok(msg) => Some(msg),
265                    Err(mpsc::error::TryRecvError::Empty) => break,
266                    Err(mpsc::error::TryRecvError::Disconnected) => {
267                        return Err(TransportError::Closed);
268                    }
269                }
270            } else if messages.is_empty() {
271                match tokio::time::timeout(
272                    std::time::Duration::from_millis(self.recv_timeout_ms),
273                    receiver.recv(),
274                )
275                .await
276                {
277                    Ok(Some(msg)) => Some(msg),
278                    Ok(None) => return Err(TransportError::Closed),
279                    Err(_) => break,
280                }
281            } else {
282                match receiver.try_recv() {
283                    Ok(msg) => Some(msg),
284                    Err(_) => break,
285                }
286            };
287
288            if let Some(internal) = result {
289                let format = PayloadFormat::detect(&internal.payload);
290                messages.push(Message {
291                    key: internal.key,
292                    payload: internal.payload,
293                    token: MemoryToken { seq: internal.seq },
294                    timestamp_ms: Some(internal.timestamp_ms),
295                    format,
296                });
297            }
298        }
299
300        // Apply inbound filters via the shared partition helper; DLQ entries
301        // are returned in the RecvBatch for the caller to route onward.
302        let batch = self.filter_engine.partition_batch(
303            messages,
304            |m| m.payload.as_slice(),
305            |m| m.key.clone(),
306        );
307        let messages = batch.messages;
308        let dlq_entries = batch.dlq_entries;
309
310        Ok(RecvBatch {
311            messages,
312            dlq_entries,
313        })
314    }
315
316    async fn commit(&self, tokens: &[Self::Token]) -> TransportResult<()> {
317        if let Some(max_seq) = tokens.iter().map(|t| t.seq).max() {
318            let _ = self.committed_seq.fetch_max(max_seq, Ordering::Relaxed);
319        }
320        Ok(())
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[tokio::test]
329    async fn send_and_receive() {
330        let config = MemoryConfig::default();
331        let transport = MemoryTransport::new(&config)
332            .expect("memory transport with valid config must construct");
333        // Send a message
334        let result = transport
335            .send("test-key", bytes::Bytes::from_static(b"hello world"))
336            .await;
337        assert!(result.is_ok());
338
339        // Receive it
340        let messages = transport.recv(10).await.unwrap().messages;
341        assert_eq!(messages.len(), 1);
342        assert_eq!(messages[0].key.as_deref(), Some("test-key"));
343        assert_eq!(messages[0].payload, b"hello world");
344    }
345
346    #[tokio::test]
347    async fn inject_messages() {
348        let config = MemoryConfig::default();
349        let transport = MemoryTransport::new(&config)
350            .expect("memory transport with valid config must construct");
351        // Inject test messages
352        transport
353            .inject(Some("key1"), b"msg1".to_vec())
354            .await
355            .unwrap();
356        transport
357            .inject(Some("key2"), b"msg2".to_vec())
358            .await
359            .unwrap();
360
361        // Receive them
362        let messages = transport.recv(10).await.unwrap().messages;
363        assert_eq!(messages.len(), 2);
364    }
365
366    #[tokio::test]
367    async fn commit_advances_sequence() {
368        let config = MemoryConfig::default();
369        let transport = MemoryTransport::new(&config)
370            .expect("memory transport with valid config must construct");
371        transport.inject(None, b"msg".to_vec()).await.unwrap();
372        let messages = transport.recv(1).await.unwrap().messages;
373
374        // Commit the message
375        let tokens: Vec<_> = messages.iter().map(|m| m.token).collect();
376        transport.commit(&tokens).await.unwrap();
377
378        // Verify committed sequence advanced
379        assert_eq!(transport.committed_sequence(), 0);
380    }
381
382    #[tokio::test]
383    async fn close_prevents_operations() {
384        let config = MemoryConfig::default();
385        let transport = MemoryTransport::new(&config)
386            .expect("memory transport with valid config must construct");
387        transport.close().await.unwrap();
388        assert!(!transport.is_healthy());
389
390        // Send should fail
391        let result = transport
392            .send("key", bytes::Bytes::from_static(b"data"))
393            .await;
394        assert!(result.is_fatal());
395
396        // Recv should fail
397        let result = transport.recv(1).await;
398        assert!(result.is_err());
399    }
400
401    #[tokio::test]
402    async fn backpressure_on_full_channel() {
403        let config = MemoryConfig {
404            buffer_size: 1,
405            recv_timeout_ms: 0,
406            ..Default::default()
407        };
408        let transport = MemoryTransport::new(&config)
409            .expect("memory transport with valid config must construct");
410
411        // Fill the channel
412        let result1 = transport
413            .send("key", bytes::Bytes::from_static(b"msg1"))
414            .await;
415        assert!(result1.is_ok());
416
417        // Next send should backpressure
418        let result2 = transport
419            .send("key", bytes::Bytes::from_static(b"msg2"))
420            .await;
421        assert!(result2.is_backpressured());
422    }
423}