use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use log::{debug, error, info, warn};
use bytes::{BufMut, Bytes, BytesMut};
use rustdds::{
dds::{ReadError, ReadResult, WriteError, WriteResult},
rpc::*,
serialization::deserialize_from_cdr_with_rep_id,
*,
};
use crate::{message::Message, message_info::MessageInfo};
use super::{request_id, RmwRequestId, ServiceMapping};
pub(super) trait Wrapper {
fn from_bytes_and_ri(input_bytes: &[u8], encoding: RepresentationIdentifier) -> Self;
fn bytes(&self) -> Bytes;
}
pub(crate) struct RequestWrapper<R> {
serialized_message: Bytes,
encoding: RepresentationIdentifier,
phantom: PhantomData<R>,
}
impl<R: Message> Wrapper for RequestWrapper<R> {
fn from_bytes_and_ri(input_bytes: &[u8], encoding: RepresentationIdentifier) -> Self {
RequestWrapper {
serialized_message: Bytes::copy_from_slice(input_bytes), encoding,
phantom: PhantomData,
}
}
fn bytes(&self) -> Bytes {
self.serialized_message.clone()
}
}
impl<R: Message> RequestWrapper<R> {
pub(super) fn unwrap(
&self,
service_mapping: ServiceMapping,
message_info: &MessageInfo,
) -> ReadResult<(RmwRequestId, R)> {
match service_mapping {
ServiceMapping::Basic => {
let mut bytes = self.serialized_message.clone(); let (header, header_size) =
deserialize_from_cdr_with_rep_id::<BasicRequestHeader>(&bytes, self.encoding)?;
if bytes.len() < header_size {
read_error_deserialization!("Service request too short")
} else {
let _header_bytes = bytes.split_off(header_size);
let (request, _request_bytes) =
deserialize_from_cdr_with_rep_id::<R>(&bytes, self.encoding)?;
Ok((RmwRequestId::from(header.request_id), request))
}
}
ServiceMapping::Enhanced => {
let (request, _request_bytes) =
deserialize_from_cdr_with_rep_id::<R>(&self.serialized_message, self.encoding)?;
let mut rmw_req_id = RmwRequestId::from(
message_info.related_sample_identity()
.unwrap_or_else(|| {
let backup_identity = message_info.sample_identity();
warn!("RequestWrapper::unwrap: related_sample_identity missing. Using sample_identity = {backup_identity:?}");
backup_identity
})
);
if rmw_req_id.sequence_number == SequenceNumber::UNKNOWN {
rmw_req_id.sequence_number = message_info.sample_identity().sequence_number;
}
Ok((rmw_req_id, request))
}
ServiceMapping::Cyclone => cyclone_unwrap::<R>(
self.serialized_message.clone(),
message_info.writer_guid(),
self.encoding,
),
}
}
pub(super) fn new(
service_mapping: ServiceMapping,
r_id: RmwRequestId,
encoding: RepresentationIdentifier,
request: R,
) -> WriteResult<Self, ()> {
let mut ser_buffer = BytesMut::with_capacity(std::mem::size_of::<R>() * 3 / 2).writer();
match service_mapping {
ServiceMapping::Basic => {
let basic_header = BasicRequestHeader::new(r_id.into());
serialization::to_writer_with_rep_id(&mut ser_buffer, &basic_header, encoding)?;
}
ServiceMapping::Enhanced => {
}
ServiceMapping::Cyclone => {
let cyclone_header = CycloneHeader::new(r_id);
serialization::to_writer_with_rep_id(&mut ser_buffer, &cyclone_header, encoding)?;
}
}
serialization::to_writer_with_rep_id(&mut ser_buffer, &request, encoding)?;
Ok(RequestWrapper {
serialized_message: ser_buffer.into_inner().freeze(),
encoding,
phantom: PhantomData,
})
}
}
pub(crate) struct ResponseWrapper<R> {
serialized_message: Bytes,
encoding: RepresentationIdentifier,
phantom: PhantomData<R>,
}
impl<R: Message> Wrapper for ResponseWrapper<R> {
fn from_bytes_and_ri(input_bytes: &[u8], encoding: RepresentationIdentifier) -> Self {
ResponseWrapper {
serialized_message: Bytes::copy_from_slice(input_bytes), encoding,
phantom: PhantomData,
}
}
fn bytes(&self) -> Bytes {
self.serialized_message.clone()
}
}
impl<R: Message> ResponseWrapper<R> {
pub(super) fn unwrap(
&self,
service_mapping: ServiceMapping,
message_info: MessageInfo,
client_guid: GUID,
) -> ReadResult<(RmwRequestId, R)> {
match service_mapping {
ServiceMapping::Basic => {
let mut bytes = self.serialized_message.clone(); let (header, header_size) =
deserialize_from_cdr_with_rep_id::<BasicReplyHeader>(&bytes, self.encoding)?;
if bytes.len() < header_size {
read_error_deserialization!("Service response too short")
} else {
let _header_bytes = bytes.split_off(header_size);
let (response, _bytes) = deserialize_from_cdr_with_rep_id::<R>(&bytes, self.encoding)?;
Ok((RmwRequestId::from(header.related_request_id), response))
}
}
ServiceMapping::Enhanced => {
let (response, _response_bytes) =
deserialize_from_cdr_with_rep_id::<R>(&self.serialized_message, self.encoding)?;
let related_sample_identity = match message_info.related_sample_identity() {
Some(rsi) => rsi,
None => {
return read_error_deserialization!("ServiceMapping=Enhanced, but response message did not have related_sample_identity parameter!")
}
};
Ok((RmwRequestId::from(related_sample_identity), response))
}
ServiceMapping::Cyclone => {
let mut client_guid_bytes = [0; 16];
{
let (first_half, second_half) = client_guid_bytes.split_at_mut(8);
first_half.copy_from_slice(&client_guid.to_bytes().as_slice()[0..8]);
second_half.copy_from_slice(&message_info.writer_guid().to_bytes()[8..16]);
}
let client_guid = GUID::from_bytes(client_guid_bytes);
cyclone_unwrap::<R>(self.serialized_message.clone(), client_guid, self.encoding)
}
}
}
pub(super) fn new(
service_mapping: ServiceMapping,
r_id: RmwRequestId,
encoding: RepresentationIdentifier,
response: R,
) -> WriteResult<Self, ()> {
let mut ser_buffer = BytesMut::with_capacity(std::mem::size_of::<R>() * 3 / 2).writer();
match service_mapping {
ServiceMapping::Basic => {
let basic_header = BasicReplyHeader::new(r_id.into());
serialization::to_writer_with_rep_id(&mut ser_buffer, &basic_header, encoding)?;
}
ServiceMapping::Enhanced => {
}
ServiceMapping::Cyclone => {
let cyclone_header = CycloneHeader::new(r_id);
serialization::to_writer_with_rep_id(&mut ser_buffer, &cyclone_header, encoding)?;
}
}
serialization::to_writer_with_rep_id(&mut ser_buffer, &response, encoding)?;
let serialized_message = ser_buffer.into_inner().freeze();
Ok(ResponseWrapper {
serialized_message,
encoding,
phantom: PhantomData,
})
}
}
#[derive(Serialize, Deserialize)]
pub struct BasicRequestHeader {
request_id: SampleIdentity,
instance_name: String, }
impl BasicRequestHeader {
fn new(request_id: SampleIdentity) -> Self {
BasicRequestHeader {
request_id,
instance_name: "".to_string(),
}
}
}
impl Message for BasicRequestHeader {}
#[derive(Serialize, Deserialize)]
pub struct BasicReplyHeader {
related_request_id: SampleIdentity,
remote_exception_code: u32,
}
impl BasicReplyHeader {
fn new(related_request_id: SampleIdentity) -> Self {
BasicReplyHeader {
related_request_id,
remote_exception_code: 0,
}
}
}
impl Message for BasicReplyHeader {}
#[derive(Serialize, Deserialize)]
pub struct CycloneHeader {
guid_second_half: [u8; 8], sequence_number_high: i32,
sequence_number_low: u32,
}
impl CycloneHeader {
fn new(r_id: RmwRequestId) -> Self {
let sn = r_id.sequence_number;
let mut guid_second_half = [0; 8];
guid_second_half.copy_from_slice(&r_id.writer_guid.to_bytes()[8..16]);
CycloneHeader {
guid_second_half,
sequence_number_high: sn.high(),
sequence_number_low: sn.low(),
}
}
}
impl Message for CycloneHeader {}
fn cyclone_unwrap<R: Message>(
serialized_message: Bytes,
writer_guid: GUID,
encoding: RepresentationIdentifier,
) -> ReadResult<(RmwRequestId, R)> {
let mut bytes = serialized_message; let (header, header_size) = deserialize_from_cdr_with_rep_id::<CycloneHeader>(&bytes, encoding)?;
if bytes.len() < header_size {
read_error_deserialization!("Service message too short")
} else {
let _header_bytes = bytes.split_off(header_size);
let (response, _response_bytes) = deserialize_from_cdr_with_rep_id::<R>(&bytes, encoding)?;
let req_id = RmwRequestId {
writer_guid, sequence_number: request_id::SequenceNumber::from_high_low(
header.sequence_number_high,
header.sequence_number_low,
),
};
Ok((req_id, response))
}
}
pub(super) type SimpleDataReaderR<RW> =
no_key::SimpleDataReader<RW, ServiceDeserializerAdapter<RW>>;
pub(super) type DataWriterR<RW> = no_key::DataWriter<RW, ServiceSerializerAdapter<RW>>;
pub(super) struct ServiceDeserializerAdapter<RW> {
phantom: PhantomData<RW>,
}
pub(super) struct ServiceSerializerAdapter<RW> {
phantom: PhantomData<RW>,
}
impl<RW> ServiceDeserializerAdapter<RW> {
const REPR_IDS: [RepresentationIdentifier; 2] = [
RepresentationIdentifier::CDR_BE,
RepresentationIdentifier::CDR_LE,
];
}
impl<RW: Wrapper> no_key::DeserializerAdapter<RW> for ServiceDeserializerAdapter<RW> {
type Error = ReadError;
type Decoded = RW;
fn supported_encodings() -> &'static [RepresentationIdentifier] {
&Self::REPR_IDS
}
fn transform_decoded(decoded: Self::Decoded) -> RW {
decoded
}
}
impl<RW: Wrapper> no_key::DefaultDecoder<RW> for ServiceDeserializerAdapter<RW> {
type Decoder = WrapperDecoder;
const DECODER: Self::Decoder = WrapperDecoder;
}
#[derive(Clone)]
pub struct WrapperDecoder;
impl<RW> no_key::Decode<RW> for WrapperDecoder
where
RW: Wrapper,
{
type Error = ReadError;
fn decode_bytes(
self,
input_bytes: &[u8],
encoding: RepresentationIdentifier,
) -> Result<RW, Self::Error> {
Ok(RW::from_bytes_and_ri(input_bytes, encoding))
}
}
impl<RW: Wrapper> no_key::SerializerAdapter<RW> for ServiceSerializerAdapter<RW> {
type Error = WriteError<()>;
fn output_encoding() -> RepresentationIdentifier {
RepresentationIdentifier::CDR_LE
}
fn to_bytes(value: &RW) -> WriteResult<Bytes, ()> {
Ok(value.bytes())
}
}