use anyhow::{anyhow, Context, Result};
use quick_protobuf::message::MessageRead;
use quick_protobuf::BytesReader;
use crate::usp::{self, Error, Msg, Notify};
use crate::usp_record::mod_Record::OneOfrecord_type;
use crate::usp_record::{NoSessionContextRecord, Record, SessionContextRecord};
pub fn try_decode_record(bytes: &[u8]) -> Result<Record> {
let mut reader = BytesReader::from_bytes(bytes);
Record::from_reader(&mut reader, bytes).context("while parsing protobuf as USP Record")
}
pub fn try_decode_msg(bytes: &[u8]) -> Result<Msg> {
let mut reader = BytesReader::from_bytes(bytes);
Msg::from_reader(&mut reader, bytes).context("while parsing protobuf as USP Message")
}
impl Msg {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let this = try_decode_msg(bytes)?;
this.check_validity()?;
Ok(this)
}
#[must_use]
pub fn msg_id(&self) -> &str {
self.header
.as_ref()
.map_or("", |header| header.msg_id.as_ref())
}
#[must_use]
pub const fn is_request(&self) -> bool {
if let Some(body) = self.body.as_ref() {
matches!(&body.msg_body, usp::mod_Body::OneOfmsg_body::request(_))
} else {
false
}
}
#[must_use]
pub const fn is_notify_request(&self) -> bool {
self.get_notify_request().is_some()
}
#[must_use]
pub const fn get_notify_request(&self) -> Option<&Notify> {
if let Some(body) = self.body.as_ref() {
if let usp::mod_Body::OneOfmsg_body::request(request) = &body.msg_body {
if let usp::mod_Request::OneOfreq_type::notify(notify) = &request.req_type {
return Some(notify);
}
}
}
None
}
#[must_use]
pub const fn is_response(&self) -> bool {
if let Some(body) = self.body.as_ref() {
matches!(&body.msg_body, usp::mod_Body::OneOfmsg_body::response(_))
} else {
false
}
}
#[must_use]
pub fn is_error(&self) -> bool {
self.get_error().is_some()
}
#[must_use]
pub fn get_error(&self) -> Option<Error> {
if let Some(body) = self.body.as_ref() {
if let usp::mod_Body::OneOfmsg_body::error(error) = &body.msg_body {
return Some(error.clone());
}
}
None
}
pub fn check_validity(&self) -> Result<()> {
use crate::usp::mod_Body::OneOfmsg_body;
use crate::usp::mod_Request::OneOfreq_type;
use crate::usp::mod_Response::OneOfresp_type;
use crate::usp::{Request, Response};
self.header
.as_ref()
.filter(|h| !h.msg_id.is_empty())
.ok_or_else(|| anyhow!("Empty message ID"))?;
let body = self
.body
.as_ref()
.filter(|b| !matches!(b.msg_body, OneOfmsg_body::None))
.ok_or_else(|| anyhow!("Invalid message body"))?;
match body.msg_body {
OneOfmsg_body::request(Request {
req_type: OneOfreq_type::None,
}) => Err(anyhow!("Invalid Request message")),
OneOfmsg_body::response(Response {
resp_type: OneOfresp_type::None,
}) => Err(anyhow!("Invalid Response message")),
_ => Ok(()),
}
}
}
impl SessionContextRecord {
pub fn payload_flatten(&mut self) -> &mut Vec<u8> {
if self.payload.len() != 1 {
let old = std::mem::take(&mut self.payload);
self.payload = vec![old.into_iter().flatten().collect()];
}
&mut self.payload[0]
}
}
impl Record {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let this = try_decode_record(bytes)?;
this.check_validity()?;
Ok(this)
}
pub fn check_validity(&self) -> Result<()> {
if self.version.is_empty() {
return Err(anyhow!("Invalid USP version"));
}
if self.to_id.is_empty() {
return Err(anyhow!("Invalid to_id field in Record"));
}
if self.from_id.is_empty() {
return Err(anyhow!("Invalid from_id field in Record"));
}
match &self.record_type {
OneOfrecord_type::None => Err(anyhow!("Invalid Record type")),
OneOfrecord_type::no_session_context(NoSessionContextRecord { payload })
if payload.is_empty() =>
{
Err(anyhow!(
"NoSessionContext Record containing an empty payload"
))
}
_ => Ok(()),
}
}
pub fn payload_flatten(&mut self) -> Option<&mut Vec<u8>> {
use crate::usp_record::mod_Record::OneOfrecord_type;
match &mut self.record_type {
OneOfrecord_type::no_session_context(no_session) => Some(&mut no_session.payload),
OneOfrecord_type::session_context(session) => Some(session.payload_flatten()),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_msg_id() {
let raw = [
0x0a, 0x02, 0x10, 0x01, 0x12, 0x28, 0x0a, 0x26, 0x0a, 0x24, 0x0a, 0x22, 0x44, 0x65,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x6e, 0x66,
0x6f, 0x2e, 0x53, 0x6f, 0x66, 0x74, 0x77, 0x61, 0x72, 0x65, 0x56, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x2e,
];
let msg = try_decode_msg(&raw)
.expect("raw should be a valid USP Message according to the protobuf schema");
assert!(msg.check_validity().is_err());
}
#[test]
fn invalid_record_to_id() {
let raw = [
0x0a, 0x03, 0x31, 0x2e, 0x33, 0x1a, 0x09, 0x64, 0x6f, 0x63, 0x3a, 0x3a, 0x66, 0x72,
0x6f, 0x6d, 0x3a, 0x35, 0x12, 0x33, 0x0a, 0x07, 0x0a, 0x03, 0x67, 0x65, 0x74, 0x10,
0x01, 0x12, 0x28, 0x0a, 0x26, 0x0a, 0x24, 0x0a, 0x22, 0x44, 0x65, 0x76, 0x69, 0x63,
0x65, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x53,
0x6f, 0x66, 0x74, 0x77, 0x61, 0x72, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x2e,
];
let record = try_decode_record(&raw)
.expect("raw should be a valid Record according to the protobuf schema");
assert!(record.check_validity().is_err());
let OneOfrecord_type::no_session_context(NoSessionContextRecord { payload: msg_raw }) =
record.record_type
else {
panic!("Record should have a NoSessionContext type");
};
let msg = try_decode_msg(&msg_raw)
.expect("msg_raw should be a valid USP Message according to the protobuf schema");
msg.check_validity().unwrap();
}
#[test]
fn invalid_record_from_id() {
let raw = [
0x0a, 0x03, 0x31, 0x2e, 0x33, 0x12, 0x07, 0x64, 0x6f, 0x63, 0x3a, 0x3a, 0x74, 0x6f,
0x3a, 0x35, 0x12, 0x33, 0x0a, 0x07, 0x0a, 0x03, 0x67, 0x65, 0x74, 0x10, 0x01, 0x12,
0x28, 0x0a, 0x26, 0x0a, 0x24, 0x0a, 0x22, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2e,
0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x53, 0x6f, 0x66,
0x74, 0x77, 0x61, 0x72, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x2e,
];
let record = try_decode_record(&raw)
.expect("raw should be a valid Record according to the protobuf schema");
assert!(record.check_validity().is_err());
let OneOfrecord_type::no_session_context(NoSessionContextRecord { payload: msg_raw }) =
record.record_type
else {
panic!("Record should have a NoSessionContext type");
};
let msg = try_decode_msg(&msg_raw)
.expect("msg_raw should be a valid USP Message according to the protobuf schema");
msg.check_validity().unwrap();
}
#[test]
fn flat_record() {
use crate::usp_record::mod_Record::OneOfrecord_type;
let raw = [
0x0a, 0x03, 0x31, 0x2e, 0x33, 0x12, 0x07, 0x64, 0x6f, 0x63, 0x3a, 0x3a, 0x74, 0x6f,
0x1a, 0x09, 0x64, 0x6f, 0x63, 0x3a, 0x3a, 0x66, 0x72, 0x6f, 0x6d, 0x42, 0x3c, 0x08,
0xd2, 0x09, 0x10, 0x01, 0x18, 0x02, 0x3a, 0x33, 0x0a, 0x07, 0x0a, 0x03, 0x67, 0x65,
0x74, 0x10, 0x01, 0x12, 0x28, 0x0a, 0x26, 0x0a, 0x24, 0x0a, 0x22, 0x44, 0x65, 0x76,
0x69, 0x63, 0x65, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f,
0x2e, 0x53, 0x6f, 0x66, 0x74, 0x77, 0x61, 0x72, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69,
0x6f, 0x6e, 0x2e,
];
let mut record = Record::from_bytes(&raw).expect("raw should be a valid Record");
let payload = record.payload_flatten().unwrap().clone();
assert!(!payload.is_empty());
let split = payload.chunks(2).map(Vec::from).collect::<Vec<_>>();
assert!(split.len() > 1);
let OneOfrecord_type::session_context(session) = &mut record.record_type else {
panic!("Record should be of type SessionContext");
};
session.payload = split;
let flatten = record.payload_flatten().unwrap();
assert_eq!(flatten, &payload);
}
}