use crate::endpoint::Endpoint;
use crate::geyser::geyser_client::GeyserClient;
use crate::geyser::subscribe_update::UpdateOneof;
use crate::geyser::{
SubscribeRequest, SubscribeRequestFilterSlots, SubscribeRequestPing,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tonic::metadata::{Ascii, MetadataKey, MetadataValue};
use tonic::transport::{Channel, ClientTlsConfig};
use tonic::Request;
#[derive(Debug, Clone)]
pub struct SlotArrival {
pub endpoint_name: String,
pub arrived_at: Instant,
}
#[derive(Debug)]
pub struct StreamResult {
pub name: String,
pub first_count: u32,
pub avg_lead_ms: f64,
pub slots_seen: u32,
}
pub async fn run_stream_lag(endpoints: Vec<Endpoint>, num_slots: u32) -> Vec<StreamResult> {
let arrivals: Arc<Mutex<HashMap<u64, Vec<SlotArrival>>>> =
Arc::new(Mutex::new(HashMap::new()));
let mut handles: Vec<JoinHandle<()>> = vec![];
for endpoint in endpoints {
let arrivals = Arc::clone(&arrivals);
let handle = tokio::spawn(async move {
subscribe_slots(endpoint, arrivals).await;
});
handles.push(handle);
}
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let map = arrivals.lock().unwrap();
let complete = map.values().filter(|v| v.len() > 1).count();
if complete >= num_slots as usize {
break;
}
}
for handle in handles {
handle.abort();
}
compute_results(&arrivals.lock().unwrap())
}
async fn subscribe_slots(
endpoint: Endpoint,
arrivals: Arc<Mutex<HashMap<u64, Vec<SlotArrival>>>>,
) {
let mut channel_builder = Channel::from_shared(endpoint.url.clone()).expect("invalid url");
if endpoint.url.starts_with("https://") {
channel_builder = channel_builder
.tls_config(ClientTlsConfig::new().with_native_roots())
.expect("tls config failed");
}
let channel = match channel_builder.connect().await {
Ok(ch) => ch,
Err(e) => {
eprintln!("[{}] connection failed: {}", endpoint.name, e);
return;
}
};
let mut client = GeyserClient::new(channel);
let (tx, rx) = mpsc::channel::<SubscribeRequest>(32);
let request_stream = ReceiverStream::new(rx);
let mut req = Request::new(request_stream);
for (key, value) in &endpoint.headers {
req.metadata_mut().insert(
MetadataKey::<Ascii>::from_bytes(key.as_bytes()).expect("invalid header name"),
MetadataValue::<Ascii>::try_from(value.as_str()).expect("invalid header value"),
);
}
let mut stream = match client.subscribe(req).await {
Ok(resp) => resp.into_inner(),
Err(e) => {
eprintln!("[{}] subscribe failed: {}", endpoint.name, e);
return;
}
};
let mut filter = HashMap::new();
filter.insert(
"slots".to_string(),
SubscribeRequestFilterSlots {
filter_by_commitment: Some(true),
interslot_updates: Some(false),
},
);
let slot_filter = filter.clone();
if tx
.send(SubscribeRequest {
slots: filter,
commitment: Some(0), ping: None,
})
.await
.is_err()
{
return;
}
let mut ping_id = 0i32;
loop {
match stream.message().await {
Ok(Some(update)) => match update.update_oneof {
Some(UpdateOneof::Slot(slot_update)) => {
eprintln!("[slot] {} got slot {}", endpoint.name, slot_update.slot);
let arrival = SlotArrival {
endpoint_name: endpoint.name.clone(),
arrived_at: Instant::now(),
};
arrivals
.lock()
.unwrap()
.entry(slot_update.slot)
.or_default()
.push(arrival);
}
Some(UpdateOneof::Ping(_)) => {
ping_id += 1;
let _ = tx
.send(SubscribeRequest {
slots: slot_filter.clone(),
commitment: Some(0),
ping: Some(SubscribeRequestPing { id: ping_id }),
})
.await;
}
_ => {}
},
Ok(None) => break,
Err(e) => {
eprintln!("[{}] stream error: {}", endpoint.name, e);
break;
}
}
}
}
fn compute_results(map: &HashMap<u64, Vec<SlotArrival>>) -> Vec<StreamResult> {
let mut first_counts: HashMap<String, u32> = HashMap::new();
let mut lead_times: HashMap<String, Vec<f64>> = HashMap::new();
let mut all_endpoints: std::collections::HashSet<String> = std::collections::HashSet::new();
for arrivals in map.values() {
if arrivals.len() < 2 {
continue;
}
for arrival in arrivals {
all_endpoints.insert(arrival.endpoint_name.clone());
}
let earliest = arrivals.iter().min_by_key(|a| a.arrived_at).unwrap();
*first_counts.entry(earliest.endpoint_name.clone()).or_insert(0) += 1;
for arrival in arrivals {
let lag_ms = arrival
.arrived_at
.duration_since(earliest.arrived_at)
.as_secs_f64()
* 1000.0;
lead_times
.entry(arrival.endpoint_name.clone())
.or_default()
.push(lag_ms);
}
}
let mut results: Vec<StreamResult> = all_endpoints
.iter()
.map(|name| {
let times = lead_times.get(name).map(|v| v.as_slice()).unwrap_or(&[]);
let avg = if times.is_empty() {
0.0
} else {
times.iter().sum::<f64>() / times.len() as f64
};
StreamResult {
name: name.clone(),
first_count: *first_counts.get(name).unwrap_or(&0),
avg_lead_ms: avg,
slots_seen: times.len() as u32,
}
})
.collect();
results.sort_by(|a, b| b.first_count.cmp(&a.first_count));
results
}