use crate::error::{ProblemJson, ViiperError};
use crate::types::*;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, Shutdown};
enum StreamWrapper {
Plain(TcpStream),
Encrypted(crate::auth::EncryptedStream),
}
impl StreamWrapper {
fn try_clone(&self) -> std::io::Result<Self> {
match self {
StreamWrapper::Plain(s) => Ok(StreamWrapper::Plain(s.try_clone()?)),
StreamWrapper::Encrypted(s) => Ok(StreamWrapper::Encrypted(s.try_clone()?)),
}
}
fn shutdown(&self, how: Shutdown) -> std::io::Result<()> {
match self {
StreamWrapper::Plain(s) => s.shutdown(how),
StreamWrapper::Encrypted(s) => s.shutdown(how),
}
}
}
impl Read for StreamWrapper {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
StreamWrapper::Plain(s) => s.read(buf),
StreamWrapper::Encrypted(s) => s.read(buf),
}
}
}
impl Write for StreamWrapper {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamWrapper::Plain(s) => s.write(buf),
StreamWrapper::Encrypted(s) => s.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
StreamWrapper::Plain(s) => s.flush(),
StreamWrapper::Encrypted(s) => s.flush(),
}
}
}
pub struct ViiperClient {
addr: SocketAddr,
password: Option<String>,
}
impl ViiperClient {
pub fn new(addr: SocketAddr) -> Self {
Self { addr, password: None }
}
pub fn new_with_password(addr: SocketAddr, password: String) -> Self {
let password = if password.is_empty() { None } else { Some(password) };
Self { addr, password }
}
fn do_request<T: for<'de> serde::Deserialize<'de>>(
&self,
path: &str,
payload: Option<&str>,
) -> Result<T, ViiperError> {
let tcp_stream = TcpStream::connect(self.addr)?;
tcp_stream.set_nodelay(true)?;
let mut stream = if let Some(ref pwd) = self.password {
StreamWrapper::Encrypted(crate::auth::perform_handshake(tcp_stream, pwd)?)
} else {
StreamWrapper::Plain(tcp_stream)
};
stream.write_all(path.as_bytes())?;
if let Some(p) = payload {
stream.write_all(b" ")?;
stream.write_all(p.as_bytes())?;
}
stream.write_all(b"\0")?;
let mut buf = Vec::new();
stream.read_to_end(&mut buf)?;
let response = String::from_utf8(buf)
.map_err(|_| ViiperError::UnexpectedResponse("invalid UTF-8".into()))?
.trim_end_matches('\n')
.to_string();
if response.starts_with("{\"status\":") {
let problem: ProblemJson = serde_json::from_str(&response)?;
return Err(ViiperError::Protocol(problem));
}
serde_json::from_str(&response).map_err(Into::into)
}
pub fn ping(&self) -> Result<PingResponse, ViiperError> {
let path = "ping".to_string();
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref())
}
pub fn bus_list(&self) -> Result<BusListResponse, ViiperError> {
let path = "bus/list".to_string();
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref())
}
pub fn bus_create(&self, uint32: Option<u32>) -> Result<BusCreateResponse, ViiperError> {
let path = "bus/create".to_string();
let payload = uint32.map(|v| v.to_string());
self.do_request(&path, payload.as_deref())
}
pub fn bus_remove(&self, uint32: Option<u32>) -> Result<BusRemoveResponse, ViiperError> {
let path = "bus/remove".to_string();
let payload = uint32.map(|v| v.to_string());
self.do_request(&path, payload.as_deref())
}
pub fn bus_devices_list(&self, id: u32) -> Result<DevicesListResponse, ViiperError> {
let path = format!("bus/{}/list", id);
let payload: Option<String> = None;
self.do_request(&path, payload.as_deref())
}
pub fn bus_device_add(&self, id: u32, device_create_request: &DeviceCreateRequest) -> Result<Device, ViiperError> {
let path = format!("bus/{}/add", id);
let payload = Some(serde_json::to_string(&device_create_request)?);
self.do_request(&path, payload.as_deref())
}
pub fn bus_device_remove(&self, id: u32, string: Option<&str>) -> Result<DeviceRemoveResponse, ViiperError> {
let path = format!("bus/{}/remove", id);
let payload = string.map(|s| s.to_string());
self.do_request(&path, payload.as_deref())
}
pub fn connect_device(&self, bus_id: u32, dev_id: &str) -> Result<DeviceStream, ViiperError> {
DeviceStream::connect(self.addr, bus_id, dev_id, self.password.as_deref())
}
}
pub struct DeviceStream {
stream: StreamWrapper,
output_thread: Option<std::thread::JoinHandle<()>>,
disconnect_callback: Option<Box<dyn FnOnce() + Send + 'static>>,
}
impl DeviceStream {
pub fn connect(addr: SocketAddr, bus_id: u32, dev_id: &str, password: Option<&str>) -> Result<Self, ViiperError> {
let tcp_stream = TcpStream::connect(addr)?;
tcp_stream.set_nodelay(true)?;
let mut stream = if let Some(pwd) = password {
StreamWrapper::Encrypted(crate::auth::perform_handshake(tcp_stream, pwd)?)
} else {
StreamWrapper::Plain(tcp_stream)
};
let handshake = format!("bus/{}/{}\0", bus_id, dev_id);
stream.write_all(handshake.as_bytes())?;
Ok(Self {
stream,
output_thread: None,
disconnect_callback: None,
})
}
pub fn send<T: crate::wire::DeviceInput>(&mut self, input: &T) -> Result<(), ViiperError> {
let bytes = input.to_bytes();
self.stream.write_all(&bytes)?;
Ok(())
}
pub fn on_output<F>(&mut self, mut callback: F) -> Result<(), ViiperError>
where
F: FnMut(&mut dyn std::io::BufRead) -> std::io::Result<()> + Send + 'static,
{
if self.output_thread.is_some() {
return Err(ViiperError::UnexpectedResponse("Output callback already registered".into()));
}
let stream = self.stream.try_clone()?;
let disconnect = self.disconnect_callback.take();
let handle = std::thread::spawn(move || {
let mut reader = std::io::BufReader::new(stream);
while callback(&mut reader).is_ok() {}
if let Some(on_disconnect) = disconnect {
on_disconnect();
}
});
self.output_thread = Some(handle);
Ok(())
}
pub fn on_disconnect<F>(&mut self, callback: F) -> Result<(), ViiperError>
where
F: FnOnce() + Send + 'static,
{
self.disconnect_callback = Some(Box::new(callback));
Ok(())
}
pub fn send_raw(&mut self, data: &[u8]) -> Result<(), ViiperError> {
self.stream.write_all(data)?;
Ok(())
}
pub fn read_raw(&mut self, buf: &mut [u8]) -> Result<usize, ViiperError> {
self.stream.read(buf).map_err(Into::into)
}
pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ViiperError> {
self.stream.read_exact(buf).map_err(Into::into)
}
}
impl Drop for DeviceStream {
fn drop(&mut self) {
let _ = self.stream.shutdown(std::net::Shutdown::Both);
if let Some(handle) = self.output_thread.take() {
let _ = handle.join();
}
}
}