pub mod comm_managers;
pub mod device_manager;
mod logger;
mod ping_timer;
pub mod remote_server;
pub use remote_server::ButtplugRemoteServer;
use crate::{
core::{
errors::*,
messages::{
self,
ButtplugClientMessage,
ButtplugDeviceCommandMessageUnion,
ButtplugDeviceManagerMessageUnion,
ButtplugMessage,
ButtplugServerMessage,
StopAllDevices,
BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION,
},
},
test::TestDeviceCommunicationManagerHelper,
util::async_manager,
};
use async_channel::{bounded, Receiver, Sender};
use comm_managers::{DeviceCommunicationManager, DeviceCommunicationManagerCreator};
use device_manager::DeviceManager;
use futures::{future::BoxFuture, StreamExt};
use logger::ButtplugLogHandler;
use ping_timer::PingTimer;
use std::{
convert::{TryFrom, TryInto},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use thiserror::Error;
pub type ButtplugServerResult = Result<ButtplugServerMessage, ButtplugError>;
pub type ButtplugServerResultFuture = BoxFuture<'static, ButtplugServerResult>;
#[derive(Error, Debug)]
pub enum ButtplugServerStartupError {
#[error("DeviceManager of type {0} has already been added.")]
DeviceManagerTypeAlreadyAdded(String),
}
pub struct ButtplugServer {
server_name: String,
client_name: String,
max_ping_time: u64,
device_manager: DeviceManager,
event_sender: Sender<ButtplugServerMessage>,
ping_timer: Option<PingTimer>,
pinged_out: Arc<AtomicBool>,
connected: Arc<AtomicBool>,
}
impl ButtplugServer {
pub fn new(name: &str, max_ping_time: u64) -> (Self, Receiver<ButtplugServerMessage>) {
let (send, recv) = bounded(256);
let pinged_out = Arc::new(AtomicBool::new(false));
let connected = Arc::new(AtomicBool::new(false));
let (ping_timer, ping_receiver) = if max_ping_time > 0 {
let (timer, mut receiver) = PingTimer::new(max_ping_time);
let (device_manager_sender, device_manager_receiver) = bounded(1);
let pinged_out_clone = pinged_out.clone();
let connected_clone = connected.clone();
let event_sender_clone = send.clone();
async_manager::spawn(async move {
receiver.next().await;
error!("Ping out signal received, stopping server");
pinged_out_clone.store(true, Ordering::SeqCst);
connected_clone.store(false, Ordering::SeqCst);
if event_sender_clone
.send(messages::Error::new(messages::ErrorCode::ErrorPing, "Ping Timeout").into())
.await
.is_err()
{
error!("Server disappeared, cannot update about ping out.");
};
if device_manager_sender.send(()).await.is_err() {
error!("Device Manager disappeared, cannot update about ping out.");
}
})
.unwrap();
(Some(timer), Some(device_manager_receiver))
} else {
(None, None)
};
(
Self {
server_name: name.to_string(),
client_name: String::default(),
max_ping_time,
device_manager: DeviceManager::new(send.clone(), ping_receiver),
ping_timer,
pinged_out,
connected,
event_sender: send,
},
recv,
)
}
pub fn client_name(&self) -> String {
self.client_name.clone()
}
pub fn add_comm_manager<T>(&self) -> Result<(), ButtplugServerStartupError>
where
T: 'static + DeviceCommunicationManager + DeviceCommunicationManagerCreator,
{
self.device_manager.add_comm_manager::<T>()
}
pub fn add_test_comm_manager(
&self,
) -> Result<TestDeviceCommunicationManagerHelper, ButtplugServerStartupError> {
self.device_manager.add_test_comm_manager()
}
pub fn connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
pub fn disconnect(&self) -> BoxFuture<Result<(), ButtplugServerError>> {
let mut ping_fut = None;
if let Some(ping_timer) = &self.ping_timer {
ping_fut = Some(ping_timer.stop_ping_timer());
}
let stop_fut = self.parse_message(ButtplugClientMessage::StopAllDevices(
StopAllDevices::default(),
));
let connected = self.connected.clone();
Box::pin(async move {
connected.store(false, Ordering::SeqCst);
if let Some(pfut) = ping_fut {
pfut.await;
}
stop_fut.await.map(|_| ())
})
}
pub fn parse_message(
&self,
msg: ButtplugClientMessage,
) -> BoxFuture<'static, Result<ButtplugServerMessage, ButtplugServerError>> {
let id = msg.get_id();
if !self.connected() {
if self.pinged_out.load(Ordering::SeqCst) {
return ButtplugServerError::new_message_error(
msg.get_id(),
ButtplugPingError::PingedOut.into(),
)
.into();
} else if !matches!(msg, ButtplugClientMessage::RequestServerInfo(_)) {
return ButtplugServerError::from(ButtplugHandshakeError::RequestServerInfoExpected).into();
}
}
let out_fut = if ButtplugDeviceManagerMessageUnion::try_from(msg.clone()).is_ok()
|| ButtplugDeviceCommandMessageUnion::try_from(msg.clone()).is_ok()
{
self.device_manager.parse_message(msg.clone())
} else {
match msg {
ButtplugClientMessage::RequestServerInfo(rsi_msg) => self.perform_handshake(rsi_msg),
ButtplugClientMessage::Ping(p) => self.handle_ping(p),
ButtplugClientMessage::RequestLog(l) => self.handle_log(l),
_ => ButtplugMessageError::UnexpectedMessageType(format!("{:?}", msg)).into(),
}
};
Box::pin(async move {
out_fut
.await
.map(|mut ok_msg| {
ok_msg.set_id(id);
ok_msg
})
.map_err(|err| ButtplugServerError::new_message_error(id, err))
})
}
fn perform_handshake(&self, msg: messages::RequestServerInfo) -> ButtplugServerResultFuture {
if self.connected() {
return ButtplugHandshakeError::HandshakeAlreadyHappened.into();
}
if BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION < msg.message_version {
return ButtplugHandshakeError::MessageSpecVersionMismatch(
BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION,
msg.message_version,
)
.into();
}
info!("Performing server handshake check");
let mut ping_timer_fut = None;
if let Some(timer) = &self.ping_timer {
ping_timer_fut = Some(timer.start_ping_timer());
}
let out_msg = messages::ServerInfo::new(
&self.server_name,
BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION,
self.max_ping_time.try_into().unwrap(),
);
let connected = self.connected.clone();
Box::pin(async move {
if let Some(fut) = ping_timer_fut {
fut.await;
}
connected.store(true, Ordering::SeqCst);
info!("Server handshake check successful.");
Result::Ok(out_msg.into())
})
}
fn handle_ping(&self, msg: messages::Ping) -> ButtplugServerResultFuture {
if let Some(timer) = &self.ping_timer {
let fut = timer.update_ping_time();
Box::pin(async move {
fut.await;
Result::Ok(messages::Ok::new(msg.get_id()).into())
})
} else {
ButtplugPingError::PingTimerNotRunning.into()
}
}
fn handle_log(&self, msg: messages::RequestLog) -> ButtplugServerResultFuture {
Box::pin(async move {
Result::Ok(messages::Ok::new(msg.get_id()).into())
})
}
pub fn create_tracing_layer(&self) -> ButtplugLogHandler {
ButtplugLogHandler::new(&messages::LogLevel::Off, self.event_sender.clone())
}
}
#[cfg(test)]
mod test {
use crate::{
core::messages::{self, ButtplugServerMessage, BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION},
server::ButtplugServer,
util::async_manager,
};
use futures::StreamExt;
#[test]
fn test_server_reuse() {
let (server, _) = ButtplugServer::new("Test Server", 0);
async_manager::block_on(async {
let msg =
messages::RequestServerInfo::new("Test Client", BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION);
let mut reply = server.parse_message(msg.clone().into()).await;
assert!(reply.is_ok(), format!("Should get back ok: {:?}", reply));
reply = server.parse_message(msg.clone().into()).await;
assert!(
reply.is_err(),
format!("Should get back err on double handshake: {:?}", reply)
);
assert!(
server.disconnect().await.is_ok(),
format!("Should disconnect ok")
);
reply = server.parse_message(msg.clone().into()).await;
assert!(
reply.is_ok(),
format!(
"Should get back ok on handshake after disconnect: {:?}",
reply
)
);
});
}
#[test]
#[ignore]
fn test_log_handler() {
let (server, mut recv) = ButtplugServer::new("Test Server", 0);
async_manager::block_on(async {
let msg =
messages::RequestServerInfo::new("Test Client", BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION);
let mut reply = server.parse_message(msg.into()).await;
assert!(reply.is_ok(), format!("Should get back ok: {:?}", reply));
reply = server
.parse_message(messages::RequestLog::new(messages::LogLevel::Debug).into())
.await;
assert!(reply.is_ok(), format!("Should get back ok: {:?}", reply));
debug!("Test log message");
let mut did_log = false;
while let Some(msg) = recv.next().await {
if let ButtplugServerMessage::Log(log) = msg {
assert_eq!(log.log_level, messages::LogLevel::Debug);
assert!(log.log_message.contains("Test log message"));
did_log = true;
break;
}
}
assert!(did_log, "Should've gotten log message");
});
}
}