use std::io::{Read, Write};
use serde::{Deserialize, Serialize};
use crate::message::base::ProtoError;
use crate::message::wire;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Auth {
pub scheme: String,
pub param: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct Hello {
#[serde(rename = "MAC")]
pub mac: String,
pub host_name: String,
pub version: String,
pub client_name: String,
#[serde(rename = "OS")]
pub os: String,
pub arch: String,
pub instance: u32,
#[serde(rename = "ID")]
pub id: String,
pub snap_stream_protocol_version: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth: Option<Auth>,
}
impl Hello {
pub fn wire_size(&self) -> u32 {
let json = serde_json::to_string(self).unwrap_or_default();
wire::string_wire_size(&json)
}
pub fn read_from<R: Read>(r: &mut R) -> Result<Self, ProtoError> {
let json_str = wire::read_string(r)?;
serde_json::from_str(&json_str)
.map_err(|e| ProtoError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))
}
pub fn write_to<W: Write>(&self, w: &mut W) -> Result<(), ProtoError> {
let json_str = serde_json::to_string(self)
.map_err(|e| ProtoError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
wire::write_string(w, &json_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_hello() -> Hello {
Hello {
mac: "00:11:22:33:44:55".into(),
host_name: "testhost".into(),
version: "0.32.0".into(),
client_name: "Snapclient".into(),
os: "Linux".into(),
arch: "x86_64".into(),
instance: 1,
id: "00:11:22:33:44:55".into(),
snap_stream_protocol_version: 2,
auth: None,
}
}
#[test]
fn round_trip() {
let original = sample_hello();
let mut buf = Vec::new();
original.write_to(&mut buf).unwrap();
let mut cursor = std::io::Cursor::new(&buf);
let decoded = Hello::read_from(&mut cursor).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn round_trip_with_auth() {
let mut hello = sample_hello();
hello.auth = Some(Auth {
scheme: "Basic".into(),
param: "dXNlcjpwYXNz".into(),
});
let mut buf = Vec::new();
hello.write_to(&mut buf).unwrap();
let mut cursor = std::io::Cursor::new(&buf);
let decoded = Hello::read_from(&mut cursor).unwrap();
assert_eq!(hello, decoded);
assert_eq!(decoded.auth.unwrap().scheme, "Basic");
}
#[test]
fn json_field_names_match_cpp() {
let hello = sample_hello();
let json_str = serde_json::to_string(&hello).unwrap();
assert!(json_str.contains("\"MAC\""));
assert!(json_str.contains("\"HostName\""));
assert!(json_str.contains("\"SnapStreamProtocolVersion\""));
assert!(json_str.contains("\"ID\""));
assert!(json_str.contains("\"OS\""));
assert!(!json_str.contains("\"Auth\""));
}
#[test]
fn deserialize_cpp_json() {
let json = r#"{"Arch":"x86_64","ClientName":"Snapclient","HostName":"myhost","ID":"aa:bb:cc:dd:ee:ff","Instance":1,"MAC":"aa:bb:cc:dd:ee:ff","OS":"Arch Linux","SnapStreamProtocolVersion":2,"Version":"0.32.0"}"#;
let hello: Hello = serde_json::from_str(json).unwrap();
assert_eq!(hello.mac, "aa:bb:cc:dd:ee:ff");
assert_eq!(hello.snap_stream_protocol_version, 2);
assert!(hello.auth.is_none());
}
}