use super::{ButtplugMessageSerializer, ButtplugSerializedMessage, ButtplugSerializerError};
use crate::{
core::{
errors::{ButtplugError, ButtplugHandshakeError},
messages::{
self,
ButtplugClientMessage,
ButtplugCurrentSpecClientMessage,
ButtplugCurrentSpecServerMessage,
ButtplugMessage,
ButtplugMessageSpecVersion,
ButtplugServerMessage,
ButtplugSpecV0ClientMessage,
ButtplugSpecV0ServerMessage,
ButtplugSpecV1ClientMessage,
ButtplugSpecV1ServerMessage,
ButtplugSpecV2ClientMessage,
ButtplugSpecV2ServerMessage,
},
},
util::json::JSONValidator,
};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::convert::TryFrom;
static MESSAGE_JSON_SCHEMA: &str =
include_str!("../../../../buttplug-schema/schema/buttplug-schema.json");
pub fn create_message_validator() -> JSONValidator {
JSONValidator::new(MESSAGE_JSON_SCHEMA)
}
pub struct ButtplugServerJSONSerializer {
pub(super) message_version: RefCell<Option<messages::ButtplugMessageSpecVersion>>,
validator: JSONValidator,
}
impl Default for ButtplugServerJSONSerializer {
fn default() -> Self {
Self {
message_version: RefCell::new(None),
validator: create_message_validator(),
}
}
}
pub fn msg_to_protocol_json<T>(msg: T) -> String
where
T: ButtplugMessage + Serialize + Deserialize<'static>,
{
serde_json::to_string(&[&msg]).unwrap()
}
pub fn vec_to_protocol_json<T>(msg: Vec<T>) -> String
where
T: ButtplugMessage + Serialize + Deserialize<'static>,
{
serde_json::to_string(&msg).unwrap()
}
fn deserialize_to_message<T>(
validator: &JSONValidator,
msg: String,
) -> Result<Vec<T>, ButtplugSerializerError>
where
T: serde::de::DeserializeOwned + Clone,
{
validator.validate(&msg).and_then(|_| {
serde_json::from_str::<Vec<T>>(&msg)
.map_err(|e| ButtplugSerializerError::JsonSerializerError(format!("{:?}", e)))
})
}
fn serialize_to_version(
version: ButtplugMessageSpecVersion,
msgs: Vec<ButtplugServerMessage>,
) -> ButtplugSerializedMessage {
ButtplugSerializedMessage::Text(match version {
ButtplugMessageSpecVersion::Version0 => {
let msg_vec: Vec<ButtplugSpecV0ServerMessage> = msgs
.iter()
.cloned()
.map(|msg| match ButtplugSpecV0ServerMessage::try_from(msg) {
Ok(msgv0) => msgv0,
Err(err) => ButtplugSpecV0ServerMessage::Error(
messages::Error::from(ButtplugError::from(err)).into(),
),
})
.collect();
vec_to_protocol_json(msg_vec)
}
ButtplugMessageSpecVersion::Version1 => {
let msg_vec: Vec<ButtplugSpecV1ServerMessage> = msgs
.iter()
.cloned()
.map(|msg| match ButtplugSpecV1ServerMessage::try_from(msg) {
Ok(msgv0) => msgv0,
Err(err) => ButtplugSpecV1ServerMessage::Error(
messages::Error::from(ButtplugError::from(err)).into(),
),
})
.collect();
vec_to_protocol_json(msg_vec)
}
ButtplugMessageSpecVersion::Version2 => {
let msg_vec: Vec<ButtplugSpecV2ServerMessage> = msgs
.iter()
.cloned()
.map(|msg| match ButtplugSpecV2ServerMessage::try_from(msg) {
Ok(msgv0) => msgv0,
Err(err) => ButtplugSpecV2ServerMessage::Error(ButtplugError::from(err).into()),
})
.collect();
vec_to_protocol_json(msg_vec)
}
})
}
unsafe impl Sync for ButtplugServerJSONSerializer {
}
unsafe impl Send for ButtplugServerJSONSerializer {
}
impl ButtplugMessageSerializer for ButtplugServerJSONSerializer {
type Inbound = ButtplugClientMessage;
type Outbound = ButtplugServerMessage;
fn deserialize(
&self,
serialized_msg: ButtplugSerializedMessage,
) -> Result<Vec<ButtplugClientMessage>, ButtplugSerializerError> {
let msg = if let ButtplugSerializedMessage::Text(text_msg) = serialized_msg {
text_msg
} else {
return Err(ButtplugSerializerError::BinaryDeserializationError);
};
if let Some(version) = *self.message_version.borrow() {
return Ok(match version {
ButtplugMessageSpecVersion::Version0 => {
deserialize_to_message::<ButtplugSpecV0ClientMessage>(&self.validator, msg)?
.iter()
.cloned()
.map(|m| m.into())
.collect()
}
ButtplugMessageSpecVersion::Version1 => {
deserialize_to_message::<ButtplugSpecV1ClientMessage>(&self.validator, msg)?
.iter()
.cloned()
.map(|m| m.into())
.collect()
}
ButtplugMessageSpecVersion::Version2 => {
deserialize_to_message::<ButtplugSpecV2ClientMessage>(&self.validator, msg)?
.iter()
.cloned()
.map(|m| m.into())
.collect()
}
});
}
let msg_union = deserialize_to_message::<ButtplugSpecV2ClientMessage>(&self.validator, msg)?;
if let ButtplugSpecV2ClientMessage::RequestServerInfo(rsi) = &msg_union[0] {
info!(
"Setting JSON Wrapper message version to {}",
rsi.message_version()
);
*self.message_version.borrow_mut() = Some(rsi.message_version());
} else {
return Err(ButtplugSerializerError::MessageSpecVersionNotReceived);
}
Ok(msg_union.iter().cloned().map(|m| m.into()).collect())
}
fn serialize(&self, msgs: Vec<ButtplugServerMessage>) -> ButtplugSerializedMessage {
if let Some(version) = *self.message_version.borrow() {
serialize_to_version(version, msgs)
} else {
if let ButtplugServerMessage::Error(_) = &msgs[0] {
serialize_to_version(ButtplugMessageSpecVersion::Version2, msgs)
} else {
ButtplugSerializedMessage::Text(msg_to_protocol_json(
ButtplugCurrentSpecServerMessage::Error(
ButtplugError::from(ButtplugHandshakeError::RequestServerInfoExpected).into(),
),
))
}
}
}
}
pub struct ButtplugClientJSONSerializer {
validator: JSONValidator,
}
impl Default for ButtplugClientJSONSerializer {
fn default() -> Self {
Self {
validator: create_message_validator(),
}
}
}
unsafe impl Sync for ButtplugClientJSONSerializer {
}
unsafe impl Send for ButtplugClientJSONSerializer {
}
impl ButtplugMessageSerializer for ButtplugClientJSONSerializer {
type Inbound = ButtplugCurrentSpecServerMessage;
type Outbound = ButtplugCurrentSpecClientMessage;
fn deserialize(
&self,
msg: ButtplugSerializedMessage,
) -> Result<Vec<ButtplugCurrentSpecServerMessage>, ButtplugSerializerError> {
if let ButtplugSerializedMessage::Text(text_msg) = msg {
deserialize_to_message::<Self::Inbound>(&self.validator, text_msg)
} else {
Err(ButtplugSerializerError::BinaryDeserializationError)
}
}
fn serialize(&self, msg: Vec<ButtplugCurrentSpecClientMessage>) -> ButtplugSerializedMessage {
ButtplugSerializedMessage::Text(vec_to_protocol_json(msg))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::core::messages::{RequestServerInfo, BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION};
#[test]
fn test_correct_message_version() {
let json = r#"[{
"RequestServerInfo": {
"Id": 1,
"ClientName": "Test Client",
"MessageVersion": 2
}
}]"#;
let serializer = ButtplugServerJSONSerializer::default();
serializer
.deserialize(ButtplugSerializedMessage::Text(json.to_owned()))
.unwrap();
assert_eq!(
*serializer.message_version.borrow(),
Some(ButtplugMessageSpecVersion::Version2)
);
}
#[test]
fn test_wrong_message_version() {
let json = r#"[{
"RequestServerInfo": {
"Id": 1,
"ClientName": "Test Client",
"MessageVersion": 100
}
}]"#;
let serializer = ButtplugServerJSONSerializer::default();
let msg = serializer.deserialize(ButtplugSerializedMessage::Text(json.to_owned()));
assert!(msg.is_err());
}
#[test]
fn test_client_incorrect_messages() {
let incorrect_incoming_messages = vec![
"not a json message",
"{}",
"[]",
"[{\"NotAMessage\":{}}]",
"[{\"Ok\":[]}]",
"[{\"Ok\":{}}]",
"{\"Ok\":{\"Id\":0}}",
"[{\"Ok\":{\"NotAField\":\"NotAValue\",\"Id\":1}}]",
];
let serializer = ButtplugClientJSONSerializer::default();
let _ = serializer.serialize(vec![RequestServerInfo::new(
"test client",
BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION,
)
.into()]);
for msg in incorrect_incoming_messages {
let res = serializer.deserialize(ButtplugSerializedMessage::Text(msg.to_owned()));
assert!(res.is_err());
if let Err(ButtplugSerializerError::MessageSpecVersionNotReceived) = res {
assert!(false, "Wrong error!");
}
}
}
}