kassandra_shared/communication/
mod.rs1#[cfg(feature = "std")]
7pub mod tcp;
8
9use alloc::string::String;
10use alloc::vec;
11use alloc::vec::Vec;
12
13use fmd::fmd2_compact::FlagCiphertexts;
14use serde::de::{DeserializeOwned, Error};
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16use thiserror::Error;
17
18use crate::db::{EncryptedResponse, Index};
19use crate::ratls::TlsCiphertext;
20
21#[derive(Debug, Copy, Clone)]
22pub struct HexBytes<const N: usize>(pub [u8; N]);
23
24impl<const N: usize> From<[u8; N]> for HexBytes<N> {
25 fn from(value: [u8; N]) -> Self {
26 Self(value)
27 }
28}
29
30macro_rules! impl_serde {
31 ($n:literal) => {
32 impl Serialize for HexBytes<$n> {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 serializer.serialize_str(&hex::encode(self.0))
38 }
39 }
40
41 impl<'de> Deserialize<'de> for HexBytes<$n> {
42 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43 where
44 D: Deserializer<'de>,
45 {
46 let s = String::deserialize(deserializer)?;
47 Ok(Self(
48 hex::decode(s.as_bytes())
49 .map_err(|_| Error::custom("Invalid hex"))?
50 .try_into()
51 .map_err(|_| Error::custom("Bytes were of wrong size"))?,
52 ))
53 }
54 }
55 };
56}
57
58impl_serde!(32);
59impl_serde!(64);
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub enum MsgToHost {
64 Basic(String),
65 Error(String),
66 ErrorForClient(String),
67 RATLS { report: Vec<u8> },
68 Report(Vec<u8>),
69 KeyRegSuccess,
70 BlockRequests(Vec<u64>),
71 FmdResults(Vec<EncryptedResponse>),
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum MsgFromHost {
77 Basic(String),
78 RegisterKey {
79 nonce: u64,
80 pk: HexBytes<32>,
81 },
82 RequestReport {
83 user_data: HexBytes<64>,
84 },
85 RATLSAck(AckType),
86 RequiredBlocks,
87 RequestedFlags {
88 synced_to: u64,
89 flags: Vec<(Index, Option<FlagCiphertexts>)>,
90 },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum ClientMsg {
96 RegisterKey {
99 nonce: u64,
100 pk: HexBytes<32>,
101 },
102 RequestReport {
103 user_data: HexBytes<64>,
104 },
105 RATLSAck(AckType),
106 RequestUUID,
108 RequestIndices {
109 key_hash: String,
110 },
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub enum AckType {
115 Success(TlsCiphertext),
116 Fail,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum ServerMsg {
122 RATLS {
124 report: Vec<u8>,
125 },
126 Error(String),
127 KeyRegSuccess,
128 UUID(String),
129 IndicesResponse(EncryptedResponse),
130}
131
132impl<'a> TryFrom<&'a ClientMsg> for MsgFromHost {
133 type Error = &'static str;
134
135 fn try_from(msg: &'a ClientMsg) -> Result<Self, Self::Error> {
136 match msg {
137 ClientMsg::RegisterKey { nonce, pk } => Ok(MsgFromHost::RegisterKey {
138 nonce: *nonce,
139 pk: *pk,
140 }),
141 ClientMsg::RequestReport { user_data } => Ok(MsgFromHost::RequestReport {
142 user_data: *user_data,
143 }),
144 ClientMsg::RATLSAck(v) => Ok(MsgFromHost::RATLSAck(v.clone())),
145 _ => Err("Message not intended for enclave"),
146 }
147 }
148}
149
150impl TryFrom<MsgToHost> for ServerMsg {
151 type Error = &'static str;
152
153 fn try_from(msg: MsgToHost) -> Result<Self, &'static str> {
154 match msg {
155 MsgToHost::RATLS { report } => Ok(ServerMsg::RATLS { report }),
156 MsgToHost::ErrorForClient(err) => Ok(ServerMsg::Error(err)),
157 MsgToHost::KeyRegSuccess => Ok(ServerMsg::KeyRegSuccess),
158 _ => Err("Message not intended for client"),
159 }
160 }
161}
162
163#[derive(Error, Debug)]
164pub enum MsgError {
165 #[error("COBS failed to decode message from COM 2 with: {0}")]
166 Decode(cobs::DecodeError),
167 #[error("Failed to deserialize CBOR with: {0}")]
168 Deserialize(serde_cbor::Error),
169 #[error("Input bytes were not valid utf-8: {0:?}")]
170 Utf8(Vec<u8>),
171}
172
173pub struct Frame {
174 pub bytes: Vec<u8>,
175}
176
177impl Frame {
178 pub fn deserialize<T: DeserializeOwned>(self) -> Result<T, MsgError> {
179 serde_cbor::from_slice(&self.bytes).map_err(MsgError::Deserialize)
180 }
181}
182
183pub trait ReadWriteByte {
185 const FRAME_BUF_SIZE: usize = 1024;
186 fn read_byte(&mut self) -> u8;
187
188 fn write_bytes(&mut self, buf: &[u8]);
189}
190
191pub trait FramedBytes: ReadWriteByte {
195 fn get_frame(&mut self) -> Result<Frame, MsgError> {
203 let mut buf_size = Self::FRAME_BUF_SIZE;
205 let mut read_bytes = Vec::<u8>::with_capacity(buf_size);
208 loop {
211 let mut frame_buf = vec![0u8; buf_size];
213 let mut decoder = cobs::CobsDecoder::new(&mut frame_buf);
214 decoder
215 .push(&read_bytes)
216 .expect("Previously read bytes should not produce a frame error.");
217
218 loop {
219 let b = self.read_byte();
220 read_bytes.push(b);
221 match decoder.feed(b) {
222 Ok(None) => continue,
223 Ok(Some(len)) => {
224 frame_buf.truncate(len);
225 return Ok(Frame { bytes: frame_buf });
226 }
227 Err(cobs::DecodeError::TargetBufTooSmall) => {
228 buf_size += Self::FRAME_BUF_SIZE;
230 break;
231 }
232 Err(e) => return Err(MsgError::Decode(e)),
233 }
234 }
235 }
236 }
237
238 fn write_frame<T: Serialize>(&mut self, msg: &T) {
241 let data = serde_cbor::to_vec(&msg).unwrap();
242 let mut encoded = cobs::encode_vec_with_sentinel(&data, 0);
243 encoded.push(0);
244 self.write_bytes(&encoded);
245 }
246}
247
248impl<T: ReadWriteByte> FramedBytes for T {}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use alloc::string::ToString;
254
255 struct MockChannel(Vec<u8>);
256
257 impl ReadWriteByte for MockChannel {
258 const FRAME_BUF_SIZE: usize = 10;
259 fn read_byte(&mut self) -> u8 {
260 self.0.remove(0)
261 }
262
263 fn write_bytes(&mut self, buf: &[u8]) {
264 self.0.extend_from_slice(buf);
265 }
266 }
267
268 #[test]
272 fn test_dynamic_frame_resizing() {
273 let msg = MsgFromHost::Basic("Test".to_string());
274 let data = serde_cbor::to_vec(&msg).expect("Test failed");
275 let mut encoded = cobs::encode_vec_with_sentinel(&data, 0);
276 encoded.push(0);
277 let mut channel = MockChannel(encoded);
278 let frame = channel.get_frame().expect("Test failed");
279 let Ok(MsgFromHost::Basic(str)) = frame.deserialize() else {
280 panic!("Test failed");
281 };
282 assert_eq!(str, "Test");
283 }
284}