Skip to main content

durable_streams_server/storage/
memory.rs

1use super::{
2    CreateStreamResult, Message, NOTIFY_CHANNEL_CAPACITY, ProducerAppendResult, ProducerCheck,
3    ProducerState, ReadResult, Storage, StreamConfig, StreamMetadata,
4};
5use crate::protocol::error::{Error, Result};
6use crate::protocol::offset::Offset;
7use crate::protocol::producer::ProducerHeaders;
8use bytes::Bytes;
9use chrono::Utc;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, RwLock};
13use tokio::sync::broadcast;
14
15const INITIAL_MESSAGES_CAPACITY: usize = 256;
16const INITIAL_PRODUCERS_CAPACITY: usize = 8;
17
18/// Internal stream entry
19struct StreamEntry {
20    config: StreamConfig,
21    messages: Vec<Message>,
22    closed: bool,
23    next_read_seq: u64,
24    next_byte_offset: u64,
25    total_bytes: u64,
26    created_at: chrono::DateTime<Utc>,
27    /// Per-producer state for idempotent producer support
28    producers: HashMap<String, ProducerState>,
29    /// Broadcast sender for notifying long-poll/SSE subscribers
30    notify: broadcast::Sender<()>,
31    /// Last Stream-Seq value received (lexicographic ordering)
32    last_seq: Option<String>,
33}
34
35impl StreamEntry {
36    fn new(config: StreamConfig) -> Self {
37        // Stream starts open; the handler closes it after any initial appends.
38        // The `created_closed` flag in config is stored for idempotent checks only.
39        let (notify, _) = broadcast::channel(NOTIFY_CHANNEL_CAPACITY);
40        Self {
41            config,
42            messages: Vec::with_capacity(INITIAL_MESSAGES_CAPACITY),
43            closed: false,
44            next_read_seq: 0,
45            next_byte_offset: 0,
46            total_bytes: 0,
47            created_at: Utc::now(),
48            producers: HashMap::with_capacity(INITIAL_PRODUCERS_CAPACITY),
49            notify,
50            last_seq: None,
51        }
52    }
53}
54
55/// In-memory storage implementation
56///
57/// Thread-safe storage with:
58/// - `RwLock<HashMap>` for stream lookup (concurrent reads)
59/// - Per-stream `RwLock` for exclusive write access (offset monotonicity)
60/// - Memory limit enforcement (global and per-stream)
61///
62/// # Concurrency Model
63///
64/// Multiple readers can access different streams concurrently.
65/// Appends to the same stream are serialized via `RwLock::write()`.
66/// Appends to different streams can proceed concurrently.
67pub struct InMemoryStorage {
68    streams: RwLock<HashMap<String, Arc<RwLock<StreamEntry>>>>,
69    total_bytes: AtomicU64,
70    max_total_bytes: u64,
71    max_stream_bytes: u64,
72}
73
74impl InMemoryStorage {
75    /// Create a new in-memory storage with memory limits
76    #[must_use]
77    pub fn new(max_total_bytes: u64, max_stream_bytes: u64) -> Self {
78        Self {
79            streams: RwLock::new(HashMap::new()),
80            total_bytes: AtomicU64::new(0),
81            max_total_bytes,
82            max_stream_bytes,
83        }
84    }
85
86    /// Get current total memory usage
87    ///
88    /// # Panics
89    ///
90    /// Panics if the `total_bytes` lock is poisoned (which indicates a panic while holding the lock).
91    #[must_use]
92    pub fn total_bytes(&self) -> u64 {
93        self.total_bytes.load(Ordering::Acquire)
94    }
95
96    fn saturating_sub_total_bytes(&self, bytes: u64) {
97        self.total_bytes
98            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
99                Some(current.saturating_sub(bytes))
100            })
101            .ok();
102    }
103
104    fn get_stream(&self, name: &str) -> Option<Arc<RwLock<StreamEntry>>> {
105        let streams = self.streams.read().expect("streams lock poisoned");
106        streams.get(name).map(Arc::clone)
107    }
108
109    /// Commit messages to a stream, checking memory limits first.
110    ///
111    /// Caller must hold the stream write lock. Updates both stream-level
112    /// and global memory counters atomically.
113    fn commit_messages(&self, stream: &mut StreamEntry, messages: Vec<Bytes>) -> Result<()> {
114        if messages.is_empty() {
115            return Ok(());
116        }
117
118        let mut total_batch_bytes = 0u64;
119        let mut message_sizes = Vec::with_capacity(messages.len());
120        for data in &messages {
121            let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
122            message_sizes.push(byte_len);
123            total_batch_bytes += byte_len;
124        }
125
126        // Reserve global bytes atomically (global precedence before per-stream).
127        if self
128            .total_bytes
129            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
130                current
131                    .checked_add(total_batch_bytes)
132                    .filter(|next| *next <= self.max_total_bytes)
133            })
134            .is_err()
135        {
136            return Err(Error::MemoryLimitExceeded);
137        }
138        if stream.total_bytes + total_batch_bytes > self.max_stream_bytes {
139            self.saturating_sub_total_bytes(total_batch_bytes);
140            return Err(Error::StreamSizeLimitExceeded);
141        }
142
143        for (data, byte_len) in messages.into_iter().zip(message_sizes) {
144            let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
145            stream.next_read_seq += 1;
146            stream.next_byte_offset += byte_len;
147            stream.total_bytes += byte_len;
148            let message = Message::new(offset, data);
149            stream.messages.push(message);
150        }
151
152        // Notify long-poll/SSE subscribers that new data is available.
153        // Ignore errors (no active receivers is fine).
154        let _ = stream.notify.send(());
155
156        Ok(())
157    }
158}
159
160impl Storage for InMemoryStorage {
161    fn create_stream(&self, name: &str, config: StreamConfig) -> Result<CreateStreamResult> {
162        let mut streams = self.streams.write().expect("streams lock poisoned");
163
164        if let Some(stream_arc) = streams.get(name) {
165            let stream = stream_arc.read().expect("stream lock poisoned");
166
167            if super::is_stream_expired(&stream.config) {
168                let stream_bytes = stream.total_bytes;
169                drop(stream);
170                streams.remove(name);
171
172                self.total_bytes
173                    .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
174                        Some(current.saturating_sub(stream_bytes))
175                    })
176                    .ok();
177            } else {
178                if stream.config == config {
179                    return Ok(CreateStreamResult::AlreadyExists);
180                }
181                return Err(Error::ConfigMismatch);
182            }
183        }
184
185        let entry = StreamEntry::new(config);
186        streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
187
188        Ok(CreateStreamResult::Created)
189    }
190
191    fn append(&self, name: &str, data: Bytes, content_type: &str) -> Result<Offset> {
192        let stream_arc = self
193            .get_stream(name)
194            .ok_or_else(|| Error::NotFound(name.to_string()))?;
195
196        let mut stream = stream_arc.write().expect("stream lock poisoned");
197
198        if super::is_stream_expired(&stream.config) {
199            return Err(Error::StreamExpired);
200        }
201
202        if stream.closed {
203            return Err(Error::StreamClosed);
204        }
205
206        super::validate_content_type(&stream.config.content_type, content_type)?;
207
208        let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
209
210        if self
211            .total_bytes
212            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
213                current
214                    .checked_add(byte_len)
215                    .filter(|next| *next <= self.max_total_bytes)
216            })
217            .is_err()
218        {
219            return Err(Error::MemoryLimitExceeded);
220        }
221
222        if stream.total_bytes + byte_len > self.max_stream_bytes {
223            self.saturating_sub_total_bytes(byte_len);
224            return Err(Error::StreamSizeLimitExceeded);
225        }
226
227        let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
228
229        stream.next_read_seq += 1;
230        stream.next_byte_offset += byte_len;
231        stream.total_bytes += byte_len;
232
233        let message = Message::new(offset.clone(), data);
234        stream.messages.push(message);
235
236        Ok(offset)
237    }
238
239    fn batch_append(
240        &self,
241        name: &str,
242        messages: Vec<Bytes>,
243        content_type: &str,
244        seq: Option<&str>,
245    ) -> Result<Offset> {
246        if messages.is_empty() {
247            return Err(Error::InvalidHeader {
248                header: "Content-Length".to_string(),
249                reason: "batch cannot be empty".to_string(),
250            });
251        }
252
253        let stream_arc = self
254            .get_stream(name)
255            .ok_or_else(|| Error::NotFound(name.to_string()))?;
256
257        let mut stream = stream_arc.write().expect("stream lock poisoned");
258
259        if super::is_stream_expired(&stream.config) {
260            return Err(Error::StreamExpired);
261        }
262
263        if stream.closed {
264            return Err(Error::StreamClosed);
265        }
266
267        super::validate_content_type(&stream.config.content_type, content_type)?;
268
269        let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
270
271        self.commit_messages(&mut stream, messages)?;
272        if let Some(new_seq) = pending_seq {
273            stream.last_seq = Some(new_seq);
274        }
275
276        Ok(Offset::new(stream.next_read_seq, stream.next_byte_offset))
277    }
278
279    fn read(&self, name: &str, from_offset: &Offset) -> Result<ReadResult> {
280        let stream_arc = self
281            .get_stream(name)
282            .ok_or_else(|| Error::NotFound(name.to_string()))?;
283
284        let stream = stream_arc.read().expect("stream lock poisoned");
285
286        if super::is_stream_expired(&stream.config) {
287            return Err(Error::StreamExpired);
288        }
289
290        if from_offset.is_now() {
291            let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
292            return Ok(ReadResult {
293                messages: Vec::new(),
294                next_offset,
295                at_tail: true,
296                closed: stream.closed,
297            });
298        }
299
300        let start_idx = if from_offset.is_start() {
301            0
302        } else {
303            match stream
304                .messages
305                .binary_search_by(|m| m.offset.cmp(from_offset))
306            {
307                Ok(idx) | Err(idx) => idx,
308            }
309        };
310
311        let messages: Vec<Bytes> = stream.messages[start_idx..]
312            .iter()
313            .map(|m| m.data.clone())
314            .collect();
315
316        let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
317
318        let at_tail = start_idx + messages.len() >= stream.messages.len();
319
320        Ok(ReadResult {
321            messages,
322            next_offset,
323            at_tail,
324            closed: stream.closed,
325        })
326    }
327
328    fn delete(&self, name: &str) -> Result<()> {
329        let mut streams = self.streams.write().expect("streams lock poisoned");
330
331        if let Some(stream_arc) = streams.remove(name) {
332            let stream = stream_arc.read().expect("stream lock poisoned");
333            self.saturating_sub_total_bytes(stream.total_bytes);
334            Ok(())
335        } else {
336            Err(Error::NotFound(name.to_string()))
337        }
338    }
339
340    fn head(&self, name: &str) -> Result<StreamMetadata> {
341        let stream_arc = self
342            .get_stream(name)
343            .ok_or_else(|| Error::NotFound(name.to_string()))?;
344
345        let stream = stream_arc.read().expect("stream lock poisoned");
346
347        if super::is_stream_expired(&stream.config) {
348            return Err(Error::StreamExpired);
349        }
350
351        Ok(StreamMetadata {
352            config: stream.config.clone(),
353            next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
354            closed: stream.closed,
355            total_bytes: stream.total_bytes,
356            message_count: u64::try_from(stream.messages.len()).unwrap_or(u64::MAX),
357            created_at: stream.created_at,
358        })
359    }
360
361    fn close_stream(&self, name: &str) -> Result<()> {
362        let stream_arc = self
363            .get_stream(name)
364            .ok_or_else(|| Error::NotFound(name.to_string()))?;
365
366        let mut stream = stream_arc.write().expect("stream lock poisoned");
367
368        if super::is_stream_expired(&stream.config) {
369            return Err(Error::StreamExpired);
370        }
371
372        stream.closed = true;
373
374        let _ = stream.notify.send(());
375
376        Ok(())
377    }
378
379    fn append_with_producer(
380        &self,
381        name: &str,
382        messages: Vec<Bytes>,
383        content_type: &str,
384        producer: &ProducerHeaders,
385        should_close: bool,
386        seq: Option<&str>,
387    ) -> Result<ProducerAppendResult> {
388        let stream_arc = self
389            .get_stream(name)
390            .ok_or_else(|| Error::NotFound(name.to_string()))?;
391
392        let mut stream = stream_arc.write().expect("stream lock poisoned");
393
394        if super::is_stream_expired(&stream.config) {
395            return Err(Error::StreamExpired);
396        }
397
398        super::cleanup_stale_producers(&mut stream.producers);
399
400        if !messages.is_empty() {
401            super::validate_content_type(&stream.config.content_type, content_type)?;
402        }
403
404        let now = Utc::now();
405
406        match super::check_producer(stream.producers.get(&producer.id), producer, stream.closed)? {
407            ProducerCheck::Accept => {}
408            ProducerCheck::Duplicate { epoch, seq } => {
409                return Ok(ProducerAppendResult::Duplicate {
410                    epoch,
411                    seq,
412                    next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
413                    closed: stream.closed,
414                });
415            }
416        }
417
418        let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
419
420        self.commit_messages(&mut stream, messages)?;
421        if let Some(new_seq) = pending_seq {
422            stream.last_seq = Some(new_seq);
423        }
424
425        if should_close {
426            stream.closed = true;
427        }
428
429        stream.producers.insert(
430            producer.id.clone(),
431            ProducerState {
432                epoch: producer.epoch,
433                last_seq: producer.seq,
434                updated_at: now,
435            },
436        );
437
438        let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
439        let closed = stream.closed;
440
441        Ok(ProducerAppendResult::Accepted {
442            epoch: producer.epoch,
443            seq: producer.seq,
444            next_offset,
445            closed,
446        })
447    }
448
449    fn create_stream_with_data(
450        &self,
451        name: &str,
452        config: StreamConfig,
453        messages: Vec<Bytes>,
454        should_close: bool,
455    ) -> Result<super::CreateWithDataResult> {
456        let mut streams = self.streams.write().expect("streams lock poisoned");
457
458        if let Some(stream_arc) = streams.get(name) {
459            let stream = stream_arc.read().expect("stream lock poisoned");
460
461            if super::is_stream_expired(&stream.config) {
462                let stream_bytes = stream.total_bytes;
463                drop(stream);
464                streams.remove(name);
465                self.saturating_sub_total_bytes(stream_bytes);
466            } else if stream.config == config {
467                let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
468                let closed = stream.closed;
469                return Ok(super::CreateWithDataResult {
470                    status: CreateStreamResult::AlreadyExists,
471                    next_offset,
472                    closed,
473                });
474            } else {
475                return Err(Error::ConfigMismatch);
476            }
477        }
478
479        let mut entry = StreamEntry::new(config);
480
481        if !messages.is_empty() {
482            self.commit_messages(&mut entry, messages)?;
483        }
484
485        if should_close {
486            entry.closed = true;
487        }
488
489        let next_offset = Offset::new(entry.next_read_seq, entry.next_byte_offset);
490        let closed = entry.closed;
491
492        streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
493
494        Ok(super::CreateWithDataResult {
495            status: CreateStreamResult::Created,
496            next_offset,
497            closed,
498        })
499    }
500
501    fn exists(&self, name: &str) -> bool {
502        let streams = self.streams.read().expect("streams lock poisoned");
503        if let Some(stream_arc) = streams.get(name) {
504            let stream = stream_arc.read().expect("stream lock poisoned");
505            !super::is_stream_expired(&stream.config)
506        } else {
507            false
508        }
509    }
510
511    fn subscribe(&self, name: &str) -> Option<broadcast::Receiver<()>> {
512        let stream_arc = self.get_stream(name)?;
513        let stream = stream_arc.read().expect("stream lock poisoned");
514
515        if super::is_stream_expired(&stream.config) {
516            return None;
517        }
518
519        Some(stream.notify.subscribe())
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use std::sync::Arc;
527    use std::thread;
528
529    fn test_storage() -> InMemoryStorage {
530        InMemoryStorage::new(1024 * 1024, 100 * 1024)
531    }
532
533    fn producer(id: &str, epoch: u64, seq: u64) -> ProducerHeaders {
534        ProducerHeaders {
535            id: id.to_string(),
536            epoch,
537            seq,
538        }
539    }
540
541    #[test]
542    fn test_concurrent_producer_appends() {
543        let storage = Arc::new(test_storage());
544        let config = StreamConfig::new("text/plain".to_string());
545        storage.create_stream("test", config).unwrap();
546
547        let num_producers = 4;
548        let seqs_per_producer = 50;
549
550        let handles: Vec<_> = (0..num_producers)
551            .map(|p| {
552                let storage = Arc::clone(&storage);
553                thread::spawn(move || {
554                    let prod_id = format!("p{p}");
555                    for seq in 0..seqs_per_producer {
556                        let result = storage.append_with_producer(
557                            "test",
558                            vec![Bytes::from(format!("{prod_id}-{seq}"))],
559                            "text/plain",
560                            &producer(&prod_id, 0, seq),
561                            false,
562                            None,
563                        );
564                        assert!(
565                            result.is_ok(),
566                            "Producer {prod_id} seq {seq} failed: {result:?}"
567                        );
568                    }
569                })
570            })
571            .collect();
572
573        for handle in handles {
574            handle.join().expect("thread panicked");
575        }
576
577        let metadata = storage.head("test").unwrap();
578        assert_eq!(metadata.message_count, num_producers * seqs_per_producer);
579    }
580}