use std::sync::Arc;
#[cfg(not(feature = "tokio"))]
use std::{net::TcpStream, sync::Mutex, thread, time::Duration};
use protobuf::CodedInputStream;
#[cfg(feature = "tokio")]
use tokio::{net::TcpStream, sync::Mutex};
use crate::{
error::RpcError,
schema::{
self, connection_request, connection_response::Status,
ConnectionRequest, ConnectionResponse, DecodeUntagged, StreamUpdate,
},
stream::StreamWrangler,
};
pub struct Client {
rpc: Mutex<TcpStream>,
stream: Mutex<TcpStream>,
streams: StreamWrangler,
}
impl Client {
#[cfg(not(feature = "tokio"))]
pub fn new(
name: &str,
ip_addr: &str,
rpc_port: u16,
stream_port: u16,
) -> Result<Arc<Self>, RpcError> {
let rpc_request = schema::ConnectionRequest {
type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
client_name: String::from(name),
..Default::default()
};
let (rpc_stream, rpc_result) = connect(ip_addr, rpc_port, rpc_request)?;
let stream_request = schema::ConnectionRequest {
type_: protobuf::EnumOrUnknown::new(
connection_request::Type::STREAM,
),
client_name: String::from(name),
client_identifier: rpc_result.client_identifier,
..Default::default()
};
let (stream_stream, _) = connect(ip_addr, stream_port, stream_request)?;
let client = Arc::new(Self {
rpc: Mutex::new(rpc_stream),
stream: Mutex::new(stream_stream),
streams: StreamWrangler::default(),
});
let bg_client = client.clone();
thread::spawn(move || loop {
bg_client.update_streams().ok();
});
Ok(client)
}
#[cfg(feature = "tokio")]
pub async fn new(
name: &str,
ip_addr: &str,
rpc_port: u16,
stream_port: u16,
) -> Result<Arc<Self>, RpcError> {
let rpc_request = schema::ConnectionRequest {
type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
client_name: String::from(name),
..Default::default()
};
let (rpc_stream, rpc_result) =
connect(ip_addr, rpc_port, rpc_request).await?;
let stream_request = schema::ConnectionRequest {
type_: protobuf::EnumOrUnknown::new(
connection_request::Type::STREAM,
),
client_name: String::from(name),
client_identifier: rpc_result.client_identifier,
..Default::default()
};
let (stream_stream, _) =
connect(ip_addr, stream_port, stream_request).await?;
let client = Arc::new(Self {
rpc: Mutex::new(rpc_stream),
stream: Mutex::new(stream_stream),
streams: StreamWrangler::default(),
});
let bg_client = client.clone();
tokio::task::spawn(async move {
loop {
bg_client.update_streams().await.ok();
}
});
Ok(client)
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn call(
&self,
request: schema::Request,
) -> Result<schema::Response, RpcError> {
let mut rpc = self.rpc.lock().map_err(|_| RpcError::Client)?;
send(&mut rpc, request)?;
recv(&mut rpc)
}
#[cfg(feature = "tokio")]
pub(crate) async fn call(
&self,
request: schema::Request,
) -> Result<schema::Response, RpcError> {
let mut rpc = self.rpc.lock().await;
send(&mut rpc, request).await?;
recv(&mut rpc).await
}
pub(crate) fn proc_call(
service: &str,
procedure: &str,
args: Vec<schema::Argument>,
) -> schema::ProcedureCall {
schema::ProcedureCall {
service: service.into(),
procedure: procedure.into(),
arguments: args,
..Default::default()
}
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn update_streams(self: &Arc<Self>) -> Result<(), RpcError> {
let mut stream = self.stream.lock()?;
let update = recv::<StreamUpdate>(&mut stream)?;
for result in update.results {
self.streams.insert(
result.id,
result.result.into_option().ok_or(RpcError::Client)?,
)?;
}
Ok(())
}
#[cfg(feature = "tokio")]
pub(crate) fn register_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
self.streams.increment_refcount(stream_id)
}
#[cfg(feature = "tokio")]
pub(crate) fn release_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
self.streams.decrement_refcount(stream_id)
}
#[cfg(feature = "tokio")]
pub(crate) async fn update_streams(
self: &Arc<Self>,
) -> Result<(), RpcError> {
let mut stream = self.stream.lock().await;
let update = recv::<StreamUpdate>(&mut stream).await?;
for result in update.results {
self.streams
.insert(
result.id,
result.result.into_option().ok_or(RpcError::Client)?,
)
.await?;
}
Ok(())
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn read_stream<T: DecodeUntagged>(
self: &Arc<Self>,
id: u64,
) -> Result<T, RpcError> {
self.streams.get(self.clone(), id)
}
#[cfg(feature = "tokio")]
pub(crate) async fn read_stream<T: DecodeUntagged>(
self: &Arc<Self>,
id: u64,
) -> Result<T, RpcError> {
self.streams.get(self.clone(), id).await
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn remove_stream(
self: &Arc<Self>,
id: u64,
) -> Result<(), RpcError> {
self.streams.remove(id);
Ok(())
}
#[cfg(feature = "tokio")]
pub(crate) async fn remove_stream(
self: &Arc<Self>,
id: u64,
) -> Result<(), RpcError> {
self.streams.remove(id).await;
Ok(())
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn await_stream(&self, id: u64) {
self.streams.wait(id)
}
#[cfg(not(feature = "tokio"))]
pub(crate) fn await_stream_timeout(&self, id: u64, dur: Duration) {
self.streams.wait_timeout(id, dur)
}
#[cfg(feature = "tokio")]
pub(crate) async fn await_stream(&self, id: u64) {
self.streams.wait(id).await
}
}
#[cfg(not(feature = "tokio"))]
fn connect(
ip_addr: &str,
port: u16,
request: ConnectionRequest,
) -> Result<(TcpStream, ConnectionResponse), RpcError> {
let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
.map_err(RpcError::Connection)?;
send(&mut conn, request)?;
let response = recv::<ConnectionResponse>(&mut conn)?;
if response.status.value() != Status::OK as i32 {
return Err(RpcError::Client);
}
Ok((conn, response))
}
#[cfg(feature = "tokio")]
async fn connect(
ip_addr: &str,
port: u16,
request: ConnectionRequest,
) -> Result<(TcpStream, ConnectionResponse), RpcError> {
let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
.await
.map_err(RpcError::Connection)?;
send(&mut conn, request).await?;
let response = recv::<ConnectionResponse>(&mut conn).await?;
if response.status.value() != Status::OK as i32 {
return Err(RpcError::Client);
}
Ok((conn, response))
}
#[cfg(not(feature = "tokio"))]
fn send<T: protobuf::Message>(
rpc: &mut TcpStream,
message: T,
) -> Result<(), RpcError> {
message
.write_length_delimited_to_writer(rpc)
.map_err(Into::into)
}
#[cfg(feature = "tokio")]
async fn send<T: protobuf::Message>(
rpc: &mut TcpStream,
message: T,
) -> Result<(), RpcError> {
use tokio::io::AsyncWriteExt;
let message = message
.write_length_delimited_to_bytes()
.map_err(Into::<RpcError>::into)?;
rpc.write_all(&message).await.map_err(Into::into)
}
#[cfg(not(feature = "tokio"))]
fn recv<T: protobuf::Message + Default>(
rpc: &mut TcpStream,
) -> Result<T, RpcError> {
CodedInputStream::new(rpc)
.read_message()
.map_err(Into::into)
}
#[cfg(feature = "tokio")]
async fn recv<T: protobuf::Message + Default>(
rpc: &mut TcpStream,
) -> Result<T, RpcError> {
use bytes::{Buf, BytesMut};
use tokio::io::AsyncReadExt;
let mut buffer = BytesMut::new();
while buffer.is_empty() {
rpc.read_buf(&mut buffer)
.await
.map_err(Into::<RpcError>::into)?;
}
let (length, processed) = {
let mut decoder = CodedInputStream::from_bytes(&buffer);
(
decoder
.read_raw_varint64()?
.try_into()
.expect("Should always fit"),
decoder.pos().try_into().expect("Should always fit"),
)
};
buffer.advance(processed);
while buffer.len() < length {
rpc.read_buf(&mut buffer)
.await
.map_err(Into::<RpcError>::into)?;
}
T::parse_from_tokio_bytes(&buffer.freeze()).map_err(Into::into)
}