solana-leader 0.4.0

solana leader library
Documentation
use crate::json::{JsonError, JsonValue, parse};
use crate::leader_entry::PubkeyError;
use crate::tls::{TlsClientConfig, connect_tcp_stream, load_tls_client_config, wrap_client_stream};
use crate::url::{ParsedUrl, UrlError, parse_http_url};
use crate::{LeaderEntry, LeaderPubkey, SLOTS_PER_LEADER, ScheduleSnapshot};
use std::collections::BTreeMap;
use std::fmt;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::time::Duration;

#[derive(Debug, Clone, Copy, Default)]
struct TpuAddresses {
    tpu_quic: Option<SocketAddr>,
    tpu_quic_fwd: Option<SocketAddr>,
}

#[derive(Debug, Clone)]
pub struct ScheduleFetcher {
    rpc_url: ParsedUrl,
    timeout: Duration,
    tls_config: Option<TlsClientConfig>,
}

impl ScheduleFetcher {
    pub fn new(rpc_url: impl AsRef<str>, timeout: Duration) -> Result<Self, FetcherError> {
        let rpc_url = parse_http_url(rpc_url.as_ref())?;
        let tls_config = rpc_url
            .uses_tls()
            .then(load_tls_client_config)
            .transpose()
            .map_err(FetcherError::Io)?;

        Ok(Self {
            rpc_url,
            timeout,
            tls_config,
        })
    }

    pub fn fetch_current(&self) -> Result<ScheduleSnapshot, FetcherError> {
        let epoch_info = self.fetch_epoch_info()?;
        let cluster_nodes = self.fetch_cluster_nodes()?;
        let schedule = self.fetch_leader_schedule(epoch_info.epoch_start_slot)?;

        let leader_count = epoch_info.slots_in_epoch.div_ceil(SLOTS_PER_LEADER) as usize;
        let mut leaders = vec![LeaderEntry::EMPTY; leader_count].into_boxed_slice();

        for (pubkey_text, slots) in schedule {
            let pubkey = LeaderPubkey::from_base58(&pubkey_text)?;
            let tpu = cluster_nodes.get(&pubkey).copied().unwrap_or_default();
            let entry = match (tpu.tpu_quic, tpu.tpu_quic_fwd) {
                (Some(tpu_quic), Some(tpu_quic_fwd)) => {
                    match LeaderEntry::new(pubkey, tpu_quic, tpu_quic_fwd) {
                        Ok(entry) => entry,
                        Err(_) => continue,
                    }
                }
                _ => continue,
            };

            let Some(slot_values) = slots.as_array() else {
                return Err(FetcherError::InvalidField("leader schedule slots"));
            };

            for slot in slot_values {
                let Some(relative_slot) = slot.as_u64() else {
                    return Err(FetcherError::InvalidField("leader schedule relative slot"));
                };

                let leader_offset = (relative_slot / SLOTS_PER_LEADER) as usize;
                if relative_slot % SLOTS_PER_LEADER != 0 || leader_offset >= leaders.len() {
                    continue;
                }
                leaders[leader_offset] = entry;
            }
        }

        Ok(ScheduleSnapshot::new(
            epoch_info.epoch,
            epoch_info.epoch_start_slot,
            leaders,
        ))
    }

    fn fetch_epoch_info(&self) -> Result<EpochInfo, FetcherError> {
        let json = self.rpc_request(
            r#"{"jsonrpc":"2.0","id":1,"method":"getEpochInfo","params":[{"commitment":"processed"}]}"#,
        )?;
        let value = parse(&json)?;
        let result = rpc_result(&value)?;

        Ok(EpochInfo {
            epoch: required_u64(result, "epoch")?,
            epoch_start_slot: required_u64(result, "absoluteSlot")?
                .saturating_sub(required_u64(result, "slotIndex")?),
            slots_in_epoch: required_u64(result, "slotsInEpoch")?,
        })
    }

