buttplug 2.1.5

Buttplug Intimate Hardware Control Library
Documentation
use super::{ButtplugMessageSerializer, ButtplugSerializedMessage, ButtplugSerializerError};
use crate::{
  core::{
    errors::{ButtplugError, ButtplugHandshakeError},
    messages::{
      self,
      ButtplugClientMessage,
      ButtplugCurrentSpecClientMessage,
      ButtplugCurrentSpecServerMessage,
      ButtplugMessage,
      ButtplugMessageSpecVersion,
      ButtplugServerMessage,
      ButtplugSpecV0ClientMessage,
      ButtplugSpecV0ServerMessage,
      ButtplugSpecV1ClientMessage,
      ButtplugSpecV1ServerMessage,
      ButtplugSpecV2ClientMessage,
      ButtplugSpecV2ServerMessage,
    },
  },
  util::json::JSONValidator,
};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::convert::TryFrom;

static MESSAGE_JSON_SCHEMA: &str =
  include_str!("../../../../buttplug-schema/schema/buttplug-schema.json");

/// Creates a [Valico][valico] validator using the built in buttplug message schema.
pub fn create_message_validator() -> JSONValidator {
  JSONValidator::new(MESSAGE_JSON_SCHEMA)
}
pub struct ButtplugServerJSONSerializer {
  pub(super) message_version: RefCell<Option<messages::ButtplugMessageSpecVersion>>,
  validator: JSONValidator,
}

impl Default for ButtplugServerJSONSerializer {
  fn default() -> Self {
    Self {
      message_version: RefCell::new(None),
      validator: create_message_validator(),
    }
  }
}

/// Returns the message as a string in Buttplug JSON Protocol format.
pub fn msg_to_protocol_json<T>(msg: T) -> String
where
  T: ButtplugMessage + Serialize + Deserialize<'static>,
{
  serde_json::to_string(&[&msg]).unwrap()
}

pub fn vec_to_protocol_json<T>(msg: Vec<T>) -> String
where
  T: ButtplugMessage + Serialize + Deserialize<'static>,
{
  serde_json::to_string(&msg).unwrap()
}

fn deserialize_to_message<T>(
  validator: &JSONValidator,
  msg: String,
) -> Result<Vec<T>, ButtplugSerializerError>
where
  T: serde::de::DeserializeOwned + Clone,
{
  // We have to pass back a string formatted error, as SerdeJson's error type
  // isn't clonable.
  validator.validate(&msg).and_then(|_| {
    serde_json::from_str::<Vec<T>>(&msg)
      .map_err(|e| ButtplugSerializerError::JsonSerializerError(format!("{:?}", e)))
  })
}

fn serialize_to_version(
  version: ButtplugMessageSpecVersion,
  msgs: Vec<ButtplugServerMessage>,
) -> ButtplugSerializedMessage {
  ButtplugSerializedMessage::Text(match version {
    ButtplugMessageSpecVersion::Version0 => {
      let msg_vec: Vec<ButtplugSpecV0ServerMessage> = msgs
        .iter()
        .cloned()
        .map(|msg| match ButtplugSpecV0ServerMessage::try_from(msg) {
          Ok(msgv0) => msgv0,
          Err(err) => ButtplugSpecV0ServerMessage::Error(
            messages::Error::from(ButtplugError::from(err)).into(),
          ),
        })
        .collect();
      vec_to_protocol_json(msg_vec)
    }
    ButtplugMessageSpecVersion::Version1 => {
      let msg_vec: Vec<ButtplugSpecV1ServerMessage> = msgs
        .iter()
        .cloned()
        .map(|msg| match ButtplugSpecV1ServerMessage::try_from(msg) {
          Ok(msgv0) => msgv0,
          Err(err) => ButtplugSpecV1ServerMessage::Error(
            messages::Error::from(ButtplugError::from(err)).into(),
          ),
        })
        .collect();
      vec_to_protocol_json(msg_vec)
    }
    ButtplugMessageSpecVersion::Version2 => {
      let msg_vec: Vec<ButtplugSpecV2ServerMessage> = msgs
        .iter()
        .cloned()
        .map(|msg| match ButtplugSpecV2ServerMessage::try_from(msg) {
          Ok(msgv0) => msgv0,
          Err(err) => ButtplugSpecV2ServerMessage::Error(ButtplugError::from(err).into()),
        })
        .collect();
      vec_to_protocol_json(msg_vec)
    }
  })
}

unsafe impl Sync for ButtplugServerJSONSerializer {
}
unsafe impl Send for ButtplugServerJSONSerializer {
}

impl ButtplugMessageSerializer for ButtplugServerJSONSerializer {
  type Inbound = ButtplugClientMessage;
  type Outbound = ButtplugServerMessage;

