grpcpulse 0.1.0

Benchmark and compare gRPC endpoints side by side — latency, throughput, and stream lag
Documentation
use crate::endpoint::Endpoint;
use crate::geyser::geyser_client::GeyserClient;
use crate::geyser::subscribe_update::UpdateOneof;
use crate::geyser::{
    SubscribeRequest, SubscribeRequestFilterSlots, SubscribeRequestPing,
};

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tonic::metadata::{Ascii, MetadataKey, MetadataValue};
use tonic::transport::{Channel, ClientTlsConfig};
use tonic::Request;

#[derive(Debug, Clone)]
pub struct SlotArrival {
    pub endpoint_name: String,
    pub arrived_at: Instant,
}

#[derive(Debug)]
pub struct StreamResult {
    pub name: String,
    pub first_count: u32,
    pub avg_lead_ms: f64,
    pub slots_seen: u32,
}

pub async fn run_stream_lag(endpoints: Vec<Endpoint>, num_slots: u32) -> Vec<StreamResult> {
    let arrivals: Arc<Mutex<HashMap<u64, Vec<SlotArrival>>>> =
        Arc::new(Mutex::new(HashMap::new()));

    let mut handles: Vec<JoinHandle<()>> = vec![];

    for endpoint in endpoints {
        let arrivals = Arc::clone(&arrivals);
        let handle = tokio::spawn(async move {
            subscribe_slots(endpoint, arrivals).await;
        });
        handles.push(handle);
    }

    loop {
        tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
        let map = arrivals.lock().unwrap();
        let complete = map.values().filter(|v| v.len() > 1).count();
        if complete >= num_slots as usize {
            break;
        }
    }

    for handle in handles {
        handle.abort();
    }

    compute_results(&arrivals.lock().unwrap())
}

async fn subscribe_slots(
    endpoint: Endpoint,
    arrivals: Arc<Mutex<HashMap<u64, Vec<SlotArrival>>>>,
) {
    let mut channel_builder = Channel::from_shared(endpoint.url.clone()).expect("invalid url");

    if endpoint.url.starts_with("https://") {
        channel_builder = channel_builder
            .tls_config(ClientTlsConfig::new().with_native_roots())
            .expect("tls config failed");
    }

    let channel = match channel_builder.connect().await {
        Ok(ch) => ch,
        Err(e) => {
            eprintln!("[{}] connection failed: {}", endpoint.name, e);
            return;
        }
    };

    let mut client = GeyserClient::new(channel);

    // bidirectional channel: we send requests, server sends updates
    let (tx, rx) = mpsc::channel::<SubscribeRequest>(32);
    let request_stream = ReceiverStream::new(rx);

    let mut req = Request::new(request_stream);
    for (key, value) in &endpoint.headers {
        req.metadata_mut().insert(
            MetadataKey::<Ascii>::from_bytes(key.as_bytes()).expect("invalid header name"),
            MetadataValue::<Ascii>::try_from(value.as_str()).expect("invalid header value"),
        );
    }

    let mut stream = match client.subscribe(req).await {
        Ok(resp) => resp.into_inner(),
        Err(e) => {
            eprintln!("[{}] subscribe failed: {}", endpoint.name, e);
            return;
        }
    };

    // send initial slot subscription
    let mut filter = HashMap::new();
    filter.insert(
        "slots".to_string(),
        SubscribeRequestFilterSlots {
            filter_by_commitment: Some(true),
            interslot_updates: Some(false),
        },
    );
    let slot_filter = filter.clone();

    if tx
        .send(SubscribeRequest {
            slots: filter,
            commitment: Some(0), // 0 = processed, most frequent
            ping: None,
        })
        .await
        .is_err()
    {
        return;
    }

    let mut ping_id = 0i32;

    loop {
        match stream.message().await {
            Ok(Some(update)) => match update.update_oneof {
                Some(UpdateOneof::Slot(slot_update)) => {
                    eprintln!("[slot] {} got slot {}", endpoint.name, slot_update.slot);
                    let arrival = SlotArrival {
                        endpoint_name: endpoint.name.clone(),
                        arrived_at: Instant::now(),
                    };
                    arrivals
                        .lock()
                        .unwrap()
                        .entry(slot_update.slot)
                        .or_default()
                        .push(arrival);
                }
                Some(UpdateOneof::Ping(_)) => {
                    ping_id += 1;
                    let _ = tx
                        .send(SubscribeRequest {
                            slots: slot_filter.clone(),
                            commitment: Some(0),
                            ping: Some(SubscribeRequestPing { id: ping_id }),
                        })
                        .await;
                }
                _ => {}
            },
            Ok(None) => break,
            Err(e) => {
                eprintln!("[{}] stream error: {}", endpoint.name, e);
                break;
            }
        }
    }
}

fn compute_results(map: &HashMap<u64, Vec<SlotArrival>>) -> Vec<StreamResult> {
    let mut first_counts: HashMap<String, u32> = HashMap::new();
    let mut lead_times: HashMap<String, Vec<f64>> = HashMap::new();
    let mut all_endpoints: std::collections::HashSet<String> = std::collections::HashSet::new();

    for arrivals in map.values() {
        if arrivals.len() < 2 {
            continue;
        }

        for arrival in arrivals {
            all_endpoints.insert(arrival.endpoint_name.clone());
        }

        let earliest = arrivals.iter().min_by_key(|a| a.arrived_at).unwrap();
        *first_counts.entry(earliest.endpoint_name.clone()).or_insert(0) += 1;

        for arrival in arrivals {
            let lag_ms = arrival
                .arrived_at
                .duration_since(earliest.arrived_at)
                .as_secs_f64()
                * 1000.0;
            lead_times
                .entry(arrival.endpoint_name.clone())
                .or_default()
                .push(lag_ms);
        }
    }

    // include ALL endpoints, even those with 0 wins
    let mut results: Vec<StreamResult> = all_endpoints
        .iter()
        .map(|name| {
            let times = lead_times.get(name).map(|v| v.as_slice()).unwrap_or(&[]);
            let avg = if times.is_empty() {
                0.0
            } else {
                times.iter().sum::<f64>() / times.len() as f64
            };
            StreamResult {
                name: name.clone(),
                first_count: *first_counts.get(name).unwrap_or(&0),
                avg_lead_ms: avg,
                slots_seen: times.len() as u32,
            }
        })
        .collect();

    results.sort_by(|a, b| b.first_count.cmp(&a.first_count));
    results
}