1use super::{
2 CreateStreamResult, Message, NOTIFY_CHANNEL_CAPACITY, ProducerAppendResult, ProducerCheck,
3 ProducerState, ReadResult, Storage, StreamConfig, StreamMetadata,
4};
5use crate::protocol::error::{Error, Result};
6use crate::protocol::offset::Offset;
7use crate::protocol::producer::ProducerHeaders;
8use bytes::Bytes;
9use chrono::Utc;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, RwLock};
13use tokio::sync::broadcast;
14
15const INITIAL_MESSAGES_CAPACITY: usize = 256;
16const INITIAL_PRODUCERS_CAPACITY: usize = 8;
17
18struct StreamEntry {
20 config: StreamConfig,
21 messages: Vec<Message>,
22 closed: bool,
23 next_read_seq: u64,
24 next_byte_offset: u64,
25 total_bytes: u64,
26 created_at: chrono::DateTime<Utc>,
27 producers: HashMap<String, ProducerState>,
29 notify: broadcast::Sender<()>,
31 last_seq: Option<String>,
33}
34
35impl StreamEntry {
36 fn new(config: StreamConfig) -> Self {
37 let (notify, _) = broadcast::channel(NOTIFY_CHANNEL_CAPACITY);
40 Self {
41 config,
42 messages: Vec::with_capacity(INITIAL_MESSAGES_CAPACITY),
43 closed: false,
44 next_read_seq: 0,
45 next_byte_offset: 0,
46 total_bytes: 0,
47 created_at: Utc::now(),
48 producers: HashMap::with_capacity(INITIAL_PRODUCERS_CAPACITY),
49 notify,
50 last_seq: None,
51 }
52 }
53}
54
55pub struct InMemoryStorage {
68 streams: RwLock<HashMap<String, Arc<RwLock<StreamEntry>>>>,
69 total_bytes: AtomicU64,
70 max_total_bytes: u64,
71 max_stream_bytes: u64,
72}
73
74impl InMemoryStorage {
75 #[must_use]
77 pub fn new(max_total_bytes: u64, max_stream_bytes: u64) -> Self {
78 Self {
79 streams: RwLock::new(HashMap::new()),
80 total_bytes: AtomicU64::new(0),
81 max_total_bytes,
82 max_stream_bytes,
83 }
84 }
85
86 #[must_use]
92 pub fn total_bytes(&self) -> u64 {
93 self.total_bytes.load(Ordering::Acquire)
94 }
95
96 fn saturating_sub_total_bytes(&self, bytes: u64) {
97 self.total_bytes
98 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
99 Some(current.saturating_sub(bytes))
100 })
101 .ok();
102 }
103
104 fn get_stream(&self, name: &str) -> Option<Arc<RwLock<StreamEntry>>> {
105 let streams = self.streams.read().expect("streams lock poisoned");
106 streams.get(name).map(Arc::clone)
107 }
108
109 fn commit_messages(&self, stream: &mut StreamEntry, messages: Vec<Bytes>) -> Result<()> {
114 if messages.is_empty() {
115 return Ok(());
116 }
117
118 let mut total_batch_bytes = 0u64;
119 let mut message_sizes = Vec::with_capacity(messages.len());
120 for data in &messages {
121 let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
122 message_sizes.push(byte_len);
123 total_batch_bytes += byte_len;
124 }
125
126 if self
128 .total_bytes
129 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
130 current
131 .checked_add(total_batch_bytes)
132 .filter(|next| *next <= self.max_total_bytes)
133 })
134 .is_err()
135 {
136 return Err(Error::MemoryLimitExceeded);
137 }
138 if stream.total_bytes + total_batch_bytes > self.max_stream_bytes {
139 self.saturating_sub_total_bytes(total_batch_bytes);
140 return Err(Error::StreamSizeLimitExceeded);
141 }
142
143 for (data, byte_len) in messages.into_iter().zip(message_sizes) {
144 let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
145 stream.next_read_seq += 1;
146 stream.next_byte_offset += byte_len;
147 stream.total_bytes += byte_len;
148 let message = Message::new(offset, data);
149 stream.messages.push(message);
150 }
151
152 let _ = stream.notify.send(());
155
156 Ok(())
157 }
158}
159
160impl Storage for InMemoryStorage {
161 fn create_stream(&self, name: &str, config: StreamConfig) -> Result<CreateStreamResult> {
162 let mut streams = self.streams.write().expect("streams lock poisoned");
163
164 if let Some(stream_arc) = streams.get(name) {
165 let stream = stream_arc.read().expect("stream lock poisoned");
166
167 if super::is_stream_expired(&stream.config) {
168 let stream_bytes = stream.total_bytes;
169 drop(stream);
170 streams.remove(name);
171
172 self.total_bytes
173 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
174 Some(current.saturating_sub(stream_bytes))
175 })
176 .ok();
177 } else {
178 if stream.config == config {
179 return Ok(CreateStreamResult::AlreadyExists);
180 }
181 return Err(Error::ConfigMismatch);
182 }
183 }
184
185 let entry = StreamEntry::new(config);
186 streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
187
188 Ok(CreateStreamResult::Created)
189 }
190
191 fn append(&self, name: &str, data: Bytes, content_type: &str) -> Result<Offset> {
192 let stream_arc = self
193 .get_stream(name)
194 .ok_or_else(|| Error::NotFound(name.to_string()))?;
195
196 let mut stream = stream_arc.write().expect("stream lock poisoned");
197
198 if super::is_stream_expired(&stream.config) {
199 return Err(Error::StreamExpired);
200 }
201
202 if stream.closed {
203 return Err(Error::StreamClosed);
204 }
205
206 super::validate_content_type(&stream.config.content_type, content_type)?;
207
208 let byte_len = u64::try_from(data.len()).unwrap_or(u64::MAX);
209
210 if self
211 .total_bytes
212 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
213 current
214 .checked_add(byte_len)
215 .filter(|next| *next <= self.max_total_bytes)
216 })
217 .is_err()
218 {
219 return Err(Error::MemoryLimitExceeded);
220 }
221
222 if stream.total_bytes + byte_len > self.max_stream_bytes {
223 self.saturating_sub_total_bytes(byte_len);
224 return Err(Error::StreamSizeLimitExceeded);
225 }
226
227 let offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
228
229 stream.next_read_seq += 1;
230 stream.next_byte_offset += byte_len;
231 stream.total_bytes += byte_len;
232
233 let message = Message::new(offset.clone(), data);
234 stream.messages.push(message);
235
236 Ok(offset)
237 }
238
239 fn batch_append(
240 &self,
241 name: &str,
242 messages: Vec<Bytes>,
243 content_type: &str,
244 seq: Option<&str>,
245 ) -> Result<Offset> {
246 if messages.is_empty() {
247 return Err(Error::InvalidHeader {
248 header: "Content-Length".to_string(),
249 reason: "batch cannot be empty".to_string(),
250 });
251 }
252
253 let stream_arc = self
254 .get_stream(name)
255 .ok_or_else(|| Error::NotFound(name.to_string()))?;
256
257 let mut stream = stream_arc.write().expect("stream lock poisoned");
258
259 if super::is_stream_expired(&stream.config) {
260 return Err(Error::StreamExpired);
261 }
262
263 if stream.closed {
264 return Err(Error::StreamClosed);
265 }
266
267 super::validate_content_type(&stream.config.content_type, content_type)?;
268
269 let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
270
271 self.commit_messages(&mut stream, messages)?;
272 if let Some(new_seq) = pending_seq {
273 stream.last_seq = Some(new_seq);
274 }
275
276 Ok(Offset::new(stream.next_read_seq, stream.next_byte_offset))
277 }
278
279 fn read(&self, name: &str, from_offset: &Offset) -> Result<ReadResult> {
280 let stream_arc = self
281 .get_stream(name)
282 .ok_or_else(|| Error::NotFound(name.to_string()))?;
283
284 let stream = stream_arc.read().expect("stream lock poisoned");
285
286 if super::is_stream_expired(&stream.config) {
287 return Err(Error::StreamExpired);
288 }
289
290 if from_offset.is_now() {
291 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
292 return Ok(ReadResult {
293 messages: Vec::new(),
294 next_offset,
295 at_tail: true,
296 closed: stream.closed,
297 });
298 }
299
300 let start_idx = if from_offset.is_start() {
301 0
302 } else {
303 match stream
304 .messages
305 .binary_search_by(|m| m.offset.cmp(from_offset))
306 {
307 Ok(idx) | Err(idx) => idx,
308 }
309 };
310
311 let messages: Vec<Bytes> = stream.messages[start_idx..]
312 .iter()
313 .map(|m| m.data.clone())
314 .collect();
315
316 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
317
318 let at_tail = start_idx + messages.len() >= stream.messages.len();
319
320 Ok(ReadResult {
321 messages,
322 next_offset,
323 at_tail,
324 closed: stream.closed,
325 })
326 }
327
328 fn delete(&self, name: &str) -> Result<()> {
329 let mut streams = self.streams.write().expect("streams lock poisoned");
330
331 if let Some(stream_arc) = streams.remove(name) {
332 let stream = stream_arc.read().expect("stream lock poisoned");
333 self.saturating_sub_total_bytes(stream.total_bytes);
334 Ok(())
335 } else {
336 Err(Error::NotFound(name.to_string()))
337 }
338 }
339
340 fn head(&self, name: &str) -> Result<StreamMetadata> {
341 let stream_arc = self
342 .get_stream(name)
343 .ok_or_else(|| Error::NotFound(name.to_string()))?;
344
345 let stream = stream_arc.read().expect("stream lock poisoned");
346
347 if super::is_stream_expired(&stream.config) {
348 return Err(Error::StreamExpired);
349 }
350
351 Ok(StreamMetadata {
352 config: stream.config.clone(),
353 next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
354 closed: stream.closed,
355 total_bytes: stream.total_bytes,
356 message_count: u64::try_from(stream.messages.len()).unwrap_or(u64::MAX),
357 created_at: stream.created_at,
358 })
359 }
360
361 fn close_stream(&self, name: &str) -> Result<()> {
362 let stream_arc = self
363 .get_stream(name)
364 .ok_or_else(|| Error::NotFound(name.to_string()))?;
365
366 let mut stream = stream_arc.write().expect("stream lock poisoned");
367
368 if super::is_stream_expired(&stream.config) {
369 return Err(Error::StreamExpired);
370 }
371
372 stream.closed = true;
373
374 let _ = stream.notify.send(());
375
376 Ok(())
377 }
378
379 fn append_with_producer(
380 &self,
381 name: &str,
382 messages: Vec<Bytes>,
383 content_type: &str,
384 producer: &ProducerHeaders,
385 should_close: bool,
386 seq: Option<&str>,
387 ) -> Result<ProducerAppendResult> {
388 let stream_arc = self
389 .get_stream(name)
390 .ok_or_else(|| Error::NotFound(name.to_string()))?;
391
392 let mut stream = stream_arc.write().expect("stream lock poisoned");
393
394 if super::is_stream_expired(&stream.config) {
395 return Err(Error::StreamExpired);
396 }
397
398 super::cleanup_stale_producers(&mut stream.producers);
399
400 if !messages.is_empty() {
401 super::validate_content_type(&stream.config.content_type, content_type)?;
402 }
403
404 let now = Utc::now();
405
406 match super::check_producer(stream.producers.get(&producer.id), producer, stream.closed)? {
407 ProducerCheck::Accept => {}
408 ProducerCheck::Duplicate { epoch, seq } => {
409 return Ok(ProducerAppendResult::Duplicate {
410 epoch,
411 seq,
412 next_offset: Offset::new(stream.next_read_seq, stream.next_byte_offset),
413 closed: stream.closed,
414 });
415 }
416 }
417
418 let pending_seq = super::validate_seq(stream.last_seq.as_deref(), seq)?;
419
420 self.commit_messages(&mut stream, messages)?;
421 if let Some(new_seq) = pending_seq {
422 stream.last_seq = Some(new_seq);
423 }
424
425 if should_close {
426 stream.closed = true;
427 }
428
429 stream.producers.insert(
430 producer.id.clone(),
431 ProducerState {
432 epoch: producer.epoch,
433 last_seq: producer.seq,
434 updated_at: now,
435 },
436 );
437
438 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
439 let closed = stream.closed;
440
441 Ok(ProducerAppendResult::Accepted {
442 epoch: producer.epoch,
443 seq: producer.seq,
444 next_offset,
445 closed,
446 })
447 }
448
449 fn create_stream_with_data(
450 &self,
451 name: &str,
452 config: StreamConfig,
453 messages: Vec<Bytes>,
454 should_close: bool,
455 ) -> Result<super::CreateWithDataResult> {
456 let mut streams = self.streams.write().expect("streams lock poisoned");
457
458 if let Some(stream_arc) = streams.get(name) {
459 let stream = stream_arc.read().expect("stream lock poisoned");
460
461 if super::is_stream_expired(&stream.config) {
462 let stream_bytes = stream.total_bytes;
463 drop(stream);
464 streams.remove(name);
465 self.saturating_sub_total_bytes(stream_bytes);
466 } else if stream.config == config {
467 let next_offset = Offset::new(stream.next_read_seq, stream.next_byte_offset);
468 let closed = stream.closed;
469 return Ok(super::CreateWithDataResult {
470 status: CreateStreamResult::AlreadyExists,
471 next_offset,
472 closed,
473 });
474 } else {
475 return Err(Error::ConfigMismatch);
476 }
477 }
478
479 let mut entry = StreamEntry::new(config);
480
481 if !messages.is_empty() {
482 self.commit_messages(&mut entry, messages)?;
483 }
484
485 if should_close {
486 entry.closed = true;
487 }
488
489 let next_offset = Offset::new(entry.next_read_seq, entry.next_byte_offset);
490 let closed = entry.closed;
491
492 streams.insert(name.to_string(), Arc::new(RwLock::new(entry)));
493
494 Ok(super::CreateWithDataResult {
495 status: CreateStreamResult::Created,
496 next_offset,
497 closed,
498 })
499 }
500
501 fn exists(&self, name: &str) -> bool {
502 let streams = self.streams.read().expect("streams lock poisoned");
503 if let Some(stream_arc) = streams.get(name) {
504 let stream = stream_arc.read().expect("stream lock poisoned");
505 !super::is_stream_expired(&stream.config)
506 } else {
507 false
508 }
509 }
510
511 fn subscribe(&self, name: &str) -> Option<broadcast::Receiver<()>> {
512 let stream_arc = self.get_stream(name)?;
513 let stream = stream_arc.read().expect("stream lock poisoned");
514
515 if super::is_stream_expired(&stream.config) {
516 return None;
517 }
518
519 Some(stream.notify.subscribe())
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use std::sync::Arc;
527 use std::thread;
528
529 fn test_storage() -> InMemoryStorage {
530 InMemoryStorage::new(1024 * 1024, 100 * 1024)
531 }
532
533 fn producer(id: &str, epoch: u64, seq: u64) -> ProducerHeaders {
534 ProducerHeaders {
535 id: id.to_string(),
536 epoch,
537 seq,
538 }
539 }
540
541 #[test]
542 fn test_concurrent_producer_appends() {
543 let storage = Arc::new(test_storage());
544 let config = StreamConfig::new("text/plain".to_string());
545 storage.create_stream("test", config).unwrap();
546
547 let num_producers = 4;
548 let seqs_per_producer = 50;
549
550 let handles: Vec<_> = (0..num_producers)
551 .map(|p| {
552 let storage = Arc::clone(&storage);
553 thread::spawn(move || {
554 let prod_id = format!("p{p}");
555 for seq in 0..seqs_per_producer {
556 let result = storage.append_with_producer(
557 "test",
558 vec![Bytes::from(format!("{prod_id}-{seq}"))],
559 "text/plain",
560 &producer(&prod_id, 0, seq),
561 false,
562 None,
563 );
564 assert!(
565 result.is_ok(),
566 "Producer {prod_id} seq {seq} failed: {result:?}"
567 );
568 }
569 })
570 })
571 .collect();
572
573 for handle in handles {
574 handle.join().expect("thread panicked");
575 }
576
577 let metadata = storage.head("test").unwrap();
578 assert_eq!(metadata.message_count, num_producers * seqs_per_producer);
579 }
580}