use crate::json::{JsonError, JsonValue, parse};
use crate::leader_entry::PubkeyError;
use crate::tls::{TlsClientConfig, connect_tcp_stream, load_tls_client_config, wrap_client_stream};
use crate::url::{ParsedUrl, UrlError, parse_http_url};
use crate::{LeaderEntry, LeaderPubkey, SLOTS_PER_LEADER, ScheduleSnapshot};
use std::collections::BTreeMap;
use std::fmt;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::time::Duration;
#[derive(Debug, Clone, Copy, Default)]
struct TpuAddresses {
tpu_quic: Option<SocketAddr>,
tpu_quic_fwd: Option<SocketAddr>,
}
#[derive(Debug, Clone)]
pub struct ScheduleFetcher {
rpc_url: ParsedUrl,
timeout: Duration,
tls_config: Option<TlsClientConfig>,
}
impl ScheduleFetcher {
pub fn new(rpc_url: impl AsRef<str>, timeout: Duration) -> Result<Self, FetcherError> {
let rpc_url = parse_http_url(rpc_url.as_ref())?;
let tls_config = rpc_url
.uses_tls()
.then(load_tls_client_config)
.transpose()
.map_err(FetcherError::Io)?;
Ok(Self {
rpc_url,
timeout,
tls_config,
})
}
pub fn fetch_current(&self) -> Result<ScheduleSnapshot, FetcherError> {
let epoch_info = self.fetch_epoch_info()?;
let cluster_nodes = self.fetch_cluster_nodes()?;
let schedule = self.fetch_leader_schedule(epoch_info.epoch_start_slot)?;
let leader_count = epoch_info.slots_in_epoch.div_ceil(SLOTS_PER_LEADER) as usize;
let mut leaders = vec![LeaderEntry::EMPTY; leader_count].into_boxed_slice();
for (pubkey_text, slots) in schedule {
let pubkey = LeaderPubkey::from_base58(&pubkey_text)?;
let tpu = cluster_nodes.get(&pubkey).copied().unwrap_or_default();
let entry = match (tpu.tpu_quic, tpu.tpu_quic_fwd) {
(Some(tpu_quic), Some(tpu_quic_fwd)) => {
match LeaderEntry::new(pubkey, tpu_quic, tpu_quic_fwd) {
Ok(entry) => entry,
Err(_) => continue,
}
}
_ => continue,
};
let Some(slot_values) = slots.as_array() else {
return Err(FetcherError::InvalidField("leader schedule slots"));
};
for slot in slot_values {
let Some(relative_slot) = slot.as_u64() else {
return Err(FetcherError::InvalidField("leader schedule relative slot"));
};
let leader_offset = (relative_slot / SLOTS_PER_LEADER) as usize;
if relative_slot % SLOTS_PER_LEADER != 0 || leader_offset >= leaders.len() {
continue;
}
leaders[leader_offset] = entry;
}
}
Ok(ScheduleSnapshot::new(
epoch_info.epoch,
epoch_info.epoch_start_slot,
leaders,
))
}
fn fetch_epoch_info(&self) -> Result<EpochInfo, FetcherError> {
let json = self.rpc_request(
r#"{"jsonrpc":"2.0","id":1,"method":"getEpochInfo","params":[{"commitment":"processed"}]}"#,
)?;
let value = parse(&json)?;
let result = rpc_result(&value)?;
Ok(EpochInfo {
epoch: required_u64(result, "epoch")?,
epoch_start_slot: required_u64(result, "absoluteSlot")?
.saturating_sub(required_u64(result, "slotIndex")?),
slots_in_epoch: required_u64(result, "slotsInEpoch")?,
})
}
fn fetch_leader_schedule(
&self,
epoch_start_slot: u64,
) -> Result<Vec<(String, JsonValue)>, FetcherError> {
let request = format!(
r#"{{"jsonrpc":"2.0","id":1,"method":"getLeaderSchedule","params":[{epoch_start_slot}]}}"#
);
let json = self.rpc_request(&request)?;
let value = parse(&json)?;
let result = rpc_result(&value)?;
let Some(entries) = result.as_object() else {
return Err(FetcherError::NoSchedule);
};
Ok(entries.to_vec())
}
fn fetch_cluster_nodes(&self) -> Result<BTreeMap<LeaderPubkey, TpuAddresses>, FetcherError> {
let json = self.rpc_request(r#"{"jsonrpc":"2.0","id":1,"method":"getClusterNodes"}"#)?;
let value = parse(&json)?;
let result = rpc_result(&value)?;
let nodes = result
.as_array()
.ok_or(FetcherError::InvalidField("cluster nodes"))?;
let mut output = BTreeMap::new();
for node in nodes {
let pubkey = LeaderPubkey::from_base58(required_str(node, "pubkey")?)?;
let tpu_quic = optional_socket_addr(node.get("tpuQuic"))?;
let tpu_quic_fwd = optional_socket_addr(node.get("tpuForwardsQuic"))?;
output.insert(
pubkey,
TpuAddresses {
tpu_quic,
tpu_quic_fwd,
},
);
}
Ok(output)
}
fn rpc_request(&self, body: &str) -> Result<String, FetcherError> {
let Some(stream) = connect_tcp_stream(
self.rpc_url.host.as_str(),
self.rpc_url.port,
self.timeout,
self.timeout,
self.timeout,
)
.map_err(FetcherError::Io)?
else {
return Err(FetcherError::NoAddress);
};
let mut stream = wrap_client_stream(&self.rpc_url, stream, self.tls_config.as_ref())
.map_err(FetcherError::Io)?;
let request = format!(
"POST {} HTTP/1.1\r\nHost: {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
self.rpc_url.path,
self.rpc_url.authority(),
body.len(),
body,
);
stream
.write_all(request.as_bytes())
.map_err(FetcherError::Io)?;
stream.flush().map_err(FetcherError::Io)?;
let mut response = Vec::with_capacity(16 * 1024);
stream
.read_to_end(&mut response)
.map_err(FetcherError::Io)?;
parse_http_response(&response)
}
}
#[derive(Debug, Clone, Copy)]
struct EpochInfo {
epoch: u64,
epoch_start_slot: u64,
slots_in_epoch: u64,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum FetcherError {
Url(UrlError),
Io(std::io::Error),
InvalidHttpResponse,
HttpStatus(u16),
Json(JsonError),
Rpc(String),
MissingField(&'static str),
InvalidField(&'static str),
InvalidSocketAddr,
InvalidPubkey(PubkeyError),
NoAddress,
NoSchedule,
}
impl fmt::Display for FetcherError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Url(error) => write!(f, "{error}"),
Self::Io(error) => write!(f, "{error}"),
Self::InvalidHttpResponse => f.write_str("invalid HTTP response"),
Self::HttpStatus(status) => write!(f, "HTTP status {status}"),
Self::Json(error) => write!(f, "{error}"),
Self::Rpc(message) => write!(f, "RPC error: {message}"),
Self::MissingField(field) => write!(f, "missing field `{field}`"),
Self::InvalidField(field) => write!(f, "invalid field `{field}`"),
Self::InvalidSocketAddr => f.write_str("invalid socket address"),
Self::InvalidPubkey(error) => write!(f, "{error}"),
Self::NoAddress => f.write_str("RPC URL resolved to no addresses"),
Self::NoSchedule => f.write_str("no leader schedule returned"),
}
}
}
impl std::error::Error for FetcherError {}
impl From<UrlError> for FetcherError {
fn from(error: UrlError) -> Self {
Self::Url(error)
}
}
impl From<JsonError> for FetcherError {
fn from(error: JsonError) -> Self {
Self::Json(error)
}
}
impl From<PubkeyError> for FetcherError {
fn from(error: PubkeyError) -> Self {
Self::InvalidPubkey(error)
}
}
fn parse_http_response(response: &[u8]) -> Result<String, FetcherError> {
let separator = response
.windows(4)
.position(|window| window == b"\r\n\r\n")
.ok_or(FetcherError::InvalidHttpResponse)?;
let headers = &response[..separator];
let body = &response[separator + 4..];
let header_text =
std::str::from_utf8(headers).map_err(|_| FetcherError::InvalidHttpResponse)?;
let mut lines = header_text.split("\r\n");
let status_line = lines.next().ok_or(FetcherError::InvalidHttpResponse)?;
let status = status_line
.split_whitespace()
.nth(1)
.ok_or(FetcherError::InvalidHttpResponse)?
.parse::<u16>()
.map_err(|_| FetcherError::InvalidHttpResponse)?;
if status != 200 {
return Err(FetcherError::HttpStatus(status));
}
let chunked = lines.any(|line| {
let lower = line.to_ascii_lowercase();
lower.starts_with("transfer-encoding:") && lower.contains("chunked")
});
let payload = if chunked {
decode_chunked(body)?
} else {
body.to_vec()
};
String::from_utf8(payload).map_err(|_| FetcherError::InvalidHttpResponse)
}
fn decode_chunked(mut body: &[u8]) -> Result<Vec<u8>, FetcherError> {
let mut output = Vec::with_capacity(body.len());
loop {
let line_end = body
.windows(2)
.position(|window| window == b"\r\n")
.ok_or(FetcherError::InvalidHttpResponse)?;
let size_text = std::str::from_utf8(&body[..line_end])
.map_err(|_| FetcherError::InvalidHttpResponse)?;
let size = usize::from_str_radix(size_text.trim(), 16)
.map_err(|_| FetcherError::InvalidHttpResponse)?;
body = &body[line_end + 2..];
if size == 0 {
return Ok(output);
}
if body.len() < size + 2 {
return Err(FetcherError::InvalidHttpResponse);
}
output.extend_from_slice(&body[..size]);
body = &body[size + 2..];
}
}
fn rpc_result(value: &JsonValue) -> Result<&JsonValue, FetcherError> {
if let Some(error) = value.get("error") {
return Err(FetcherError::Rpc(format!("{error:?}")));
}
value
.get("result")
.ok_or(FetcherError::MissingField("result"))
}
fn required_u64(value: &JsonValue, key: &'static str) -> Result<u64, FetcherError> {
value
.get(key)
.and_then(JsonValue::as_u64)
.ok_or(FetcherError::MissingField(key))
}
fn required_str<'a>(value: &'a JsonValue, key: &'static str) -> Result<&'a str, FetcherError> {
value
.get(key)
.and_then(JsonValue::as_str)
.ok_or(FetcherError::MissingField(key))
}
fn optional_socket_addr(value: Option<&JsonValue>) -> Result<Option<SocketAddr>, FetcherError> {
match value {
None | Some(JsonValue::Null) => Ok(None),
Some(JsonValue::String(text)) => text
.parse::<SocketAddr>()
.map(Some)
.map_err(|_| FetcherError::InvalidSocketAddr),
_ => Err(FetcherError::InvalidField("socket address")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_chunked_response() {
let response =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\n\r\n";
let body = parse_http_response(response).unwrap();
assert_eq!(body, "test");
}
}