1use crate::errors::PortalError::*;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::error::Error;
4use std::io::{Read, Write};
5
6use hkdf::Hkdf;
8use sha2::Sha256;
9use spake2::{Ed25519Group, Spake2};
10
11mod exchange;
13pub use exchange::*;
14
15mod encrypted;
17pub use encrypted::*;
18
19mod transferinfo;
21pub use transferinfo::*;
22
23#[cfg(test)]
24mod tests;
25
26pub struct Protocol;
30
31#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Copy, Clone)]
34pub enum Direction {
35 Sender,
36 Receiver,
37}
38
39#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
42pub struct ConnectMessage {
43 pub id: String,
44 pub direction: Direction,
45}
46
47#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
49pub enum PortalMessage {
50 Connect(ConnectMessage),
53
54 KeyExchange(PortalKeyExchange),
56
57 Confirm(PortalConfirmation),
59
60 EncryptedDataHeader(EncryptedMessage),
63}
64
65impl PortalMessage {
66 pub fn send<W: Write>(&mut self, writer: &mut W) -> Result<usize, Box<dyn Error>> {
68 let data = bincode::serialize(&self).or(Err(SerializeError))?;
69 writer.write_all(&data).or(Err(IOError))?;
70 Ok(data.len())
71 }
72
73 pub fn recv<R: Read>(reader: &mut R) -> Result<Self, Box<dyn Error>> {
75 Ok(bincode::deserialize_from::<&mut R, PortalMessage>(reader)?)
76 }
77
78 pub fn parse(data: &[u8]) -> Result<Self, Box<dyn Error>> {
80 Ok(bincode::deserialize(data)?)
81 }
82}
83
84impl Protocol {
85 pub fn connect<P: Read + Write>(
87 peer: &mut P,
88 id: &str,
89 direction: Direction,
90 msg: PortalKeyExchange,
91 ) -> Result<PortalKeyExchange, Box<dyn Error>> {
92 let c = ConnectMessage {
94 id: id.to_owned(),
95 direction,
96 };
97
98 PortalMessage::Connect(c).send(peer)?;
100
101 let _info = PortalMessage::recv(peer)?;
105
106 PortalMessage::KeyExchange(msg).send(peer)?;
108
109 match PortalMessage::recv(peer).or(Err(IOError))? {
111 PortalMessage::KeyExchange(data) => Ok(data),
112 _ => Err(Box::new(BadMsg)),
113 }
114 }
115
116 pub fn derive_key(
120 state: Spake2<Ed25519Group>,
121 peer_data: &PortalKeyExchange,
122 ) -> Result<Vec<u8>, Box<dyn Error>> {
123 Ok(state.finish(peer_data.into()).or(Err(BadMsg))?)
124 }
125
126 pub fn confirm_peer<P: Read + Write>(
129 peer: &mut P,
130 id: &str,
131 direction: Direction,
132 key: &[u8],
133 ) -> Result<(), Box<dyn Error>> {
134 let sender_info = format!("{}-{}", id, "senderinfo");
136 let receiver_info = format!("{}-{}", id, "receiverinfo");
137
138 let h = Hkdf::<Sha256>::new(None, key);
140 let mut sender_confirm = [0u8; 42];
141 let mut receiver_confirm = [0u8; 42];
142 h.expand(sender_info.as_bytes(), &mut sender_confirm)
143 .or(Err(BadMsg))?;
144 h.expand(receiver_info.as_bytes(), &mut receiver_confirm)
145 .or(Err(BadMsg))?;
146
147 let (to_send, expected) = match direction {
149 Direction::Sender => (sender_confirm, receiver_confirm),
150 Direction::Receiver => (receiver_confirm, sender_confirm),
151 };
152
153 let expected = PortalConfirmation(expected);
155
156 PortalMessage::Confirm(PortalConfirmation(to_send)).send(peer)?;
158
159 let peer_msg = match PortalMessage::recv(peer)? {
161 PortalMessage::Confirm(inner) => inner,
162 _ => return Err(BadMsg.into()),
163 };
164
165 if peer_msg != expected {
167 return Err(PeerKeyMismatch.into());
168 }
169
170 Ok(())
172 }
173
174 pub fn read_encrypted_from<R, D>(reader: &mut R, key: &[u8]) -> Result<D, Box<dyn Error>>
176 where
177 R: Read,
178 D: DeserializeOwned,
179 {
180 let mut storage = [0u8; 2048];
182
183 Protocol::read_encrypted_zero_copy(reader, key, &mut storage)?;
185
186 Ok(bincode::deserialize(&storage).or(Err(BadMsg))?)
188 }
189
190 pub fn read_encrypted_zero_copy<R>(
195 reader: &mut R,
196 key: &[u8],
197 storage: &mut [u8],
198 ) -> Result<usize, Box<dyn Error>>
199 where
200 R: Read,
201 {
202 let mut msg = match PortalMessage::recv(reader).or(Err(IOError))? {
204 PortalMessage::EncryptedDataHeader(inner) => inner,
205 _ => return Err(BadMsg.into()),
206 };
207
208 if storage.len() < msg.len {
210 return Err(BufferTooSmall.into());
211 }
212
213 let mut pos = 0;
215 while pos < msg.len {
216 match reader.read(&mut storage[pos..msg.len]) {
217 Ok(0) => break,
218 Ok(len) => {
219 pos += len;
220 }
221 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
222 Err(e) => return Err(e.into()),
223 };
224 }
225
226 msg.decrypt(key, &mut storage[..pos])
228 }
229
230 pub fn encrypt_and_write_object<W, S>(
232 writer: &mut W,
233 key: &[u8],
234 nseq: &mut NonceSequence,
235 msg: &S,
236 ) -> Result<usize, Box<dyn Error>>
237 where
238 W: Write,
239 S: Serialize,
240 {
241 let mut data = bincode::serialize(msg)?;
243
244 let encmsg = EncryptedMessage::encrypt(key, nseq, &mut data)?;
246
247 PortalMessage::EncryptedDataHeader(encmsg).send(writer)?;
249
250 writer.write_all(&data).or(Err(IOError))?;
252
253 Ok(data.len())
254 }
255
256 pub fn encrypt_and_write_header_only<W>(
258 writer: &mut W,
259 key: &[u8],
260 nseq: &mut NonceSequence,
261 data: &mut [u8],
262 ) -> Result<usize, Box<dyn Error>>
263 where
264 W: Write,
265 {
266 let header = EncryptedMessage::encrypt(key, nseq, data)?;
268
269 PortalMessage::EncryptedDataHeader(header).send(writer)
271 }
272}