sos_net/pairing/
websocket.rs

1//! Protocol for pairing devices.
2use super::{DeviceEnrollment, Error, Result, ServerPairUrl};
3use crate::NetworkAccount;
4use futures::{
5    stream::{SplitSink, SplitStream},
6    SinkExt, StreamExt,
7};
8use prost::bytes::Bytes;
9use snow::{Builder, HandshakeState, Keypair, TransportState};
10use sos_account::Account;
11use sos_backend::BackendTarget;
12use sos_core::{
13    device::{DeviceMetaData, DevicePublicKey, TrustedDevice},
14    events::DeviceEvent,
15    AccountId, Origin,
16};
17use sos_protocol::{
18    network_client::WebSocketRequest,
19    pairing_message,
20    tokio_tungstenite::{
21        connect_async,
22        tungstenite::{
23            protocol::{frame::coding::CloseCode, CloseFrame, Message},
24            Utf8Bytes,
25        },
26        MaybeTlsStream, WebSocketStream,
27    },
28    AccountSync, PairingConfirm, PairingMessage, PairingReady,
29    PairingRequest, ProtoMessage, RelayHeader, RelayPacket, RelayPayload,
30    SyncOptions,
31};
32use std::collections::HashSet;
33use tokio::{net::TcpStream, sync::mpsc};
34use url::Url;
35
36const PATTERN: &str = "Noise_XXpsk3_25519_ChaChaPoly_BLAKE2s";
37const RELAY_PATH: &str = "api/v1/relay";
38// 16-byte authentication tag appended to the ciphertext
39// as part of the noise protocol
40const TAGLEN: usize = 16;
41
42/// State of the encrypted tunnel.
43enum Tunnel {
44    /// Handshake state.
45    Handshake(HandshakeState),
46    /// Transport state.
47    Transport(TransportState),
48}
49
50type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
51type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
52
53/// State machine variants for the protocol.
54#[derive(Debug)]
55enum PairProtocolState {
56    /// Waiting to start the protocol.
57    Pending,
58    /// Initial noise handshake completed.
59    Handshake,
60    /// Pre shared key handshake completed.
61    PskHandshake,
62    /// Protocol completed.
63    Done,
64}
65
66#[derive(Debug)]
67enum IncomingAction {
68    Reply(PairProtocolState, RelayPacket),
69    HandleMessage(PairingMessage),
70}
71
72/// Listen for incoming messages on the websocket stream.
73async fn listen(
74    mut rx: WsStream,
75    tx: mpsc::Sender<RelayPacket>,
76    close_tx: mpsc::Sender<()>,
77) {
78    while let Some(message) = rx.next().await {
79        match message {
80            Ok(message) => {
81                if let Message::Binary(msg) = message {
82                    let buf: Bytes = msg.into();
83                    match RelayPacket::decode_proto(buf).await {
84                        Ok(result) => {
85                            if let Err(e) = tx.send(result).await {
86                                tracing::error!(error = ?e);
87                            }
88                        }
89                        Err(e) => {
90                            tracing::error!(error = ?e);
91                            let _ = close_tx.send(()).await;
92                            break;
93                        }
94                    }
95                }
96            }
97            Err(e) => {
98                tracing::error!(error = ?e);
99                let _ = close_tx.send(()).await;
100                break;
101            }
102        }
103    }
104    tracing::debug!("pairing::websocket::connection_closed");
105}
106
107/// Offer is the device that is authenticated and can
108/// authorize the new device.
109pub struct OfferPairing<'a> {
110    /// Noise session keypair.
111    keypair: Keypair,
112    /// Network account.
113    account: &'a mut NetworkAccount,
114    /// Pairing URL to share with the other device.
115    share_url: ServerPairUrl,
116    /// Noise protocol state.
117    tunnel: Option<Tunnel>,
118    /// Sink side of the websocket.
119    tx: WsSink,
120    /// Current state of the protocol.
121    state: PairProtocolState,
122    /// Determine if the URL sharing is inverted.
123    is_inverted: bool,
124}
125
126impl<'a> OfferPairing<'a> {
127    /// Create a new pairing offer.
128    pub async fn new(
129        account: &'a mut NetworkAccount,
130        url: Url,
131    ) -> Result<(OfferPairing<'a>, WsStream)> {
132        let builder = Builder::new(PATTERN.parse()?);
133        let keypair = builder.generate_keypair()?;
134        let share_url = ServerPairUrl::new(
135            *account.account_id(),
136            url.clone(),
137            keypair.public.clone(),
138        );
139        Self::new_connection(account, share_url, keypair, false).await
140    }
141
142    /// Create a new pairing offer from a share URL generated
143    /// by the accepting device.
144    pub async fn new_inverted(
145        account: &'a mut NetworkAccount,
146        share_url: ServerPairUrl,
147    ) -> Result<(OfferPairing<'a>, WsStream)> {
148        let builder = Builder::new(PATTERN.parse()?);
149        let keypair = builder.generate_keypair()?;
150        Self::new_connection(account, share_url, keypair, true).await
151    }
152
153    async fn new_connection(
154        account: &'a mut NetworkAccount,
155        share_url: ServerPairUrl,
156        keypair: Keypair,
157        is_inverted: bool,
158    ) -> Result<(OfferPairing<'a>, WsStream)> {
159        let psk = share_url.pre_shared_key().to_vec();
160        let tunnel = if is_inverted {
161            Builder::new(PATTERN.parse()?)
162                .local_private_key(&keypair.private)
163                .remote_public_key(share_url.public_key())
164                .psk(3, &psk)
165                .build_initiator()?
166        } else {
167            Builder::new(PATTERN.parse()?)
168                .local_private_key(&keypair.private)
169                .psk(3, &psk)
170                .build_responder()?
171        };
172
173        let mut request = WebSocketRequest::new(
174            *account.account_id(),
175            share_url.server(),
176            RELAY_PATH,
177        )?;
178        request
179            .uri
180            .query_pairs_mut()
181            .append_pair("public_key", &hex::encode(&keypair.public));
182
183        let (socket, _) = connect_async(request).await?;
184        let (tx, rx) = socket.split();
185        Ok((
186            Self {
187                keypair,
188                account,
189                share_url,
190                tunnel: Some(Tunnel::Handshake(tunnel)),
191                tx,
192                state: PairProtocolState::Pending,
193                is_inverted,
194            },
195            rx,
196        ))
197    }
198
199    /// URL that can be shared with the other device.
200    pub fn share_url(&self) -> &ServerPairUrl {
201        &self.share_url
202    }
203
204    /// Start the event loop.
205    pub async fn run(
206        &mut self,
207        stream: WsStream,
208        mut shutdown_rx: mpsc::Receiver<()>,
209    ) -> Result<()> {
210        if self.is_inverted {
211            // Start pairing
212            self.noise_send_e().await?;
213            self.state = PairProtocolState::Handshake;
214        }
215
216        let (offer_tx, mut offer_rx) = mpsc::channel::<RelayPacket>(32);
217        let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
218        tokio::task::spawn(listen(stream, offer_tx, close_tx));
219        loop {
220            tokio::select! {
221                biased;
222                // Explicit shutdown notification
223                Some(_) = shutdown_rx.recv() => {
224                    tracing::debug!("pairing::offer::shutdown_received");
225                    if let Err(error) = self.tx.send(Message::Close(Some(CloseFrame {
226                        code: CloseCode::Normal,
227                        reason: Utf8Bytes::from_static("closed"),
228                    }))).await {
229                        tracing::error!(
230                            error = %error,
231                            "pairing::offer::websocket_close_frame::error");
232                    }
233                    break;
234                }
235                // Close signal from the websocket stream
236                Some(_) = close_rx.recv() => {
237                    break;
238                }
239                // Incoming event
240                Some(event) = offer_rx.recv() => {
241                    self.incoming(event).await?;
242                    if self.is_finished() {
243                        break;
244                    }
245                }
246            }
247        }
248
249        Ok(())
250    }
251
252    /// Determine if the protocol has completed.
253    pub fn is_finished(&self) -> bool {
254        matches!(&self.state, PairProtocolState::Done)
255    }
256
257    /// Process incoming packet.
258    async fn incoming(&mut self, packet: RelayPacket) -> Result<()> {
259        if packet.header.as_ref().unwrap().to_public_key
260            != self.keypair.public
261        {
262            return Err(Error::NotForMe);
263        }
264
265        let action = if !self.is_inverted {
266            match (&self.state, packet.is_handshake()) {
267                (PairProtocolState::Pending, true) => {
268                    let reply = self.noise_read_e(&packet).await?;
269                    IncomingAction::Reply(PairProtocolState::Handshake, reply)
270                }
271                (PairProtocolState::Handshake, true) => {
272                    let reply = self.noise_read_s(&packet).await?;
273                    IncomingAction::Reply(
274                        PairProtocolState::PskHandshake,
275                        reply,
276                    )
277                }
278                (PairProtocolState::PskHandshake, false) => {
279                    if let Some(Tunnel::Transport(transport)) =
280                        self.tunnel.as_mut()
281                    {
282                        let payload = packet.payload.as_ref().unwrap();
283                        let body = payload.body.as_ref().unwrap();
284                        let (len, buf) =
285                            (body.length as usize, &body.contents);
286
287                        IncomingAction::HandleMessage(
288                            decrypt(transport, len, buf).await?,
289                        )
290                    } else {
291                        unreachable!();
292                    }
293                }
294                _ => {
295                    return Err(Error::BadState);
296                }
297            }
298        } else {
299            match (&self.state, packet.is_handshake()) {
300                (PairProtocolState::Handshake, true) => {
301                    let reply = self.noise_send_s(&packet).await?;
302                    IncomingAction::Reply(
303                        PairProtocolState::PskHandshake,
304                        reply,
305                    )
306                }
307                (PairProtocolState::PskHandshake, false) => {
308                    if let Some(Tunnel::Transport(transport)) =
309                        self.tunnel.as_mut()
310                    {
311                        let payload = packet.payload.as_ref().unwrap();
312                        let body = payload.body.as_ref().unwrap();
313                        let (len, buf) =
314                            (body.length as usize, &body.contents);
315
316                        IncomingAction::HandleMessage(
317                            decrypt(transport, len, buf).await?,
318                        )
319                    } else {
320                        unreachable!();
321                    }
322                }
323                _ => {
324                    return Err(Error::BadState);
325                }
326            }
327        };
328
329        match action {
330            IncomingAction::Reply(next_state, reply) => {
331                self.state = next_state;
332                let buffer = reply.encode_prefixed().await?;
333                self.tx.send(Message::Binary(buffer.into())).await?;
334            }
335            IncomingAction::HandleMessage(msg) => {
336                let msg = msg.inner.unwrap();
337                // In inverted mode we can get a ready event
338                // so we just reply with another ready event
339                // to trigger the usual exchange of information
340                if let pairing_message::Inner::Ready(_) = msg {
341                    let payload = if let Some(Tunnel::Transport(transport)) =
342                        self.tunnel_mut()
343                    {
344                        let private_message = PairingReady {};
345                        encrypt(
346                            transport,
347                            PairingMessage {
348                                inner: Some(pairing_message::Inner::Ready(
349                                    private_message,
350                                )),
351                            },
352                        )
353                        .await?
354                    } else {
355                        unreachable!();
356                    };
357                    let reply = RelayPacket {
358                        header: Some(RelayHeader {
359                            to_public_key: packet
360                                .header
361                                .as_ref()
362                                .unwrap()
363                                .from_public_key
364                                .clone(),
365                            from_public_key: self.keypair().public.clone(),
366                        }),
367                        payload: Some(payload),
368                    };
369
370                    let buffer = reply.encode_prefixed().await?;
371                    self.tx.send(Message::Binary(buffer.into())).await?;
372                } else if let pairing_message::Inner::Request(message) = msg {
373                    tracing::debug!("<- device");
374
375                    let device_bytes = message.device_meta_data;
376                    let device: DeviceMetaData =
377                        serde_json::from_slice(&device_bytes)?;
378
379                    let (device_signer, manager) =
380                        self.account.new_device_vault().await?;
381                    let device_vault = manager.into_vault_buffer().await?;
382                    let servers = self.account.servers().await;
383                    let account_name = self.account.account_name().await?;
384
385                    self.register_device(device_signer.public_key(), device)
386                        .await?;
387
388                    let private_message = PairingConfirm {
389                        account_id: message.account_id,
390                        account_name,
391                        device_signing_key: device_signer.to_bytes().to_vec(),
392                        device_vault,
393                        servers: servers
394                            .into_iter()
395                            .map(|s| s.into())
396                            .collect(),
397                    };
398
399                    let payload = if let Some(Tunnel::Transport(transport)) =
400                        self.tunnel.as_mut()
401                    {
402                        encrypt(
403                            transport,
404                            PairingMessage {
405                                inner: Some(pairing_message::Inner::Confirm(
406                                    private_message,
407                                )),
408                            },
409                        )
410                        .await?
411                    } else {
412                        unreachable!();
413                    };
414
415                    let reply = RelayPacket {
416                        header: Some(RelayHeader {
417                            to_public_key: packet
418                                .header
419                                .unwrap()
420                                .from_public_key
421                                .to_vec(),
422                            from_public_key: self.keypair.public.to_vec(),
423                        }),
424                        payload: Some(payload),
425                    };
426
427                    tracing::debug!("-> private-key");
428                    let buffer = reply.encode_prefixed().await?;
429                    self.tx.send(Message::Binary(buffer.into())).await?;
430                    self.state = PairProtocolState::Done;
431                } else {
432                    return Err(Error::BadState);
433                }
434            }
435        }
436
437        Ok(())
438    }
439
440    async fn register_device(
441        &mut self,
442        public_key: DevicePublicKey,
443        device: DeviceMetaData,
444    ) -> Result<()> {
445        let trusted_device =
446            TrustedDevice::new(public_key, Some(device), None);
447        // Trust the other device in our local event log
448        let events: Vec<DeviceEvent> =
449            vec![DeviceEvent::Trust(trusted_device)];
450        {
451            self.account
452                .patch_devices_unchecked(events.as_slice())
453                .await?;
454        }
455
456        // Send the patch to the remote server.
457        //
458        // We only send to the target server otherwise
459        // another server that is down can prevent pairing
460        // from completing.
461        //
462        // Other servers will need to eventually get the updated
463        // devices the next time they are synced.
464        let origins = vec![self.share_url.server().clone().into()];
465        let options = SyncOptions {
466            origins,
467            ..Default::default()
468        };
469        if let Some(sync_error) =
470            self.account.sync_with_options(&options).await.first_error()
471        {
472            return Err(Error::DevicePatchSync(Box::new(sync_error)));
473        }
474
475        // Creating a new device vault saves the folder password
476        // and therefore updates the identity folder so we need
477        // to sync to ensure the other half of the pairing will
478        // fetch data that includes the password for the device
479        // vault we will send
480        if let Some(sync_error) =
481            self.account.sync_with_options(&options).await.first_error()
482        {
483            return Err(Error::EnrollSync(Box::new(sync_error)));
484        }
485
486        Ok(())
487    }
488}
489
490impl<'a> NoiseTunnel for OfferPairing<'a> {
491    async fn send(&mut self, message: Message) -> Result<()> {
492        Ok(self.tx.send(message).await?)
493    }
494
495    fn pairing_public_key(&self) -> &[u8] {
496        self.share_url.public_key()
497    }
498
499    fn keypair(&self) -> &Keypair {
500        &self.keypair
501    }
502
503    fn tunnel_mut(&mut self) -> Option<&mut Tunnel> {
504        self.tunnel.as_mut()
505    }
506
507    fn into_transport_mode(&mut self) -> Result<()> {
508        let tunnel = self.tunnel.take().unwrap();
509        if let Tunnel::Handshake(state) = tunnel {
510            self.tunnel =
511                Some(Tunnel::Transport(state.into_transport_mode()?));
512        }
513        Ok(())
514    }
515}
516
517/// Accept is the device being paired.
518pub struct AcceptPairing<'a> {
519    /// Noise session keypair.
520    keypair: Keypair,
521    /// Current device information.
522    device: &'a DeviceMetaData,
523    /// Backend target.
524    target: BackendTarget,
525    /// URL shared by the offering device.
526    share_url: ServerPairUrl,
527    /// Noise protocol state.
528    tunnel: Option<Tunnel>,
529    /// Sink side of the websocket.
530    tx: WsSink,
531    /// Current state of the protocol.
532    state: PairProtocolState,
533    /// Device enrollment.
534    enrollment: Option<DeviceEnrollment>,
535    /// Whether the pairing is inverted.
536    is_inverted: bool,
537}
538
539impl<'a> AcceptPairing<'a> {
540    /// Create a new pairing connection.
541    pub async fn new(
542        share_url: ServerPairUrl,
543        device: &'a DeviceMetaData,
544        target: BackendTarget,
545    ) -> Result<(AcceptPairing<'a>, WsStream)> {
546        let builder = Builder::new(PATTERN.parse()?);
547        let keypair = builder.generate_keypair()?;
548        Self::new_connection(share_url, device, target, keypair, false).await
549    }
550
551    /// Create a new inverted pairing connection.
552    pub async fn new_inverted(
553        account_id: AccountId,
554        server: Url,
555        device: &'a DeviceMetaData,
556        target: BackendTarget,
557    ) -> Result<(ServerPairUrl, AcceptPairing<'a>, WsStream)> {
558        let builder = Builder::new(PATTERN.parse()?);
559        let keypair = builder.generate_keypair()?;
560        let share_url =
561            ServerPairUrl::new(account_id, server, keypair.public.clone());
562        let (pairing, stream) = Self::new_connection(
563            share_url.clone(),
564            device,
565            target,
566            keypair,
567            true,
568        )
569        .await?;
570        Ok((share_url, pairing, stream))
571    }
572
573    async fn new_connection(
574        share_url: ServerPairUrl,
575        device: &'a DeviceMetaData,
576        target: BackendTarget,
577        keypair: Keypair,
578        is_inverted: bool,
579    ) -> Result<(AcceptPairing<'a>, WsStream)> {
580        let psk = share_url.pre_shared_key().to_vec();
581        let tunnel = if is_inverted {
582            Builder::new(PATTERN.parse()?)
583                .local_private_key(&keypair.private)
584                .psk(3, &psk)
585                .build_responder()?
586        } else {
587            Builder::new(PATTERN.parse()?)
588                .local_private_key(&keypair.private)
589                .remote_public_key(share_url.public_key())
590                .psk(3, &psk)
591                .build_initiator()?
592        };
593
594        let mut request = WebSocketRequest::new(
595            *share_url.account_id(),
596            share_url.server(),
597            RELAY_PATH,
598        )?;
599        request
600            .uri
601            .query_pairs_mut()
602            .append_pair("public_key", &hex::encode(&keypair.public));
603        let (socket, _) = connect_async(request).await?;
604        let (tx, rx) = socket.split();
605        Ok((
606            Self {
607                keypair,
608                device,
609                share_url,
610                target,
611                tunnel: Some(Tunnel::Handshake(tunnel)),
612                tx,
613                state: PairProtocolState::Pending,
614                enrollment: None,
615                is_inverted,
616            },
617            rx,
618        ))
619    }
620
621    /// Start the event loop and the pairing protocol.
622    pub async fn run(
623        &mut self,
624        stream: WsStream,
625        mut shutdown_rx: mpsc::Receiver<()>,
626    ) -> Result<()> {
627        if !self.is_inverted {
628            // Start pairing
629            self.noise_send_e().await?;
630            self.state = PairProtocolState::Handshake;
631        }
632
633        // Run the event loop
634        let (offer_tx, mut offer_rx) = mpsc::channel::<RelayPacket>(32);
635        let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
636        tokio::task::spawn(listen(stream, offer_tx, close_tx));
637
638        loop {
639            tokio::select! {
640                biased;
641                event = shutdown_rx.recv() => {
642                    if event.is_some() {
643                        let _ = self.tx.send(Message::Close(Some(CloseFrame {
644                            code: CloseCode::Normal,
645                            reason: Utf8Bytes::from_static("closed"),
646                        }))).await;
647                        break;
648                    }
649                }
650                event = offer_rx.recv() => {
651                    if let Some(event) = event {
652                        self.incoming(event).await?;
653                        if self.is_finished() {
654                            break;
655                        }
656                    }
657                }
658                event = close_rx.recv() => {
659                    if event.is_some() {
660                        break;
661                    }
662                }
663            }
664        }
665
666        Ok(())
667    }
668
669    /// Determine if the protocol has completed.
670    pub fn is_finished(&self) -> bool {
671        matches!(&self.state, PairProtocolState::Done)
672    }
673
674    /// Take the final device enrollment.
675    ///
676    /// Errors if the protocol has not reached completion.
677    pub fn take_enrollment(self) -> Result<DeviceEnrollment> {
678        self.enrollment.ok_or(Error::NoEnrollment)
679    }
680
681    /// Process incoming packet.
682    async fn incoming(&mut self, packet: RelayPacket) -> Result<()> {
683        if packet.header.as_ref().unwrap().to_public_key
684            != self.keypair.public
685        {
686            return Err(Error::NotForMe);
687        }
688
689        let action = if !self.is_inverted {
690            match (&self.state, packet.is_handshake()) {
691                (PairProtocolState::Handshake, true) => {
692                    let reply = self.noise_send_s(&packet).await?;
693                    IncomingAction::Reply(
694                        PairProtocolState::PskHandshake,
695                        reply,
696                    )
697                }
698                (PairProtocolState::PskHandshake, false) => {
699                    if let Some(Tunnel::Transport(transport)) =
700                        self.tunnel.as_mut()
701                    {
702                        let payload = packet.payload.as_ref().unwrap();
703                        let body = payload.body.as_ref().unwrap();
704                        let (len, buf) =
705                            (body.length as usize, &body.contents);
706
707                        IncomingAction::HandleMessage(
708                            decrypt(transport, len, buf).await?,
709                        )
710                    } else {
711                        unreachable!();
712                    }
713                }
714                _ => {
715                    return Err(Error::BadState);
716                }
717            }
718        } else {
719            match (&self.state, packet.is_handshake()) {
720                (PairProtocolState::Pending, true) => {
721                    let reply = self.noise_read_e(&packet).await?;
722                    IncomingAction::Reply(PairProtocolState::Handshake, reply)
723                }
724                (PairProtocolState::Handshake, true) => {
725                    let reply = self.noise_read_s(&packet).await?;
726                    IncomingAction::Reply(
727                        PairProtocolState::PskHandshake,
728                        reply,
729                    )
730                }
731                (
732                    PairProtocolState::PskHandshake,
733                    false,
734                    // RelayPayload::Transport(len, buf),
735                ) => {
736                    if let Some(Tunnel::Transport(transport)) =
737                        self.tunnel.as_mut()
738                    {
739                        let payload = packet.payload.as_ref().unwrap();
740                        let body = payload.body.as_ref().unwrap();
741                        let (len, buf) =
742                            (body.length as usize, &body.contents);
743
744                        IncomingAction::HandleMessage(
745                            decrypt(transport, len, buf).await?,
746                        )
747                    } else {
748                        unreachable!();
749                    }
750                }
751                _ => {
752                    return Err(Error::BadState);
753                }
754            }
755        };
756
757        match action {
758            IncomingAction::Reply(next_state, reply) => {
759                self.state = next_state;
760
761                let buffer = reply.encode_prefixed().await?;
762                self.tx.send(Message::Binary(buffer.into())).await?;
763            }
764            IncomingAction::HandleMessage(msg) => {
765                let msg = msg.inner.unwrap();
766
767                // When the noise handshake is complete start
768                // pairing by sending the trusted device information
769                if let pairing_message::Inner::Ready(_) = msg {
770                    tracing::debug!("<- ready");
771                    if let Some(Tunnel::Transport(transport)) =
772                        self.tunnel.as_mut()
773                    {
774                        let device_bytes = serde_json::to_vec(&self.device)?;
775
776                        let private_message = PairingRequest {
777                            device_meta_data: device_bytes,
778                            account_id: self
779                                .share_url
780                                .account_id()
781                                .to_string(),
782                        };
783
784                        let payload = encrypt(
785                            transport,
786                            PairingMessage {
787                                inner: Some(pairing_message::Inner::Request(
788                                    private_message,
789                                )),
790                            },
791                        )
792                        .await?;
793                        let reply = RelayPacket {
794                            header: Some(RelayHeader {
795                                to_public_key: packet
796                                    .header
797                                    .as_ref()
798                                    .unwrap()
799                                    .from_public_key
800                                    .to_vec(),
801                                from_public_key: self.keypair.public.to_vec(),
802                            }),
803                            payload: Some(payload),
804                        };
805                        tracing::debug!("-> device");
806                        let buffer = reply.encode_prefixed().await?;
807                        self.tx.send(Message::Binary(buffer.into())).await?;
808                    } else {
809                        unreachable!();
810                    }
811                } else if let pairing_message::Inner::Confirm(confirmation) =
812                    msg
813                {
814                    self.create_enrollment(confirmation).await?;
815                    self.state = PairProtocolState::Done;
816                } else {
817                    return Err(Error::BadState);
818                }
819            }
820        }
821
822        Ok(())
823    }
824
825    /// Create the device enrollment once pairing is complete.
826    ///
827    /// Callers can now access the device enrollment using
828    /// [AcceptPairing::take_enrollment] and then call
829    /// [DeviceEnrollment::fetch_account] to retrieve the
830    /// account data.
831    async fn create_enrollment(
832        &mut self,
833        confirmation: PairingConfirm,
834    ) -> Result<()> {
835        // let signing_key: [u8; 32] =
836        //     confirmation.account_signing_key.as_slice().try_into()?;
837
838        let device_signing_key: [u8; 32] =
839            confirmation.device_signing_key.as_slice().try_into()?;
840        let device_vault = confirmation.device_vault;
841        let mut servers = HashSet::new();
842        for server in confirmation.servers {
843            servers.insert(server.try_into()?);
844        }
845        let account_id: AccountId = confirmation.account_id.parse()?;
846
847        // let signer: SingleParty = signing_key.try_into()?;
848
849        let server = self.share_url.server().clone();
850        let origin: Origin = server.into();
851        // let data_dir = self.data_dir.clone();
852
853        let enrollment = DeviceEnrollment::new(
854            self.target.clone(),
855            account_id,
856            confirmation.account_name,
857            origin,
858            device_signing_key.try_into()?,
859            device_vault,
860            servers,
861        )
862        .await?;
863        self.enrollment = Some(enrollment);
864
865        Ok(())
866    }
867}
868
869/// Serialize and encrypt a message.
870async fn encrypt<T: prost::Message>(
871    transport: &mut TransportState,
872    message: T,
873) -> crate::pairing::Result<RelayPayload> {
874    let mut plaintext = Vec::new();
875    message.encode(&mut plaintext)?;
876    let mut contents = vec![0u8; plaintext.len() + TAGLEN];
877    let length = transport.write_message(&plaintext, &mut contents)?;
878    Ok(RelayPayload::new_transport(length, contents))
879}
880
881/// Decrypt a message and deserialize the content.
882async fn decrypt<T: prost::Message + Default>(
883    transport: &mut TransportState,
884    length: usize,
885    message: &[u8],
886) -> crate::pairing::Result<T> {
887    let mut contents = vec![0; length];
888    transport.read_message(&message[..length], &mut contents)?;
889    let message = &contents[..contents.len() - TAGLEN];
890    let message: prost::bytes::Bytes = message.to_vec().into();
891    Ok(T::decode(message)?)
892}
893
894impl<'a> NoiseTunnel for AcceptPairing<'a> {
895    async fn send(&mut self, message: Message) -> Result<()> {
896        Ok(self.tx.send(message).await?)
897    }
898
899    fn pairing_public_key(&self) -> &[u8] {
900        self.share_url.public_key()
901    }
902
903    fn keypair(&self) -> &Keypair {
904        &self.keypair
905    }
906
907    fn tunnel_mut(&mut self) -> Option<&mut Tunnel> {
908        self.tunnel.as_mut()
909    }
910
911    fn into_transport_mode(&mut self) -> Result<()> {
912        let tunnel = self.tunnel.take().unwrap();
913        if let Tunnel::Handshake(state) = tunnel {
914            self.tunnel =
915                Some(Tunnel::Transport(state.into_transport_mode()?));
916        }
917        Ok(())
918    }
919}
920
921trait NoiseTunnel {
922    /// Send a message.
923    async fn send(&mut self, message: Message) -> Result<()>;
924
925    /// Public key of the party that created the pairing URL.
926    fn pairing_public_key(&self) -> &[u8];
927
928    /// Noise keypair.
929    fn keypair(&self) -> &Keypair;
930
931    /// Noise tunnel state.
932    fn tunnel_mut(&mut self) -> Option<&mut Tunnel>;
933
934    /// Update the noise tunnel state.
935    fn into_transport_mode(&mut self) -> Result<()>;
936
937    /// Send the first packet of the initial noise handshake.
938    async fn noise_send_e(&mut self) -> Result<()> {
939        let buffer = if let Some(Tunnel::Handshake(state)) = self.tunnel_mut()
940        {
941            let mut buf = [0u8; 1024];
942            // -> e
943            tracing::debug!("-> e");
944            let len = state.write_message(&[], &mut buf)?;
945            let message = RelayPacket {
946                header: Some(RelayHeader {
947                    to_public_key: self.pairing_public_key().to_vec(),
948                    from_public_key: self.keypair().public.to_vec(),
949                }),
950                payload: Some(RelayPayload::new_handshake(len, buf.to_vec())),
951            };
952            message.encode_prefixed().await?
953        } else {
954            unreachable!();
955        };
956        self.send(Message::Binary(buffer.into())).await?;
957        Ok(())
958    }
959
960    /// Respond to the first packet of the noise protocol handshake.
961    async fn noise_read_e(
962        &mut self,
963        packet: &RelayPacket,
964    ) -> Result<RelayPacket> {
965        if let (Some(Tunnel::Handshake(state)), true) =
966            (self.tunnel_mut(), packet.is_handshake())
967        {
968            let payload = packet.payload.as_ref().unwrap();
969            let body = payload.body.as_ref().unwrap();
970            let (len, init_msg) = (body.length as usize, &body.contents);
971
972            let mut buf = [0; 1024];
973            let mut reply = [0; 1024];
974            // <- e
975            tracing::debug!("<- e");
976            state.read_message(&init_msg[..len], &mut buf)?;
977            // -> e, ee, s, es
978            tracing::debug!("-> e, ee, s, es");
979            let len = state.write_message(&[], &mut reply)?;
980            Ok(RelayPacket {
981                header: Some(RelayHeader {
982                    to_public_key: packet
983                        .header
984                        .as_ref()
985                        .unwrap()
986                        .from_public_key
987                        .clone(),
988                    from_public_key: self.keypair().public.clone(),
989                }),
990                payload: Some(RelayPayload::new_handshake(
991                    len,
992                    reply.to_vec(),
993                )),
994            })
995        } else {
996            Err(Error::BadState)
997        }
998    }
999
1000    /// Handle the second packet of the noise protocol handshake
1001    /// and transition into transport mode.
1002    async fn noise_send_s(
1003        &mut self,
1004        packet: &RelayPacket,
1005    ) -> Result<RelayPacket> {
1006        let packet = if let (Some(Tunnel::Handshake(state)), true) =
1007            (self.tunnel_mut(), packet.is_handshake())
1008        {
1009            let payload = packet.payload.as_ref().unwrap();
1010            let body = payload.body.as_ref().unwrap();
1011            let (len, init_msg) = (body.length as usize, &body.contents);
1012
1013            let mut buf = [0; 1024];
1014            let mut reply = [0; 1024];
1015            // <- e, ee, s, es
1016            tracing::debug!("<- e, ee, s, es");
1017            state.read_message(&init_msg[..len], &mut buf)?;
1018            // -> s, se
1019            tracing::debug!("-> s, se");
1020            let len = state.write_message(&[], &mut reply)?;
1021            Some(RelayPacket {
1022                header: Some(RelayHeader {
1023                    to_public_key: packet
1024                        .header
1025                        .as_ref()
1026                        .unwrap()
1027                        .from_public_key
1028                        .clone(),
1029                    from_public_key: self.keypair().public.clone(),
1030                }),
1031                payload: Some(RelayPayload::new_handshake(
1032                    len,
1033                    reply.to_vec(),
1034                )),
1035            })
1036        } else {
1037            None
1038        };
1039
1040        if let Some(packet) = packet {
1041            self.into_transport_mode()?;
1042            Ok(packet)
1043        } else {
1044            return Err(Error::BadState);
1045        }
1046    }
1047
1048    /// Handle the final packet of the noise protocol handshake
1049    /// and transition into transport mode.
1050    async fn noise_read_s(
1051        &mut self,
1052        packet: &RelayPacket,
1053    ) -> Result<RelayPacket> {
1054        if let (Some(Tunnel::Handshake(state)), true) =
1055            (self.tunnel_mut(), packet.is_handshake())
1056        {
1057            let payload = packet.payload.as_ref().unwrap();
1058            let body = payload.body.as_ref().unwrap();
1059            let (len, init_msg) = (body.length as usize, &body.contents);
1060
1061            let mut buf = [0; 1024];
1062            // <- s, se
1063            tracing::debug!("<- s, se");
1064            state.read_message(&init_msg[..len], &mut buf)?;
1065
1066            self.into_transport_mode()?;
1067
1068            let payload = if let Some(Tunnel::Transport(transport)) =
1069                self.tunnel_mut()
1070            {
1071                let private_message = PairingReady {};
1072                encrypt(
1073                    transport,
1074                    PairingMessage {
1075                        inner: Some(pairing_message::Inner::Ready(
1076                            private_message,
1077                        )),
1078                    },
1079                )
1080                .await?
1081            } else {
1082                unreachable!();
1083            };
1084            Ok(RelayPacket {
1085                header: Some(RelayHeader {
1086                    to_public_key: packet
1087                        .header
1088                        .as_ref()
1089                        .unwrap()
1090                        .from_public_key
1091                        .clone(),
1092                    from_public_key: self.keypair().public.clone(),
1093                }),
1094                payload: Some(payload),
1095            })
1096        } else {
1097            Err(Error::BadState)
1098        }
1099    }
1100}