Skip to main content

ralph_api/stream_domain/
mod.rs

1mod filters;
2mod rpc_side_effects;
3
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::sync::{Arc, Mutex};
6
7use chrono::Utc;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use tokio::sync::broadcast;
11
12use crate::errors::ApiError;
13use crate::loop_support::now_ts;
14use crate::protocol::{API_VERSION, STREAM_NAME, STREAM_TOPICS};
15
16use self::filters::{
17    SubscriptionFilters, cursor_is_older, cursor_sequence, normalize_topics, validate_cursor,
18};
19
20pub const KEEPALIVE_INTERVAL_MS: u64 = 15_000;
21
22const DEFAULT_REPLAY_LIMIT: usize = 200;
23const HISTORY_LIMIT: usize = 2_048;
24const LIVE_BUFFER_CAPACITY: usize = 256;
25
26#[derive(Debug, Clone, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct StreamSubscribeParams {
29    pub topics: Vec<String>,
30    pub cursor: Option<String>,
31    pub replay_limit: Option<u16>,
32    pub filters: Option<Value>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct StreamUnsubscribeParams {
38    pub subscription_id: String,
39}
40
41#[derive(Debug, Clone, Deserialize)]
42#[serde(rename_all = "camelCase")]
43pub struct StreamAckParams {
44    pub subscription_id: String,
45    pub cursor: String,
46}
47
48#[derive(Debug, Clone, Serialize)]
49#[serde(rename_all = "camelCase")]
50pub struct StreamSubscribeResult {
51    pub subscription_id: String,
52    pub accepted_topics: Vec<String>,
53    pub cursor: String,
54}
55
56#[derive(Debug, Clone)]
57pub struct ReplayBatch {
58    pub events: Vec<StreamEventEnvelope>,
59    pub dropped_count: usize,
60}
61
62#[derive(Debug, Clone, Serialize)]
63#[serde(rename_all = "camelCase")]
64pub struct StreamEventEnvelope {
65    pub api_version: String,
66    pub stream: String,
67    pub topic: String,
68    pub cursor: String,
69    pub sequence: u64,
70    pub ts: String,
71    pub resource: StreamResource,
72    pub replay: StreamReplay,
73    pub payload: Value,
74}
75
76#[derive(Debug, Clone, Serialize)]
77pub struct StreamResource {
78    #[serde(rename = "type")]
79    pub kind: String,
80    pub id: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84#[serde(rename_all = "camelCase")]
85pub struct StreamReplay {
86    pub mode: String,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub requested_cursor: Option<String>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub batch: Option<u64>,
91}
92
93#[derive(Clone)]
94pub struct StreamDomain {
95    state: Arc<Mutex<StreamState>>,
96    live_tx: broadcast::Sender<StreamEventEnvelope>,
97}
98
99#[derive(Debug, Clone)]
100struct SubscriptionRecord {
101    topics: HashSet<String>,
102    filters: SubscriptionFilters,
103    cursor: String,
104    replay_limit: usize,
105    explicit_cursor: bool,
106    principal: String,
107}
108
109struct StreamState {
110    sequence: u64,
111    subscription_counter: u64,
112    history: VecDeque<StreamEventEnvelope>,
113    subscriptions: HashMap<String, SubscriptionRecord>,
114}
115
116impl StreamDomain {
117    pub fn new() -> Self {
118        let (live_tx, _) = broadcast::channel(LIVE_BUFFER_CAPACITY);
119        Self {
120            state: Arc::new(Mutex::new(StreamState {
121                sequence: 1,
122                subscription_counter: 0,
123                history: VecDeque::with_capacity(HISTORY_LIMIT),
124                subscriptions: HashMap::new(),
125            })),
126            live_tx,
127        }
128    }
129
130    pub fn subscribe(
131        &self,
132        params: StreamSubscribeParams,
133        principal: &str,
134    ) -> Result<StreamSubscribeResult, ApiError> {
135        let accepted_topics = normalize_topics(&params.topics, STREAM_TOPICS)?;
136        let cursor = if let Some(cursor) = &params.cursor {
137            validate_cursor(cursor)?;
138            cursor.clone()
139        } else {
140            self.latest_cursor_or_now()?
141        };
142
143        let replay_limit = usize::from(params.replay_limit.unwrap_or(DEFAULT_REPLAY_LIMIT as u16));
144        let filters = SubscriptionFilters::from_json(params.filters)?;
145
146        let mut state = self.lock_state()?;
147        state.subscription_counter = state.subscription_counter.saturating_add(1);
148        let subscription_id = format!(
149            "sub-{}-{:04x}",
150            Utc::now().timestamp_millis(),
151            state.subscription_counter
152        );
153
154        let topics = accepted_topics.iter().cloned().collect::<HashSet<_>>();
155        state.subscriptions.insert(
156            subscription_id.clone(),
157            SubscriptionRecord {
158                topics,
159                filters,
160                cursor: cursor.clone(),
161                replay_limit,
162                explicit_cursor: params.cursor.is_some(),
163                principal: principal.to_string(),
164            },
165        );
166
167        Ok(StreamSubscribeResult {
168            subscription_id,
169            accepted_topics,
170            cursor,
171        })
172    }
173
174    pub fn get_subscription_principal(&self, subscription_id: &str) -> Option<String> {
175        let state = self.lock_state().ok()?;
176        state
177            .subscriptions
178            .get(subscription_id)
179            .map(|s| s.principal.clone())
180    }
181
182    pub fn unsubscribe(&self, params: StreamUnsubscribeParams) -> Result<(), ApiError> {
183        let mut state = self.lock_state()?;
184        let removed = state.subscriptions.remove(&params.subscription_id);
185        if removed.is_none() {
186            return Err(ApiError::not_found(format!(
187                "subscription '{}' not found",
188                params.subscription_id
189            ))
190            .with_details(json!({ "subscriptionId": params.subscription_id })));
191        }
192
193        Ok(())
194    }
195
196    pub fn ack(&self, params: StreamAckParams) -> Result<(), ApiError> {
197        validate_cursor(&params.cursor)?;
198
199        let mut state = self.lock_state()?;
200        let Some(subscription) = state.subscriptions.get_mut(&params.subscription_id) else {
201            return Err(ApiError::not_found(format!(
202                "subscription '{}' not found",
203                params.subscription_id
204            ))
205            .with_details(json!({ "subscriptionId": params.subscription_id })));
206        };
207
208        if cursor_is_older(&params.cursor, &subscription.cursor)? {
209            return Err(ApiError::precondition_failed(
210                "stream.ack cursor is older than the subscription checkpoint",
211            )
212            .with_details(json!({
213                "subscriptionId": params.subscription_id,
214                "cursor": params.cursor,
215                "currentCursor": subscription.cursor
216            })));
217        }
218
219        subscription.cursor = params.cursor;
220        subscription.explicit_cursor = true;
221        Ok(())
222    }
223
224    pub fn live_receiver(&self) -> broadcast::Receiver<StreamEventEnvelope> {
225        self.live_tx.subscribe()
226    }
227
228    pub fn has_subscription(&self, subscription_id: &str) -> bool {
229        self.state
230            .lock()
231            .ok()
232            .is_some_and(|state| state.subscriptions.contains_key(subscription_id))
233    }
234
235    pub fn matches_subscription(&self, subscription_id: &str, event: &StreamEventEnvelope) -> bool {
236        let Ok(state) = self.state.lock() else {
237            return false;
238        };
239
240        let Some(subscription) = state.subscriptions.get(subscription_id) else {
241            return false;
242        };
243
244        subscription.matches(event)
245    }
246
247    pub fn subscription_cursor_sequence(&self, subscription_id: &str) -> Result<u64, ApiError> {
248        let state = self.lock_state()?;
249        let Some(subscription) = state.subscriptions.get(subscription_id) else {
250            return Err(ApiError::not_found(format!(
251                "subscription '{}' not found",
252                subscription_id
253            ))
254            .with_details(json!({ "subscriptionId": subscription_id })));
255        };
256
257        cursor_sequence(&subscription.cursor)
258    }
259
260    pub fn subscription_cursor(&self, subscription_id: &str) -> Result<String, ApiError> {
261        let state = self.lock_state()?;
262        let Some(subscription) = state.subscriptions.get(subscription_id) else {
263            return Err(ApiError::not_found(format!(
264                "subscription '{}' not found",
265                subscription_id
266            ))
267            .with_details(json!({ "subscriptionId": subscription_id })));
268        };
269
270        Ok(subscription.cursor.clone())
271    }
272
273    pub fn replay_for_subscription(&self, subscription_id: &str) -> Result<ReplayBatch, ApiError> {
274        let state = self.lock_state()?;
275        let Some(subscription) = state.subscriptions.get(subscription_id) else {
276            return Err(ApiError::not_found(format!(
277                "subscription '{}' not found",
278                subscription_id
279            ))
280            .with_details(json!({ "subscriptionId": subscription_id })));
281        };
282
283        let cursor_sequence = cursor_sequence(&subscription.cursor)?;
284        let current_cursor = subscription.cursor.clone();
285        let mut events = state
286            .history
287            .iter()
288            .filter(|event| {
289                event.sequence > cursor_sequence
290                    || (event.sequence == cursor_sequence && event.cursor != current_cursor)
291            })
292            .filter(|event| subscription.matches(event))
293            .cloned()
294            .collect::<Vec<_>>();
295
296        let dropped_count = events.len().saturating_sub(subscription.replay_limit);
297        if dropped_count > 0 {
298            events = events.split_off(dropped_count);
299        }
300
301        if !events.is_empty() {
302            let replay_mode = if subscription.explicit_cursor {
303                "resume"
304            } else {
305                "replay"
306            };
307
308            let batch = u64::try_from(events.len()).unwrap_or(u64::MAX);
309            for event in &mut events {
310                event.replay.mode = replay_mode.to_string();
311                event.replay.requested_cursor = Some(subscription.cursor.clone());
312                event.replay.batch = Some(batch);
313            }
314        }
315
316        Ok(ReplayBatch {
317            events,
318            dropped_count,
319        })
320    }
321
322    pub fn keepalive_event(&self, subscription_id: &str, interval_ms: u64) -> StreamEventEnvelope {
323        self.ephemeral_event(
324            "stream.keepalive",
325            "stream",
326            subscription_id,
327            json!({ "intervalMs": interval_ms }),
328            "live",
329            None,
330            None,
331        )
332    }
333
334    pub fn backpressure_event(
335        &self,
336        subscription_id: &str,
337        dropped_count: usize,
338    ) -> StreamEventEnvelope {
339        self.ephemeral_event(
340            "error.raised",
341            "stream",
342            subscription_id,
343            json!({
344                "code": "BACKPRESSURE_DROPPED",
345                "message": format!(
346                    "subscription '{}' dropped {} event(s) due to backpressure",
347                    subscription_id,
348                    dropped_count
349                ),
350                "retryable": true
351            }),
352            "live",
353            None,
354            None,
355        )
356    }
357
358    pub fn publish(&self, topic: &str, resource_type: &str, resource_id: &str, payload: Value) {
359        if !STREAM_TOPICS.contains(&topic) {
360            return;
361        }
362
363        let Ok(mut state) = self.lock_state() else {
364            return;
365        };
366
367        let event = next_event(
368            &mut state,
369            topic,
370            resource_type,
371            resource_id,
372            payload,
373            "live",
374            None,
375            None,
376        );
377
378        if state.history.len() >= HISTORY_LIMIT {
379            state.history.pop_front();
380        }
381        state.history.push_back(event.clone());
382        let _ = self.live_tx.send(event);
383    }
384
385    pub fn publish_rpc_side_effect(&self, method: &str, params: &Value, result: &Value) {
386        rpc_side_effects::publish_rpc_side_effect(self, method, params, result);
387    }
388
389    fn latest_cursor_or_now(&self) -> Result<String, ApiError> {
390        let state = self.lock_state()?;
391        Ok(state
392            .history
393            .back()
394            .map(|event| event.cursor.clone())
395            .unwrap_or_else(|| format!("{}-0", Utc::now().timestamp_millis())))
396    }
397
398    fn ephemeral_event(
399        &self,
400        topic: &str,
401        resource_type: &str,
402        resource_id: &str,
403        payload: Value,
404        mode: &str,
405        requested_cursor: Option<String>,
406        batch: Option<u64>,
407    ) -> StreamEventEnvelope {
408        let Ok(mut state) = self.lock_state() else {
409            return StreamEventEnvelope {
410                api_version: API_VERSION.to_string(),
411                stream: STREAM_NAME.to_string(),
412                topic: topic.to_string(),
413                cursor: format!("{}-0", Utc::now().timestamp_millis()),
414                sequence: 0,
415                ts: now_ts(),
416                resource: StreamResource {
417                    kind: resource_type.to_string(),
418                    id: resource_id.to_string(),
419                },
420                replay: StreamReplay {
421                    mode: mode.to_string(),
422                    requested_cursor,
423                    batch,
424                },
425                payload,
426            };
427        };
428
429        next_event(
430            &mut state,
431            topic,
432            resource_type,
433            resource_id,
434            payload,
435            mode,
436            requested_cursor,
437            batch,
438        )
439    }
440
441    fn lock_state(&self) -> Result<std::sync::MutexGuard<'_, StreamState>, ApiError> {
442        self.state
443            .lock()
444            .map_err(|_| ApiError::internal("stream state lock poisoned"))
445    }
446}
447
448impl Default for StreamDomain {
449    fn default() -> Self {
450        Self::new()
451    }
452}
453
454impl SubscriptionRecord {
455    fn matches(&self, event: &StreamEventEnvelope) -> bool {
456        self.topics.contains(&event.topic) && self.filters.matches(event)
457    }
458}
459
460fn next_event(
461    state: &mut StreamState,
462    topic: &str,
463    resource_type: &str,
464    resource_id: &str,
465    payload: Value,
466    mode: &str,
467    requested_cursor: Option<String>,
468    batch: Option<u64>,
469) -> StreamEventEnvelope {
470    let sequence = state.sequence;
471    state.sequence = state.sequence.saturating_add(1);
472
473    StreamEventEnvelope {
474        api_version: API_VERSION.to_string(),
475        stream: STREAM_NAME.to_string(),
476        topic: topic.to_string(),
477        cursor: format!("{}-{sequence}", Utc::now().timestamp_millis()),
478        sequence,
479        ts: now_ts(),
480        resource: StreamResource {
481            kind: resource_type.to_string(),
482            id: resource_id.to_string(),
483        },
484        replay: StreamReplay {
485            mode: mode.to_string(),
486            requested_cursor,
487            batch,
488        },
489        payload,
490    }
491}