1use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use anyhow::{anyhow, Context, Result};
21use futures_util::{SinkExt, StreamExt};
22use log::{debug, error, info, warn};
23use tokio::sync::{watch, RwLock};
24use tokio::time::sleep;
25use tokio_tungstenite::tungstenite::Message;
26use tokio_tungstenite::{connect_async, tungstenite};
27use url::Url;
28
29use drasi_lib::channels::{ChangeDispatcher, SourceEvent, SourceEventWrapper};
30use drasi_lib::profiling::{timestamp_ns, ProfilingMetadata};
31use drasi_lib::sources::base::SourceBase;
32use drasi_lib::state_store::StateStoreProvider;
33
34use crate::config::RisLiveSourceConfig;
35use crate::mapping::{GraphMapper, PersistedStreamState, StreamState};
36use crate::messages::{
37 message_timestamp_millis, RisErrorData, RisIncomingMessage, RisMessageData, RisSubscribeMessage,
38};
39
40const STATE_KEY: &str = "ris-live.stream-state.v1";
41
42const PERSIST_INTERVAL: Duration = Duration::from_secs(5);
44
45pub async fn run_stream_loop(
47 source_id: String,
48 config: RisLiveSourceConfig,
49 dispatchers: Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
50 state_store: Option<Arc<dyn StateStoreProvider>>,
51 mut shutdown_rx: watch::Receiver<bool>,
52) -> Result<()> {
53 let initial_state = load_initial_state(&source_id, &config, &state_store).await?;
54 let mut mapper = GraphMapper::new(source_id.clone(), initial_state);
55 let mut last_persisted = Instant::now();
56
57 loop {
58 if *shutdown_rx.borrow() {
59 persist_state(&source_id, &state_store, mapper.state()).await?;
60 info!("[{source_id}] RIS stream shutdown requested");
61 return Ok(());
62 }
63
64 match run_single_connection(
65 &source_id,
66 &config,
67 &dispatchers,
68 &state_store,
69 &mut mapper,
70 &mut last_persisted,
71 &mut shutdown_rx,
72 )
73 .await
74 {
75 Ok(()) => {
76 if *shutdown_rx.borrow() {
77 persist_state(&source_id, &state_store, mapper.state()).await?;
78 return Ok(());
79 }
80 warn!("[{source_id}] RIS connection ended, reconnecting");
81 }
82 Err(error) => {
83 error!("[{source_id}] RIS streaming error: {error}");
84 }
85 }
86
87 tokio::select! {
88 _ = sleep(Duration::from_secs(config.reconnect_delay_secs())) => {}
89 _ = shutdown_rx.changed() => {
90 if *shutdown_rx.borrow() {
91 persist_state(&source_id, &state_store, mapper.state()).await?;
92 return Ok(());
93 }
94 }
95 }
96 }
97}
98
99async fn run_single_connection(
100 source_id: &str,
101 config: &RisLiveSourceConfig,
102 dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
103 state_store: &Option<Arc<dyn StateStoreProvider>>,
104 mapper: &mut GraphMapper,
105 last_persisted: &mut Instant,
106 shutdown_rx: &mut watch::Receiver<bool>,
107) -> Result<()> {
108 ensure_crypto_provider();
109 let url = build_url(config)?;
110 info!("[{source_id}] Connecting to RIS Live: {url}");
111
112 let (mut socket, response) = connect_async(url.as_str())
113 .await
114 .with_context(|| format!("failed to connect to RIS Live at {url}"))?;
115 debug!(
116 "[{source_id}] Connected to RIS Live (status: {})",
117 response.status()
118 );
119
120 let subscribe = RisSubscribeMessage::from_config(config);
121 let payload =
122 serde_json::to_string(&subscribe).context("failed to serialize subscribe payload")?;
123 socket
124 .send(Message::Text(payload))
125 .await
126 .context("failed to send ris_subscribe")?;
127 info!("[{source_id}] Subscription message sent");
128
129 loop {
130 tokio::select! {
131 _ = shutdown_rx.changed() => {
132 if *shutdown_rx.borrow() {
133 let _ = socket.close(None).await;
134 return Ok(());
135 }
136 }
137 frame = socket.next() => {
138 match frame {
139 Some(Ok(Message::Text(text))) => {
140 process_text_frame(source_id, config, dispatchers, state_store, mapper, last_persisted, &text).await?;
141 }
142 Some(Ok(Message::Binary(_))) => {}
143 Some(Ok(Message::Ping(payload))) => {
144 socket.send(Message::Pong(payload)).await.context("failed to send pong")?;
145 }
146 Some(Ok(Message::Pong(_))) => {}
147 Some(Ok(Message::Close(_))) => {
148 info!("[{source_id}] RIS server closed the connection");
149 return Ok(());
150 }
151 Some(Ok(Message::Frame(_))) => {}
152 Some(Err(tungstenite::Error::ConnectionClosed)) => {
153 info!("[{source_id}] RIS connection closed");
154 return Ok(());
155 }
156 Some(Err(error)) => {
157 return Err(anyhow!("websocket read error: {error}"));
158 }
159 None => {
160 info!("[{source_id}] RIS stream ended");
161 return Ok(());
162 }
163 }
164 }
165 }
166 }
167}
168
169async fn process_text_frame(
170 source_id: &str,
171 config: &RisLiveSourceConfig,
172 dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
173 state_store: &Option<Arc<dyn StateStoreProvider>>,
174 mapper: &mut GraphMapper,
175 last_persisted: &mut Instant,
176 text: &str,
177) -> Result<()> {
178 let incoming: RisIncomingMessage = serde_json::from_str(text).with_context(|| {
179 let truncated: String = text.chars().take(512).collect();
180 format!("failed to parse RIS message wrapper: {truncated}")
181 })?;
182
183 match incoming.msg_type.as_str() {
184 "ris_subscribe_ok" => {
185 info!("[{source_id}] Subscription acknowledged");
186 Ok(())
187 }
188 "ris_error" => {
189 if let Some(payload) = incoming.data {
190 let err: RisErrorData =
191 serde_json::from_value(payload).context("failed to parse ris_error payload")?;
192 Err(anyhow!("RIS server error: {}", err.message))
193 } else {
194 Err(anyhow!("RIS server sent ris_error without payload"))
195 }
196 }
197 "ris_message" => {
198 let payload = incoming
199 .data
200 .ok_or_else(|| anyhow!("ris_message missing payload"))?;
201 let message: RisMessageData =
202 serde_json::from_value(payload).context("failed to parse ris_message payload")?;
203
204 if !config.should_process_timestamp(message_timestamp_millis(&message)) {
205 debug!("[{source_id}] Skipping message due to start_from timestamp");
206 return Ok(());
207 }
208
209 let mut changes = Vec::new();
210 match message.msg_type.as_deref() {
211 Some("UPDATE") => {
212 changes.extend(mapper.process_announcements(&message));
213 changes.extend(mapper.process_withdrawals(&message));
214 }
215 Some("RIS_PEER_STATE") if config.include_peer_state => {
216 changes.extend(mapper.process_peer_state(&message));
217 }
218 Some("RIS_PEER_STATE") => {}
219 Some("OPEN") | Some("KEEPALIVE") | Some("NOTIFICATION") => {}
220 _ => {}
221 }
222
223 if !changes.is_empty() {
224 for change in changes {
225 dispatch_change(source_id, dispatchers, change).await?;
226 }
227 if last_persisted.elapsed() >= PERSIST_INTERVAL {
228 persist_state(source_id, state_store, mapper.state()).await?;
229 *last_persisted = Instant::now();
230 }
231 }
232 Ok(())
233 }
234 "pong" => Ok(()),
235 "ris_rrc_list" => Ok(()),
236 other => {
237 debug!("[{source_id}] Ignoring unsupported message type: {other}");
238 Ok(())
239 }
240 }
241}
242
243async fn dispatch_change(
244 source_id: &str,
245 dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
246 change: drasi_core::models::SourceChange,
247) -> Result<()> {
248 let mut profiling = ProfilingMetadata::new();
249 profiling.source_send_ns = Some(timestamp_ns());
250
251 let wrapper = SourceEventWrapper::with_profiling(
252 source_id.to_string(),
253 SourceEvent::Change(change),
254 chrono::Utc::now(),
255 profiling,
256 );
257
258 SourceBase::dispatch_from_task(dispatchers.clone(), wrapper, source_id).await
259}
260
261fn ensure_crypto_provider() {
262 let _ = rustls::crypto::ring::default_provider().install_default();
263}
264
265fn build_url(config: &RisLiveSourceConfig) -> Result<Url> {
266 let mut url = Url::parse(&config.websocket_url)
267 .with_context(|| format!("invalid websocket_url '{}'", config.websocket_url))?;
268
269 match url.scheme() {
270 "ws" | "wss" => {}
271 other => {
272 return Err(anyhow!(
273 "websocket_url scheme must be ws or wss, got: {other}"
274 ));
275 }
276 }
277
278 if let Some(client_name) = &config.client_name {
279 let existing: Vec<(String, String)> = url
281 .query_pairs()
282 .filter(|(k, _)| k != "client")
283 .map(|(k, v)| (k.into_owned(), v.into_owned()))
284 .collect();
285 url.query_pairs_mut().clear().extend_pairs(existing);
286 url.query_pairs_mut().append_pair("client", client_name);
287 }
288
289 Ok(url)
290}
291
292async fn load_initial_state(
293 source_id: &str,
294 config: &RisLiveSourceConfig,
295 state_store: &Option<Arc<dyn StateStoreProvider>>,
296) -> Result<StreamState> {
297 if config.clear_state_on_start {
298 if let Some(store) = state_store {
299 store.delete(source_id, STATE_KEY).await.with_context(|| {
300 format!("failed to clear persisted state for source '{source_id}'")
301 })?;
302 }
303 return Ok(StreamState::default());
304 }
305
306 load_state(source_id, state_store).await
307}
308
309async fn load_state(
310 source_id: &str,
311 state_store: &Option<Arc<dyn StateStoreProvider>>,
312) -> Result<StreamState> {
313 let Some(store) = state_store else {
314 return Ok(StreamState::default());
315 };
316
317 let bytes = store
318 .get(source_id, STATE_KEY)
319 .await
320 .with_context(|| format!("failed to read state for source '{source_id}'"))?;
321
322 let Some(bytes) = bytes else {
323 return Ok(StreamState::default());
324 };
325
326 let persisted: PersistedStreamState = serde_json::from_slice(&bytes)
327 .with_context(|| format!("invalid persisted state payload for source '{source_id}'"))?;
328
329 Ok(persisted.into())
330}
331
332async fn persist_state(
333 source_id: &str,
334 state_store: &Option<Arc<dyn StateStoreProvider>>,
335 state: &StreamState,
336) -> Result<()> {
337 let Some(store) = state_store else {
338 return Ok(());
339 };
340
341 let payload =
342 serde_json::to_vec(&PersistedStreamState::from(state)).context("failed to encode state")?;
343 store
344 .set(source_id, STATE_KEY, payload)
345 .await
346 .with_context(|| format!("failed to persist state for source '{source_id}'"))?;
347 Ok(())
348}