use std::fmt::{Debug, Display};
use std::sync::Arc;
use std::time::SystemTime;
use async_trait::async_trait;
use p256::elliptic_curve::sec1::ToEncodedPoint;
use p256::{NonZeroScalar, SecretKey};
use rand::rngs::OsRng;
use rand::RngCore;
use serde::Serialize;
use serde_bytes::ByteArray;
use serde_indexed::SerializeIndexed;
use serde_repr::Serialize_repr;
use tokio::sync::{broadcast, mpsc, watch};
use tokio::task;
use tracing::instrument;
use super::connection_stages::{
connection_stage, handshake_stage, proximity_check_stage, ConnectionInput, HandshakeInput,
MpscUxUpdateSender, ProximityCheckInput, TunnelConnectionInput, UxUpdateSender,
};
use super::known_devices::CableKnownDeviceInfoStore;
use super::protocol;
use super::tunnel::KNOWN_TUNNEL_DOMAINS;
use super::{channel::CableChannel, channel::ConnectionState, Cable};
use crate::proto::ctap2::cbor;
use crate::transport::cable::digit_encode;
use crate::transport::Device;
use crate::webauthn::error::Error;
use crate::webauthn::TransportError;
#[derive(Debug, Clone, Copy, Serialize, PartialEq)]
pub enum QrCodeOperationHint {
#[serde(rename = "ga")]
GetAssertionRequest,
#[serde(rename = "mc")]
MakeCredential,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr)]
#[repr(u8)]
pub(crate) enum CableTransportChannel {
WebSocket = 0,
Ble = 1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CableTransports {
CloudAssistedOnly,
CloudAssistedOrLocal,
}
impl CableTransports {
pub(crate) fn to_qr_field(self) -> Option<Vec<CableTransportChannel>> {
match self {
Self::CloudAssistedOnly => None,
Self::CloudAssistedOrLocal => Some(vec![
CableTransportChannel::WebSocket,
CableTransportChannel::Ble,
]),
}
}
}
#[derive(Debug, Clone, SerializeIndexed)]
pub struct CableQrCode {
#[serde(index = 0x00)]
pub public_key: ByteArray<33>,
#[serde(index = 0x01)]
pub qr_secret: ByteArray<16>,
#[serde(index = 0x02)]
pub known_tunnel_domains_count: u8,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(index = 0x03)]
pub current_time: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(index = 0x04)]
pub state_assisted: Option<bool>,
#[serde(index = 0x05)]
pub operation_hint: QrCodeOperationHint,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(index = 0x06)]
pub(crate) transports: Option<Vec<CableTransportChannel>>,
}
impl std::fmt::Display for CableQrCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let serialized = cbor::to_vec(&self).map_err(|_| std::fmt::Error)?;
write!(f, "FIDO:/{}", digit_encode(&serialized))
}
}
#[derive(Clone)]
pub struct CableQrCodeDevice {
pub qr_code: CableQrCode,
pub private_key: NonZeroScalar,
pub(crate) store: Option<Arc<dyn CableKnownDeviceInfoStore>>,
}
impl Debug for CableQrCodeDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CableQrCodeDevice")
.field("qr_code", &self.qr_code)
.field("store", &self.store)
.finish()
}
}
impl CableQrCodeDevice {
pub fn new_persistent(
hint: QrCodeOperationHint,
store: Arc<dyn CableKnownDeviceInfoStore>,
transports: CableTransports,
) -> Result<Self, Error> {
Self::new(hint, true, Some(store), transports)
}
fn new(
hint: QrCodeOperationHint,
state_assisted: bool,
store: Option<Arc<dyn CableKnownDeviceInfoStore>>,
transports: CableTransports,
) -> Result<Self, Error> {
let private_key_scalar = NonZeroScalar::random(&mut OsRng);
let private_key = SecretKey::from(private_key_scalar);
let public_key: [u8; 33] = private_key
.public_key()
.as_affine()
.to_encoded_point(true)
.as_bytes()
.try_into()
.map_err(|_| Error::Transport(TransportError::InvalidKey))?;
let mut qr_secret = [0u8; 16];
OsRng.fill_bytes(&mut qr_secret);
let current_unix_time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|t| t.as_secs());
let transports = transports.to_qr_field();
Ok(Self {
qr_code: CableQrCode {
public_key: ByteArray::from(public_key),
qr_secret: ByteArray::from(qr_secret),
known_tunnel_domains_count: KNOWN_TUNNEL_DOMAINS.len() as u8,
current_time: current_unix_time,
operation_hint: hint,
state_assisted: state_assisted.then_some(true),
transports,
},
private_key: private_key_scalar,
store,
})
}
}
impl CableQrCodeDevice {
pub fn new_transient(
hint: QrCodeOperationHint,
transports: CableTransports,
) -> Result<Self, Error> {
Self::new(hint, false, None, transports)
}
#[instrument(skip_all, err)]
async fn connection(
qr_device: &CableQrCodeDevice,
ux_sender: &MpscUxUpdateSender,
) -> Result<super::connection_stages::HandshakeOutput, Error> {
let proximity_input = ProximityCheckInput::new_for_qr_code(qr_device)?;
let proximity_output = proximity_check_stage(proximity_input, ux_sender).await?;
let connection_input = ConnectionInput::new_for_qr_code(qr_device, &proximity_output)?;
let connection_output = connection_stage(connection_input, ux_sender).await?;
let handshake_input =
HandshakeInput::new_for_qr_code(qr_device, connection_output, proximity_output)?;
let handshake_output = handshake_stage(handshake_input, ux_sender).await?;
Ok(handshake_output)
}
}
impl Display for CableQrCodeDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CableQrCodeDevice")
}
}
#[async_trait]
impl<'d> Device<'d, Cable, CableChannel> for CableQrCodeDevice {
async fn channel(&'d mut self) -> Result<CableChannel, Error> {
let (ux_update_sender, _) = broadcast::channel(16);
let (cbor_tx_send, cbor_tx_recv) = mpsc::channel(16);
let (cbor_rx_send, cbor_rx_recv) = mpsc::channel(16);
let (connection_state_sender, connection_state_receiver) =
watch::channel(ConnectionState::Connecting);
let ux_update_sender_clone = ux_update_sender.clone();
let qr_device = self.clone();
let handle_connection = task::spawn(async move {
let ux_sender =
MpscUxUpdateSender::new(ux_update_sender_clone.clone(), connection_state_sender);
let handshake_output = match Self::connection(&qr_device, &ux_sender).await {
Ok(handshake_output) => handshake_output,
Err(e) => {
let transport_err = match e {
Error::Transport(t) => t,
_ => TransportError::ConnectionFailed,
};
ux_sender.send_error(transport_err).await;
return;
}
};
let tunnel_input = TunnelConnectionInput::from_handshake_output(
handshake_output,
qr_device.store,
cbor_tx_recv,
cbor_rx_send,
);
match protocol::connection(tunnel_input).await {
Ok(()) => {
ux_sender
.set_connection_state(ConnectionState::Terminated)
.await;
}
Err(e) => {
ux_sender.send_error(e).await;
}
}
});
Ok(CableChannel {
handle_connection,
cbor_sender: cbor_tx_send,
cbor_receiver: cbor_rx_recv,
ux_update_sender,
connection_state_receiver,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CableQrCodeDevice>();
};
#[test]
fn qr_code_omits_key_6_for_cloud_assisted_only() {
let device = CableQrCodeDevice::new_transient(
QrCodeOperationHint::MakeCredential,
CableTransports::CloudAssistedOnly,
)
.unwrap();
let bytes = cbor::to_vec(&device.qr_code).unwrap();
let map: BTreeMap<u64, cbor::Value> = cbor::from_slice(&bytes).unwrap();
assert_eq!(map.get(&6), None);
}
#[test]
fn qr_code_encodes_key_6_for_cloud_assisted_or_local() {
let device = CableQrCodeDevice::new_transient(
QrCodeOperationHint::MakeCredential,
CableTransports::CloudAssistedOrLocal,
)
.unwrap();
let bytes = cbor::to_vec(&device.qr_code).unwrap();
let map: BTreeMap<u64, cbor::Value> = cbor::from_slice(&bytes).unwrap();
assert_eq!(
map.get(&6),
Some(&cbor::Value::Array(vec![
cbor::Value::Integer(0),
cbor::Value::Integer(1),
])),
);
}
}