use crate::tls::{
TlsClientConfig, connect_tcp_stream, load_tls_client_config, shared_crypto_provider,
wrap_client_stream,
};
use crate::url::{ParsedUrl, UrlError, parse_ws_url};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use cpu::{cpu_pause, set_current_thread_affinity};
use sha1::{Digest, Sha1};
use std::fmt;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread;
use std::time::{Duration, Instant};
const MAX_WEBSOCKET_FRAME_PAYLOAD: usize = 1 << 20;
const MAX_WEBSOCKET_HEADER_BYTES: usize = 16 * 1024;
const WEBSOCKET_READ_TIMEOUT: Duration = Duration::from_millis(250);
const MAX_CONSECUTIVE_TIMEOUTS: usize = 20;
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SlotSourceConfig {
WebSocket {
url: String,
},
#[cfg(feature = "slot-grpc")]
Grpc {
url: String,
commitment: GrpcCommitment,
},
}
#[cfg(feature = "slot-grpc")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GrpcCommitment {
Processed,
Confirmed,
Finalized,
}
#[derive(Debug, Default)]
pub struct SlotState {
slot: AtomicU64,
connected: AtomicBool,
}
impl SlotState {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn load(&self) -> u64 {
self.slot.load(Ordering::Acquire)
}
#[inline]
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::Acquire)
}
}
#[derive(Clone)]
pub(crate) enum SlotCursor {
WebSocket(Arc<SlotState>),
#[cfg(feature = "slot-grpc")]
Grpc(Arc<GrpcCursor>),
}
impl SlotCursor {
#[inline]
pub(crate) fn load_slot(&self) -> u64 {
match self {
Self::WebSocket(state) => state.load(),
#[cfg(feature = "slot-grpc")]
Self::Grpc(cursor) => cursor.load_slot(),
}
}
#[inline]
pub(crate) fn is_connected(&self) -> bool {
match self {
Self::WebSocket(state) => state.is_connected(),
#[cfg(feature = "slot-grpc")]
Self::Grpc(cursor) => cursor.is_connected(),
}
}
}
enum SlotSourceHandle {
WebSocket(WebSocketSlotSubscriber),
#[cfg(feature = "slot-grpc")]
Grpc(GrpcSlotSource),
}
impl SlotSourceHandle {
fn start(&mut self) -> Result<(), SlotSubscriberError> {
match self {
Self::WebSocket(source) => source.start(),
#[cfg(feature = "slot-grpc")]
Self::Grpc(source) => {
source.start();
Ok(())
}
}
}
fn stop(&mut self) {
match self {
Self::WebSocket(source) => source.stop(),
#[cfg(feature = "slot-grpc")]
Self::Grpc(source) => source.stop(),
}
}
fn wait_for_connection(&self, timeout: Duration) -> bool {
match self {
Self::WebSocket(source) => source.wait_for_connection(timeout),
#[cfg(feature = "slot-grpc")]
Self::Grpc(source) => source.wait_for_connection(timeout),
}
}
}
pub trait SlotSource: Send {
fn start(&mut self) -> Result<(), SlotSubscriberError>;
fn stop(&mut self);
fn latest_slot(&self) -> u64;
fn is_connected(&self) -> bool;
fn wait_for_connection(&self, timeout: Duration) -> bool;
}
pub struct SlotSubscriber {
cursor: SlotCursor,
handle: SlotSourceHandle,
}
impl SlotSubscriber {
pub fn new(
config: SlotSourceConfig,
cpu_core: Option<usize>,
connect_timeout: Duration,
reconnect_backoff: Duration,
) -> Result<Self, SlotSubscriberError> {
match config {
SlotSourceConfig::WebSocket { url } => {
let source = WebSocketSlotSubscriber::new(
url,
cpu_core,
connect_timeout,
reconnect_backoff,
)?;
let cursor = SlotCursor::WebSocket(Arc::clone(&source.state));
Ok(Self {
cursor,
handle: SlotSourceHandle::WebSocket(source),
})
}
#[cfg(feature = "slot-grpc")]
SlotSourceConfig::Grpc { url, commitment } => {
let source = GrpcSlotSource::new(url, commitment, cpu_core);
let cursor = SlotCursor::Grpc(Arc::clone(&source.cursor));
Ok(Self {
cursor,
handle: SlotSourceHandle::Grpc(source),
})
}
}
}
pub fn start(&mut self) -> Result<(), SlotSubscriberError> {
self.handle.start()
}
pub fn stop(&mut self) {
self.handle.stop();
}
#[inline]
pub fn load_slot(&self) -> u64 {
self.cursor.load_slot()
}
#[inline]
pub fn is_connected(&self) -> bool {
self.cursor.is_connected()
}
pub fn wait_for_connection(&self, timeout: Duration) -> bool {
self.handle.wait_for_connection(timeout)
}
#[cfg(feature = "managed")]
pub(crate) fn cursor(&self) -> SlotCursor {
self.cursor.clone()
}
}
impl SlotSource for SlotSubscriber {
#[inline]
fn start(&mut self) -> Result<(), SlotSubscriberError> {
SlotSubscriber::start(self)
}
#[inline]
fn stop(&mut self) {
SlotSubscriber::stop(self);
}
#[inline]
fn latest_slot(&self) -> u64 {
self.load_slot()
}
#[inline]
fn is_connected(&self) -> bool {
SlotSubscriber::is_connected(self)
}
#[inline]
fn wait_for_connection(&self, timeout: Duration) -> bool {
SlotSubscriber::wait_for_connection(self, timeout)
}
}
impl Drop for SlotSubscriber {
fn drop(&mut self) {
self.stop();
}
}
struct WebSocketSlotSubscriber {
url: ParsedUrl,
tls_config: Option<TlsClientConfig>,
state: Arc<SlotState>,
running: Arc<AtomicBool>,
connect_timeout: Duration,
reconnect_backoff: Duration,
cpu_core: Option<usize>,
thread_handle: Option<thread::JoinHandle<()>>,
}
impl WebSocketSlotSubscriber {
fn new(
url: String,
cpu_core: Option<usize>,
connect_timeout: Duration,
reconnect_backoff: Duration,
) -> Result<Self, SlotSubscriberError> {
let url = parse_ws_url(&url)?;
let tls_config = url
.uses_tls()
.then(load_tls_client_config)
.transpose()
.map_err(SlotSubscriberError::Io)?;
Ok(Self {
url,
tls_config,
state: Arc::new(SlotState::new()),
running: Arc::new(AtomicBool::new(false)),
connect_timeout,
reconnect_backoff,
cpu_core,
thread_handle: None,
})
}
fn start(&mut self) -> Result<(), SlotSubscriberError> {
if self.running.swap(true, Ordering::AcqRel) {
return Ok(());
}
let url = self.url.clone();
let state = Arc::clone(&self.state);
let running = Arc::clone(&self.running);
let connect_timeout = self.connect_timeout;
let reconnect_backoff = self.reconnect_backoff;
let cpu_core = self.cpu_core;
let tls_config = self.tls_config.clone();
let handle = thread::Builder::new()
.name("slot-websocket".to_string())
.spawn(move || {
if let Some(core) = cpu_core {
let _ = set_current_thread_affinity([core]);
}
run_websocket_loop(
url,
tls_config,
state,
running,
connect_timeout,
reconnect_backoff,
);
})
.map_err(SlotSubscriberError::ThreadSpawn)?;
self.thread_handle = Some(handle);
Ok(())
}
fn stop(&mut self) {
self.running.store(false, Ordering::Release);
self.state.connected.store(false, Ordering::Release);
if let Some(handle) = self.thread_handle.take() {
let _ = handle.join();
}
}
fn wait_for_connection(&self, timeout: Duration) -> bool {
wait_for_connection(&self.state.connected, timeout)
}
}
fn run_websocket_loop(
url: ParsedUrl,
tls_config: Option<TlsClientConfig>,
state: Arc<SlotState>,
running: Arc<AtomicBool>,
connect_timeout: Duration,
reconnect_backoff: Duration,
) {
let mut payload = Vec::with_capacity(1024);
while running.load(Ordering::Acquire) {
state.connected.store(false, Ordering::Release);
let result = connect_and_subscribe(
&url,
tls_config.as_ref(),
&state,
&running,
connect_timeout,
&mut payload,
);
if !running.load(Ordering::Acquire) {
break;
}
if result.is_err() {
thread::sleep(reconnect_backoff);
}
}
state.connected.store(false, Ordering::Release);
}
fn connect_and_subscribe(
url: &ParsedUrl,
tls_config: Option<&TlsClientConfig>,
state: &SlotState,
running: &AtomicBool,
connect_timeout: Duration,
payload: &mut Vec<u8>,
) -> Result<(), SlotSubscriberError> {
let Some(stream) = connect_tcp_stream(
url.host.as_str(),
url.port,
connect_timeout,
WEBSOCKET_READ_TIMEOUT,
connect_timeout,
)
.map_err(SlotSubscriberError::Io)?
else {
return Err(SlotSubscriberError::NoAddress);
};
let mut stream =
wrap_client_stream(url, stream, tls_config).map_err(SlotSubscriberError::Io)?;
perform_websocket_handshake(&mut stream, url)?;
write_masked_text_frame(
&mut stream,
r#"{"jsonrpc":"2.0","id":1,"method":"slotSubscribe"}"#,
)?;
state.connected.store(true, Ordering::Release);
let mut consecutive_timeouts = 0usize;
while running.load(Ordering::Acquire) {
match read_frame(&mut stream, payload) {
Ok(Frame::Text | Frame::Binary) => {
consecutive_timeouts = 0;
if let Some(slot) = parse_slot_notification(payload) {
state.slot.store(slot, Ordering::Release);
}
}
Ok(Frame::Ping) => {
consecutive_timeouts = 0;
write_masked_frame(&mut stream, 0xA, payload)?
}
Ok(Frame::Pong) => consecutive_timeouts = 0,
Ok(Frame::Close) => {
write_masked_frame(&mut stream, 0x8, payload)?;
return Err(SlotSubscriberError::ConnectionClosed);
}
Err(SlotSubscriberError::Timeout) => {
consecutive_timeouts += 1;
if consecutive_timeouts >= MAX_CONSECUTIVE_TIMEOUTS {
return Err(SlotSubscriberError::Timeout);
}
continue;
}
Err(error) => return Err(error),
}
}
Ok(())
}
fn wait_for_connection(flag: &AtomicBool, timeout: Duration) -> bool {
let start = Instant::now();
let mut spins = 16u32;
loop {
if flag.load(Ordering::Acquire) {
return true;
}
if start.elapsed() >= timeout {
return false;
}
if spins < 1024 {
for _ in 0..spins {
cpu_pause();
}
spins = spins.saturating_mul(2);
} else {
thread::yield_now();
}
}
}
fn perform_websocket_handshake(
stream: &mut impl ReadWrite,
url: &ParsedUrl,
) -> Result<(), SlotSubscriberError> {
let request_key = generate_websocket_key()?;
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {}\r\nSec-WebSocket-Version: 13\r\n\r\n",
url.path,
url.authority(),
request_key,
);
stream
.write_all(request.as_bytes())
.map_err(SlotSubscriberError::Io)?;
stream.flush().map_err(SlotSubscriberError::Io)?;
let response = read_http_headers(stream)?;
validate_websocket_handshake(&response, &request_key)
}
trait ReadWrite: Read + Write {}
impl<T: Read + Write> ReadWrite for T {}
fn read_http_headers(stream: &mut impl ReadWrite) -> Result<String, SlotSubscriberError> {
let mut buffer = Vec::with_capacity(1024);
let mut byte = [0u8; 1];
while !buffer.ends_with(b"\r\n\r\n") {
if buffer.len() >= MAX_WEBSOCKET_HEADER_BYTES {
return Err(SlotSubscriberError::InvalidHandshake);
}
match stream.read(&mut byte) {
Ok(0) => return Err(SlotSubscriberError::InvalidHandshake),
Ok(_) => buffer.push(byte[0]),
Err(error)
if matches!(
error.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
return Err(SlotSubscriberError::Timeout);
}
Err(error) => return Err(SlotSubscriberError::Io(error)),
}
}
String::from_utf8(buffer).map_err(|_| SlotSubscriberError::InvalidHandshake)
}
fn validate_websocket_handshake(
response: &str,
request_key: &str,
) -> Result<(), SlotSubscriberError> {
let mut lines = response.split("\r\n");
let Some(status_line) = lines.next() else {
return Err(SlotSubscriberError::InvalidHandshake);
};
if !status_line.starts_with("HTTP/1.1 101") && !status_line.starts_with("HTTP/1.0 101") {
return Err(SlotSubscriberError::InvalidHandshake);
}
let mut accept_header = None;
for line in lines {
let Some((name, value)) = line.split_once(':') else {
continue;
};
if name.eq_ignore_ascii_case("sec-websocket-accept") {
accept_header = Some(value.trim());
}
}
let expected = expected_websocket_accept(request_key);
if accept_header != Some(expected.as_str()) {
return Err(SlotSubscriberError::InvalidHandshake);
}
Ok(())
}
fn generate_websocket_key() -> Result<String, SlotSubscriberError> {
let mut key = [0u8; 16];
fill_websocket_random(&mut key)?;
Ok(BASE64_STANDARD.encode(key))
}
fn expected_websocket_accept(request_key: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(request_key.as_bytes());
hasher.update(WEBSOCKET_GUID.as_bytes());
BASE64_STANDARD.encode(hasher.finalize())
}
fn write_masked_text_frame(
stream: &mut impl ReadWrite,
text: &str,
) -> Result<(), SlotSubscriberError> {
write_masked_frame(stream, 0x1, text.as_bytes())
}
fn write_masked_frame(
stream: &mut impl ReadWrite,
opcode: u8,
payload: &[u8],
) -> Result<(), SlotSubscriberError> {
if payload.len() > u16::MAX as usize {
return Err(SlotSubscriberError::FrameTooLarge);
}
let mut mask = [0u8; 4];
fill_websocket_random(&mut mask)?;
let mut frame = Vec::with_capacity(payload.len() + 8);
frame.push(0x80 | opcode);
if payload.len() < 126 {
frame.push(0x80 | payload.len() as u8);
} else {
frame.push(0x80 | 126);
frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
}
frame.extend_from_slice(&mask);
for (index, byte) in payload.iter().enumerate() {
frame.push(*byte ^ mask[index & 3]);
}
stream.write_all(&frame).map_err(SlotSubscriberError::Io)?;
stream.flush().map_err(SlotSubscriberError::Io)?;
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Frame {
Text,
Binary,
Ping,
Pong,
Close,
}
fn read_frame(
stream: &mut impl ReadWrite,
payload: &mut Vec<u8>,
) -> Result<Frame, SlotSubscriberError> {
let mut header = [0u8; 2];
read_exact_timeout(stream, &mut header)?;
let fin = header[0] & 0x80 != 0;
if !fin {
return Err(SlotSubscriberError::InvalidFrame);
}
if header[0] & 0x70 != 0 {
return Err(SlotSubscriberError::InvalidFrame);
}
let opcode = header[0] & 0x0F;
let masked = header[1] & 0x80 != 0;
if masked {
return Err(SlotSubscriberError::InvalidFrame);
}
let mut len = (header[1] & 0x7F) as usize;
if len == 126 {
let mut extended = [0u8; 2];
read_exact_timeout(stream, &mut extended)?;
len = u16::from_be_bytes(extended) as usize;
} else if len == 127 {
let mut extended = [0u8; 8];
read_exact_timeout(stream, &mut extended)?;
if extended[0] & 0x80 != 0 {
return Err(SlotSubscriberError::InvalidFrame);
}
let extended_len = u64::from_be_bytes(extended);
if extended_len > usize::MAX as u64 {
return Err(SlotSubscriberError::FrameTooLarge);
}
len = extended_len as usize;
}
if matches!(opcode, 0x8..=0xA) && len > 125 {
return Err(SlotSubscriberError::InvalidFrame);
}
if len > MAX_WEBSOCKET_FRAME_PAYLOAD {
return Err(SlotSubscriberError::FrameTooLarge);
}
payload.clear();
payload.resize(len, 0);
read_exact_timeout(stream, payload.as_mut_slice())?;
match opcode {
0x1 => Ok(Frame::Text),
0x2 => Ok(Frame::Binary),
0x8 => Ok(Frame::Close),
0x9 => Ok(Frame::Ping),
0xA => Ok(Frame::Pong),
_ => Err(SlotSubscriberError::InvalidFrame),
}
}
fn read_exact_timeout(
stream: &mut impl ReadWrite,
buffer: &mut [u8],
) -> Result<(), SlotSubscriberError> {
match stream.read_exact(buffer) {
Ok(()) => Ok(()),
Err(error)
if matches!(
error.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
Err(SlotSubscriberError::Timeout)
}
Err(error) => Err(SlotSubscriberError::Io(error)),
}
}
fn parse_slot_notification(payload: &[u8]) -> Option<u64> {
let params_index = find_json_key(payload, b"params")?;
let result_index = params_index + find_json_key(&payload[params_index..], b"result")?;
let slot_index = result_index + find_json_key(&payload[result_index..], b"slot")?;
let mut cursor = slot_index + br#""slot""#.len();
while cursor < payload.len() && matches!(payload[cursor], b' ' | b'\t' | b'\r' | b'\n') {
cursor += 1;
}
if cursor >= payload.len() || payload[cursor] != b':' {
return None;
}
cursor += 1;
while cursor < payload.len() && matches!(payload[cursor], b' ' | b'\t' | b'\r' | b'\n') {
cursor += 1;
}
let mut value = 0u64;
let mut found_digit = false;
while cursor < payload.len() {
let byte = payload[cursor];
if !byte.is_ascii_digit() {
break;
}
found_digit = true;
value = value.checked_mul(10)?.checked_add((byte - b'0') as u64)?;
cursor += 1;
}
found_digit.then_some(value)
}
fn find_json_key(payload: &[u8], key: &[u8]) -> Option<usize> {
payload.windows(key.len() + 2).position(|window| {
window.first() == Some(&b'"')
&& window.last() == Some(&b'"')
&& &window[1..window.len() - 1] == key
})
}
fn fill_websocket_random(buffer: &mut [u8]) -> Result<(), SlotSubscriberError> {
shared_crypto_provider()
.secure_random
.fill(buffer)
.map_err(|_| {
SlotSubscriberError::Io(io::Error::other("failed to generate websocket randomness"))
})
}
#[derive(Debug)]
#[non_exhaustive]
pub enum SlotSubscriberError {
Url(UrlError),
Io(std::io::Error),
ThreadSpawn(std::io::Error),
InvalidHandshake,
InvalidFrame,
FrameTooLarge,
ConnectionClosed,
Timeout,
NoAddress,
}
impl fmt::Display for SlotSubscriberError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Url(error) => write!(f, "{error}"),
Self::Io(error) => write!(f, "{error}"),
Self::ThreadSpawn(error) => write!(f, "{error}"),
Self::InvalidHandshake => f.write_str("invalid websocket handshake"),
Self::InvalidFrame => f.write_str("invalid websocket frame"),
Self::FrameTooLarge => f.write_str("websocket frame too large"),
Self::ConnectionClosed => f.write_str("websocket peer closed the connection"),
Self::Timeout => f.write_str("slot source timed out"),
Self::NoAddress => f.write_str("slot source URL resolved to no addresses"),
}
}
}
impl std::error::Error for SlotSubscriberError {}
impl From<UrlError> for SlotSubscriberError {
fn from(error: UrlError) -> Self {
Self::Url(error)
}
}
#[cfg(feature = "slot-grpc")]
pub(crate) struct GrpcCursor {
tracker: Arc<connector::SlotTracker>,
commitment: GrpcCommitment,
}
#[cfg(feature = "slot-grpc")]
impl GrpcCursor {
#[inline]
fn load_slot(&self) -> u64 {
match self.commitment {
GrpcCommitment::Processed => self.tracker.load_processed(),
GrpcCommitment::Confirmed => self.tracker.load_confirmed(),
GrpcCommitment::Finalized => self.tracker.load_finalized(),
}
}
#[inline]
fn is_connected(&self) -> bool {
self.tracker.is_connected()
}
}
#[cfg(feature = "slot-grpc")]
struct GrpcSlotSource {
cursor: Arc<GrpcCursor>,
subscriber: connector::GrpcSlotSubscriber,
}
#[cfg(feature = "slot-grpc")]
impl GrpcSlotSource {
fn new(url: String, commitment: GrpcCommitment, cpu_core: Option<usize>) -> Self {
let tracker = Arc::new(connector::SlotTracker::new());
let subscriber = connector::GrpcSlotSubscriber::new(
url,
Arc::clone(&tracker),
match commitment {
GrpcCommitment::Processed => connector::CommitmentLevel::Processed,
GrpcCommitment::Confirmed => connector::CommitmentLevel::Confirmed,
GrpcCommitment::Finalized => connector::CommitmentLevel::Finalized,
},
cpu_core,
);
Self {
cursor: Arc::new(GrpcCursor {
tracker,
commitment,
}),
subscriber,
}
}
fn start(&mut self) {
self.subscriber.start();
}
fn stop(&mut self) {
self.subscriber.stop();
}
fn wait_for_connection(&self, timeout: Duration) -> bool {
self.subscriber.wait_for_connection(timeout)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn parse_slot_from_notification() {
let payload = br#"{"jsonrpc":"2.0","method":"slotNotification","params":{"result":{"parent":1,"root":2,"slot":123456}}}"#;
assert_eq!(parse_slot_notification(payload), Some(123456));
}
#[test]
fn ignore_unrelated_slot_keys_before_result_slot() {
let payload = br#"{"jsonrpc":"2.0","method":"slotNotification","params":{"slotIndex":7,"result":{"parent":1,"root":2,"slot":123456}}}"#;
assert_eq!(parse_slot_notification(payload), Some(123456));
}
#[test]
fn websocket_accept_matches_rfc_example() {
assert_eq!(
expected_websocket_accept("dGhlIHNhbXBsZSBub25jZQ=="),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
}
#[test]
fn reject_frame_with_rsv_bits() {
let mut payload = Vec::new();
let mut stream = Cursor::new(vec![0xC1, 0x00]);
assert!(matches!(
read_frame(&mut stream, &mut payload),
Err(SlotSubscriberError::InvalidFrame)
));
}
#[test]
fn reject_masked_server_frame() {
let mut payload = Vec::new();
let mut stream = Cursor::new(vec![0x81, 0x80, 0x00, 0x00, 0x00, 0x00]);
assert!(matches!(
read_frame(&mut stream, &mut payload),
Err(SlotSubscriberError::InvalidFrame)
));
}
#[test]
fn wait_for_connection_times_out() {
let flag = AtomicBool::new(false);
assert!(!wait_for_connection(&flag, Duration::from_millis(1)));
}
}