    fn fetch_leader_schedule(
        &self,
        epoch_start_slot: u64,
    ) -> Result<Vec<(String, JsonValue)>, FetcherError> {
        let request = format!(
            r#"{{"jsonrpc":"2.0","id":1,"method":"getLeaderSchedule","params":[{epoch_start_slot}]}}"#
        );
        let json = self.rpc_request(&request)?;
        let value = parse(&json)?;
        let result = rpc_result(&value)?;
        let Some(entries) = result.as_object() else {
            return Err(FetcherError::NoSchedule);
        };
        Ok(entries.to_vec())
    }

    fn fetch_cluster_nodes(&self) -> Result<BTreeMap<LeaderPubkey, TpuAddresses>, FetcherError> {
        let json = self.rpc_request(r#"{"jsonrpc":"2.0","id":1,"method":"getClusterNodes"}"#)?;
        let value = parse(&json)?;
        let result = rpc_result(&value)?;
        let nodes = result
            .as_array()
            .ok_or(FetcherError::InvalidField("cluster nodes"))?;

        let mut output = BTreeMap::new();
        for node in nodes {
            let pubkey = LeaderPubkey::from_base58(required_str(node, "pubkey")?)?;
            let tpu_quic = optional_socket_addr(node.get("tpuQuic"))?;
            let tpu_quic_fwd = optional_socket_addr(node.get("tpuForwardsQuic"))?;
            output.insert(
                pubkey,
                TpuAddresses {
                    tpu_quic,
                    tpu_quic_fwd,
                },
            );
        }

        Ok(output)
    }

    fn rpc_request(&self, body: &str) -> Result<String, FetcherError> {
        let Some(stream) = connect_tcp_stream(
            self.rpc_url.host.as_str(),
            self.rpc_url.port,
            self.timeout,
            self.timeout,
            self.timeout,
        )
        .map_err(FetcherError::Io)?
        else {
            return Err(FetcherError::NoAddress);
        };
        let mut stream = wrap_client_stream(&self.rpc_url, stream, self.tls_config.as_ref())
            .map_err(FetcherError::Io)?;

        let request = format!(
            "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
            self.rpc_url.path,
            self.rpc_url.authority(),
            body.len(),
            body,
        );
        stream
            .write_all(request.as_bytes())
            .map_err(FetcherError::Io)?;
        stream.flush().map_err(FetcherError::Io)?;

        let mut response = Vec::with_capacity(16 * 1024);
        stream
            .read_to_end(&mut response)
            .map_err(FetcherError::Io)?;

        parse_http_response(&response)
    }
}

#[derive(Debug, Clone, Copy)]
struct EpochInfo {
    epoch: u64,
    epoch_start_slot: u64,
    slots_in_epoch: u64,
}

#[derive(Debug)]
#[non_exhaustive]
pub enum FetcherError {
    Url(UrlError),
    Io(std::io::Error),
    InvalidHttpResponse,
    HttpStatus(u16),
    Json(JsonError),
    Rpc(String),
    MissingField(&'static str),
    InvalidField(&'static str),
    InvalidSocketAddr,
    InvalidPubkey(PubkeyError),
    NoAddress,
    NoSchedule,
}

impl fmt::Display for FetcherError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Url(error) => write!(f, "{error}"),
            Self::Io(error) => write!(f, "{error}"),
            Self::InvalidHttpResponse => f.write_str("invalid HTTP response"),
            Self::HttpStatus(status) => write!(f, "HTTP status {status}"),
            Self::Json(error) => write!(f, "{error}"),
            Self::Rpc(message) => write!(f, "RPC error: {message}"),
            Self::MissingField(field) => write!(f, "missing field `{field}`"),
            Self::InvalidField(field) => write!(f, "invalid field `{field}`"),
            Self::InvalidSocketAddr => f.write_str("invalid socket address"),
            Self::InvalidPubkey(error) => write!(f, "{error}"),
            Self::NoAddress => f.write_str("RPC URL resolved to no addresses"),
            Self::NoSchedule => f.write_str("no leader schedule returned"),
        }
    }
}

