1use super::{
2 CreateStreamResult, Message, NOTIFY_CHANNEL_CAPACITY, ProducerAppendResult, ProducerCheck,
3 ProducerState, ReadResult, Storage, StreamConfig, StreamMetadata,
4};
5use crate::protocol::error::{Error, Result};
6use crate::protocol::offset::Offset;
7use crate::protocol::producer::ProducerHeaders;
8use bytes::Bytes;
9use chrono::Utc;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, RwLock};
13use tokio::sync::broadcast;
14
15const INITIAL_MESSAGES_CAPACITY: usize = 256;
16const INITIAL_PRODUCERS_CAPACITY: usize = 8;
17
18struct StreamEntry {
20 config: StreamConfig,
21 messages: Vec<Message>,
22 closed: bool,
23 next_read_seq: u64,
24 next_byte_offset: u64,
25 total_bytes: u64,
26 created_at: chrono::DateTime<Utc>,
27 producers: HashMap<String, ProducerState>,
29 notify: broadcast::Sender<()>,
31 last_seq: Option<String>,
33}
34
35impl StreamEntry {
36 fn new(config: StreamConfig) -> Self {
37 let (notify, _) = broadcast::channel(NOTIFY_CHANNEL_CAPACITY);
40 Self {
41 config,
42 messages: Vec::with_capacity(INITIAL_MESSAGES_CAPACITY),
43 closed: false,
44 next_read_seq: 0,
45 next_byte_offset: 0,
46 total_bytes: 0,
47 created_at: Utc::now(),
48 producers: HashMap::with_capacity(INITIAL_PRODUCERS_CAPACITY),
49 notify,
50 last_seq: None,
51 }
52 }
53}
54
55pub struct InMemoryStorage {
68 streams: RwLock<HashMap<String, Arc<RwLock<StreamEntry>>>>,
69 total_bytes: AtomicU64,
70 max_total_bytes: u64,
71 max_stream_bytes: u64,
72}
73
74impl InMemoryStorage {
75 #[must_use]
77 pub fn new(max_total_bytes: u64, max_stream_bytes: u64) -> Self {
78 Self {
79 streams: RwLock::new(HashMap::new()),
80 total_bytes: AtomicU64::new(0),
81 max_total_bytes,
82 max_stream_bytes,
83 }
84 }
85
86 #[must_use]
88 pub fn total_bytes(&self) -> u64 {
89 self.total_bytes.load(Ordering::Acquire)
90 }
91
92 fn saturating_sub_total_bytes(&self, bytes: u64) {
93 self.total_bytes
94 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
95 Some(current.saturating_sub(bytes))
96 })
97 .ok();
98 }
99
100 fn get_stream(&self, name: &str) -> Option<Arc<RwLock<StreamEntry>>> {
101 let streams = self.streams.read().expect("streams lock poisoned");
102 streams.get(name).map(Arc::clone)
103 }
104
105 fn commit_messages(&self, stream: &mut StreamEntry, messages: Vec<Bytes>) -> Result<()> {
110 if messages.is_empty() {
111 return Ok(());
112 }
113
114 let mut total_batch_bytes = 0u64;
115 let mut message_sizes = Vec::with_capacity(messages.len());
116 for data in &messages {
117 let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
118 message_sizes.push(byte_len);
119 total_batch_bytes += byte_len;
120 }
121
122 if self
124 .total_bytes
125 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
126 current
127 .checked_add(total_batch_bytes)
128 .filter(|next| *next <= self.max_total_bytes)
129 })
130 .is_err()
131 {
132 return Err(Error::MemoryLimitExceeded);
133 }
134 if stream.total_bytes + total_batch_bytes > self.max_stream_bytes {
135 self.saturating_sub_total_bytes(total_batch_bytes);
136 return Err(Error::StreamSizeLimitExceeded);
137 }
138
139 for (data, byte_len) in messages.into_iter().zip(message_sizes) {
140 let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
141 stream.next_read_seq += 1;
142 stream.next_byte_offset += byte_len;
143 stream.total_bytes += byte_len;
144 let message = Message::new(offset, data);
145 stream.messages.push(message);
146 }
147
148 let _ = stream.notify.send(());
151
152 Ok(())
153 }
154}
155
156impl Storage for InMemoryStorage {
157 fn create_stream(&self, name: &str, config: StreamConfig) -> Result<CreateStreamResult> {
158 let mut streams = self.streams.write().expect("streams lock poisoned");
159
160 if let Some(stream_arc) = streams.get(name) {
161 let stream = stream_arc.read().expect("stream lock poisoned");
162
163 if super::is_stream_expired(&stream.config) {
164 let stream_bytes = stream.total_bytes;
165 drop(stream);
166 streams.remove(name);
167
168 self.total_bytes
169 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
170 Some(current.saturating_sub(stream_bytes))
171 })
172 .ok();
173 } else {
174 if stream.config == config {
175 return Ok(CreateStreamResult::AlreadyExists);
176 }
177 return Err(Error::ConfigMismatch);
178 }
179 }
180
181 let entry = StreamEntry::new(config);
182 streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
183
184 Ok(CreateStreamResult::Created)
185 }
186
187 fn append(&self, name: &str, data: Bytes, content_type: &str) -> Result<Offset> {
188 let stream_arc = self
189 .get_stream(name)
190 .ok_or_else(|| Error::NotFound(name.to_string()))?;
191
192 let mut stream = stream_arc.write().expect("stream lock poisoned");
193
194 if super::is_stream_expired(&stream.config) {
195 return Err(Error::StreamExpired);
196 }
197
198 if stream.closed {
199 return Err(Error::StreamClosed);
200 }
201
202 super::validate_content_type(&stream.config.content_type, content_type)?;
203
204 let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
205
206 if self
207 .total_bytes
208 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
209 current
210 .checked_add(byte_len)
211 .filter(|next| *next <= self.max_total_bytes)
212 })
213 .is_err()
214 {
215 return Err(Error::MemoryLimitExceeded);
216 }
217
218 if stream.total_bytes + byte_len > self.max_stream_bytes {
219 self.saturating_sub_total_bytes(byte_len);
220 return Err(Error::StreamSizeLimitExceeded);
221 }
222
223 let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
224
225 stream.next_read_seq += 1;
226 stream.next_byte_offset += byte_len;
227 stream.total_bytes += byte_len;
228
229 let message = Message::new(offset.clone(), data);
230 stream.messages.push(message);
231
232 Ok(offset)
233 }
234
235 fn batch_append(
236 &self,
237 name: &str,
238 messages: Vec<Bytes>,
239 content_type: &str,
240 seq: Option<&str>,
241 ) -> Result<Offset> {
242 if messages.is_empty() {
243 return Err(Error::InvalidHeader {
244 header: "Content-Length".to_string(),
245 reason: "batch cannot be empty".to_string(),
246 });
247 }
248
249 let stream_arc = self
250 .get_stream(name)
251 .ok_or_else(|| Error::NotFound(name.to_string()))?;
252
253 let mut stream = stream_arc.write().expect("stream lock poisoned");
254
255 if super::is_stream_expired(&stream.config) {
256 return Err(Error::StreamExpired);
257 }
258
259 if stream.closed {
260 return Err(Error::StreamClosed);
261 }
262
263 super::validate_content_type(&stream.config.content_type, content_type)?;
264
265 let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
266
267 self.commit_messages(&mut stream, messages)?;
268 if let Some(new_seq) = pending_seq {
269 stream.last_seq = Some(new_seq);
270 }
271
272 Ok(Offset::new(stream.next_read_seq, stream.next_byte_offset))
273 }
274
275 fn read(&self, name: &str, from_offset: &Offset) -> Result<ReadResult> {
276 let stream_arc = self
277 .get_stream(name)
278 .ok_or_else(|| Error::NotFound(name.to_string()))?;
279
280 let stream = stream_arc.read().expect("stream lock poisoned");
281
282 if super::is_stream_expired(&stream.config) {
283 return Err(Error::StreamExpired);
284 }
285
286 if from_offset.is_now() {
287 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
288 return Ok(ReadResult {
289 messages: Vec::new(),
290 next_offset,
291 at_tail: true,
292 closed: stream.closed,
293 });
294 }
295
296 let start_idx = if from_offset.is_start() {
297 0
298 } else {
299 match stream
300 .messages
301 .binary_search_by(|m| m.offset.cmp(from_offset))
302 {
303 Ok(idx) | Err(idx) => idx,
304 }
305 };
306
307 let messages: Vec<Bytes> = stream.messages[start_idx..]
308 .iter()
309 .map(|m| m.data.clone())
310 .collect();
311
312 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
313
314 let at_tail = start_idx + messages.len() >= stream.messages.len();
315
316 Ok(ReadResult {
317 messages,
318 next_offset,
319 at_tail,
320 closed: stream.closed,
321 })
322 }
323
324 fn delete(&self, name: &str) -> Result<()> {
325 let mut streams = self.streams.write().expect("streams lock poisoned");
326
327 if let Some(stream_arc) = streams.remove(name) {
328 let stream = stream_arc.read().expect("stream lock poisoned");
329 self.saturating_sub_total_bytes(stream.total_bytes);
330 Ok(())
331 } else {
332 Err(Error::NotFound(name.to_string()))
333 }
334 }
335
336 fn head(&self, name: &str) -> Result<StreamMetadata> {
337 let stream_arc = self
338 .get_stream(name)
339 .ok_or_else(|| Error::NotFound(name.to_string()))?;
340
341 let stream = stream_arc.read().expect("stream lock poisoned");
342
343 if super::is_stream_expired(&stream.config) {
344 return Err(Error::StreamExpired);
345 }
346
347 Ok(StreamMetadata {
348 config: stream.config.clone(),
349 next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
350 closed: stream.closed,
351 total_bytes: stream.total_bytes,
352 message_count: u64::try_from(stream.messages.len()).unwrap_or(u64::MAX),
353 created_at: stream.created_at,
354 })
355 }
356
357 fn close_stream(&self, name: &str) -> Result<()> {
358 let stream_arc = self
359 .get_stream(name)
360 .ok_or_else(|| Error::NotFound(name.to_string()))?;
361
362 let mut stream = stream_arc.write().expect("stream lock poisoned");
363
364 if super::is_stream_expired(&stream.config) {
365 return Err(Error::StreamExpired);
366 }
367
368 stream.closed = true;
369
370 let _ = stream.notify.send(());
371
372 Ok(())
373 }
374
375 fn append_with_producer(
376 &self,
377 name: &str,
378 messages: Vec<Bytes>,
379 content_type: &str,
380 producer: &ProducerHeaders,
381 should_close: bool,
382 seq: Option<&str>,
383 ) -> Result<ProducerAppendResult> {
384 let stream_arc = self
385 .get_stream(name)
386 .ok_or_else(|| Error::NotFound(name.to_string()))?;
387
388 let mut stream = stream_arc.write().expect("stream lock poisoned");
389
390 if super::is_stream_expired(&stream.config) {
391 return Err(Error::StreamExpired);
392 }
393
394 super::cleanup_stale_producers(&mut stream.producers);
395
396 if !messages.is_empty() {
397 super::validate_content_type(&stream.config.content_type, content_type)?;
398 }
399
400 let now = Utc::now();
401
402 match super::check_producer(stream.producers.get(&producer.id), producer, stream.closed)? {
403 ProducerCheck::Accept => {}
404 ProducerCheck::Duplicate { epoch, seq } => {
405 return Ok(ProducerAppendResult::Duplicate {
406 epoch,
407 seq,
408 next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
409 closed: stream.closed,
410 });
411 }
412 }
413
414 let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
415
416 self.commit_messages(&mut stream, messages)?;
417 if let Some(new_seq) = pending_seq {
418 stream.last_seq = Some(new_seq);
419 }
420
421 if should_close {
422 stream.closed = true;
423 }
424
425 stream.producers.insert(
426 producer.id.clone(),
427 ProducerState {
428 epoch: producer.epoch,
429 last_seq: producer.seq,
430 updated_at: now,
431 },
432 );
433
434 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
435 let closed = stream.closed;
436
437 Ok(ProducerAppendResult::Accepted {
438 epoch: producer.epoch,
439 seq: producer.seq,
440 next_offset,
441 closed,
442 })
443 }
444
445 fn create_stream_with_data(
446 &self,
447 name: &str,
448 config: StreamConfig,
449 messages: Vec<Bytes>,
450 should_close: bool,
451 ) -> Result<super::CreateWithDataResult> {
452 let mut streams = self.streams.write().expect("streams lock poisoned");
453
454 if let Some(stream_arc) = streams.get(name) {
455 let stream = stream_arc.read().expect("stream lock poisoned");
456
457 if super::is_stream_expired(&stream.config) {
458 let stream_bytes = stream.total_bytes;
459 drop(stream);
460 streams.remove(name);
461 self.saturating_sub_total_bytes(stream_bytes);
462 } else if stream.config == config {
463 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
464 let closed = stream.closed;
465 return Ok(super::CreateWithDataResult {
466 status: CreateStreamResult::AlreadyExists,
467 next_offset,
468 closed,
469 });
470 } else {
471 return Err(Error::ConfigMismatch);
472 }
473 }
474
475 let mut entry = StreamEntry::new(config);
476
477 if !messages.is_empty() {
478 self.commit_messages(&mut entry, messages)?;
479 }
480
481 if should_close {
482 entry.closed = true;
483 }
484
485 let next_offset = Offset::new(entry.next_read_seq, entry.next_byte_offset);
486 let closed = entry.closed;
487
488 streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
489
490 Ok(super::CreateWithDataResult {
491 status: CreateStreamResult::Created,
492 next_offset,
493 closed,
494 })
495 }
496
497 fn exists(&self, name: &str) -> bool {
498 let streams = self.streams.read().expect("streams lock poisoned");
499 if let Some(stream_arc) = streams.get(name) {
500 let stream = stream_arc.read().expect("stream lock poisoned");
501 !super::is_stream_expired(&stream.config)
502 } else {
503 false
504 }
505 }
506
507 fn subscribe(&self, name: &str) -> Option<broadcast::Receiver<()>> {
508 let stream_arc = self.get_stream(name)?;
509 let stream = stream_arc.read().expect("stream lock poisoned");
510
511 if super::is_stream_expired(&stream.config) {
512 return None;
513 }
514
515 Some(stream.notify.subscribe())
516 }
517
518 fn cleanup_expired_streams(&self) -> usize {
519 let mut streams = self.streams.write().expect("streams lock poisoned");
520 let mut expired = Vec::new();
521
522 for (name, stream_arc) in streams.iter() {
523 let stream = stream_arc.read().expect("stream lock poisoned");
524 if super::is_stream_expired(&stream.config) {
525 expired.push((name.clone(), stream.total_bytes));
526 }
527 }
528
529 for (name, bytes) in &expired {
530 streams.remove(name);
531 self.saturating_sub_total_bytes(*bytes);
532 }
533
534 expired.len()
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use std::sync::Arc;
542 use std::thread;
543
544 fn test_storage() -> InMemoryStorage {
545 InMemoryStorage::new(1024 * 1024, 100 * 1024)
546 }
547
548 fn producer(id: &str, epoch: u64, seq: u64) -> ProducerHeaders {
549 ProducerHeaders {
550 id: id.to_string(),
551 epoch,
552 seq,
553 }
554 }
555
556 #[test]
557 fn test_concurrent_producer_appends() {
558 let storage = Arc::new(test_storage());
559 let config = StreamConfig::new("text/plain".to_string());
560 storage.create_stream("test", config).unwrap();
561
562 let num_producers = 4;
563 let seqs_per_producer = 50;
564
565 let handles: Vec<_> = (0..num_producers)
566 .map(|p| {
567 let storage = Arc::clone(&storage);
568 thread::spawn(move || {
569 let prod_id = format!("p{p}");
570 for seq in 0..seqs_per_producer {
571 let result = storage.append_with_producer(
572 "test",
573 vec![Bytes::from(format!("{prod_id}-{seq}"))],
574 "text/plain",
575 &producer(&prod_id, 0, seq),
576 false,
577 None,
578 );
579 assert!(
580 result.is_ok(),
581 "Producer {prod_id} seq {seq} failed: {result:?}"
582 );
583 }
584 })
585 })
586 .collect();
587
588 for handle in handles {
589 handle.join().expect("thread panicked");
590 }
591
592 let metadata = storage.head("test").unwrap();
593 assert_eq!(metadata.message_count, num_producers * seqs_per_producer);
594 }
595}