use std::{collections::HashMap, fmt::Debug, sync::Arc};
use linera_base::time::{Duration, Instant};
use tokio::sync::broadcast;
use super::{
cache::SubsumingKey,
request::{RequestKey, RequestResult},
};
use crate::node::NodeError;
#[derive(Debug, Clone)]
pub(super) struct InFlightTracker<N> {
entries: Arc<tokio::sync::RwLock<HashMap<RequestKey, InFlightEntry<N>>>>,
timeout: Duration,
}
impl<N: Clone> InFlightTracker<N> {
pub(super) fn new(timeout: Duration) -> Self {
Self {
entries: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
timeout,
}
}
pub(super) async fn try_subscribe(&self, key: &RequestKey) -> Option<InFlightMatch> {
let in_flight = self.entries.read().await;
if let Some(entry) = in_flight.get(key) {
let elapsed = Instant::now().duration_since(entry.started_at);
if elapsed <= self.timeout {
return Some(InFlightMatch::Exact(Subscribed(entry.sender.subscribe())));
}
}
for (in_flight_key, entry) in in_flight.iter() {
if in_flight_key.subsumes(key) {
let elapsed = Instant::now().duration_since(entry.started_at);
if elapsed <= self.timeout {
return Some(InFlightMatch::Subsuming {
key: in_flight_key.clone(),
outcome: Subscribed(entry.sender.subscribe()),
});
}
}
}
None
}
pub(super) async fn insert_new(&self, key: RequestKey) {
let (sender, _receiver) = broadcast::channel(1);
let mut in_flight = self.entries.write().await;
in_flight.insert(
key,
InFlightEntry {
sender,
started_at: Instant::now(),
alternative_peers: Arc::new(tokio::sync::RwLock::new(Vec::new())),
},
);
}
pub(super) async fn complete_and_broadcast(
&self,
key: &RequestKey,
result: Arc<Result<RequestResult, NodeError>>,
) -> usize {
let mut in_flight = self.entries.write().await;
if let Some(entry) = in_flight.remove(key) {
let waiter_count = entry.sender.receiver_count();
tracing::trace!(
key = ?key,
waiters = waiter_count,
"request completed; broadcasting result to waiters",
);
if waiter_count != 0 {
if let Err(err) = entry.sender.send(result) {
tracing::warn!(
key = ?key,
error = ?err,
"failed to broadcast result to waiters"
);
}
}
return waiter_count;
}
0
}
pub(super) async fn add_alternative_peer(&self, key: &RequestKey, peer: N)
where
N: PartialEq + Eq,
{
if let Some(entry) = self.entries.read().await.get(key) {
{
let mut alt_peers = entry.alternative_peers.write().await;
if !alt_peers.contains(&peer) {
alt_peers.push(peer);
}
}
}
}
pub(super) async fn get_alternative_peers(&self, key: &RequestKey) -> Option<Vec<N>> {
let in_flight = self.entries.read().await;
let entry = in_flight.get(key)?;
let peers = entry.alternative_peers.read().await;
Some(peers.clone())
}
pub(super) async fn remove_alternative_peer(&self, key: &RequestKey, peer: &N)
where
N: PartialEq + Eq,
{
if let Some(entry) = self.entries.read().await.get(key) {
let mut alt_peers = entry.alternative_peers.write().await;
alt_peers.retain(|p| p != peer);
}
}
pub(super) async fn pop_alternative_peer(&self, key: &RequestKey) -> Option<N> {
if let Some(entry) = self.entries.read().await.get(key) {
let mut alt_peers = entry.alternative_peers.write().await;
alt_peers.pop()
} else {
None
}
}
}
#[derive(Debug)]
pub(super) enum InFlightMatch {
Exact(Subscribed),
Subsuming {
key: RequestKey,
outcome: Subscribed,
},
}
#[derive(Debug)]
pub(super) struct Subscribed(pub(super) broadcast::Receiver<Arc<Result<RequestResult, NodeError>>>);
#[derive(Debug)]
pub(super) struct InFlightEntry<N> {
sender: broadcast::Sender<Arc<Result<RequestResult, NodeError>>>,
started_at: Instant,
alternative_peers: Arc<tokio::sync::RwLock<Vec<N>>>,
}