chainlink-data-streams-sdk 1.2.1

Chainlink Data Streams client SDK
Documentation
use super::{Stats, StreamError, WebSocketConnection};

use crate::{
    auth::generate_auth_headers,
    config::{Config, WebSocketHighAvailability},
    endpoints::API_V1_WS,
    stream::{DEFAULT_WS_CONNECT_TIMEOUT, MAX_WS_RECONNECT_INTERVAL, MIN_WS_RECONNECT_INTERVAL},
};

use chainlink_data_streams_report::feed_id::ID;

use std::{
    sync::{atomic::Ordering, Arc},
    time::{SystemTime, UNIX_EPOCH},
};
use tokio::{
    net::TcpStream,
    time::{sleep, timeout},
};
use tokio_tungstenite::{
    connect_async, tungstenite::client::IntoClientRequest, MaybeTlsStream,
    WebSocketStream as TungsteniteWebSocketStream,
};
use tracing::{error, info};

fn parse_origins(ws_url: &str) -> Vec<String> {
    ws_url
        .split(',')
        .map(|url| url.trim().to_string())
        .collect()
}

async fn connect_to_origin(
    config: &Config,
    origin: &str,
    feed_ids: &[ID],
) -> Result<TungsteniteWebSocketStream<MaybeTlsStream<TcpStream>>, StreamError> {
    let feed_ids: Vec<String> = feed_ids.iter().map(|id| id.to_hex_string()).collect();
    let feed_ids_joined = feed_ids.join(",");

    let method = "GET";
    let path = format!("{}?feedIDs={}", API_V1_WS, feed_ids_joined.as_str());
    let body = b"";
    let client_id = &config.api_key;
    let user_secret = &config.api_secret;
    let request_timestamp = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("System time error")
        .as_millis();

    let headers = generate_auth_headers(
        method,
        &path,
        body,
        client_id,
        user_secret,
        request_timestamp,
    )?;

    let url = format!("{}{}", origin, path);
    let mut request = url.into_client_request().map_err(|e| {
        StreamError::ConnectionError(format!("Failed to create client request: {}", e))
    })?;
    request.headers_mut().extend(headers);

    let connect_future = connect_async(request);

    let (ws_stream, ws_response) = timeout(DEFAULT_WS_CONNECT_TIMEOUT, connect_future)
        .await
        .map_err(|_| StreamError::ConnectionError("WebSocket connection timed out".to_string()))?
        .map_err(|e| StreamError::ConnectionError(format!("Failed to connect: {}", e)))?;

    info!("Connected to WebSocket: {:#?}", ws_response);

    Ok(ws_stream)
}

pub(crate) async fn connect(
    config: &Config,
    feed_ids: &[ID],
    stats: Arc<Stats>,
) -> Result<WebSocketConnection, StreamError> {
    let origins = parse_origins(&config.ws_url);

    if config.ws_ha == WebSocketHighAvailability::Enabled && origins.len() > 1 {
        let mut streams = Vec::new();

        for origin in origins {
            match connect_to_origin(config, &origin, feed_ids).await {
                Ok(stream) => {
                    streams.push(stream);
                    stats.configured_connections.fetch_add(1, Ordering::SeqCst);
                    stats.active_connections.fetch_add(1, Ordering::SeqCst);
                }
                Err(e) => {
                    error!("Failed to connect to origin {}: {:?}", origin, e);
                }
            }
        }

        if streams.is_empty() {
            return Err(StreamError::ConnectionError(
                "Failed to connect to any WebSocket origins".into(),
            ));
        }

        Ok(WebSocketConnection::Multiple(streams))
    } else {
        let origin = origins.first().ok_or_else(|| {
            StreamError::ConnectionError("No WebSocket origin found in config".into())
        })?;

        let stream = connect_to_origin(config, origin, feed_ids).await?;
        stats.configured_connections.fetch_add(1, Ordering::SeqCst);
        stats.active_connections.fetch_add(1, Ordering::SeqCst);

        Ok(WebSocketConnection::Single(stream))
    }
}

pub(crate) async fn try_to_reconnect(
    stats: Arc<Stats>,
    config: &Config,
    feed_ids: &[ID],
) -> Result<TungsteniteWebSocketStream<MaybeTlsStream<TcpStream>>, StreamError> {
    let mut reconnect_attempts = 0;
    let max_reconnect_attempts = config.ws_max_reconnect;
    let origin = config.ws_url.split(',').next().unwrap();
    let mut backoff = MIN_WS_RECONNECT_INTERVAL;

    loop {
        info!("Attempting to reconnect to origin: {}", origin);
        reconnect_attempts += 1;
        match connect_to_origin(config, origin, feed_ids).await {
            Ok(new_stream) => {
                stats.active_connections.fetch_add(1, Ordering::SeqCst);
                return Ok(new_stream);
            }
            Err(e) => {
                error!(
                    "Reconnection attempt {} failed: {:?}.",
                    reconnect_attempts, e
                );

                if reconnect_attempts >= max_reconnect_attempts {
                    error!("Max reconnect attempts reached. Exiting.");
                    return Err(StreamError::ConnectionError(
                        "Max reconnect attempts reached".to_string(),
                    ));
                }

                error!("Retrying in {:?}.", backoff);

                sleep(backoff).await;
                backoff = (backoff * 2).min(MAX_WS_RECONNECT_INTERVAL);
            }
        }
    }
}