1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
27
28mod crypt_writer;
29use std::{
30    error,
31    fmt::{self, Write},
32    io,
33    io::Error as IoError,
34    num::ParseIntError,
35    pin::Pin,
36    str::FromStr,
37    task::{Context, Poll},
38};
39
40use crypt_writer::CryptWriter;
41use futures::prelude::*;
42use pin_project::pin_project;
43use rand::RngCore;
44use salsa20::{
45    cipher::{KeyIvInit, StreamCipher},
46    Salsa20, XSalsa20,
47};
48use sha3::{digest::ExtendableOutput, Shake128};
49
50const KEY_SIZE: usize = 32;
51const NONCE_SIZE: usize = 24;
52const WRITE_BUFFER_SIZE: usize = 1024;
53const FINGERPRINT_SIZE: usize = 16;
54
55#[derive(Copy, Clone, PartialEq, Eq)]
57pub struct PreSharedKey([u8; KEY_SIZE]);
58
59impl PreSharedKey {
60    pub fn new(data: [u8; KEY_SIZE]) -> Self {
62        Self(data)
63    }
64
65    pub fn fingerprint(&self) -> Fingerprint {
71        use std::io::{Read, Write};
72        let mut enc = [0u8; 64];
73        let nonce: [u8; 8] = *b"finprint";
74        let mut out = [0u8; 16];
75        let mut cipher = Salsa20::new(&self.0.into(), &nonce.into());
76        cipher.apply_keystream(&mut enc);
77        let mut hasher = Shake128::default();
78        hasher.write_all(&enc).expect("shake128 failed");
79        hasher
80            .finalize_xof()
81            .read_exact(&mut out)
82            .expect("shake128 failed");
83        Fingerprint(out)
84    }
85}
86
87fn parse_hex_key(s: &str) -> Result<[u8; KEY_SIZE], KeyParseError> {
88    if s.len() == KEY_SIZE * 2 {
89        let mut r = [0u8; KEY_SIZE];
90        for i in 0..KEY_SIZE {
91            r[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
92                .map_err(KeyParseError::InvalidKeyChar)?;
93        }
94        Ok(r)
95    } else {
96        Err(KeyParseError::InvalidKeyLength)
97    }
98}
99
100fn to_hex(bytes: &[u8]) -> String {
101    let mut hex = String::with_capacity(bytes.len() * 2);
102
103    for byte in bytes {
104        write!(hex, "{byte:02x}").expect("Can't fail on writing to string");
105    }
106
107    hex
108}
109
110impl FromStr for PreSharedKey {
114    type Err = KeyParseError;
115
116    fn from_str(s: &str) -> Result<Self, Self::Err> {
117        if let [keytype, encoding, key] = *s.lines().take(3).collect::<Vec<_>>().as_slice() {
118            if keytype != "/key/swarm/psk/1.0.0/" {
119                return Err(KeyParseError::InvalidKeyType);
120            }
121            if encoding != "/base16/" {
122                return Err(KeyParseError::InvalidKeyEncoding);
123            }
124            parse_hex_key(key.trim_end()).map(PreSharedKey)
125        } else {
126            Err(KeyParseError::InvalidKeyFile)
127        }
128    }
129}
130
131impl fmt::Debug for PreSharedKey {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_tuple("PreSharedKey")
134            .field(&to_hex(&self.0))
135            .finish()
136    }
137}
138
139impl fmt::Display for PreSharedKey {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        writeln!(f, "/key/swarm/psk/1.0.0/")?;
143        writeln!(f, "/base16/")?;
144        writeln!(f, "{}", to_hex(&self.0))
145    }
146}
147
148#[derive(Copy, Clone, PartialEq, Eq)]
150pub struct Fingerprint([u8; FINGERPRINT_SIZE]);
151
152impl fmt::Display for Fingerprint {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        write!(f, "{}", to_hex(&self.0))
156    }
157}
158
159#[derive(Clone, Debug, PartialEq, Eq)]
161#[allow(clippy::enum_variant_names)] pub enum KeyParseError {
163    InvalidKeyFile,
165    InvalidKeyType,
167    InvalidKeyEncoding,
169    InvalidKeyLength,
171    InvalidKeyChar(ParseIntError),
173}
174
175impl fmt::Display for KeyParseError {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        write!(f, "{self:?}")
178    }
179}
180
181impl error::Error for KeyParseError {
182    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
183        match *self {
184            KeyParseError::InvalidKeyChar(ref err) => Some(err),
185            _ => None,
186        }
187    }
188}
189
190#[derive(Debug, Copy, Clone)]
192pub struct PnetConfig {
193    key: PreSharedKey,
195}
196impl PnetConfig {
197    pub fn new(key: PreSharedKey) -> Self {
198        Self { key }
199    }
200
201    pub async fn handshake<TSocket>(
206        self,
207        mut socket: TSocket,
208    ) -> Result<PnetOutput<TSocket>, PnetError>
209    where
210        TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
211    {
212        tracing::trace!("exchanging nonces");
213        let mut local_nonce = [0u8; NONCE_SIZE];
214        let mut remote_nonce = [0u8; NONCE_SIZE];
215        rand::thread_rng().fill_bytes(&mut local_nonce);
216        socket
217            .write_all(&local_nonce)
218            .await
219            .map_err(PnetError::HandshakeError)?;
220        socket.flush().await?;
221        socket
222            .read_exact(&mut remote_nonce)
223            .await
224            .map_err(PnetError::HandshakeError)?;
225        tracing::trace!("setting up ciphers");
226        let write_cipher = XSalsa20::new(&self.key.0.into(), &local_nonce.into());
227        let read_cipher = XSalsa20::new(&self.key.0.into(), &remote_nonce.into());
228        Ok(PnetOutput::new(socket, write_cipher, read_cipher))
229    }
230}
231
232#[pin_project]
235pub struct PnetOutput<S> {
236    #[pin]
237    inner: CryptWriter<S>,
238    read_cipher: XSalsa20,
239}
240
241impl<S: AsyncRead + AsyncWrite> PnetOutput<S> {
242    fn new(inner: S, write_cipher: XSalsa20, read_cipher: XSalsa20) -> Self {
243        Self {
244            inner: CryptWriter::with_capacity(WRITE_BUFFER_SIZE, inner, write_cipher),
245            read_cipher,
246        }
247    }
248}
249
250impl<S: AsyncRead + AsyncWrite> AsyncRead for PnetOutput<S> {
251    fn poll_read(
252        self: Pin<&mut Self>,
253        cx: &mut Context<'_>,
254        buf: &mut [u8],
255    ) -> Poll<Result<usize, io::Error>> {
256        let this = self.project();
257        let result = this.inner.get_pin_mut().poll_read(cx, buf);
258        if let Poll::Ready(Ok(size)) = &result {
259            tracing::trace!(bytes=%size, "read bytes");
260            this.read_cipher.apply_keystream(&mut buf[..*size]);
261            tracing::trace!(bytes=%size, "decrypted bytes");
262        }
263        result
264    }
265}
266
267impl<S: AsyncRead + AsyncWrite> AsyncWrite for PnetOutput<S> {
268    fn poll_write(
269        self: Pin<&mut Self>,
270        cx: &mut Context<'_>,
271        buf: &[u8],
272    ) -> Poll<Result<usize, io::Error>> {
273        self.project().inner.poll_write(cx, buf)
274    }
275
276    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
277        self.project().inner.poll_flush(cx)
278    }
279
280    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
281        self.project().inner.poll_close(cx)
282    }
283}
284
285#[derive(Debug)]
287pub enum PnetError {
288    HandshakeError(IoError),
290    IoError(IoError),
292}
293
294impl From<IoError> for PnetError {
295    #[inline]
296    fn from(err: IoError) -> PnetError {
297        PnetError::IoError(err)
298    }
299}
300
301impl error::Error for PnetError {
302    fn cause(&self) -> Option<&dyn error::Error> {
303        match *self {
304            PnetError::HandshakeError(ref err) => Some(err),
305            PnetError::IoError(ref err) => Some(err),
306        }
307    }
308}
309
310impl fmt::Display for PnetError {
311    #[inline]
312    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
313        match self {
314            PnetError::HandshakeError(e) => write!(f, "Handshake error: {e}"),
315            PnetError::IoError(e) => write!(f, "I/O error: {e}"),
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use quickcheck::*;
323
324    use super::*;
325
326    impl Arbitrary for PreSharedKey {
327        fn arbitrary(g: &mut Gen) -> PreSharedKey {
328            let key = core::array::from_fn(|_| u8::arbitrary(g));
329            PreSharedKey(key)
330        }
331    }
332
333    #[test]
334    fn psk_tostring_parse() {
335        fn prop(key: PreSharedKey) -> bool {
336            let text = key.to_string();
337            text.parse::<PreSharedKey>()
338                .map(|res| res == key)
339                .unwrap_or(false)
340        }
341        QuickCheck::new()
342            .tests(10)
343            .quickcheck(prop as fn(PreSharedKey) -> _);
344    }
345
346    #[test]
347    fn psk_parse_failure() {
348        use KeyParseError::*;
349        assert_eq!("".parse::<PreSharedKey>().unwrap_err(), InvalidKeyFile);
350        assert_eq!(
351            "a\nb\nc".parse::<PreSharedKey>().unwrap_err(),
352            InvalidKeyType
353        );
354        assert_eq!(
355            "/key/swarm/psk/1.0.0/\nx\ny"
356                .parse::<PreSharedKey>()
357                .unwrap_err(),
358            InvalidKeyEncoding
359        );
360        assert_eq!(
361            "/key/swarm/psk/1.0.0/\n/base16/\ny"
362                .parse::<PreSharedKey>()
363                .unwrap_err(),
364            InvalidKeyLength
365        );
366    }
367
368    #[test]
369    fn fingerprint() {
370        let key = "/key/swarm/psk/1.0.0/\n/base16/\n6189c5cf0b87fb800c1a9feeda73c6ab5e998db48fb9e6a978575c770ceef683".parse::<PreSharedKey>().unwrap();
372        let expected = "45fc986bbc9388a11d939df26f730f0c";
373        let actual = key.fingerprint().to_string();
374        assert_eq!(expected, actual);
375    }
376}