  fn deserialize(
    &self,
    serialized_msg: ButtplugSerializedMessage,
  ) -> Result<Vec<ButtplugClientMessage>, ButtplugSerializerError> {
    let msg = if let ButtplugSerializedMessage::Text(text_msg) = serialized_msg {
      text_msg
    } else {
      return Err(ButtplugSerializerError::BinaryDeserializationError);
    };
    // If we don't have a message version yet, we need to parse this as a
    // RequestServerInfo message to get the version. RequestServerInfo can
    // always be parsed as the latest message version, as we keep it
    // compatible across versions via serde options.
    if let Some(version) = *self.message_version.borrow() {
      return Ok(match version {
        ButtplugMessageSpecVersion::Version0 => {
          deserialize_to_message::<ButtplugSpecV0ClientMessage>(&self.validator, msg)?
            .iter()
            .cloned()
            .map(|m| m.into())
            .collect()
        }
        ButtplugMessageSpecVersion::Version1 => {
          deserialize_to_message::<ButtplugSpecV1ClientMessage>(&self.validator, msg)?
            .iter()
            .cloned()
            .map(|m| m.into())
            .collect()
        }
        ButtplugMessageSpecVersion::Version2 => {
          deserialize_to_message::<ButtplugSpecV2ClientMessage>(&self.validator, msg)?
            .iter()
            .cloned()
            .map(|m| m.into())
            .collect()
        }
      });
    }
    // instead of using if/else here, return in the if, which drops the borrow.
    // so we can possibly mutate it now.
    let msg_union = deserialize_to_message::<ButtplugSpecV2ClientMessage>(&self.validator, msg)?;
    if let ButtplugSpecV2ClientMessage::RequestServerInfo(rsi) = &msg_union[0] {
      info!(
        "Setting JSON Wrapper message version to {}",
        rsi.message_version()
      );
      *self.message_version.borrow_mut() = Some(rsi.message_version());
    } else {
      return Err(ButtplugSerializerError::MessageSpecVersionNotReceived);
    }
    Ok(msg_union.iter().cloned().map(|m| m.into()).collect())
  }

  fn serialize(&self, msgs: Vec<ButtplugServerMessage>) -> ButtplugSerializedMessage {
    if let Some(version) = *self.message_version.borrow() {
      serialize_to_version(version, msgs)
    } else {
      // In the rare event that there is a problem with the
      // RequestServerInfo message (so we can't set up our known spec
      // version), just encode to the latest and return.
      if let ButtplugServerMessage::Error(_) = &msgs[0] {
        serialize_to_version(ButtplugMessageSpecVersion::Version2, msgs)
      } else {
        // If we don't even have enough info to know which message
        // version to convert to, consider this a handshake error.
        ButtplugSerializedMessage::Text(msg_to_protocol_json(
          ButtplugCurrentSpecServerMessage::Error(
            ButtplugError::from(ButtplugHandshakeError::RequestServerInfoExpected).into(),
          ),
        ))
      }
    }
  }
}

pub struct ButtplugClientJSONSerializer {
  validator: JSONValidator,
}

impl Default for ButtplugClientJSONSerializer {
  fn default() -> Self {
    Self {
      validator: create_message_validator(),
    }
  }
}

unsafe impl Sync for ButtplugClientJSONSerializer {
}
unsafe impl Send for ButtplugClientJSONSerializer {
}

impl ButtplugMessageSerializer for ButtplugClientJSONSerializer {
  type Inbound = ButtplugCurrentSpecServerMessage;
  type Outbound = ButtplugCurrentSpecClientMessage;

  fn deserialize(
    &self,
    msg: ButtplugSerializedMessage,
  ) -> Result<Vec<ButtplugCurrentSpecServerMessage>, ButtplugSerializerError> {
    if let ButtplugSerializedMessage::Text(text_msg) = msg {
      deserialize_to_message::<Self::Inbound>(&self.validator, text_msg)
    } else {
      Err(ButtplugSerializerError::BinaryDeserializationError)
    }
  }

  fn serialize(&self, msg: Vec<ButtplugCurrentSpecClientMessage>) -> ButtplugSerializedMessage {
    ButtplugSerializedMessage::Text(vec_to_protocol_json(msg))
  }
}

#[cfg(test)]
mod test {
  use super::*;
  use crate::core::messages::{RequestServerInfo, BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION};

  #[test]
  fn test_correct_message_version() {
    let json = r#"[{
            "RequestServerInfo": {
                "Id": 1,
                "ClientName": "Test Client",
                "MessageVersion": 2
            }
        }]"#;
    let serializer = ButtplugServerJSONSerializer::default();
    serializer
      .deserialize(ButtplugSerializedMessage::Text(json.to_owned()))
      .unwrap();
    assert_eq!(
      *serializer.message_version.borrow(),
      Some(ButtplugMessageSpecVersion::Version2)
    );
  }

  #[test]
  fn test_wrong_message_version() {
    let json = r#"[{
            "RequestServerInfo": {
                "Id": 1,
                "ClientName": "Test Client",
                "MessageVersion": 100
            }
        }]"#;
    let serializer = ButtplugServerJSONSerializer::default();
    let msg = serializer.deserialize(ButtplugSerializedMessage::Text(json.to_owned()));
    assert!(msg.is_err());
  }

  #[test]
  fn test_client_incorrect_messages() {
    let incorrect_incoming_messages = vec![
      // Not valid JSON
      "not a json message",
      // Valid json object but no contents
      "{}",
      // Valid json but not an object
      "[]",
      // Not a message type
      "[{\"NotAMessage\":{}}]",
      // Valid json and message type but not in correct format
      "[{\"Ok\":[]}]",
      // Valid json and message type but not in correct format
      "[{\"Ok\":{}}]",
      // Valid json and message type but not an array.
      "{\"Ok\":{\"Id\":0}}",
      // Valid json and message type but not an array.
      // TODO This should fail (Ok can't have an Id of 0), but currently doesn't.
      // "[{\"Ok\":{\"Id\":0}}]",
      // Valid json and message type but with extra content
      "[{\"Ok\":{\"NotAField\":\"NotAValue\",\"Id\":1}}]",
    ];
    let serializer = ButtplugClientJSONSerializer::default();
    let _ = serializer.serialize(vec![RequestServerInfo::new(
      "test client",
      BUTTPLUG_CURRENT_MESSAGE_SPEC_VERSION,
    )
    .into()]);
    for msg in incorrect_incoming_messages {
      let res = serializer.deserialize(ButtplugSerializedMessage::Text(msg.to_owned()));
      assert!(res.is_err());
      if let Err(ButtplugSerializerError::MessageSpecVersionNotReceived) = res {
        assert!(false, "Wrong error!");
      }
    }
  }
}