use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{anyhow, Context, Result};
use futures_util::{SinkExt, StreamExt};
use log::{debug, error, info, warn};
use tokio::sync::{watch, RwLock};
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, tungstenite};
use url::Url;
use drasi_lib::channels::{ChangeDispatcher, SourceEvent, SourceEventWrapper};
use drasi_lib::profiling::{timestamp_ns, ProfilingMetadata};
use drasi_lib::sources::base::SourceBase;
use drasi_lib::state_store::StateStoreProvider;
use crate::config::RisLiveSourceConfig;
use crate::mapping::{GraphMapper, PersistedStreamState, StreamState};
use crate::messages::{
message_timestamp_millis, RisErrorData, RisIncomingMessage, RisMessageData, RisSubscribeMessage,
};
const STATE_KEY: &str = "ris-live.stream-state.v1";
const PERSIST_INTERVAL: Duration = Duration::from_secs(5);
pub async fn run_stream_loop(
source_id: String,
config: RisLiveSourceConfig,
dispatchers: Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
state_store: Option<Arc<dyn StateStoreProvider>>,
mut shutdown_rx: watch::Receiver<bool>,
) -> Result<()> {
let initial_state = load_initial_state(&source_id, &config, &state_store).await?;
let mut mapper = GraphMapper::new(source_id.clone(), initial_state);
let mut last_persisted = Instant::now();
loop {
if *shutdown_rx.borrow() {
persist_state(&source_id, &state_store, mapper.state()).await?;
info!("[{source_id}] RIS stream shutdown requested");
return Ok(());
}
match run_single_connection(
&source_id,
&config,
&dispatchers,
&state_store,
&mut mapper,
&mut last_persisted,
&mut shutdown_rx,
)
.await
{
Ok(()) => {
if *shutdown_rx.borrow() {
persist_state(&source_id, &state_store, mapper.state()).await?;
return Ok(());
}
warn!("[{source_id}] RIS connection ended, reconnecting");
}
Err(error) => {
error!("[{source_id}] RIS streaming error: {error}");
}
}
tokio::select! {
_ = sleep(Duration::from_secs(config.reconnect_delay_secs())) => {}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
persist_state(&source_id, &state_store, mapper.state()).await?;
return Ok(());
}
}
}
}
}
async fn run_single_connection(
source_id: &str,
config: &RisLiveSourceConfig,
dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
state_store: &Option<Arc<dyn StateStoreProvider>>,
mapper: &mut GraphMapper,
last_persisted: &mut Instant,
shutdown_rx: &mut watch::Receiver<bool>,
) -> Result<()> {
ensure_crypto_provider();
let url = build_url(config)?;
info!("[{source_id}] Connecting to RIS Live: {url}");
let (mut socket, response) = connect_async(url.as_str())
.await
.with_context(|| format!("failed to connect to RIS Live at {url}"))?;
debug!(
"[{source_id}] Connected to RIS Live (status: {})",
response.status()
);
let subscribe = RisSubscribeMessage::from_config(config);
let payload =
serde_json::to_string(&subscribe).context("failed to serialize subscribe payload")?;
socket
.send(Message::Text(payload))
.await
.context("failed to send ris_subscribe")?;
info!("[{source_id}] Subscription message sent");
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
let _ = socket.close(None).await;
return Ok(());
}
}
frame = socket.next() => {
match frame {
Some(Ok(Message::Text(text))) => {
process_text_frame(source_id, config, dispatchers, state_store, mapper, last_persisted, &text).await?;
}
Some(Ok(Message::Binary(_))) => {}
Some(Ok(Message::Ping(payload))) => {
socket.send(Message::Pong(payload)).await.context("failed to send pong")?;
}
Some(Ok(Message::Pong(_))) => {}
Some(Ok(Message::Close(_))) => {
info!("[{source_id}] RIS server closed the connection");
return Ok(());
}
Some(Ok(Message::Frame(_))) => {}
Some(Err(tungstenite::Error::ConnectionClosed)) => {
info!("[{source_id}] RIS connection closed");
return Ok(());
}
Some(Err(error)) => {
return Err(anyhow!("websocket read error: {error}"));
}
None => {
info!("[{source_id}] RIS stream ended");
return Ok(());
}
}
}
}
}
}
async fn process_text_frame(
source_id: &str,
config: &RisLiveSourceConfig,
dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
state_store: &Option<Arc<dyn StateStoreProvider>>,
mapper: &mut GraphMapper,
last_persisted: &mut Instant,
text: &str,
) -> Result<()> {
let incoming: RisIncomingMessage = serde_json::from_str(text).with_context(|| {
let truncated: String = text.chars().take(512).collect();
format!("failed to parse RIS message wrapper: {truncated}")
})?;
match incoming.msg_type.as_str() {
"ris_subscribe_ok" => {
info!("[{source_id}] Subscription acknowledged");
Ok(())
}
"ris_error" => {
if let Some(payload) = incoming.data {
let err: RisErrorData =
serde_json::from_value(payload).context("failed to parse ris_error payload")?;
Err(anyhow!("RIS server error: {}", err.message))
} else {
Err(anyhow!("RIS server sent ris_error without payload"))
}
}
"ris_message" => {
let payload = incoming
.data
.ok_or_else(|| anyhow!("ris_message missing payload"))?;
let message: RisMessageData =
serde_json::from_value(payload).context("failed to parse ris_message payload")?;
if !config.should_process_timestamp(message_timestamp_millis(&message)) {
debug!("[{source_id}] Skipping message due to start_from timestamp");
return Ok(());
}
let mut changes = Vec::new();
match message.msg_type.as_deref() {
Some("UPDATE") => {
changes.extend(mapper.process_announcements(&message));
changes.extend(mapper.process_withdrawals(&message));
}
Some("RIS_PEER_STATE") if config.include_peer_state => {
changes.extend(mapper.process_peer_state(&message));
}
Some("RIS_PEER_STATE") => {}
Some("OPEN") | Some("KEEPALIVE") | Some("NOTIFICATION") => {}
_ => {}
}
if !changes.is_empty() {
for change in changes {
dispatch_change(source_id, dispatchers, change).await?;
}
if last_persisted.elapsed() >= PERSIST_INTERVAL {
persist_state(source_id, state_store, mapper.state()).await?;
*last_persisted = Instant::now();
}
}
Ok(())
}
"pong" => Ok(()),
"ris_rrc_list" => Ok(()),
other => {
debug!("[{source_id}] Ignoring unsupported message type: {other}");
Ok(())
}
}
}
async fn dispatch_change(
source_id: &str,
dispatchers: &Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
change: drasi_core::models::SourceChange,
) -> Result<()> {
let mut profiling = ProfilingMetadata::new();
profiling.source_send_ns = Some(timestamp_ns());
let wrapper = SourceEventWrapper::with_profiling(
source_id.to_string(),
SourceEvent::Change(change),
chrono::Utc::now(),
profiling,
);
SourceBase::dispatch_from_task(dispatchers.clone(), wrapper, source_id).await
}
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
fn build_url(config: &RisLiveSourceConfig) -> Result<Url> {
let mut url = Url::parse(&config.websocket_url)
.with_context(|| format!("invalid websocket_url '{}'", config.websocket_url))?;
match url.scheme() {
"ws" | "wss" => {}
other => {
return Err(anyhow!(
"websocket_url scheme must be ws or wss, got: {other}"
));
}
}
if let Some(client_name) = &config.client_name {
let existing: Vec<(String, String)> = url
.query_pairs()
.filter(|(k, _)| k != "client")
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect();
url.query_pairs_mut().clear().extend_pairs(existing);
url.query_pairs_mut().append_pair("client", client_name);
}
Ok(url)
}
async fn load_initial_state(
source_id: &str,
config: &RisLiveSourceConfig,
state_store: &Option<Arc<dyn StateStoreProvider>>,
) -> Result<StreamState> {
if config.clear_state_on_start {
if let Some(store) = state_store {
store.delete(source_id, STATE_KEY).await.with_context(|| {
format!("failed to clear persisted state for source '{source_id}'")
})?;
}
return Ok(StreamState::default());
}
load_state(source_id, state_store).await
}
async fn load_state(
source_id: &str,
state_store: &Option<Arc<dyn StateStoreProvider>>,
) -> Result<StreamState> {
let Some(store) = state_store else {
return Ok(StreamState::default());
};
let bytes = store
.get(source_id, STATE_KEY)
.await
.with_context(|| format!("failed to read state for source '{source_id}'"))?;
let Some(bytes) = bytes else {
return Ok(StreamState::default());
};
let persisted: PersistedStreamState = serde_json::from_slice(&bytes)
.with_context(|| format!("invalid persisted state payload for source '{source_id}'"))?;
Ok(persisted.into())
}
async fn persist_state(
source_id: &str,
state_store: &Option<Arc<dyn StateStoreProvider>>,
state: &StreamState,
) -> Result<()> {
let Some(store) = state_store else {
return Ok(());
};
let payload =
serde_json::to_vec(&PersistedStreamState::from(state)).context("failed to encode state")?;
store
.set(source_id, STATE_KEY, payload)
.await
.with_context(|| format!("failed to persist state for source '{source_id}'"))?;
Ok(())
}