use std::fmt;
use std::future::{Future, IntoFuture};
use std::pin::Pin;
use crate::client::Transport;
use crate::error::{ErrKind, Error, Result};
use crate::pb::{self, watch_client::WatchClient as PbWatchClient};
use crate::utils::build_prefix_end;
use crate::Client;
use tokio::sync::mpsc::{channel, Sender};
use tonic::codec::Streaming;
const MPSC_CHANNEL_SIZE: usize = 1;
#[derive(Debug, Clone)]
pub struct WatchClient {
inner: PbWatchClient<Transport>,
}
impl WatchClient {
pub(crate) fn new(transport: Transport) -> Self {
let inner = PbWatchClient::new(transport);
WatchClient { inner }
}
pub fn with_client(client: &Client) -> Self {
Self::new(client.transport.clone())
}
pub fn do_watch(&mut self, key: impl AsRef<[u8]>) -> DoCreateWatch {
DoCreateWatch::new(key, self.clone())
}
pub async fn watch(
&mut self,
request: impl tonic::IntoStreamingRequest<Message = pb::WatchRequest>,
) -> Result<tonic::codec::Streaming<pb::WatchResponse>> {
Ok(self.inner.watch(request).await?.into_inner())
}
pub async fn watch_key(&mut self, key: impl AsRef<[u8]>) -> Result<Watcher> {
self.do_watch(key).await
}
pub async fn watch_prefix(&mut self, key: impl AsRef<[u8]>) -> Result<Watcher> {
self.do_watch(key).with_prefix().await
}
}
pub struct DoCreateWatch {
pub request: pb::WatchCreateRequest,
client: WatchClient,
}
impl DoCreateWatch {
pub fn new(key: impl AsRef<[u8]>, client: WatchClient) -> Self {
DoCreateWatch {
request: pb::WatchCreateRequest::new(key),
client,
}
}
async fn send(self) -> Result<Watcher> {
let DoCreateWatch {
request,
mut client,
} = self;
let create_watch = pb::watch_request::RequestUnion::CreateRequest(request);
let create_req = pb::WatchRequest {
request_union: Some(create_watch),
};
let (req_tx, req_rx) = channel::<pb::WatchRequest>(MPSC_CHANNEL_SIZE);
req_tx
.send(create_req)
.await
.map_err(|err| Error::new(ErrKind::WatchRequestFailed, err))?;
let rx = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut resp = client.watch(rx).await?;
let watch_id = match resp.message().await? {
Some(msg) => msg.watch_id,
None => return Err(Error::from_kind(ErrKind::WatchStartFailed)),
};
let watcher = Watcher::new(watch_id, req_tx, resp);
Ok(watcher)
}
pub fn with_range_end(mut self, end: impl AsRef<[u8]>) -> Self {
self.request.range_end = end.as_ref().to_vec();
self
}
pub fn with_prefix(mut self) -> Self {
self.request.range_end = build_prefix_end(&self.request.key);
self
}
pub fn with_prev_kv(mut self) -> Self {
self.request.prev_kv = true;
self
}
}
impl fmt::Debug for DoCreateWatch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DoCreateWatch")
.field("request", &self.request)
.finish()
}
}
impl IntoFuture for DoCreateWatch {
type Output = Result<Watcher>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<Watcher>>>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl pb::WatchCreateRequest {
pub fn new(key: impl AsRef<[u8]>) -> Self {
pb::WatchCreateRequest {
key: key.as_ref().to_vec(),
..Default::default()
}
}
}
impl pb::WatchRequest {
pub fn create_watch(key: impl AsRef<[u8]>) -> Self {
let request = pb::WatchCreateRequest::new(key);
let request_union = pb::watch_request::RequestUnion::CreateRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
pub fn progress_watch() -> Self {
let request = pb::WatchProgressRequest {};
let request_union = pb::watch_request::RequestUnion::ProgressRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
pub fn cancel_watch(watch_id: i64) -> Self {
let request = pb::WatchCancelRequest { watch_id };
let request_union = pb::watch_request::RequestUnion::CancelRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
}
impl From<pb::WatchCreateRequest> for pb::WatchRequest {
fn from(request: pb::WatchCreateRequest) -> Self {
let request_union = pb::watch_request::RequestUnion::CreateRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
}
impl From<pb::WatchProgressRequest> for pb::WatchRequest {
fn from(request: pb::WatchProgressRequest) -> Self {
let request_union = pb::watch_request::RequestUnion::ProgressRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
}
impl From<pb::WatchCancelRequest> for pb::WatchRequest {
fn from(request: pb::WatchCancelRequest) -> Self {
let request_union = pb::watch_request::RequestUnion::CancelRequest(request);
pb::WatchRequest {
request_union: Some(request_union),
}
}
}
pub struct Watcher {
watch_id: i64,
req_tx: Sender<pb::WatchRequest>,
inbound: Streaming<crate::pb::WatchResponse>,
}
impl Watcher {
pub(crate) fn new(
watch_id: i64,
req_tx: Sender<pb::WatchRequest>,
inbound: Streaming<crate::pb::WatchResponse>,
) -> Watcher {
Watcher {
watch_id,
req_tx,
inbound,
}
}
pub async fn progress(&mut self) -> Result<()> {
let request = pb::WatchRequest::progress_watch();
self.req_tx
.send(request)
.await
.map_err(|err| Error::new(ErrKind::WatchRequestFailed, err))?;
Ok(())
}
pub async fn cancel(&mut self) -> Result<()> {
let request = pb::WatchRequest::cancel_watch(self.watch_id);
self.req_tx
.send(request)
.await
.map_err(|err| Error::new(ErrKind::WatchRequestFailed, err))?;
Ok(())
}
pub async fn message(&mut self) -> Result<Option<pb::WatchResponse>> {
match self.inbound.message().await? {
Some(resp) => Ok(Some(resp)),
None => Ok(None),
}
}
}
impl fmt::Debug for Watcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Watcher")
.field("watch_id", &self.watch_id)
.finish()
}
}