zookeeper-client 0.11.1

ZooKeeper async client
Documentation
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::{Buf, BufMut};
use futures::channel::oneshot;
use ignore_result::Ignore;

use super::types::WatchMode;
use super::watch::{WatchReceiver, WatcherId};
use crate::error::Error;
use crate::proto::{self, AddWatchMode, ConnectRequest, OpCode, RequestHeader};
use crate::record::{self, Record, StaticRecord};

#[derive(Clone, Debug, Default)]
pub struct MarshalledRequest(pub Vec<u8>);

#[derive(Clone, Copy, Debug)]
pub enum OpStat<'a> {
    None,
    #[allow(dead_code)]
    Path(&'a str),
    Watch {
        path: &'a str,
        mode: WatchMode,
    },
}

impl MarshalledRequest {
    pub fn new(code: OpCode, body: &dyn Record) -> MarshalledRequest {
        let header = RequestHeader::with_code(code);
        let buf = proto::build_session_request(&header, body);
        MarshalledRequest(buf)
    }

    pub fn new_record(body: &dyn Record) -> MarshalledRequest {
        MarshalledRequest(proto::build_record_request(body))
    }

    pub fn as_slice(&self) -> &[u8] {
        self.0.as_slice()
    }

    pub fn get_code(&self) -> OpCode {
        let mut buf = &self.0[8..12];
        buf.get_i32().try_into().unwrap()
    }

    pub fn get_xid(&self) -> i32 {
        let mut xid_buf = &self.0[4..8];
        xid_buf.get_i32()
    }

    pub fn set_xid(&mut self, xid: i32) {
        let mut xid_buf = &mut self.0[4..8];
        xid_buf.put_i32(xid);
    }

    fn get_body(&self) -> &[u8] {
        let offset = 4 + RequestHeader::record_len();
        &self.0[offset..]
    }

    pub fn get_operation_info(&self) -> (OpCode, OpStat<'_>) {
        let op_code = self.get_code();
        let stat = match op_code {
            OpCode::Create
            | OpCode::Create2
            | OpCode::CreateTtl
            | OpCode::CreateContainer
            | OpCode::SetData
            | OpCode::Delete
            | OpCode::DeleteContainer
            | OpCode::GetACL
            | OpCode::SetACL
            | OpCode::Sync
            | OpCode::Check
            | OpCode::CheckWatches
            | OpCode::GetEphemerals
            | OpCode::GetAllChildrenNumber => {
                let mut body = self.get_body();
                let server_path = record::deserialize::<&str>(&mut body).unwrap();
                OpStat::Path(server_path)
            },
            OpCode::GetData
            | OpCode::Exists
            | OpCode::GetChildren
            | OpCode::GetChildren2
            | OpCode::AddWatch
            | OpCode::RemoveWatches => {
                let mut body = self.get_body();
                let server_path = record::deserialize::<&str>(&mut body).unwrap();
                if op_code == OpCode::AddWatch || op_code == OpCode::RemoveWatches {
                    body.advance(3);
                }
                let watch = body.get_u8();
                if op_code == OpCode::AddWatch {
                    let add_mode = AddWatchMode::try_from(watch as i32).unwrap();
                    OpStat::Watch { path: server_path, mode: WatchMode::from(add_mode) }
                } else if op_code == OpCode::RemoveWatches {
                    OpStat::Watch { path: server_path, mode: WatchMode::try_from(watch as i32).unwrap() }
                } else if watch == 1 {
                    let mode = if op_code == OpCode::GetData || op_code == OpCode::Exists {
                        WatchMode::Data
                    } else {
                        WatchMode::Child
                    };
                    OpStat::Watch { path: server_path, mode }
                } else {
                    assert!(watch == 0);
                    OpStat::Path(server_path)
                }
            },
            _ => OpStat::None,
        };
        (op_code, stat)
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

pub enum Operation {
    Connect(ConnectOperation),
    Session(SessionOperation),
}

pub enum Request {
    Session(SessionOperation),
    RemoveWatcher { id: WatcherId, responser: StateResponser },
}

impl From<SessionOperation> for Request {
    fn from(operation: SessionOperation) -> Self {
        Self::Session(operation)
    }
}

impl Request {
    pub fn into_responser(self) -> StateResponser {
        match self {
            Self::Session(operation) => operation.responser,
            Self::RemoveWatcher { responser, .. } => responser,
        }
    }
}

impl Operation {
    pub fn get_data(&self) -> &[u8] {
        match self {
            Operation::Connect(operation) => operation.request.as_slice(),
            Operation::Session(operation) => operation.request.as_slice(),
        }
    }
}

pub struct ConnectOperation {
    pub request: Vec<u8>,
}

impl ConnectOperation {
    pub fn new(request: &ConnectRequest) -> Self {
        let buf = proto::build_record_request(request);
        Self { request: buf }
    }
}

#[derive(Debug)]
pub struct SessionOperation {
    pub request: MarshalledRequest,
    pub responser: StateResponser,
}

impl SessionOperation {
    pub fn new(code: OpCode, body: &dyn Record) -> Self {
        let request = MarshalledRequest::new(code, body);
        Self { request, responser: Default::default() }
    }

    pub fn new_without_body(code: OpCode) -> Self {
        let header = RequestHeader::with_code(code);
        let request = MarshalledRequest::new_record(&header);
        Self { request, responser: StateResponser::default() }
    }

    pub fn new_marshalled(request: MarshalledRequest) -> Self {
        Self { request, responser: Default::default() }
    }

    pub fn with_responser(self) -> (Self, StateReceiver) {
        let (sender, receiver) = oneshot::channel();
        let request = self.request;
        let code = request.get_code();
        let operation = Self { request, responser: StateResponser::new(sender) };
        (operation, StateReceiver { code, receiver })
    }
}

impl From<MarshalledRequest> for SessionOperation {
    fn from(request: MarshalledRequest) -> Self {
        SessionOperation { request, responser: StateResponser::none() }
    }
}

pub struct StateReceiver {
    code: OpCode,
    receiver: oneshot::Receiver<Result<(Vec<u8>, WatchReceiver), Error>>,
}

impl StateReceiver {
    pub fn new(code: OpCode, receiver: oneshot::Receiver<Result<(Vec<u8>, WatchReceiver), Error>>) -> Self {
        Self { code, receiver }
    }
}

impl Future for StateReceiver {
    type Output = Result<(Vec<u8>, WatchReceiver), Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let code = self.code;
        let receiver = unsafe { self.map_unchecked_mut(|r| &mut r.receiver) };
        match receiver.poll(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(result) => match result {
                Err(_) => {
                    Poll::Ready(Err(Error::UnexpectedError(format!("BUG: {code} expect response, but got none"))))
                },
                Ok(r) => Poll::Ready(r),
            },
        }
    }
}

type StateSender = oneshot::Sender<Result<(Vec<u8>, WatchReceiver), Error>>;

#[derive(Default, Debug)]
pub struct StateResponser(Option<StateSender>);

impl StateResponser {
    pub fn new(sender: oneshot::Sender<Result<(Vec<u8>, WatchReceiver), Error>>) -> Self {
        StateResponser(Some(sender))
    }

    pub fn none() -> Self {
        StateResponser(None)
    }

    pub fn send(mut self, result: Result<(Vec<u8>, WatchReceiver), Error>) -> bool {
        if let Some(sender) = self.0.take() {
            sender.send(result).ignore();
            return true;
        }
        false
    }

    pub fn send_empty(self) -> bool {
        self.send(Ok((Vec::new(), WatchReceiver::None)))
    }
}