orientdb-client 0.6.0

A Rust client for OrientDB™
Documentation
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};

use futures::sink::SinkExt;
use futures::stream::StreamExt;
use std::collections::VecDeque;
use std::net::Shutdown;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::common::protocol::messages::request::HandShake;
use crate::common::protocol::messages::{
    response::LiveQueryResult, response::Status, Request, Response,
};
use crate::sync::protocol::WiredProtocol;
use crate::{OrientError, OrientResult};

use super::super::live::LiveQueryManager;
use super::decoder::decode;
use super::reader;
use crate::asynchronous::network::stream::{split, ShutdownStream};
use crate::types::LiveResult;

#[cfg(feature = "async-std-runtime")]
mod async_std_use {
    pub use async_std::io::BufReader;
    pub use async_std::net::TcpStream;
    pub use async_std::sync::Mutex;
    pub use async_std::task;
}
#[cfg(feature = "async-std-runtime")]
use async_std_use::*;

#[cfg(feature = "tokio-runtime")]
mod tokio_use {
    pub use tokio::io::BufReader;
    pub use tokio::net::TcpStream;
    pub use tokio::sync::Mutex;
    pub use tokio::task;
}

use futures::channel::mpsc::{channel, Receiver, Sender};

#[cfg(feature = "tokio-runtime")]
pub use tokio_use::*;

pub type ChannelMsg = Cmd;
pub type ResponseChannel = Sender<OrientResult<Response>>;

#[derive(Debug)]
pub enum Cmd {
    Msg((Sender<OrientResult<Response>>, Request)),
    MsgNoResponse((Sender<OrientResult<()>>, Request)),
    Shutdown,
}

pub struct Connection {
    sender: Sender<ChannelMsg>,
    live_query_manager: Arc<LiveQueryManager>,
}

impl std::fmt::Debug for Connection {
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        fmt.debug_struct("Connection").finish()
    }
}

async fn encode_and_write<T>(
    stream: &mut T,
    protocol: &mut WiredProtocol,
    request: Request,
) -> OrientResult<()>
where
    T: AsyncWrite + Unpin,
{
    let buf = protocol.encode(request)?;
    stream.write_all(buf.as_slice()).await?;

    Ok(())
}
fn sender_loop<T>(
    mut stream: T,
    mut channel: Receiver<ChannelMsg>,
    queue: Arc<Mutex<VecDeque<ResponseChannel>>>,
    mut protocol: WiredProtocol,
    shutdown_flag: Arc<AtomicBool>,
) where
    T: AsyncWrite + Send + ShutdownStream + Unpin + 'static,
{
    task::spawn(async move {
        loop {
            match channel.next().await {
                Some(msg) => match msg {
                    Cmd::Msg(mut m) => {
                        let mut guard = queue.lock().await;

                        match encode_and_write(&mut stream, &mut protocol, m.1).await {
                            Ok(()) => {
                                guard.push_back(m.0);
                                drop(guard);
                            }
                            Err(e) => {
                                drop(guard);
                                match m.0.send(Err(e)).await {
                                    Ok(_) => {}
                                    Err(_e) => {}
                                };
                            }
                        }
                    }
                    Cmd::MsgNoResponse(mut m) => {
                        match encode_and_write(&mut stream, &mut protocol, m.1).await {
                            Ok(()) => match m.0.send(Ok(())).await {
                                Ok(_e) => {}
                                Err(_e) => {}
                            },
                            Err(e) => match m.0.send(Err(e)).await {
                                Ok(_e) => {}
                                Err(_e) => {}
                            },
                        }
                    }
                    Cmd::Shutdown => {
                        shutdown_flag.store(true, Ordering::SeqCst);
                        stream
                            .shutdown(Shutdown::Both)
                            .expect("Failed to shutdown the socket");
                    }
                },
                None => {
                    shutdown_flag.store(true, Ordering::SeqCst);

                    match stream.shutdown(Shutdown::Both) {
                        Ok(_e) => {}
                        Err(_e) => {}
                    }
                    break;
                }
            }
        }
    });
}

fn responder_loop<T>(
    mut stream: T,
    queue: Arc<Mutex<VecDeque<ResponseChannel>>>,
    protocol: WiredProtocol,
    live_manager: Arc<LiveQueryManager>,
    shutdown_flag: Arc<AtomicBool>,
) where
    T: AsyncRead + Send + Unpin + 'static,
{
    task::spawn(async move {
        loop {
            let response = decode(protocol.version, &mut stream).await;

            if let Err(e) = &response {
                match e {
                    OrientError::Protocol(_p) => {
                        if shutdown_flag.load(Ordering::SeqCst) {
                            break;
                        }
                    }
                    _ => {}
                }
            }

            let result = match response {
                Ok(mut r) => match r.header.status {
                    Status::PUSH => {
                        let live_result: LiveQueryResult = r.payload::<LiveQueryResult>();
                        match live_manager.fire_event(live_result).await {
                            Ok(_) => {}
                            Err(_e) => {}
                        }
                        None
                    }
                    _ => Some(Ok(r)),
                },
                Err(e) => Some(Err(e)),
            };

            if let Some(response) = result {
                let mut guard = queue.lock().await;

                match guard.pop_front() {
                    Some(mut s) => {
                        drop(guard);
                        match s.send(response).await {
                            Ok(_m) => {}
                            Err(_e) => {}
                        }
                    }
                    None => {}
                }
            }
        }
    });
}

impl Connection {
    pub async fn connect(addr: &SocketAddr) -> OrientResult<Self> {
        let stream = TcpStream::connect(addr).await?;

        let (mut reader, writer) = split(stream);

        let p = reader::read_i16(&mut reader).await?;

        let protocol = WiredProtocol::from_version(p)?;

        let (sender, receiver) = channel(20);

        let live_query_manager = Arc::new(LiveQueryManager::default());

        let conn = Connection {
            sender,
            live_query_manager: live_query_manager.clone(),
        };

        let shutdown_flag = Arc::new(AtomicBool::new(false));

        let queue = Arc::new(Mutex::new(VecDeque::new()));

        let p_version = protocol.version;

        sender_loop(
            writer,
            receiver,
            queue.clone(),
            protocol.clone(),
            shutdown_flag.clone(),
        );

        responder_loop(reader, queue, protocol, live_query_manager, shutdown_flag);

        conn.handshake(p_version).await
    }

    async fn handshake(mut self, p_version: i16) -> OrientResult<Connection> {
        let handshake = HandShake {
            p_version,
            name: String::from("Rust Driver"),
            version: String::from("0.1"),
        };
        self.send_and_forget(handshake.into()).await?;
        Ok(self)
    }

    pub(crate) async fn send_and_forget(&mut self, request: Request) -> OrientResult<()> {
        let (sender, mut receiver) = channel(1);

        self.sender
            .send(Cmd::MsgNoResponse((sender, request)))
            .await?;

        let result = receiver
            .next()
            .await
            .expect("It should contain the response");

        return result;
    }

    pub async fn register_handler(
        &self,
        monitor_id: i32,
        sender: Sender<OrientResult<LiveResult>>,
    ) -> OrientResult<()> {
        self.live_query_manager
            .register_handler(monitor_id, sender)
            .await
    }

    pub async fn send(&mut self, request: Request) -> OrientResult<Response> {
        let (sender, mut receiver) = channel(1);
        self.sender.send(Cmd::Msg((sender, request))).await?;
        receiver
            .next()
            .await
            .expect("It should contain the response")
    }

    pub async fn close(mut self) -> OrientResult<()> {
        &self.sender.send(Cmd::Shutdown).await;
        Ok(())
    }
}