use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use futures::pin_mut;
use futures::prelude::*;
use futures::task::AtomicWaker;
use futures::task::Context;
use futures::task::Poll;
use log::debug;
use log::info;
use pin_project::pin_project;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tonic::transport::Channel;
use crate::internal_err;
use crate::proto::pdpb::pd_client::PdClient;
use crate::proto::pdpb::*;
use crate::Result;
const MAX_BATCH_SIZE: usize = 64;
const MAX_PENDING_COUNT: usize = 1 << 16;
type TimestampRequest = oneshot::Sender<Timestamp>;
#[derive(Clone)]
pub(crate) struct TimestampOracle {
request_tx: mpsc::Sender<TimestampRequest>,
}
impl TimestampOracle {
pub(crate) fn new(cluster_id: u64, pd_client: &PdClient<Channel>) -> Result<TimestampOracle> {
let pd_client = pd_client.clone();
let (request_tx, request_rx) = mpsc::channel(MAX_BATCH_SIZE);
tokio::spawn(run_tso(cluster_id, pd_client, request_rx));
Ok(TimestampOracle { request_tx })
}
pub(crate) async fn get_timestamp(self) -> Result<Timestamp> {
debug!("getting current timestamp");
let (request, response) = oneshot::channel();
self.request_tx
.send(request)
.await
.map_err(|_| internal_err!("TimestampRequest channel is closed"))?;
Ok(response.await?)
}
}
async fn run_tso(
cluster_id: u64,
mut pd_client: PdClient<Channel>,
request_rx: mpsc::Receiver<TimestampRequest>,
) -> Result<()> {
let pending_requests = Arc::new(Mutex::new(VecDeque::with_capacity(MAX_PENDING_COUNT)));
let sending_future_waker = Arc::new(AtomicWaker::new());
let request_stream = TsoRequestStream {
cluster_id,
request_rx,
pending_requests: pending_requests.clone(),
self_waker: sending_future_waker.clone(),
};
let mut responses = pd_client.tso(request_stream).await?.into_inner();
while let Some(Ok(resp)) = responses.next().await {
{
let mut pending_requests = pending_requests.lock().await;
allocate_timestamps(&resp, &mut pending_requests)?;
}
sending_future_waker.wake();
}
info!("TSO stream terminated");
Ok(())
}
struct RequestGroup {
tso_request: TsoRequest,
requests: Vec<TimestampRequest>,
}
#[pin_project]
struct TsoRequestStream {
cluster_id: u64,
#[pin]
request_rx: mpsc::Receiver<oneshot::Sender<Timestamp>>,
pending_requests: Arc<Mutex<VecDeque<RequestGroup>>>,
self_waker: Arc<AtomicWaker>,
}
impl Stream for TsoRequestStream {
type Item = TsoRequest;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let pending_requests = this.pending_requests.lock();
pin_mut!(pending_requests);
let mut pending_requests = if let Poll::Ready(pending_requests) = pending_requests.poll(cx)
{
pending_requests
} else {
this.self_waker.register(cx.waker());
return Poll::Pending;
};
let mut requests = Vec::new();
while requests.len() < MAX_BATCH_SIZE && pending_requests.len() < MAX_PENDING_COUNT {
match this.request_rx.poll_recv(cx) {
Poll::Ready(Some(sender)) => {
requests.push(sender);
}
Poll::Ready(None) if requests.is_empty() => return Poll::Ready(None),
_ => break,
}
}
if !requests.is_empty() {
let req = TsoRequest {
header: Some(RequestHeader {
cluster_id: *this.cluster_id,
sender_id: 0,
}),
count: requests.len() as u32,
dc_location: String::new(),
};
let request_group = RequestGroup {
tso_request: req.clone(),
requests,
};
pending_requests.push_back(request_group);
Poll::Ready(Some(req))
} else {
this.self_waker.register(cx.waker());
Poll::Pending
}
}
}
fn allocate_timestamps(
resp: &TsoResponse,
pending_requests: &mut VecDeque<RequestGroup>,
) -> Result<()> {
let tail_ts = resp
.timestamp
.as_ref()
.ok_or_else(|| internal_err!("No timestamp in TsoResponse"))?;
let mut offset = resp.count;
if let Some(RequestGroup {
tso_request,
requests,
}) = pending_requests.pop_front()
{
if tso_request.count != offset {
return Err(internal_err!(
"PD gives different number of timestamps than expected"
));
}
for request in requests {
offset -= 1;
let ts = Timestamp {
physical: tail_ts.physical,
logical: tail_ts.logical - offset as i64,
suffix_bits: tail_ts.suffix_bits,
};
let _ = request.send(ts);
}
} else {
return Err(internal_err!("PD gives more TsoResponse than expected"));
};
Ok(())
}