kassandra_shared/communication/
mod.rs

1//! These are the generic components for framing and serializing data passed
2//! between the different service components. This is especially necessary between
3//! a host environment and an enclave as enclaves may be resource constrained, making
4//! higher level abstractions unavailable.
5
6#[cfg(feature = "std")]
7pub mod tcp;
8
9use alloc::string::String;
10use alloc::vec;
11use alloc::vec::Vec;
12
13use fmd::fmd2_compact::FlagCiphertexts;
14use serde::de::{DeserializeOwned, Error};
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16use thiserror::Error;
17
18use crate::db::{EncryptedResponse, Index};
19use crate::ratls::TlsCiphertext;
20
21#[derive(Debug, Copy, Clone)]
22pub struct HexBytes<const N: usize>(pub [u8; N]);
23
24impl<const N: usize> From<[u8; N]> for HexBytes<N> {
25    fn from(value: [u8; N]) -> Self {
26        Self(value)
27    }
28}
29
30macro_rules! impl_serde {
31    ($n:literal) => {
32        impl Serialize for HexBytes<$n> {
33            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34            where
35                S: Serializer,
36            {
37                serializer.serialize_str(&hex::encode(self.0))
38            }
39        }
40
41        impl<'de> Deserialize<'de> for HexBytes<$n> {
42            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43            where
44                D: Deserializer<'de>,
45            {
46                let s = String::deserialize(deserializer)?;
47                Ok(Self(
48                    hex::decode(s.as_bytes())
49                        .map_err(|_| Error::custom("Invalid hex"))?
50                        .try_into()
51                        .map_err(|_| Error::custom("Bytes were of wrong size"))?,
52                ))
53            }
54        }
55    };
56}
57
58impl_serde!(32);
59impl_serde!(64);
60
61/// Messages to host environment from the enclave
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub enum MsgToHost {
64    Basic(String),
65    Error(String),
66    ErrorForClient(String),
67    RATLS { report: Vec<u8> },
68    Report(Vec<u8>),
69    KeyRegSuccess,
70    BlockRequests(Vec<u64>),
71    FmdResults(Vec<EncryptedResponse>),
72}
73
74/// Messages from host environment to the enclave
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum MsgFromHost {
77    Basic(String),
78    RegisterKey {
79        nonce: u64,
80        pk: HexBytes<32>,
81    },
82    RequestReport {
83        user_data: HexBytes<64>,
84    },
85    RATLSAck(AckType),
86    RequiredBlocks,
87    RequestedFlags {
88        synced_to: u64,
89        flags: Vec<(Index, Option<FlagCiphertexts>)>,
90    },
91}
92
93/// Messages from clients to hosts
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum ClientMsg {
96    /// Gives the clients public part of the shared key
97    /// and requests the enclaves part.
98    RegisterKey {
99        nonce: u64,
100        pk: HexBytes<32>,
101    },
102    RequestReport {
103        user_data: HexBytes<64>,
104    },
105    RATLSAck(AckType),
106    /// Request the host's UUID
107    RequestUUID,
108    RequestIndices {
109        key_hash: String,
110    },
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub enum AckType {
115    Success(TlsCiphertext),
116    Fail,
117}
118
119/// Messages from hosts to clients
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum ServerMsg {
122    /// The raw report bytes
123    RATLS {
124        report: Vec<u8>,
125    },
126    Error(String),
127    KeyRegSuccess,
128    UUID(String),
129    IndicesResponse(EncryptedResponse),
130}
131
132impl<'a> TryFrom<&'a ClientMsg> for MsgFromHost {
133    type Error = &'static str;
134
135    fn try_from(msg: &'a ClientMsg) -> Result<Self, Self::Error> {
136        match msg {
137            ClientMsg::RegisterKey { nonce, pk } => Ok(MsgFromHost::RegisterKey {
138                nonce: *nonce,
139                pk: *pk,
140            }),
141            ClientMsg::RequestReport { user_data } => Ok(MsgFromHost::RequestReport {
142                user_data: *user_data,
143            }),
144            ClientMsg::RATLSAck(v) => Ok(MsgFromHost::RATLSAck(v.clone())),
145            _ => Err("Message not intended for enclave"),
146        }
147    }
148}
149
150impl TryFrom<MsgToHost> for ServerMsg {
151    type Error = &'static str;
152
153    fn try_from(msg: MsgToHost) -> Result<Self, &'static str> {
154        match msg {
155            MsgToHost::RATLS { report } => Ok(ServerMsg::RATLS { report }),
156            MsgToHost::ErrorForClient(err) => Ok(ServerMsg::Error(err)),
157            MsgToHost::KeyRegSuccess => Ok(ServerMsg::KeyRegSuccess),
158            _ => Err("Message not intended for client"),
159        }
160    }
161}
162
163#[derive(Error, Debug)]
164pub enum MsgError {
165    #[error("COBS failed to decode message from COM 2 with: {0}")]
166    Decode(cobs::DecodeError),
167    #[error("Failed to deserialize CBOR with: {0}")]
168    Deserialize(serde_cbor::Error),
169    #[error("Input bytes were not valid utf-8: {0:?}")]
170    Utf8(Vec<u8>),
171}
172
173pub struct Frame {
174    pub bytes: Vec<u8>,
175}
176
177impl Frame {
178    pub fn deserialize<T: DeserializeOwned>(self) -> Result<T, MsgError> {
179        serde_cbor::from_slice(&self.bytes).map_err(MsgError::Deserialize)
180    }
181}
182
183/// A trait for getting the next byte in a byte stream
184pub trait ReadWriteByte {
185    const FRAME_BUF_SIZE: usize = 1024;
186    fn read_byte(&mut self) -> u8;
187
188    fn write_bytes(&mut self, buf: &[u8]);
189}
190
191/// A trait for reading / writing framed data from a byte stream.
192/// This trait should not be implemented directly, but rely on
193/// the default implementation.
194pub trait FramedBytes: ReadWriteByte {
195    /// Blocking method that reads a frame
196    ///
197    /// Uses an initial buffer with 1Kb in size. Dynamically increases the
198    /// size of the frame buffer by 1Kb until either the message is decoded
199    /// or an error occurs.
200    ///
201    /// Returns the raw framed bytes
202    fn get_frame(&mut self) -> Result<Frame, MsgError> {
203        // initial buffer size for the frame
204        let mut buf_size = Self::FRAME_BUF_SIZE;
205        // keep track of bytes processed so far incase we need to increase
206        // buffer size
207        let mut read_bytes = Vec::<u8>::with_capacity(buf_size);
208        // continue trying to populate the frame buffer until
209        // a successful frame decoding or a decode error occurs.
210        loop {
211            // initial buffer
212            let mut frame_buf = vec![0u8; buf_size];
213            let mut decoder = cobs::CobsDecoder::new(&mut frame_buf);
214            decoder
215                .push(&read_bytes)
216                .expect("Previously read bytes should not produce a frame error.");
217
218            loop {
219                let b = self.read_byte();
220                read_bytes.push(b);
221                match decoder.feed(b) {
222                    Ok(None) => continue,
223                    Ok(Some(len)) => {
224                        frame_buf.truncate(len);
225                        return Ok(Frame { bytes: frame_buf });
226                    }
227                    Err(cobs::DecodeError::TargetBufTooSmall) => {
228                        // increase the buffer size ny 1Kb
229                        buf_size += Self::FRAME_BUF_SIZE;
230                        break;
231                    }
232                    Err(e) => return Err(MsgError::Decode(e)),
233                }
234            }
235        }
236    }
237
238    /// Write a serializable message out to the serial port in CBOR,
239    /// framed with COBS.
240    fn write_frame<T: Serialize>(&mut self, msg: &T) {
241        let data = serde_cbor::to_vec(&msg).unwrap();
242        let mut encoded = cobs::encode_vec_with_sentinel(&data, 0);
243        encoded.push(0);
244        self.write_bytes(&encoded);
245    }
246}
247
248impl<T: ReadWriteByte> FramedBytes for T {}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use alloc::string::ToString;
254
255    struct MockChannel(Vec<u8>);
256
257    impl ReadWriteByte for MockChannel {
258        const FRAME_BUF_SIZE: usize = 10;
259        fn read_byte(&mut self) -> u8 {
260            self.0.remove(0)
261        }
262
263        fn write_bytes(&mut self, buf: &[u8]) {
264            self.0.extend_from_slice(buf);
265        }
266    }
267
268    /// Test that if the data we are decoding does not initially
269    /// fit into the frame buffer, we dynamically resize it until the
270    /// data fits and decoding is successful.
271    #[test]
272    fn test_dynamic_frame_resizing() {
273        let msg = MsgFromHost::Basic("Test".to_string());
274        let data = serde_cbor::to_vec(&msg).expect("Test failed");
275        let mut encoded = cobs::encode_vec_with_sentinel(&data, 0);
276        encoded.push(0);
277        let mut channel = MockChannel(encoded);
278        let frame = channel.get_frame().expect("Test failed");
279        let Ok(MsgFromHost::Basic(str)) = frame.deserialize() else {
280            panic!("Test failed");
281        };
282        assert_eq!(str, "Test");
283    }
284}