Skip to main content

durable_streams_server/storage/
memory.rs

1use super::{
2    CreateStreamResult, Message, NOTIFY_CHANNEL_CAPACITY, ProducerAppendResult, ProducerCheck,
3    ProducerState, ReadResult, Storage, StreamConfig, StreamMetadata,
4};
5use crate::protocol::error::{Error, Result};
6use crate::protocol::offset::Offset;
7use crate::protocol::producer::ProducerHeaders;
8use bytes::Bytes;
9use chrono::Utc;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, RwLock};
13use tokio::sync::broadcast;
14
15const INITIAL_MESSAGES_CAPACITY: usize = 256;
16const INITIAL_PRODUCERS_CAPACITY: usize = 8;
17
18/// Internal stream entry
19struct StreamEntry {
20    config: StreamConfig,
21    messages: Vec<Message>,
22    closed: bool,
23    next_read_seq: u64,
24    next_byte_offset: u64,
25    total_bytes: u64,
26    created_at: chrono::DateTime<Utc>,
27    /// Per-producer state for idempotent producer support
28    producers: HashMap<String, ProducerState>,
29    /// Broadcast sender for notifying long-poll/SSE subscribers
30    notify: broadcast::Sender<()>,
31    /// Last Stream-Seq value received (lexicographic ordering)
32    last_seq: Option<String>,
33}
34
35impl StreamEntry {
36    fn new(config: StreamConfig) -> Self {
37        // Stream starts open; the handler closes it after any initial appends.
38        // The `created_closed` flag in config is stored for idempotent checks only.
39        let (notify, _) = broadcast::channel(NOTIFY_CHANNEL_CAPACITY);
40        Self {
41            config,
42            messages: Vec::with_capacity(INITIAL_MESSAGES_CAPACITY),
43            closed: false,
44            next_read_seq: 0,
45            next_byte_offset: 0,
46            total_bytes: 0,
47            created_at: Utc::now(),
48            producers: HashMap::with_capacity(INITIAL_PRODUCERS_CAPACITY),
49            notify,
50            last_seq: None,
51        }
52    }
53}
54
55/// In-memory storage implementation
56///
57/// Thread-safe storage with:
58/// - `RwLock<HashMap>` for stream lookup (concurrent reads)
59/// - Per-stream `RwLock` for exclusive write access (offset monotonicity)
60/// - Memory limit enforcement (global and per-stream)
61///
62/// # Concurrency Model
63///
64/// Multiple readers can access different streams concurrently.
65/// Appends to the same stream are serialized via `RwLock::write()`.
66/// Appends to different streams can proceed concurrently.
67pub struct InMemoryStorage {
68    streams: RwLock<HashMap<String, Arc<RwLock<StreamEntry>>>>,
69    total_bytes: AtomicU64,
70    max_total_bytes: u64,
71    max_stream_bytes: u64,
72}
73
74impl InMemoryStorage {
75    /// Create a new in-memory storage with memory limits
76    #[must_use]
77    pub fn new(max_total_bytes: u64, max_stream_bytes: u64) -> Self {
78        Self {
79            streams: RwLock::new(HashMap::new()),
80            total_bytes: AtomicU64::new(0),
81            max_total_bytes,
82            max_stream_bytes,
83        }
84    }
85
86    /// Return the currently tracked total payload bytes across all streams.
87    #[must_use]
88    pub fn total_bytes(&self) -> u64 {
89        self.total_bytes.load(Ordering::Acquire)
90    }
91
92    fn saturating_sub_total_bytes(&self, bytes: u64) {
93        self.total_bytes
94            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
95                Some(current.saturating_sub(bytes))
96            })
97            .ok();
98    }
99
100    fn get_stream(&self, name: &str) -> Option<Arc<RwLock<StreamEntry>>> {
101        let streams = self.streams.read().expect("streams lock poisoned");
102        streams.get(name).map(Arc::clone)
103    }
104
105    /// Commit messages to a stream, checking memory limits first.
106    ///
107    /// Caller must hold the stream write lock. Updates both stream-level
108    /// and global memory counters atomically.
109    fn commit_messages(&self, stream: &mut StreamEntry, messages: Vec<Bytes>) -> Result<()> {
110        if messages.is_empty() {
111            return Ok(());
112        }
113
114        let mut total_batch_bytes = 0u64;
115        let mut message_sizes = Vec::with_capacity(messages.len());
116        for data in &messages {
117            let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
118            message_sizes.push(byte_len);
119            total_batch_bytes += byte_len;
120        }
121
122        // Reserve global bytes atomically (global precedence before per-stream).
123        if self
124            .total_bytes
125            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
126                current
127                    .checked_add(total_batch_bytes)
128                    .filter(|next| *next <= self.max_total_bytes)
129            })
130            .is_err()
131        {
132            return Err(Error::MemoryLimitExceeded);
133        }
134        if stream.total_bytes + total_batch_bytes > self.max_stream_bytes {
135            self.saturating_sub_total_bytes(total_batch_bytes);
136            return Err(Error::StreamSizeLimitExceeded);
137        }
138
139        for (data, byte_len) in messages.into_iter().zip(message_sizes) {
140            let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
141            stream.next_read_seq += 1;
142            stream.next_byte_offset += byte_len;
143            stream.total_bytes += byte_len;
144            let message = Message::new(offset, data);
145            stream.messages.push(message);
146        }
147
148        // Notify long-poll/SSE subscribers that new data is available.
149        // Ignore errors (no active receivers is fine).
150        let _ = stream.notify.send(());
151
152        Ok(())
153    }
154}
155
156impl Storage for InMemoryStorage {
157    fn create_stream(&self, name: &str, config: StreamConfig) -> Result<CreateStreamResult> {
158        let mut streams = self.streams.write().expect("streams lock poisoned");
159
160        if let Some(stream_arc) = streams.get(name) {
161            let stream = stream_arc.read().expect("stream lock poisoned");
162
163            if super::is_stream_expired(&stream.config) {
164                let stream_bytes = stream.total_bytes;
165                drop(stream);
166                streams.remove(name);
167
168                self.total_bytes
169                    .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
170                        Some(current.saturating_sub(stream_bytes))
171                    })
172                    .ok();
173            } else {
174                if stream.config == config {
175                    return Ok(CreateStreamResult::AlreadyExists);
176                }
177                return Err(Error::ConfigMismatch);
178            }
179        }
180
181        let entry = StreamEntry::new(config);
182        streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
183
184        Ok(CreateStreamResult::Created)
185    }
186
187    fn append(&self, name: &str, data: Bytes, content_type: &str) -> Result<Offset> {
188        let stream_arc = self
189            .get_stream(name)
190            .ok_or_else(|| Error::NotFound(name.to_string()))?;
191
192        let mut stream = stream_arc.write().expect("stream lock poisoned");
193
194        if super::is_stream_expired(&stream.config) {
195            return Err(Error::StreamExpired);
196        }
197
198        if stream.closed {
199            return Err(Error::StreamClosed);
200        }
201
202        super::validate_content_type(&stream.config.content_type, content_type)?;
203
204        let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
205
206        if self
207            .total_bytes
208            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
209                current
210                    .checked_add(byte_len)
211                    .filter(|next| *next <= self.max_total_bytes)
212            })
213            .is_err()
214        {
215            return Err(Error::MemoryLimitExceeded);
216        }
217
218        if stream.total_bytes + byte_len > self.max_stream_bytes {
219            self.saturating_sub_total_bytes(byte_len);
220            return Err(Error::StreamSizeLimitExceeded);
221        }
222
223        let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
224
225        stream.next_read_seq += 1;
226        stream.next_byte_offset += byte_len;
227        stream.total_bytes += byte_len;
228
229        let message = Message::new(offset.clone(), data);
230        stream.messages.push(message);
231
232        Ok(offset)
233    }
234
235    fn batch_append(
236        &self,
237        name: &str,
238        messages: Vec<Bytes>,
239        content_type: &str,
240        seq: Option<&str>,
241    ) -> Result<Offset> {
242        if messages.is_empty() {
243            return Err(Error::InvalidHeader {
244                header: "Content-Length".to_string(),
245                reason: "batch cannot be empty".to_string(),
246            });
247        }
248
249        let stream_arc = self
250            .get_stream(name)
251            .ok_or_else(|| Error::NotFound(name.to_string()))?;
252
253        let mut stream = stream_arc.write().expect("stream lock poisoned");
254
255        if super::is_stream_expired(&stream.config) {
256            return Err(Error::StreamExpired);
257        }
258
259        if stream.closed {
260            return Err(Error::StreamClosed);
261        }
262
263        super::validate_content_type(&stream.config.content_type, content_type)?;
264
265        let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
266
267        self.commit_messages(&mut stream, messages)?;
268        if let Some(new_seq) = pending_seq {
269            stream.last_seq = Some(new_seq);
270        }
271
272        Ok(Offset::new(stream.next_read_seq, stream.next_byte_offset))
273    }
274
275    fn read(&self, name: &str, from_offset: &Offset) -> Result<ReadResult> {
276        let stream_arc = self
277            .get_stream(name)
278            .ok_or_else(|| Error::NotFound(name.to_string()))?;
279
280        let stream = stream_arc.read().expect("stream lock poisoned");
281
282        if super::is_stream_expired(&stream.config) {
283            return Err(Error::StreamExpired);
284        }
285
286        if from_offset.is_now() {
287            let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
288            return Ok(ReadResult {
289                messages: Vec::new(),
290                next_offset,
291                at_tail: true,
292                closed: stream.closed,
293            });
294        }
295
296        let start_idx = if from_offset.is_start() {
297            0
298        } else {
299            match stream
300                .messages
301                .binary_search_by(|m| m.offset.cmp(from_offset))
302            {
303                Ok(idx) | Err(idx) => idx,
304            }
305        };
306
307        let messages: Vec<Bytes> = stream.messages[start_idx..]
308            .iter()
309            .map(|m| m.data.clone())
310            .collect();
311
312        let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
313
314        let at_tail = start_idx + messages.len() >= stream.messages.len();
315
316        Ok(ReadResult {
317            messages,
318            next_offset,
319            at_tail,
320            closed: stream.closed,
321        })
322    }
323
324    fn delete(&self, name: &str) -> Result<()> {
325        let mut streams = self.streams.write().expect("streams lock poisoned");
326
327        if let Some(stream_arc) = streams.remove(name) {
328            let stream = stream_arc.read().expect("stream lock poisoned");
329            self.saturating_sub_total_bytes(stream.total_bytes);
330            Ok(())
331        } else {
332            Err(Error::NotFound(name.to_string()))
333        }
334    }
335
336    fn head(&self, name: &str) -> Result<StreamMetadata> {
337        let stream_arc = self
338            .get_stream(name)
339            .ok_or_else(|| Error::NotFound(name.to_string()))?;
340
341        let stream = stream_arc.read().expect("stream lock poisoned");
342
343        if super::is_stream_expired(&stream.config) {
344            return Err(Error::StreamExpired);
345        }
346
347        Ok(StreamMetadata {
348            config: stream.config.clone(),
349            next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
350            closed: stream.closed,
351            total_bytes: stream.total_bytes,
352            message_count: u64::try_from(stream.messages.len()).unwrap_or(u64::MAX),
353            created_at: stream.created_at,
354        })
355    }
356
357    fn close_stream(&self, name: &str) -> Result<()> {
358        let stream_arc = self
359            .get_stream(name)
360            .ok_or_else(|| Error::NotFound(name.to_string()))?;
361
362        let mut stream = stream_arc.write().expect("stream lock poisoned");
363
364        if super::is_stream_expired(&stream.config) {
365            return Err(Error::StreamExpired);
366        }
367
368        stream.closed = true;
369
370        let _ = stream.notify.send(());
371
372        Ok(())
373    }
374
375    fn append_with_producer(
376        &self,
377        name: &str,
378        messages: Vec<Bytes>,
379        content_type: &str,
380        producer: &ProducerHeaders,
381        should_close: bool,
382        seq: Option<&str>,
383    ) -> Result<ProducerAppendResult> {
384        let stream_arc = self
385            .get_stream(name)
386            .ok_or_else(|| Error::NotFound(name.to_string()))?;
387
388        let mut stream = stream_arc.write().expect("stream lock poisoned");
389
390        if super::is_stream_expired(&stream.config) {
391            return Err(Error::StreamExpired);
392        }
393
394        super::cleanup_stale_producers(&mut stream.producers);
395
396        if !messages.is_empty() {
397            super::validate_content_type(&stream.config.content_type, content_type)?;
398        }
399
400        let now = Utc::now();
401
402        match super::check_producer(stream.producers.get(&producer.id), producer, stream.closed)? {
403            ProducerCheck::Accept => {}
404            ProducerCheck::Duplicate { epoch, seq } => {
405                return Ok(ProducerAppendResult::Duplicate {
406                    epoch,
407                    seq,
408                    next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
409                    closed: stream.closed,
410                });
411            }
412        }
413
414        let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
415
416        self.commit_messages(&mut stream, messages)?;
417        if let Some(new_seq) = pending_seq {
418            stream.last_seq = Some(new_seq);
419        }
420
421        if should_close {
422            stream.closed = true;
423        }
424
425        stream.producers.insert(
426            producer.id.clone(),
427            ProducerState {
428                epoch: producer.epoch,
429                last_seq: producer.seq,
430                updated_at: now,
431            },
432        );
433
434        let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
435        let closed = stream.closed;
436
437        Ok(ProducerAppendResult::Accepted {
438            epoch: producer.epoch,
439            seq: producer.seq,
440            next_offset,
441            closed,
442        })
443    }
444
445    fn create_stream_with_data(
446        &self,
447        name: &str,
448        config: StreamConfig,
449        messages: Vec<Bytes>,
450        should_close: bool,
451    ) -> Result<super::CreateWithDataResult> {
452        let mut streams = self.streams.write().expect("streams lock poisoned");
453
454        if let Some(stream_arc) = streams.get(name) {
455            let stream = stream_arc.read().expect("stream lock poisoned");
456
457            if super::is_stream_expired(&stream.config) {
458                let stream_bytes = stream.total_bytes;
459                drop(stream);
460                streams.remove(name);
461                self.saturating_sub_total_bytes(stream_bytes);
462            } else if stream.config == config {
463                let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
464                let closed = stream.closed;
465                return Ok(super::CreateWithDataResult {
466                    status: CreateStreamResult::AlreadyExists,
467                    next_offset,
468                    closed,
469                });
470            } else {
471                return Err(Error::ConfigMismatch);
472            }
473        }
474
475        let mut entry = StreamEntry::new(config);
476
477        if !messages.is_empty() {
478            self.commit_messages(&mut entry, messages)?;
479        }
480
481        if should_close {
482            entry.closed = true;
483        }
484
485        let next_offset = Offset::new(entry.next_read_seq, entry.next_byte_offset);
486        let closed = entry.closed;
487
488        streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
489
490        Ok(super::CreateWithDataResult {
491            status: CreateStreamResult::Created,
492            next_offset,
493            closed,
494        })
495    }
496
497    fn exists(&self, name: &str) -> bool {
498        let streams = self.streams.read().expect("streams lock poisoned");
499        if let Some(stream_arc) = streams.get(name) {
500            let stream = stream_arc.read().expect("stream lock poisoned");
501            !super::is_stream_expired(&stream.config)
502        } else {
503            false
504        }
505    }
506
507    fn subscribe(&self, name: &str) -> Option<broadcast::Receiver<()>> {
508        let stream_arc = self.get_stream(name)?;
509        let stream = stream_arc.read().expect("stream lock poisoned");
510
511        if super::is_stream_expired(&stream.config) {
512            return None;
513        }
514
515        Some(stream.notify.subscribe())
516    }
517
518    fn cleanup_expired_streams(&self) -> usize {
519        let mut streams = self.streams.write().expect("streams lock poisoned");
520        let mut expired = Vec::new();
521
522        for (name, stream_arc) in streams.iter() {
523            let stream = stream_arc.read().expect("stream lock poisoned");
524            if super::is_stream_expired(&stream.config) {
525                expired.push((name.clone(), stream.total_bytes));
526            }
527        }
528
529        for (name, bytes) in &expired {
530            streams.remove(name);
531            self.saturating_sub_total_bytes(*bytes);
532        }
533
534        expired.len()
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use std::sync::Arc;
542    use std::thread;
543
544    fn test_storage() -> InMemoryStorage {
545        InMemoryStorage::new(1024 * 1024, 100 * 1024)
546    }
547
548    fn producer(id: &str, epoch: u64, seq: u64) -> ProducerHeaders {
549        ProducerHeaders {
550            id: id.to_string(),
551            epoch,
552            seq,
553        }
554    }
555
556    #[test]
557    fn test_concurrent_producer_appends() {
558        let storage = Arc::new(test_storage());
559        let config = StreamConfig::new("text/plain".to_string());
560        storage.create_stream("test", config).unwrap();
561
562        let num_producers = 4;
563        let seqs_per_producer = 50;
564
565        let handles: Vec<_> = (0..num_producers)
566            .map(|p| {
567                let storage = Arc::clone(&storage);
568                thread::spawn(move || {
569                    let prod_id = format!("p{p}");
570                    for seq in 0..seqs_per_producer {
571                        let result = storage.append_with_producer(
572                            "test",
573                            vec![Bytes::from(format!("{prod_id}-{seq}"))],
574                            "text/plain",
575                            &producer(&prod_id, 0, seq),
576                            false,
577                            None,
578                        );
579                        assert!(
580                            result.is_ok(),
581                            "Producer {prod_id} seq {seq} failed: {result:?}"
582                        );
583                    }
584                })
585            })
586            .collect();
587
588        for handle in handles {
589            handle.join().expect("thread panicked");
590        }
591
592        let metadata = storage.head("test").unwrap();
593        assert_eq!(metadata.message_count, num_producers * seqs_per_producer);
594    }
595}