use crate::error::{ProblemJson, ViiperError};
use crate::types::*;
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(feature = "async")]
pub enum AsyncStreamWrapper {
Plain(TcpStream),
Encrypted(crate::auth::AsyncEncryptedStream),
}
#[cfg(feature = "async")]
pub enum AsyncReadWrapper {
Plain(tokio::net::tcp::OwnedReadHalf),
Encrypted(crate::auth::AsyncEncryptedRead),
}
#[cfg(feature = "async")]
pub enum AsyncWriteWrapper {
Plain(tokio::net::tcp::OwnedWriteHalf),
Encrypted(crate::auth::AsyncEncryptedWrite),
}
#[cfg(feature = "async")]
impl AsyncRead for AsyncStreamWrapper {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut *self {
AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
#[cfg(feature = "async")]
impl AsyncRead for AsyncReadWrapper {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut *self {
AsyncReadWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
AsyncReadWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
#[cfg(feature = "async")]
impl AsyncWrite for AsyncStreamWrapper {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut *self {
AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
#[cfg(feature = "async")]
impl AsyncWrite for AsyncWriteWrapper {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut *self {
AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
#[cfg(feature = "async")]
pub struct AsyncViiperClient {
addr: SocketAddr,
password: Option<String>,
}
#[cfg(feature = "async")]
impl AsyncViiperClient {
pub fn new(addr: SocketAddr) -> Self {
Self { addr, password: None }
}
pub fn new_with_password(addr: SocketAddr, password: String) -> Self {
let password = if password.is_empty() { None } else { Some(password) };
Self { addr, password }
}
async fn do_request<T: for<'de> serde::Deserialize<'de>>(
&self,
path: &str,
payload: Option<&str>,
) -> Result<T, ViiperError> {
let tcp_stream = TcpStream::connect(self.addr).await?;
tcp_stream.set_nodelay(true)?;
let mut stream = if let Some(ref pwd) = self.password {
AsyncStreamWrapper::Encrypted(crate::auth::perform_handshake_async(tcp_stream, pwd).await?)
} else {
AsyncStreamWrapper::Plain(tcp_stream)
};
stream.write_all(path.as_bytes()).await?;
if let Some(p) = payload {
stream.write_all(b" ").await?;
stream.write_all(p.as_bytes()).await?;
}
stream.write_all(b"\0").await?;
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?;
let response = String::from_utf8(buf)
.map_err(|_| ViiperError::UnexpectedResponse("invalid UTF-8".into()))?
.trim_end_matches('\n')
.to_string();
if response.starts_with("{\"status\":") {
let problem: ProblemJson = serde_json::from_str(&response)?;
return Err(ViiperError::Protocol(problem));
}
serde_json::from_str(&response).map_err(Into::into)
}
pub async fn ping(&self) -> Result<PingResponse, ViiperError> {
let path = "ping".to_string();
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_list(&self) -> Result<BusListResponse, ViiperError> {
let path = "bus/list".to_string();
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_create(&self, uint32: Option<u32>) -> Result<BusCreateResponse, ViiperError> {
let path = "bus/create".to_string();
let payload = uint32.map(|v| v.to_string());
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_remove(&self, uint32: Option<u32>) -> Result<BusRemoveResponse, ViiperError> {
let path = "bus/remove".to_string();
let payload = uint32.map(|v| v.to_string());
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_devices_list(&self, id: u32) -> Result<DevicesListResponse, ViiperError> {
let path = format!("bus/{}/list", id);
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_device_add(&self, id: u32, device_create_request: &DeviceCreateRequest) -> Result<Device, ViiperError> {
let path = format!("bus/{}/add", id);
let payload = Some(serde_json::to_string(&device_create_request)?);
self.do_request(&path, payload.as_deref()).await
}
pub async fn bus_device_remove(&self, id: u32, string: Option<&str>) -> Result<DeviceRemoveResponse, ViiperError> {
let path = format!("bus/{}/remove", id);
let payload = string.map(|s| s.to_string());
self.do_request(&path, payload.as_deref()).await
}
pub async fn connect_device(&self, bus_id: u32, dev_id: &str) -> Result<AsyncDeviceStream, ViiperError> {
AsyncDeviceStream::connect(self.addr, bus_id, dev_id, self.password.as_deref()).await
}
}
#[cfg(feature = "async")]
pub struct AsyncDeviceStream {
read_stream: std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>,
write_stream: std::sync::Arc<tokio::sync::Mutex<AsyncWriteWrapper>>,
cancel_token: Option<tokio_util::sync::CancellationToken>,
disconnect_callback: std::sync::Mutex<Option<Box<dyn FnOnce() + Send + 'static>>>,
}
#[cfg(feature = "async")]
impl AsyncDeviceStream {
pub async fn connect(addr: SocketAddr, bus_id: u32, dev_id: &str, password: Option<&str>) -> Result<Self, ViiperError> {
let tcp_stream = TcpStream::connect(addr).await?;
tcp_stream.set_nodelay(true)?;
let (read_stream, mut write_stream) = if let Some(pwd) = password {
let encrypted = crate::auth::perform_handshake_async(tcp_stream, pwd).await?;
let (read_half, write_half) = encrypted.into_split();
(AsyncReadWrapper::Encrypted(read_half), AsyncWriteWrapper::Encrypted(write_half))
} else {
let (read_half, write_half) = tcp_stream.into_split();
(AsyncReadWrapper::Plain(read_half), AsyncWriteWrapper::Plain(write_half))
};
let handshake = format!("bus/{}/{}\0", bus_id, dev_id);
write_stream.write_all(handshake.as_bytes()).await?;
Ok(Self {
read_stream: std::sync::Arc::new(tokio::sync::Mutex::new(read_stream)),
write_stream: std::sync::Arc::new(tokio::sync::Mutex::new(write_stream)),
cancel_token: None,
disconnect_callback: std::sync::Mutex::new(None),
})
}
pub async fn send<T: crate::wire::DeviceInput>(
&self,
input: &T,
) -> Result<(), ViiperError> {
let bytes = input.to_bytes();
let mut stream = self.write_stream.lock().await;
stream.write_all(&bytes).await?;
Ok(())
}
pub async fn send_timeout<T: crate::wire::DeviceInput>(
&self,
input: &T,
timeout: std::time::Duration,
) -> Result<(), ViiperError> {
let bytes = input.to_bytes();
let mut stream = self.write_stream.lock().await;
tokio::time::timeout(timeout, stream.write_all(&bytes))
.await
.map_err(|_| ViiperError::Timeout)?
.map_err(Into::into)
}
pub fn on_output<F, Fut>(&mut self, callback: F) -> Result<(), ViiperError>
where
F: Fn(std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>) -> Fut + Send + 'static,
Fut: std::future::Future<Output = std::io::Result<()>> + Send + 'static,
{
if self.cancel_token.is_some() {
return Err(ViiperError::UnexpectedResponse("Output callback already registered".into()));
}
let stream = self.read_stream.clone();
let cancel_token = tokio_util::sync::CancellationToken::new();
let cancel_clone = cancel_token.clone();
let Ok(mut guard) = self.disconnect_callback.lock() else {
return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
};
let disconnect = guard.take();
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel_clone.cancelled() => break,
result = callback(stream.clone()) => {
match result {
Ok(()) => continue,
Err(_) => break,
}
}
}
}
if let Some(cb) = disconnect {
cb();
}
});
self.cancel_token = Some(cancel_token);
Ok(())
}
pub fn on_disconnect<F>(&mut self, callback: F) -> Result<(), ViiperError>
where
F: FnOnce() + Send + 'static,
{
let Ok(mut guard) = self.disconnect_callback.lock() else {
return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
};
*guard = Some(Box::new(callback));
Ok(())
}
pub async fn send_raw(&self, data: &[u8]) -> Result<(), ViiperError> {
let mut stream = self.write_stream.lock().await;
stream.write_all(data).await?;
Ok(())
}
pub async fn read_raw(&self, buf: &mut [u8]) -> Result<usize, ViiperError> {
let mut stream = self.read_stream.lock().await;
stream.read(buf).await.map_err(Into::into)
}
pub async fn read_exact(&self, buf: &mut [u8]) -> Result<(), ViiperError> {
let mut stream = self.read_stream.lock().await;
stream.read_exact(buf).await?;
Ok(())
}
}
#[cfg(feature = "async")]
impl Drop for AsyncDeviceStream {
fn drop(&mut self) {
if let Some(token) = &self.cancel_token {
token.cancel();
}
}
}