use crate::error::*;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::sync::RwLock;
pub trait AsyncVarlinkStream: AsyncReadExt + AsyncWriteExt + Send + Sync + Unpin {}
impl AsyncVarlinkStream for TcpStream {}
#[cfg(unix)]
impl AsyncVarlinkStream for UnixStream {}
pub enum AsyncStream {
TCP(TcpStream),
#[cfg(unix)]
UNIX(UnixStream),
}
impl AsyncStream {
#[allow(dead_code)]
async fn shutdown(&mut self) -> std::io::Result<()> {
match self {
AsyncStream::TCP(stream) => stream.shutdown().await,
#[cfg(unix)]
AsyncStream::UNIX(_) => Ok(()), }
}
}
pub async fn async_varlink_connect<S: AsRef<str>>(address: S) -> Result<(AsyncStream, String)> {
let address = address.as_ref();
let new_address: String = address.into();
if let Some(addr) = new_address.strip_prefix("tcp:") {
let stream = TcpStream::connect(addr)
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok((AsyncStream::TCP(stream), new_address))
} else if let Some(addr) = new_address.strip_prefix("unix:") {
#[cfg(unix)]
{
if let Some(abstract_addr) = addr.strip_prefix('@') {
#[cfg(any(target_os = "linux", target_os = "android"))]
{
let addr = abstract_addr.split(';').next().unwrap_or(abstract_addr);
let socket_path = format!("\0{}", addr);
let stream = UnixStream::connect(socket_path)
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok((AsyncStream::UNIX(stream), new_address))
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
{
let _ = abstract_addr;
Err(context!(ErrorKind::InvalidAddress))
}
} else {
let addr = addr.split(';').next().unwrap_or(addr);
let stream = UnixStream::connect(addr)
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok((AsyncStream::UNIX(stream), new_address))
}
}
#[cfg(not(unix))]
{
let _ = addr;
Err(context!(ErrorKind::InvalidAddress))
}
} else {
Err(context!(ErrorKind::InvalidAddress))
}
}
pub struct AsyncConnection {
address: String,
stream: Arc<RwLock<Option<AsyncStream>>>,
}
impl AsyncConnection {
pub async fn with_address<S: AsRef<str>>(address: S) -> Result<Arc<Self>> {
let (stream, address) = async_varlink_connect(address).await?;
Ok(Arc::new(AsyncConnection {
address,
stream: Arc::new(RwLock::new(Some(stream))),
}))
}
pub fn address(&self) -> String {
self.address.clone()
}
}
impl Drop for AsyncConnection {
fn drop(&mut self) {
}
}
pub struct AsyncMethodCall<MRequest, MReply, MError>
where
MRequest: Serialize,
MReply: DeserializeOwned,
MError: From<Error>,
{
connection: Arc<AsyncConnection>,
request: Option<MRequest>,
method: Option<String>,
continues: bool,
phantom_reply: PhantomData<MReply>,
phantom_error: PhantomData<MError>,
}
impl<MRequest, MReply, MError> AsyncMethodCall<MRequest, MReply, MError>
where
MRequest: Serialize,
MReply: DeserializeOwned,
MError: From<Error>,
{
pub fn new<S: Into<String>>(
connection: Arc<AsyncConnection>,
method: S,
parameters: MRequest,
) -> Self {
AsyncMethodCall {
connection,
request: Some(parameters),
method: Some(method.into()),
continues: false,
phantom_reply: PhantomData,
phantom_error: PhantomData,
}
}
async fn send(
&mut self,
oneway: bool,
more: bool,
upgrade: bool,
) -> std::result::Result<(), MError> {
use crate::Request;
let mut req = match (self.method.take(), self.request.take()) {
(Some(method), Some(request)) => Request::create(
method,
Some(serde_json::to_value(request).map_err(map_context!())?),
),
_ => {
return Err(MError::from(context!(ErrorKind::MethodCalledAlready)));
}
};
if oneway {
req.oneway = Some(true);
}
if more {
req.more = Some(true);
self.continues = true;
}
if upgrade {
req.upgrade = Some(true);
}
let data = crate::sansio::protocol::serialize_request(&req)?;
let stream_lock = self.connection.stream.clone();
let mut stream_guard = stream_lock.write().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| MError::from(context!(ErrorKind::ConnectionClosed)))?;
match stream {
AsyncStream::TCP(s) => {
s.write_all(&data)
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
s.flush()
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
}
#[cfg(unix)]
AsyncStream::UNIX(s) => {
s.write_all(&data)
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
s.flush()
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
}
}
Ok(())
}
pub async fn recv(&mut self) -> std::result::Result<MReply, MError> {
let mut buf = Vec::new();
let stream_lock = self.connection.stream.clone();
let mut stream_guard = stream_lock.write().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| MError::from(context!(ErrorKind::ConnectionClosed)))?;
let n = match stream {
AsyncStream::TCP(s) => {
let mut reader = BufReader::new(s);
reader
.read_until(0, &mut buf)
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?
}
#[cfg(unix)]
AsyncStream::UNIX(s) => {
let mut reader = BufReader::new(s);
reader
.read_until(0, &mut buf)
.await
.map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?
}
};
if n == 0 || buf.is_empty() {
return Err(MError::from(context!(ErrorKind::ConnectionClosed)));
}
use crate::sansio::types::ParseResult;
let reply: crate::Reply = match crate::sansio::protocol::parse_message(&buf) {
ParseResult::Complete { message, .. } => {
crate::sansio::protocol::parse_reply(&message)?
}
ParseResult::Incomplete { .. } => {
return Err(MError::from(context!(ErrorKind::ConnectionClosed)));
}
ParseResult::Invalid { error } => {
return Err(MError::from(context!(ErrorKind::InvalidParameter(error))));
}
};
match reply.continues {
Some(true) => self.continues = true,
_ => self.continues = false,
}
if reply.error.is_some() {
return Err(MError::from(context!(ErrorKind::from(reply))));
}
match reply {
crate::Reply {
parameters: Some(p),
..
} => {
let mreply: MReply = serde_json::from_value(p).map_err(map_context!())?;
Ok(mreply)
}
crate::Reply {
parameters: None, ..
} => {
let mreply: MReply =
serde_json::from_value(serde_json::Value::Object(serde_json::Map::new()))
.map_err(map_context!())?;
Ok(mreply)
}
}
}
pub async fn call(&mut self) -> std::result::Result<MReply, MError> {
self.send(false, false, false).await?;
self.recv().await
}
pub async fn more(&mut self) -> std::result::Result<&mut Self, MError> {
self.send(false, true, false).await?;
Ok(self)
}
pub async fn oneway(&mut self) -> std::result::Result<(), MError> {
self.send(true, false, false).await
}
pub async fn upgrade(&mut self) -> std::result::Result<MReply, MError> {
self.send(false, false, true).await?;
self.recv().await
}
}