1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::task::{Context as TaskContext, Poll};
9
10use bytes::Bytes;
11use connectrpc::{Chain, ConnectError, ConnectRpcService, Context, Limits};
12use exoware_proto::common::KvEntry;
13use exoware_proto::compact::{
14 PruneResponse, Service as CompactApi, ServiceServer as CompactServiceServer,
15};
16use exoware_proto::google::rpc::{ErrorInfo, RetryInfo};
17use exoware_proto::ingest::{
18 PutResponse as ProtoPutResponse, Service as IngestApi, ServiceServer as IngestServiceServer,
19};
20use exoware_proto::query::{
21 Detail, GetManyEntry, GetManyFrame, GetResponse, RangeFrame, ReduceResponse,
22 Service as QueryApi, ServiceServer as QueryServiceServer,
23};
24use exoware_proto::store::stream::v1::{
25 GetRequestView, GetResponse as StreamGetResponse, Service as StreamApi,
26 ServiceServer as StreamServiceServer, SubscribeRequestView, SubscribeResponse,
27};
28use exoware_proto::stream_filter::{BytesFilter, StreamFilter};
29use exoware_proto::{
30 connect_compression_registry, encode_query_detail_header_value,
31 parse_range_traversal_direction, to_domain_reduce_request_from_view,
32 to_proto_optional_reduced_value, to_proto_reduced_value, with_error_info_detail,
33 with_query_detail, with_retry_info_detail, RangeTraversalDirection,
34 QUERY_DETAIL_RESPONSE_HEADER,
35};
36use exoware_sdk as exoware_proto;
37use exoware_sdk::keys::Key;
38use exoware_sdk::match_key::MatchKey;
39use exoware_sdk::store::common::v1::bytes_filter::KindView as ProtoBytesFilterKindView;
40use futures::{stream as stream_util, Stream};
41use http::header::HeaderValue;
42use http::HeaderName;
43use tokio::sync::futures::OwnedNotified;
44use tokio::sync::Notify;
45
46use crate::reduce::reduce_over_rows;
47use crate::stream::StreamHub;
48use crate::validate;
49use crate::StoreEngine;
50
51const MAX_CONNECTRPC_BODY_BYTES: usize = 256 * 1024 * 1024;
52
53fn read_bytes_for_kv_rows<K: AsRef<[u8]>, V: AsRef<[u8]>>(entries: &[(K, V)]) -> u64 {
55 entries
56 .iter()
57 .map(|(k, v)| k.as_ref().len() as u64 + v.as_ref().len() as u64)
58 .sum()
59}
60
61fn read_stats_read_bytes<K: AsRef<[u8]>, V: AsRef<[u8]>>(
62 entries: &[(K, V)],
63) -> HashMap<String, u64> {
64 [("read_bytes".to_string(), read_bytes_for_kv_rows(entries))]
65 .into_iter()
66 .collect()
67}
68
69#[derive(Clone)]
70pub struct AppState {
71 pub engine: Arc<dyn StoreEngine>,
72 pub ready: Arc<AtomicBool>,
75 pub stream: Arc<StreamHub>,
79}
80
81impl AppState {
82 pub fn new(engine: Arc<dyn StoreEngine>) -> Self {
83 let current_sequence = engine.current_sequence();
84 Self {
85 engine,
86 ready: Arc::new(AtomicBool::new(true)),
87 stream: Arc::new(StreamHub::new(current_sequence)),
88 }
89 }
90}
91
92#[derive(Clone)]
93pub struct IngestConnect {
94 state: AppState,
95}
96
97impl IngestConnect {
98 pub fn new(state: AppState) -> Self {
99 Self { state }
100 }
101}
102
103impl IngestApi for IngestConnect {
104 async fn put(
105 &self,
106 _ctx: Context,
107 request: buffa::view::OwnedView<exoware_proto::store::ingest::v1::PutRequestView<'static>>,
108 ) -> Result<(ProtoPutResponse, Context), ConnectError> {
109 if !self.state.ready.load(Ordering::SeqCst) {
110 return Err(with_error_info_detail(
111 ConnectError::unavailable("ingest is not ready"),
112 ErrorInfo {
113 reason: "WORKER_NOT_READY".to_string(),
114 domain: "store.ingest".to_string(),
115 ..Default::default()
116 },
117 ));
118 }
119
120 validate::validate_put_request(&request)?;
121
122 let wire = request.bytes();
123 let mut batch = Vec::new();
124 for kv in request.kvs.iter() {
125 let key: Key = wire.slice_ref(kv.key);
126 let value = wire.slice_ref(kv.value);
127 batch.push((key, value));
128 }
129
130 let seq = self
131 .state
132 .engine
133 .put_batch(&batch)
134 .map_err(ConnectError::internal)?;
135
136 self.state.stream.publish(seq);
140
141 Ok((
142 ProtoPutResponse {
143 sequence_number: seq,
144 ..Default::default()
145 },
146 Context::default(),
147 ))
148 }
149}
150
151#[derive(Clone)]
152pub struct QueryConnect {
153 state: AppState,
154}
155
156impl QueryConnect {
157 pub fn new(state: AppState) -> Self {
158 Self { state }
159 }
160
161 fn current_sequence_number(&self) -> u64 {
162 self.state.engine.current_sequence()
163 }
164
165 fn error_detail(&self) -> Detail {
166 Detail {
167 sequence_number: self.current_sequence_number(),
168 read_stats: HashMap::new(),
169 ..Default::default()
170 }
171 }
172
173 fn consistency_not_ready_error(&self, required: u64, current: u64) -> ConnectError {
174 let err = with_retry_info_detail(
175 ConnectError::aborted("minimum consistency token is not yet visible"),
176 RetryInfo {
177 retry_delay: Some(buffa_types::google::protobuf::Duration::from(
178 std::time::Duration::from_secs(1),
179 ))
180 .into(),
181 ..Default::default()
182 },
183 );
184 with_query_detail(
185 with_error_info_detail(
186 err,
187 ErrorInfo {
188 reason: "CONSISTENCY_NOT_READY".to_string(),
189 domain: "store.query".to_string(),
190 metadata: [
191 ("required_sequence_number".to_string(), required.to_string()),
192 ("current_sequence_number".to_string(), current.to_string()),
193 ]
194 .into_iter()
195 .collect(),
196 ..Default::default()
197 },
198 ),
199 self.error_detail(),
200 )
201 }
202
203 fn ensure_min_sequence_number(&self, required: Option<u64>) -> Result<u64, ConnectError> {
204 let current = self.current_sequence_number();
205 if let Some(required) = required {
206 if current < required {
207 return Err(self.consistency_not_ready_error(required, current));
208 }
209 }
210 Ok(current)
211 }
212
213 fn apply_query_detail_header(ctx: &mut Context, detail: &Detail) {
214 if let Ok(value) = HeaderValue::from_str(&encode_query_detail_header_value(detail)) {
215 if let Ok(name) = HeaderName::from_bytes(QUERY_DETAIL_RESPONSE_HEADER.as_bytes()) {
216 ctx.response_headers.insert(name, value);
217 }
218 }
219 }
220
221 fn apply_query_detail_trailer(ctx: &mut Context, detail: &Detail) {
222 if let Ok(value) = HeaderValue::from_str(&encode_query_detail_header_value(detail)) {
223 if let Ok(name) = HeaderName::from_bytes(QUERY_DETAIL_RESPONSE_HEADER.as_bytes()) {
224 ctx.set_trailer(name, value);
225 }
226 }
227 }
228}
229
230impl QueryApi for QueryConnect {
231 async fn get(
232 &self,
233 mut ctx: Context,
234 request: buffa::view::OwnedView<exoware_proto::store::query::v1::GetRequestView<'static>>,
235 ) -> Result<(GetResponse, Context), ConnectError> {
236 validate::validate_get_request(&request)?;
237 let token = self.ensure_min_sequence_number(request.min_sequence_number)?;
238 let wire = request.bytes();
239 let key: Key = wire.slice_ref(request.key);
240 let value = self
241 .state
242 .engine
243 .get(key.as_ref())
244 .map_err(ConnectError::internal)?;
245 let read_bytes =
246 key.as_ref().len() as u64 + value.as_ref().map_or(0u64, |v| v.len() as u64);
247 let detail = Detail {
248 sequence_number: token,
249 read_stats: [("read_bytes".to_string(), read_bytes)]
250 .into_iter()
251 .collect(),
252 ..Default::default()
253 };
254 Self::apply_query_detail_header(&mut ctx, &detail);
255 Ok((
256 GetResponse {
257 value,
258 ..Default::default()
259 },
260 ctx,
261 ))
262 }
263
264 async fn get_many(
265 &self,
266 mut ctx: Context,
267 request: buffa::view::OwnedView<
268 exoware_proto::store::query::v1::GetManyRequestView<'static>,
269 >,
270 ) -> Result<
271 (
272 Pin<Box<dyn Stream<Item = Result<GetManyFrame, ConnectError>> + Send>>,
273 Context,
274 ),
275 ConnectError,
276 > {
277 validate::validate_get_many_request(&request)?;
278 let sequence_number = self.ensure_min_sequence_number(request.min_sequence_number)?;
279
280 let key_refs: Vec<&[u8]> = request.keys.iter().copied().collect();
281 let entries = self
282 .state
283 .engine
284 .get_many(&key_refs)
285 .map_err(ConnectError::internal)?;
286 let read_bytes: u64 = entries
287 .iter()
288 .map(|(k, v)| k.len() as u64 + v.as_ref().map_or(0u64, |v| v.len() as u64))
289 .sum();
290 let detail = Detail {
291 sequence_number,
292 read_stats: [("read_bytes".to_string(), read_bytes)]
293 .into_iter()
294 .collect(),
295 ..Default::default()
296 };
297 Self::apply_query_detail_trailer(&mut ctx, &detail);
298
299 let batch_size = request.batch_size as usize;
300 let mut frames = Vec::new();
301 let mut chunk = Vec::new();
302 for (key, value) in entries {
303 chunk.push(GetManyEntry {
304 key,
305 value,
306 ..Default::default()
307 });
308 if chunk.len() >= batch_size {
309 frames.push(Ok(GetManyFrame {
310 results: std::mem::take(&mut chunk),
311 ..Default::default()
312 }));
313 }
314 }
315 if !chunk.is_empty() {
316 frames.push(Ok(GetManyFrame {
317 results: chunk,
318 ..Default::default()
319 }));
320 }
321
322 Ok((Box::pin(stream_util::iter(frames)), ctx))
323 }
324
325 async fn range(
326 &self,
327 mut ctx: Context,
328 request: buffa::view::OwnedView<exoware_proto::store::query::v1::RangeRequestView<'static>>,
329 ) -> Result<
330 (
331 Pin<Box<dyn Stream<Item = Result<RangeFrame, ConnectError>> + Send>>,
332 Context,
333 ),
334 ConnectError,
335 > {
336 validate::validate_range_request(&request)?;
337 let sequence_number = self.ensure_min_sequence_number(request.min_sequence_number)?;
338 let wire = request.bytes();
339 let start_key: Key = wire.slice_ref(request.start);
340 let end_key: Key = wire.slice_ref(request.end);
341 let limit = request.limit.map(|v| v as usize).unwrap_or(usize::MAX);
342 let batch_size = request.batch_size as usize;
343 let forward = match parse_range_traversal_direction(request.mode) {
344 Ok(RangeTraversalDirection::Forward) => true,
345 Ok(RangeTraversalDirection::Reverse) => false,
346 Err(e) => return Err(ConnectError::internal(format!("traversal mode: {e:?}"))),
347 };
348
349 let entries = self
350 .state
351 .engine
352 .range_scan(start_key.as_ref(), end_key.as_ref(), limit, forward)
353 .map_err(ConnectError::internal)?;
354 let detail = Detail {
355 sequence_number,
356 read_stats: read_stats_read_bytes(&entries),
357 ..Default::default()
358 };
359 Self::apply_query_detail_trailer(&mut ctx, &detail);
360
361 let mut frames = Vec::new();
362 let mut chunk = Vec::new();
363 for (key, value) in entries {
364 chunk.push((key, value));
365 if chunk.len() >= batch_size {
366 frames.push(Ok(RangeFrame {
367 results: chunk
368 .drain(..)
369 .map(|(k, v)| KvEntry {
370 key: k.into(),
371 value: v.into(),
372 ..Default::default()
373 })
374 .collect(),
375 ..Default::default()
376 }));
377 }
378 }
379 if !chunk.is_empty() {
380 frames.push(Ok(RangeFrame {
381 results: chunk
382 .into_iter()
383 .map(|(k, v)| KvEntry {
384 key: k.into(),
385 value: v.into(),
386 ..Default::default()
387 })
388 .collect(),
389 ..Default::default()
390 }));
391 }
392
393 Ok((Box::pin(stream_util::iter(frames)), ctx))
394 }
395
396 async fn reduce(
397 &self,
398 mut ctx: Context,
399 request: buffa::view::OwnedView<
400 exoware_proto::store::query::v1::ReduceRequestView<'static>,
401 >,
402 ) -> Result<(ReduceResponse, Context), ConnectError> {
403 validate::validate_reduce_request(&request)?; let token = self.ensure_min_sequence_number(request.min_sequence_number)?;
405 let wire = request.bytes();
406 let start_key: Key = wire.slice_ref(request.start);
407 let end_key: Key = wire.slice_ref(request.end);
408 let domain = to_domain_reduce_request_from_view(&request.params)
409 .map_err(validate::reduce_params_error)?;
410
411 let rows = self
412 .state
413 .engine
414 .range_scan(start_key.as_ref(), end_key.as_ref(), usize::MAX, true)
415 .map_err(ConnectError::internal)?;
416
417 let response = reduce_over_rows(&rows, &domain)
418 .map_err(|e: crate::RangeError| ConnectError::internal(e.to_string()))?;
419 let detail = Detail {
420 sequence_number: token,
421 read_stats: read_stats_read_bytes(&rows),
422 ..Default::default()
423 };
424 Self::apply_query_detail_header(&mut ctx, &detail);
425
426 Ok((
427 ReduceResponse {
428 results: response
429 .results
430 .into_iter()
431 .map(|result| exoware_proto::query::RangeReduceResult {
432 value: result.value.map(to_proto_reduced_value).into(),
433 ..Default::default()
434 })
435 .collect(),
436 groups: response
437 .groups
438 .into_iter()
439 .map(|group| {
440 let group_values_present =
441 group.group_values.iter().map(Option::is_some).collect();
442 exoware_proto::query::RangeReduceGroup {
443 group_values: group
444 .group_values
445 .into_iter()
446 .map(to_proto_optional_reduced_value)
447 .collect(),
448 group_values_present,
449 results: group
450 .results
451 .into_iter()
452 .map(|result| exoware_proto::query::RangeReduceResult {
453 value: result.value.map(to_proto_reduced_value).into(),
454 ..Default::default()
455 })
456 .collect(),
457 ..Default::default()
458 }
459 })
460 .collect(),
461 ..Default::default()
462 },
463 ctx,
464 ))
465 }
466}
467
468#[derive(Clone)]
469pub struct CompactConnect {
470 state: AppState,
471}
472
473impl CompactConnect {
474 pub fn new(state: AppState) -> Self {
475 Self { state }
476 }
477}
478
479impl CompactApi for CompactConnect {
480 async fn prune(
481 &self,
482 ctx: Context,
483 request: buffa::view::OwnedView<
484 exoware_proto::store::compact::v1::PruneRequestView<'static>,
485 >,
486 ) -> Result<(PruneResponse, Context), ConnectError> {
487 validate::validate_prune_request(&request)?;
488 let document = exoware_proto::prune_policy_document_from_prune_request_view(&request)
489 .map_err(|e| ConnectError::invalid_argument(e.to_string()))?;
490 crate::prune::execute_prune(&self.state.engine, &document)
491 .map_err(|e| ConnectError::internal(e.to_string()))?;
492 Ok((PruneResponse::default(), ctx))
493 }
494}
495
496#[derive(Clone)]
497pub struct StreamConnect {
498 state: AppState,
499}
500
501impl StreamConnect {
502 pub fn new(state: AppState) -> Self {
503 Self { state }
504 }
505
506 fn batch_evicted_connect_error(oldest_retained: Option<u64>) -> ConnectError {
507 let mut metadata = HashMap::new();
508 if let Some(v) = oldest_retained {
509 metadata.insert(
510 crate::stream::METADATA_OLDEST_RETAINED.to_string(),
511 v.to_string(),
512 );
513 }
514 with_error_info_detail(
515 ConnectError::out_of_range("batch has been evicted from the log"),
516 ErrorInfo {
517 reason: crate::stream::REASON_BATCH_EVICTED.to_string(),
518 domain: crate::stream::STREAM_ERROR_DOMAIN.to_string(),
519 metadata,
520 ..Default::default()
521 },
522 )
523 }
524
525 fn batch_evicted_error(&self, oldest_retained: Option<u64>) -> ConnectError {
526 Self::batch_evicted_connect_error(oldest_retained)
527 }
528
529 fn batch_not_found_error(&self) -> ConnectError {
530 with_error_info_detail(
531 ConnectError::not_found("batch not found"),
532 ErrorInfo {
533 reason: crate::stream::REASON_BATCH_NOT_FOUND.to_string(),
534 domain: crate::stream::STREAM_ERROR_DOMAIN.to_string(),
535 ..Default::default()
536 },
537 )
538 }
539}
540
541fn filtered_subscribe_response(
542 seq: u64,
543 kvs: &[(Bytes, Bytes)],
544 matchers: &crate::stream::CompiledMatchers,
545) -> Option<SubscribeResponse> {
546 let entries = crate::stream::apply_filter(matchers, kvs);
547 (!entries.is_empty()).then_some(SubscribeResponse {
548 sequence_number: seq,
549 entries,
550 ..Default::default()
551 })
552}
553
554struct ReplayState {
555 next_sequence: u64,
556 bound: u64,
557 first_batch: Option<Vec<(Bytes, Bytes)>>,
558}
559
560enum ReplayProgress {
561 Frame(SubscribeResponse),
562 Advanced,
563 Done,
564}
565
566enum LiveProgress {
567 Frame(SubscribeResponse),
568 Advanced,
569 NeedWait,
570}
571
572struct SubscriptionStream {
573 state: AppState,
574 matchers: crate::stream::CompiledMatchers,
575 replay: Option<ReplayState>,
576 next_live_sequence: u64,
577 live_notify: Arc<Notify>,
578 live_wait: Option<Pin<Box<OwnedNotified>>>,
579 terminal_error: Option<ConnectError>,
580 terminated: bool,
581}
582
583impl SubscriptionStream {
584 fn new(
585 state: AppState,
586 matchers: crate::stream::CompiledMatchers,
587 replay: Option<ReplayState>,
588 next_live_sequence: u64,
589 live_notify: Arc<Notify>,
590 ) -> Self {
591 Self {
592 state,
593 matchers,
594 replay,
595 next_live_sequence,
596 live_notify,
597 live_wait: None,
598 terminal_error: None,
599 terminated: false,
600 }
601 }
602
603 fn next_replay_frame(&mut self) -> Result<ReplayProgress, ConnectError> {
604 let Some(replay) = &mut self.replay else {
605 return Ok(ReplayProgress::Done);
606 };
607 let seq = replay.next_sequence;
608 let kvs = if let Some(first_batch) = replay.first_batch.take() {
609 Some(first_batch)
610 } else {
611 self.state
612 .engine
613 .get_batch(seq)
614 .map_err(ConnectError::internal)?
615 };
616 replay.next_sequence += 1;
617 if replay.next_sequence > replay.bound {
618 self.replay = None;
619 }
620 let Some(kvs) = kvs else {
621 let oldest = self
622 .state
623 .engine
624 .oldest_retained_batch()
625 .map_err(ConnectError::internal)?;
626 return Err(StreamConnect::batch_evicted_connect_error(oldest));
627 };
628 Ok(
629 match filtered_subscribe_response(seq, &kvs, &self.matchers) {
630 Some(frame) => ReplayProgress::Frame(frame),
631 None => ReplayProgress::Advanced,
632 },
633 )
634 }
635
636 fn next_live_frame(&mut self) -> Result<LiveProgress, ConnectError> {
637 let current = self.state.stream.current_sequence();
638 if self.next_live_sequence > current {
639 return Ok(LiveProgress::NeedWait);
640 }
641 let seq = self.next_live_sequence;
642 self.next_live_sequence += 1;
643 let kvs = self
644 .state
645 .engine
646 .get_batch(seq)
647 .map_err(ConnectError::internal)?;
648 let Some(kvs) = kvs else {
649 let oldest = self
650 .state
651 .engine
652 .oldest_retained_batch()
653 .map_err(ConnectError::internal)?;
654 return Err(StreamConnect::batch_evicted_connect_error(oldest));
655 };
656 Ok(
657 match filtered_subscribe_response(seq, &kvs, &self.matchers) {
658 Some(frame) => LiveProgress::Frame(frame),
659 None => LiveProgress::Advanced,
660 },
661 )
662 }
663}
664
665impl Stream for SubscriptionStream {
666 type Item = Result<SubscribeResponse, ConnectError>;
667
668 fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
669 loop {
670 if let Some(err) = self.terminal_error.take() {
671 self.terminated = true;
672 return Poll::Ready(Some(Err(err)));
673 }
674 if self.terminated {
675 return Poll::Ready(None);
676 }
677
678 if self.replay.is_some() {
679 match self.next_replay_frame() {
680 Ok(ReplayProgress::Frame(frame)) => return Poll::Ready(Some(Ok(frame))),
681 Ok(ReplayProgress::Advanced) => continue,
682 Ok(ReplayProgress::Done) => {}
683 Err(err) => {
684 self.terminal_error = Some(err);
685 continue;
686 }
687 }
688 }
689
690 match self.next_live_frame() {
691 Ok(LiveProgress::Frame(frame)) => return Poll::Ready(Some(Ok(frame))),
692 Ok(LiveProgress::Advanced) => continue,
693 Ok(LiveProgress::NeedWait) => {
694 if self.live_wait.is_none() {
695 self.live_wait = Some(Box::pin(self.live_notify.clone().notified_owned()));
696 }
697 if self.next_live_sequence <= self.state.stream.current_sequence() {
698 self.live_wait = None;
699 continue;
700 }
701 match self
702 .live_wait
703 .as_mut()
704 .expect("wait future")
705 .as_mut()
706 .poll(cx)
707 {
708 Poll::Ready(()) => {
709 self.live_wait = None;
710 continue;
711 }
712 Poll::Pending => return Poll::Pending,
713 }
714 }
715 Err(err) => {
716 self.terminal_error = Some(err);
717 continue;
718 }
719 }
720 }
721 }
722}
723
724fn domain_filter_from_subscribe_view(
725 req: &SubscribeRequestView<'_>,
726) -> Result<StreamFilter, ConnectError> {
727 let mut match_keys = Vec::with_capacity(req.match_keys.len());
728 for mk in req.match_keys.iter() {
729 let reserved_bits = u8::try_from(mk.reserved_bits).map_err(|_| {
730 ConnectError::invalid_argument(format!(
731 "match_key.reserved_bits {} does not fit in u8",
732 mk.reserved_bits
733 ))
734 })?;
735 let prefix = u16::try_from(mk.prefix).map_err(|_| {
736 ConnectError::invalid_argument(format!(
737 "match_key.prefix {} does not fit in u16",
738 mk.prefix
739 ))
740 })?;
741 match_keys.push(MatchKey {
742 reserved_bits,
743 prefix,
744 payload_regex: exoware_sdk::kv_codec::Utf8::from(mk.payload_regex),
745 });
746 }
747 let mut value_filters = Vec::with_capacity(req.value_filters.len());
748 for vf in req.value_filters.iter() {
749 value_filters.push(match vf.kind {
750 Some(ProtoBytesFilterKindView::Exact(bytes)) => BytesFilter::Exact(bytes.to_vec()),
751 Some(ProtoBytesFilterKindView::Prefix(bytes)) => BytesFilter::Prefix(bytes.to_vec()),
752 Some(ProtoBytesFilterKindView::Regex(pattern)) => {
753 BytesFilter::Regex(pattern.to_string())
754 }
755 None => {
756 return Err(ConnectError::invalid_argument(
757 "each value_filter must set exactly one of exact, prefix, or regex",
758 ))
759 }
760 });
761 }
762 Ok(StreamFilter {
763 match_keys,
764 value_filters,
765 })
766}
767
768impl StreamApi for StreamConnect {
769 async fn subscribe(
770 &self,
771 ctx: Context,
772 request: buffa::view::OwnedView<SubscribeRequestView<'static>>,
773 ) -> Result<
774 (
775 Pin<Box<dyn Stream<Item = Result<SubscribeResponse, ConnectError>> + Send>>,
776 Context,
777 ),
778 ConnectError,
779 > {
780 let filter = domain_filter_from_subscribe_view(&request)?;
781 let since = request.since_sequence_number;
782
783 let (matchers, replay_bound, live_notify) = self.state.stream.subscribe(filter)?;
788
789 let replay = match since {
794 Some(s) if s <= replay_bound && s > 0 => {
795 let first_batch = self
796 .state
797 .engine
798 .get_batch(s)
799 .map_err(ConnectError::internal)?;
800 let Some(first_batch) = first_batch else {
801 let oldest = self
802 .state
803 .engine
804 .oldest_retained_batch()
805 .map_err(ConnectError::internal)?;
806 return Err(self.batch_evicted_error(oldest));
807 };
808 Some(ReplayState {
809 next_sequence: s,
810 bound: replay_bound,
811 first_batch: Some(first_batch),
812 })
813 }
814 _ => None,
815 };
816 let next_live_sequence = replay_bound.saturating_add(1);
817
818 Ok((
819 Box::pin(SubscriptionStream::new(
820 self.state.clone(),
821 matchers,
822 replay,
823 next_live_sequence,
824 live_notify,
825 )),
826 ctx,
827 ))
828 }
829
830 async fn get(
831 &self,
832 ctx: Context,
833 request: buffa::view::OwnedView<GetRequestView<'static>>,
834 ) -> Result<(StreamGetResponse, Context), ConnectError> {
835 let seq = request.sequence_number;
836 match self
837 .state
838 .engine
839 .get_batch(seq)
840 .map_err(ConnectError::internal)?
841 {
842 Some(kvs) => {
843 let entries = kvs
844 .into_iter()
845 .map(|(k, v)| KvEntry {
846 key: k.to_vec(),
847 value: v.to_vec(),
848 ..Default::default()
849 })
850 .collect();
851 Ok((
852 StreamGetResponse {
853 sequence_number: seq,
854 entries,
855 ..Default::default()
856 },
857 ctx,
858 ))
859 }
860 None => {
861 let current = self.state.engine.current_sequence();
862 if seq > current {
864 Err(self.batch_not_found_error())
865 } else {
866 let oldest = self
867 .state
868 .engine
869 .oldest_retained_batch()
870 .map_err(ConnectError::internal)?;
871 Err(self.batch_evicted_error(oldest))
872 }
873 }
874 }
875 }
876}
877
878fn connect_limits() -> Limits {
879 Limits::default()
880 .max_request_body_size(MAX_CONNECTRPC_BODY_BYTES)
881 .max_message_size(MAX_CONNECTRPC_BODY_BYTES)
882}
883
884pub fn connect_stack(
885 state: AppState,
886) -> ConnectRpcService<
887 Chain<
888 IngestServiceServer<IngestConnect>,
889 Chain<
890 QueryServiceServer<QueryConnect>,
891 Chain<CompactServiceServer<CompactConnect>, StreamServiceServer<StreamConnect>>,
892 >,
893 >,
894> {
895 ConnectRpcService::new(Chain(
896 IngestServiceServer::new(IngestConnect::new(state.clone())),
897 Chain(
898 QueryServiceServer::new(QueryConnect::new(state.clone())),
899 Chain(
900 CompactServiceServer::new(CompactConnect::new(state.clone())),
901 StreamServiceServer::new(StreamConnect::new(state)),
902 ),
903 ),
904 ))
905 .with_limits(connect_limits())
906 .with_compression(connect_compression_registry())
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912 use std::collections::BTreeMap;
913 use std::sync::Mutex;
914 use std::time::Duration;
915
916 use buffa::Message;
917 use exoware_proto::store::common::v1::MatchKey as ProtoMatchKey;
918 use exoware_proto::store::stream::v1::SubscribeRequest;
919 use exoware_sdk::decode_connect_error;
920 use exoware_sdk::keys::KeyCodec;
921 use futures::StreamExt;
922
923 const TEST_RESERVED_BITS: u8 = 4;
924 const TEST_PREFIX: u16 = 1;
925
926 #[derive(Clone)]
927 struct PublishDuringReplay {
928 hub: Arc<StreamHub>,
929 sequence_offset: u64,
930 kvs: Vec<(Bytes, Bytes)>,
931 }
932
933 #[derive(Default)]
934 struct FakeEngineState {
935 current_sequence: u64,
936 batches: BTreeMap<u64, Option<Vec<(Bytes, Bytes)>>>,
937 oldest_retained: Option<u64>,
938 publish_on_get_batch: Option<PublishDuringReplay>,
939 }
940
941 #[derive(Default)]
942 struct FakeEngine {
943 state: Mutex<FakeEngineState>,
944 }
945
946 impl FakeEngine {
947 fn set_current_sequence(&self, sequence_number: u64) {
948 self.state.lock().expect("lock").current_sequence = sequence_number;
949 }
950
951 fn set_batch(&self, sequence_number: u64, kvs: Option<Vec<(Bytes, Bytes)>>) {
952 self.state
953 .lock()
954 .expect("lock")
955 .batches
956 .insert(sequence_number, kvs);
957 }
958
959 fn set_oldest_retained(&self, oldest_retained: Option<u64>) {
960 self.state.lock().expect("lock").oldest_retained = oldest_retained;
961 }
962
963 fn publish_live(
964 &self,
965 hub: Arc<StreamHub>,
966 sequence_number: u64,
967 kvs: Vec<(Bytes, Bytes)>,
968 ) {
969 let mut state = self.state.lock().expect("lock");
970 state.current_sequence = state.current_sequence.max(sequence_number);
971 state.batches.insert(sequence_number, Some(kvs.clone()));
972 drop(state);
973 hub.publish(sequence_number);
974 }
975
976 fn publish_on_every_get_batch(
977 &self,
978 hub: Arc<StreamHub>,
979 sequence_offset: u64,
980 kvs: Vec<(Bytes, Bytes)>,
981 ) {
982 self.state.lock().expect("lock").publish_on_get_batch = Some(PublishDuringReplay {
983 hub,
984 sequence_offset,
985 kvs,
986 });
987 }
988 }
989
990 impl StoreEngine for FakeEngine {
991 fn put_batch(&self, kvs: &[(Bytes, Bytes)]) -> Result<u64, String> {
992 let mut state = self.state.lock().map_err(|e| e.to_string())?;
993 state.current_sequence += 1;
994 let seq = state.current_sequence;
995 state.batches.insert(seq, Some(kvs.to_vec()));
996 Ok(seq)
997 }
998
999 fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>, String> {
1000 Ok(None)
1001 }
1002
1003 fn range_scan(
1004 &self,
1005 _start: &[u8],
1006 _end: &[u8],
1007 _limit: usize,
1008 _forward: bool,
1009 ) -> Result<Vec<(Bytes, Bytes)>, String> {
1010 Ok(Vec::new())
1011 }
1012
1013 fn delete_batch(&self, _keys: &[&[u8]]) -> Result<u64, String> {
1014 let mut state = self.state.lock().map_err(|e| e.to_string())?;
1015 state.current_sequence += 1;
1016 Ok(state.current_sequence)
1017 }
1018
1019 fn current_sequence(&self) -> u64 {
1020 self.state.lock().expect("lock").current_sequence
1021 }
1022
1023 fn get_batch(&self, sequence_number: u64) -> Result<Option<Vec<(Bytes, Bytes)>>, String> {
1024 let (publish, batch) = {
1025 let mut state = self.state.lock().map_err(|e| e.to_string())?;
1026 let publish = state.publish_on_get_batch.clone();
1027 if let Some(publish) = publish.as_ref() {
1028 let live_sequence = publish.sequence_offset + sequence_number;
1029 state.current_sequence = state.current_sequence.max(live_sequence);
1030 state
1031 .batches
1032 .entry(live_sequence)
1033 .or_insert_with(|| Some(publish.kvs.clone()));
1034 }
1035 (
1036 publish,
1037 state.batches.get(&sequence_number).cloned().unwrap_or(None),
1038 )
1039 };
1040 if let Some(publish) = publish {
1041 publish
1042 .hub
1043 .publish(publish.sequence_offset + sequence_number);
1044 }
1045 Ok(batch)
1046 }
1047
1048 fn oldest_retained_batch(&self) -> Result<Option<u64>, String> {
1049 Ok(self
1050 .state
1051 .lock()
1052 .map_err(|e| e.to_string())?
1053 .oldest_retained)
1054 }
1055
1056 fn prune_batch_log(&self, _cutoff_exclusive: u64) -> Result<u64, String> {
1057 Ok(0)
1058 }
1059 }
1060
1061 fn matching_kv(payload: &[u8], value: &[u8]) -> (Bytes, Bytes) {
1062 let codec = KeyCodec::new(TEST_RESERVED_BITS, TEST_PREFIX);
1063 let key = codec.encode(payload).expect("encode key");
1064 (
1065 Bytes::copy_from_slice(key.as_ref()),
1066 Bytes::copy_from_slice(value),
1067 )
1068 }
1069
1070 fn subscribe_request_bytes(since_sequence_number: Option<u64>) -> Vec<u8> {
1071 SubscribeRequest {
1072 match_keys: vec![ProtoMatchKey {
1073 reserved_bits: u32::from(TEST_RESERVED_BITS),
1074 prefix: u32::from(TEST_PREFIX),
1075 payload_regex: "(?s).*".to_string(),
1076 ..Default::default()
1077 }],
1078 since_sequence_number,
1079 ..Default::default()
1080 }
1081 .encode_to_vec()
1082 }
1083
1084 async fn subscribe_stream(
1085 connect: &StreamConnect,
1086 since_sequence_number: Option<u64>,
1087 ) -> Result<
1088 Pin<Box<dyn Stream<Item = Result<SubscribeResponse, ConnectError>> + Send>>,
1089 ConnectError,
1090 > {
1091 let bytes = subscribe_request_bytes(since_sequence_number);
1092 let request = buffa::view::OwnedView::<SubscribeRequestView<'static>>::decode(bytes.into())
1093 .expect("decode subscribe request");
1094 let (stream, _ctx) = StreamApi::subscribe(connect, Context::default(), request).await?;
1095 Ok(stream)
1096 }
1097
1098 #[tokio::test]
1099 async fn subscribe_without_replay_reads_the_next_live_batch() {
1100 let engine = Arc::new(FakeEngine::default());
1101 let state = AppState::new(engine.clone());
1102 let connect = StreamConnect::new(state.clone());
1103 let mut stream = subscribe_stream(&connect, None).await.expect("subscribe");
1104 engine.publish_live(state.stream.clone(), 1, vec![matching_kv(b"hit", b"v1")]);
1105 let frame = tokio::time::timeout(Duration::from_secs(1), stream.next())
1106 .await
1107 .expect("stream should yield")
1108 .expect("frame should exist")
1109 .expect("frame should be ok");
1110 assert_eq!(frame.sequence_number, 1);
1111 assert_eq!(frame.entries.len(), 1);
1112 assert_eq!(frame.entries[0].value.as_slice(), b"v1");
1113 }
1114
1115 #[tokio::test]
1116 async fn subscribe_past_end_reads_only_future_live_batches() {
1117 let engine = Arc::new(FakeEngine::default());
1118 engine.set_current_sequence(5);
1119 for seq in 1..=5 {
1120 engine.set_batch(seq, Some(vec![matching_kv(b"seed", b"v")]));
1121 }
1122 let state = AppState::new(engine.clone());
1123 let connect = StreamConnect::new(state.clone());
1124 let mut stream = subscribe_stream(&connect, Some(15))
1125 .await
1126 .expect("subscribe");
1127
1128 assert!(
1129 tokio::time::timeout(Duration::from_millis(200), stream.next())
1130 .await
1131 .is_err(),
1132 "past-end cursor should not replay synthetic or historical frames",
1133 );
1134
1135 engine.publish_live(state.stream.clone(), 6, vec![matching_kv(b"live", b"n")]);
1136 let frame = tokio::time::timeout(Duration::from_secs(1), stream.next())
1137 .await
1138 .expect("stream should yield")
1139 .expect("frame should exist")
1140 .expect("frame should be ok");
1141 assert_eq!(frame.sequence_number, 6);
1142 assert_eq!(frame.entries.len(), 1);
1143 assert_eq!(frame.entries[0].value.as_slice(), b"n");
1144 }
1145
1146 #[tokio::test]
1147 async fn replay_hole_returns_batch_evicted_error_instead_of_empty_frame() {
1148 let engine = Arc::new(FakeEngine::default());
1149 engine.set_current_sequence(3);
1150 engine.set_oldest_retained(Some(2));
1151 engine.set_batch(2, Some(vec![matching_kv(b"replay", b"v2")]));
1152
1153 let state = AppState::new(engine);
1154 let connect = StreamConnect::new(state);
1155 let mut stream = subscribe_stream(&connect, Some(2))
1156 .await
1157 .expect("subscribe");
1158
1159 let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
1160 .await
1161 .expect("stream should yield")
1162 .expect("first replay frame should exist")
1163 .expect("first replay frame should be ok");
1164 assert_eq!(first.sequence_number, 2);
1165 assert_eq!(first.entries.len(), 1);
1166
1167 let err = tokio::time::timeout(Duration::from_secs(1), stream.next())
1168 .await
1169 .expect("stream should yield error")
1170 .expect("error item should exist")
1171 .expect_err("replay hole must be surfaced as an error");
1172 let decoded = decode_connect_error(&err).expect("decode connect error");
1173 assert_eq!(
1174 decoded.error_info.expect("error info").reason,
1175 crate::stream::REASON_BATCH_EVICTED,
1176 );
1177 assert!(
1178 tokio::time::timeout(Duration::from_secs(1), stream.next())
1179 .await
1180 .expect("stream should terminate")
1181 .is_none(),
1182 "stream must terminate after surfacing the replay hole",
1183 );
1184 }
1185
1186 #[tokio::test]
1187 async fn replay_with_live_burst_under_capacity_still_delivers_in_order() {
1188 const REPLAY_BATCHES: u64 = 100;
1189
1190 let engine = Arc::new(FakeEngine::default());
1191 engine.set_current_sequence(REPLAY_BATCHES);
1192 engine.set_oldest_retained(Some(1));
1193 for seq in 1..=REPLAY_BATCHES {
1194 engine.set_batch(seq, Some(vec![matching_kv(b"replay", b"v")]));
1195 }
1196
1197 let state = AppState::new(engine.clone());
1198 engine.publish_on_every_get_batch(
1199 state.stream.clone(),
1200 REPLAY_BATCHES,
1201 vec![matching_kv(b"live", b"tail")],
1202 );
1203
1204 let connect = StreamConnect::new(state);
1205 let mut stream = subscribe_stream(&connect, Some(1))
1206 .await
1207 .expect("subscribe");
1208 let mut sequence_numbers = Vec::with_capacity((REPLAY_BATCHES * 2) as usize);
1209 while sequence_numbers.len() < (REPLAY_BATCHES * 2) as usize {
1210 let frame = tokio::time::timeout(Duration::from_secs(2), stream.next())
1211 .await
1212 .expect("stream should keep yielding")
1213 .expect("frame should exist")
1214 .expect("frame should be ok");
1215 sequence_numbers.push(frame.sequence_number);
1216 }
1217
1218 let expected: Vec<u64> = (1..=(REPLAY_BATCHES * 2)).collect();
1219 assert_eq!(sequence_numbers, expected);
1220 }
1221
1222 #[tokio::test]
1223 async fn replay_large_live_burst_is_paced_by_client_reads() {
1224 const REPLAY_BATCHES: u64 = 300;
1225
1226 let engine = Arc::new(FakeEngine::default());
1227 engine.set_current_sequence(REPLAY_BATCHES);
1228 engine.set_oldest_retained(Some(1));
1229 for seq in 1..=REPLAY_BATCHES {
1230 engine.set_batch(seq, Some(vec![matching_kv(b"replay", b"v")]));
1231 }
1232
1233 let state = AppState::new(engine.clone());
1234 engine.publish_on_every_get_batch(
1235 state.stream.clone(),
1236 REPLAY_BATCHES,
1237 vec![matching_kv(b"live", b"tail")],
1238 );
1239
1240 let connect = StreamConnect::new(state);
1241 let mut stream = subscribe_stream(&connect, Some(1))
1242 .await
1243 .expect("subscribe");
1244 let mut sequence_numbers = Vec::with_capacity((REPLAY_BATCHES * 2) as usize);
1245 while sequence_numbers.len() < (REPLAY_BATCHES * 2) as usize {
1246 let frame = tokio::time::timeout(Duration::from_secs(2), stream.next())
1247 .await
1248 .expect("stream should keep yielding")
1249 .expect("frame should exist")
1250 .expect("frame should be ok");
1251 sequence_numbers.push(frame.sequence_number);
1252 }
1253 let expected: Vec<u64> = (1..=(REPLAY_BATCHES * 2)).collect();
1254 assert_eq!(sequence_numbers, expected);
1255 }
1256}