Skip to main content

drasi_source_ris_live/
stream.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! WebSocket stream processing for RIS Live.
16
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use anyhow::{anyhow, Context, Result};
21use futures_util::{SinkExt, StreamExt};
22use log::{debug, error, info, warn};
23use tokio::sync::{watch, RwLock};
24use tokio::time::sleep;
25use tokio_tungstenite::tungstenite::Message;
26use tokio_tungstenite::{connect_async, tungstenite};
27use url::Url;
28
29use drasi_lib::channels::{ChangeDispatcher, SourceEvent, SourceEventWrapper};
30use drasi_lib::profiling::{timestamp_ns, ProfilingMetadata};
31use drasi_lib::sources::base::SourceBase;
32use drasi_lib::state_store::StateStoreProvider;
33
34use crate::config::RisLiveSourceConfig;
35use crate::mapping::{GraphMapper, PersistedStreamState, StreamState};
36use crate::messages::{
37    message_timestamp_millis, RisErrorData, RisIncomingMessage, RisMessageData, RisSubscribeMessage,
38};
39
40const STATE_KEY: &str = "ris-live.stream-state.v1";
41
42/// Minimum interval between state persistence writes.
43const PERSIST_INTERVAL: Duration = Duration::from_secs(5);
44
45/// Runs the resilient streaming loop with reconnect behavior.
46pub async fn run_stream_loop(
47    source_id: String,
48    config: RisLiveSourceConfig,
49    dispatchers: Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
50    state_store: Option<Arc<dyn StateStoreProvider>>,
51    mut shutdown_rx: watch::Receiver<bool>,
52) -> Result<()> {
53    let initial_state = load_initial_state(&source_id, &config, &state_store).await?;
54    let mut mapper = GraphMapper::new(source_id.clone(), initial_state);
55    let mut last_persisted = Instant::now();
56
57    loop {
58        if *shutdown_rx.borrow() {
59            persist_state(&source_id, &state_store, mapper.state()).await?;
60            info!("[{source_id}] RIS stream shutdown requested");
61            return Ok(());
62        }
63
64        match run_single_connection(
65            &source_id,
66            &config,
67            &dispatchers,
68            &state_store,
69            &mut mapper,
70            &mut last_persisted,
71            &mut shutdown_rx,
72        )
73        .await
74        {
75            Ok(()) => {
76                if *shutdown_rx.borrow() {
77                    persist_state(&source_id, &state_store, mapper.state()).await?;
78                    return Ok(());
79                }
80                warn!("[{source_id}] RIS connection ended, reconnecting");
81            }
82            Err(error) => {
83                error!("[{source_id}] RIS streaming error: {error}");
84            }
85        }
86
87        tokio::select! {
88            _ = sleep(Duration::from_secs(config.reconnect_delay_secs())) => {}
89            _ = shutdown_rx.changed() => {
90                if *shutdown_rx.borrow() {
91                    persist_state(&source_id, &state_store, mapper.state()).await?;
92                    return Ok(());
93                }
94            }
95        }
96    }
97}
98
99async fn run_single_connection(
100    source_id: &str,
101    config: &RisLiveSourceConfig,
102    dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
103    state_store: &Option<Arc<dyn StateStoreProvider>>,
104    mapper: &mut GraphMapper,
105    last_persisted: &mut Instant,
106    shutdown_rx: &mut watch::Receiver<bool>,
107) -> Result<()> {
108    ensure_crypto_provider();
109    let url = build_url(config)?;
110    info!("[{source_id}] Connecting to RIS Live: {url}");
111
112    let (mut socket, response) = connect_async(url.as_str())
113        .await
114        .with_context(|| format!("failed to connect to RIS Live at {url}"))?;
115    debug!(
116        "[{source_id}] Connected to RIS Live (status: {})",
117        response.status()
118    );
119
120    let subscribe = RisSubscribeMessage::from_config(config);
121    let payload =
122        serde_json::to_string(&subscribe).context("failed to serialize subscribe payload")?;
123    socket
124        .send(Message::Text(payload))
125        .await
126        .context("failed to send ris_subscribe")?;
127    info!("[{source_id}] Subscription message sent");
128
129    loop {
130        tokio::select! {
131            _ = shutdown_rx.changed() => {
132                if *shutdown_rx.borrow() {
133                    let _ = socket.close(None).await;
134                    return Ok(());
135                }
136            }
137            frame = socket.next() => {
138                match frame {
139                    Some(Ok(Message::Text(text))) => {
140                        process_text_frame(source_id, config, dispatchers, state_store, mapper, last_persisted, &text).await?;
141                    }
142                    Some(Ok(Message::Binary(_))) => {}
143                    Some(Ok(Message::Ping(payload))) => {
144                        socket.send(Message::Pong(payload)).await.context("failed to send pong")?;
145                    }
146                    Some(Ok(Message::Pong(_))) => {}
147                    Some(Ok(Message::Close(_))) => {
148                        info!("[{source_id}] RIS server closed the connection");
149                        return Ok(());
150                    }
151                    Some(Ok(Message::Frame(_))) => {}
152                    Some(Err(tungstenite::Error::ConnectionClosed)) => {
153                        info!("[{source_id}] RIS connection closed");
154                        return Ok(());
155                    }
156                    Some(Err(error)) => {
157                        return Err(anyhow!("websocket read error: {error}"));
158                    }
159                    None => {
160                        info!("[{source_id}] RIS stream ended");
161                        return Ok(());
162                    }
163                }
164            }
165        }
166    }
167}
168
169async fn process_text_frame(
170    source_id: &str,
171    config: &RisLiveSourceConfig,
172    dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
173    state_store: &Option<Arc<dyn StateStoreProvider>>,
174    mapper: &mut GraphMapper,
175    last_persisted: &mut Instant,
176    text: &str,
177) -> Result<()> {
178    let incoming: RisIncomingMessage = serde_json::from_str(text).with_context(|| {
179        let truncated: String = text.chars().take(512).collect();
180        format!("failed to parse RIS message wrapper: {truncated}")
181    })?;
182
183    match incoming.msg_type.as_str() {
184        "ris_subscribe_ok" => {
185            info!("[{source_id}] Subscription acknowledged");
186            Ok(())
187        }
188        "ris_error" => {
189            if let Some(payload) = incoming.data {
190                let err: RisErrorData =
191                    serde_json::from_value(payload).context("failed to parse ris_error payload")?;
192                Err(anyhow!("RIS server error: {}", err.message))
193            } else {
194                Err(anyhow!("RIS server sent ris_error without payload"))
195            }
196        }
197        "ris_message" => {
198            let payload = incoming
199                .data
200                .ok_or_else(|| anyhow!("ris_message missing payload"))?;
201            let message: RisMessageData =
202                serde_json::from_value(payload).context("failed to parse ris_message payload")?;
203
204            if !config.should_process_timestamp(message_timestamp_millis(&message)) {
205                debug!("[{source_id}] Skipping message due to start_from timestamp");
206                return Ok(());
207            }
208
209            let mut changes = Vec::new();
210            match message.msg_type.as_deref() {
211                Some("UPDATE") => {
212                    changes.extend(mapper.process_announcements(&message));
213                    changes.extend(mapper.process_withdrawals(&message));
214                }
215                Some("RIS_PEER_STATE") if config.include_peer_state => {
216                    changes.extend(mapper.process_peer_state(&message));
217                }
218                Some("RIS_PEER_STATE") => {}
219                Some("OPEN") | Some("KEEPALIVE") | Some("NOTIFICATION") => {}
220                _ => {}
221            }
222
223            if !changes.is_empty() {
224                for change in changes {
225                    dispatch_change(source_id, dispatchers, change).await?;
226                }
227                if last_persisted.elapsed() >= PERSIST_INTERVAL {
228                    persist_state(source_id, state_store, mapper.state()).await?;
229                    *last_persisted = Instant::now();
230                }
231            }
232            Ok(())
233        }
234        "pong" => Ok(()),
235        "ris_rrc_list" => Ok(()),
236        other => {
237            debug!("[{source_id}] Ignoring unsupported message type: {other}");
238            Ok(())
239        }
240    }
241}
242
243async fn dispatch_change(
244    source_id: &str,
245    dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
246    change: drasi_core::models::SourceChange,
247) -> Result<()> {
248    let mut profiling = ProfilingMetadata::new();
249    profiling.source_send_ns = Some(timestamp_ns());
250
251    let wrapper = SourceEventWrapper::with_profiling(
252        source_id.to_string(),
253        SourceEvent::Change(change),
254        chrono::Utc::now(),
255        profiling,
256    );
257
258    SourceBase::dispatch_from_task(dispatchers.clone(), wrapper, source_id).await
259}
260
261fn ensure_crypto_provider() {
262    let _ = rustls::crypto::ring::default_provider().install_default();
263}
264
265fn build_url(config: &RisLiveSourceConfig) -> Result<Url> {
266    let mut url = Url::parse(&config.websocket_url)
267        .with_context(|| format!("invalid websocket_url '{}'", config.websocket_url))?;
268
269    match url.scheme() {
270        "ws" | "wss" => {}
271        other => {
272            return Err(anyhow!(
273                "websocket_url scheme must be ws or wss, got: {other}"
274            ));
275        }
276    }
277
278    if let Some(client_name) = &config.client_name {
279        // Remove any existing `client` parameter before appending ours
280        let existing: Vec<(String, String)> = url
281            .query_pairs()
282            .filter(|(k, _)| k != "client")
283            .map(|(k, v)| (k.into_owned(), v.into_owned()))
284            .collect();
285        url.query_pairs_mut().clear().extend_pairs(existing);
286        url.query_pairs_mut().append_pair("client", client_name);
287    }
288
289    Ok(url)
290}
291
292async fn load_initial_state(
293    source_id: &str,
294    config: &RisLiveSourceConfig,
295    state_store: &Option<Arc<dyn StateStoreProvider>>,
296) -> Result<StreamState> {
297    if config.clear_state_on_start {
298        if let Some(store) = state_store {
299            store.delete(source_id, STATE_KEY).await.with_context(|| {
300                format!("failed to clear persisted state for source '{source_id}'")
301            })?;
302        }
303        return Ok(StreamState::default());
304    }
305
306    load_state(source_id, state_store).await
307}
308
309async fn load_state(
310    source_id: &str,
311    state_store: &Option<Arc<dyn StateStoreProvider>>,
312) -> Result<StreamState> {
313    let Some(store) = state_store else {
314        return Ok(StreamState::default());
315    };
316
317    let bytes = store
318        .get(source_id, STATE_KEY)
319        .await
320        .with_context(|| format!("failed to read state for source '{source_id}'"))?;
321
322    let Some(bytes) = bytes else {
323        return Ok(StreamState::default());
324    };
325
326    let persisted: PersistedStreamState = serde_json::from_slice(&bytes)
327        .with_context(|| format!("invalid persisted state payload for source '{source_id}'"))?;
328
329    Ok(persisted.into())
330}
331
332async fn persist_state(
333    source_id: &str,
334    state_store: &Option<Arc<dyn StateStoreProvider>>,
335    state: &StreamState,
336) -> Result<()> {
337    let Some(store) = state_store else {
338        return Ok(());
339    };
340
341    let payload =
342        serde_json::to_vec(&PersistedStreamState::from(state)).context("failed to encode state")?;
343    store
344        .set(source_id, STATE_KEY, payload)
345        .await
346        .with_context(|| format!("failed to persist state for source '{source_id}'"))?;
347    Ok(())
348}