impl std::error::Error for FetcherError {}

impl From<UrlError> for FetcherError {
    fn from(error: UrlError) -> Self {
        Self::Url(error)
    }
}

impl From<JsonError> for FetcherError {
    fn from(error: JsonError) -> Self {
        Self::Json(error)
    }
}

impl From<PubkeyError> for FetcherError {
    fn from(error: PubkeyError) -> Self {
        Self::InvalidPubkey(error)
    }
}

fn parse_http_response(response: &[u8]) -> Result<String, FetcherError> {
    let separator = response
        .windows(4)
        .position(|window| window == b"\r\n\r\n")
        .ok_or(FetcherError::InvalidHttpResponse)?;

    let headers = &response[..separator];
    let body = &response[separator + 4..];

    let header_text =
        std::str::from_utf8(headers).map_err(|_| FetcherError::InvalidHttpResponse)?;
    let mut lines = header_text.split("\r\n");
    let status_line = lines.next().ok_or(FetcherError::InvalidHttpResponse)?;
    let status = status_line
        .split_whitespace()
        .nth(1)
        .ok_or(FetcherError::InvalidHttpResponse)?
        .parse::<u16>()
        .map_err(|_| FetcherError::InvalidHttpResponse)?;
    if status != 200 {
        return Err(FetcherError::HttpStatus(status));
    }

    let chunked = lines.any(|line| {
        let lower = line.to_ascii_lowercase();
        lower.starts_with("transfer-encoding:") && lower.contains("chunked")
    });

    let payload = if chunked {
        decode_chunked(body)?
    } else {
        body.to_vec()
    };

    String::from_utf8(payload).map_err(|_| FetcherError::InvalidHttpResponse)
}

fn decode_chunked(mut body: &[u8]) -> Result<Vec<u8>, FetcherError> {
    let mut output = Vec::with_capacity(body.len());

    loop {
        let line_end = body
            .windows(2)
            .position(|window| window == b"\r\n")
            .ok_or(FetcherError::InvalidHttpResponse)?;
        let size_text = std::str::from_utf8(&body[..line_end])
            .map_err(|_| FetcherError::InvalidHttpResponse)?;
        let size = usize::from_str_radix(size_text.trim(), 16)
            .map_err(|_| FetcherError::InvalidHttpResponse)?;
        body = &body[line_end + 2..];

        if size == 0 {
            return Ok(output);
        }
        if body.len() < size + 2 {
            return Err(FetcherError::InvalidHttpResponse);
        }

        output.extend_from_slice(&body[..size]);
        body = &body[size + 2..];
    }
}

fn rpc_result(value: &JsonValue) -> Result<&JsonValue, FetcherError> {
    if let Some(error) = value.get("error") {
        return Err(FetcherError::Rpc(format!("{error:?}")));
    }
    value
        .get("result")
        .ok_or(FetcherError::MissingField("result"))
}

fn required_u64(value: &JsonValue, key: &'static str) -> Result<u64, FetcherError> {
    value
        .get(key)
        .and_then(JsonValue::as_u64)
        .ok_or(FetcherError::MissingField(key))
}

fn required_str<'a>(value: &'a JsonValue, key: &'static str) -> Result<&'a str, FetcherError> {
    value
        .get(key)
        .and_then(JsonValue::as_str)
        .ok_or(FetcherError::MissingField(key))
}

fn optional_socket_addr(value: Option<&JsonValue>) -> Result<Option<SocketAddr>, FetcherError> {
    match value {
        None | Some(JsonValue::Null) => Ok(None),
        Some(JsonValue::String(text)) => text
            .parse::<SocketAddr>()
            .map(Some)
            .map_err(|_| FetcherError::InvalidSocketAddr),
        _ => Err(FetcherError::InvalidField("socket address")),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_chunked_response() {
        let response =
            b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\n\r\n";
        let body = parse_http_response(response).unwrap();
        assert_eq!(body, "test");
    }
}