ctaphid-types 0.2.0

Data types for the CTAPHID protocol
Documentation
// Copyright (C) 2021 Robin Krahl <robin.krahl@ireas.org>
// SPDX-License-Identifier: Apache-2.0 or MIT

use core::cmp;

use crate::{
    channel::Channel,
    command::Command,
    error::{DefragmentationError, FragmentationError},
    packet::{
        ContinuationPacket, InitializationPacket, Packet, PacketType, CONTINUATION_HEADER_SIZE,
        INITIALIZATION_HEADER_SIZE,
    },
};

/// A CTAPHID message.
///
/// See [§ 11.2.2 of the CTAP specification][spec].
///
/// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-protocol-and-framing
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct Message<T: AsRef<[u8]>> {
    /// The channel this message is sent or received on.
    pub channel: Channel,
    /// The CTAPHID command.
    pub command: Command,
    /// The message payload.
    pub data: T,
}

impl<T: AsRef<[u8]>> Message<T> {
    /// Fragments this message into CTAPHID packets and returns an iterator over the packets.
    ///
    /// See [§ 11.2.4 of the CTAP specification][spec].
    ///
    /// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-message-and-packet-structure
    pub fn fragments(&self, packet_size: usize) -> Result<Fragments<'_>, FragmentationError> {
        Fragments::new(self, packet_size)
    }
}

impl<T: AsRef<[u8]> + Default + Extend<u8>> Message<T> {
    /// Assembles a CTAPHID message from a sequence of packets, starting with the given
    /// initialization packet.
    ///
    /// See [§ 11.2.4 of the CTAP specification][spec].
    ///
    /// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-message-and-packet-structure
    pub fn from_fragments<S: AsRef<[u8]>>(
        packet: InitializationPacket<S>,
    ) -> DefragmentedMessage<T> {
        packet.into()
    }

    /// Tries to assemble a CTAPHID message from a sequence of packets, starting with the given
    /// packet.
    ///
    /// The packet must be an initialization packet.  This is a shorthand for matching the packet
    /// and calling [`Message::from_fragments`][].
    ///
    /// See [§ 11.2.4 of the CTAP specification][spec].
    ///
    /// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-message-and-packet-structure
    pub fn try_from_fragments<S: AsRef<[u8]>>(
        packet: Packet<S>,
    ) -> Result<DefragmentedMessage<T>, DefragmentationError> {
        if let Packet::Initialization(packet) = packet {
            Ok(Self::from_fragments(packet))
        } else {
            Err(DefragmentationError::InvalidPacketType {
                expected: PacketType::Initialization,
                actual: packet.packet_type(),
            })
        }
    }
}

/// An iterator over CTAPHID packets with the data of a CTAPHID message.
///
/// See [§ 11.2.4 of the CTAP specification][spec].
///
/// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-message-and-packet-structure
#[derive(Clone, Debug)]
pub struct Fragments<'a> {
    channel: Channel,
    init: Option<InitializationPacket<&'a [u8]>>,
    chunks: core::iter::Enumerate<core::slice::Chunks<'a, u8>>,
}

impl<'a> Fragments<'a> {
    fn new<T: AsRef<[u8]>>(
        message: &'a Message<T>,
        packet_size: usize,
    ) -> Result<Self, FragmentationError> {
        if INITIALIZATION_HEADER_SIZE >= packet_size || CONTINUATION_HEADER_SIZE >= packet_size {
            return Err(FragmentationError::PacketSizeTooSmall);
        }
        let data_size_init = packet_size - INITIALIZATION_HEADER_SIZE;
        let data_size_cont = packet_size - CONTINUATION_HEADER_SIZE;
        if message.data.as_ref().len() > data_size_init + 128 * data_size_cont {
            return Err(FragmentationError::DataTooLong);
        }
        let n = cmp::min(data_size_init, message.data.as_ref().len());
        let (data_init, data_cont) = message.data.as_ref().split_at(n);
        let init = InitializationPacket {
            channel: message.channel,
            command: message.command,
            length: message.data.as_ref().len() as u16,
            data: data_init,
        };
        Ok(Self {
            channel: message.channel,
            init: Some(init),
            chunks: data_cont.chunks(data_size_cont).enumerate(),
        })
    }
}

impl<'a> Iterator for Fragments<'a> {
    type Item = Packet<&'a [u8]>;

    fn next(&mut self) -> Option<Self::Item> {
        if let Some(init) = self.init.take() {
            Some(Packet::Initialization(init))
        } else if let Some((sequence, data)) = self.chunks.next() {
            Some(Packet::Continuation(ContinuationPacket {
                channel: self.channel,
                sequence: sequence as u8,
                data,
            }))
        } else {
            None
        }
    }
}

/// A complete or partial message obtained by assembling one or more CTAPHID packets.
///
/// See [§ 11.2.4 of the CTAP specification][spec].
///
/// [spec]: https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#usb-message-and-packet-structure
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub enum DefragmentedMessage<T: AsRef<[u8]> + Default + Extend<u8>> {
    /// A complete message.
    Complete(Message<T>),
    /// A partial message.
    Partial(PartialMessage<T>),
}

impl<T: AsRef<[u8]> + Default + Extend<u8>, S: AsRef<[u8]>> From<InitializationPacket<S>>
    for DefragmentedMessage<T>
{
    fn from(packet: InitializationPacket<S>) -> Self {
        PartialMessage::from(packet).into()
    }
}

impl<T: AsRef<[u8]> + Default + Extend<u8>> From<PartialMessage<T>> for DefragmentedMessage<T> {
    fn from(message: PartialMessage<T>) -> Self {
        if message.length == message.data.as_ref().len() {
            Self::Complete(Message {
                channel: message.channel,
                command: message.command,
                data: message.data,
            })
        } else {
            Self::Partial(message)
        }
    }
}

/// A [`Message`][] that has been partially assembled from CTAPHID packets.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct PartialMessage<T: AsRef<[u8]> + Default + Extend<u8>> {
    channel: Channel,
    command: Command,
    length: usize,
    data: T,
    next_sequence: u8,
}

impl<T: AsRef<[u8]> + Default + Extend<u8>> PartialMessage<T> {
    /// Continues assembling a [`Message`][] using the given continuation packet.
    pub fn extend<S: AsRef<[u8]>>(
        mut self,
        packet: &ContinuationPacket<S>,
    ) -> Result<DefragmentedMessage<T>, DefragmentationError> {
        if self.channel != packet.channel {
            return Err(DefragmentationError::InvalidChannel {
                expected: self.channel,
                actual: packet.channel,
            });
        }
        if self.next_sequence != packet.sequence {
            return Err(DefragmentationError::InvalidSequence {
                expected: self.next_sequence,
                actual: packet.sequence,
            });
        }
        self.extend_data(packet.data.as_ref());
        self.next_sequence += 1;
        Ok(self.into())
    }

    /// Tries to continue assembling a [`Message`][] using the given packet.
    ///
    /// The packet must be a continuation packet.  This is a shorthand for matching the packet and
    /// calling [`PartialMessage::extend`][].
    pub fn try_extend<S: AsRef<[u8]>>(
        self,
        packet: &Packet<S>,
    ) -> Result<DefragmentedMessage<T>, DefragmentationError> {
        if let Packet::Continuation(packet) = packet {
            self.extend(packet)
        } else {
            Err(DefragmentationError::InvalidPacketType {
                expected: PacketType::Continuation,
                actual: packet.packet_type(),
            })
        }
    }

    fn extend_data(&mut self, data: &[u8]) {
        // TODO: use something like extend_from_slice
        let n = cmp::min(data.len(), self.length - self.data.as_ref().len());
        self.data.extend(data[..n].iter().cloned());
    }
}

impl<T: AsRef<[u8]> + Default + Extend<u8>, S: AsRef<[u8]>> From<InitializationPacket<S>>
    for PartialMessage<T>
{
    fn from(packet: InitializationPacket<S>) -> Self {
        // TODO: initialize data with capacity
        let mut message = Self {
            channel: packet.channel,
            command: packet.command,
            length: usize::from(packet.length),
            data: T::default(),
            next_sequence: 0,
        };
        message.extend_data(packet.data.as_ref());
        message
    }
}

#[cfg(test)]
mod test {
    use quickcheck::Arbitrary;

    use super::{DefragmentedMessage, Message};

    impl Arbitrary for Message<Vec<u8>> {
        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
            Self {
                channel: Arbitrary::arbitrary(g),
                command: Arbitrary::arbitrary(g),
                data: Arbitrary::arbitrary(g),
            }
        }

        fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
            let channel = self.channel;
            let command = self.command;
            Box::new(self.data.shrink().map(move |data| Self {
                channel,
                command,
                data,
            }))
        }
    }

    quickcheck::quickcheck! {
        fn fragments(message: Message<Vec<u8>>) -> bool {
            use std::convert::TryInto;

            let mut d: Option<DefragmentedMessage<Vec<u8>>> = None;
            for fragment in message.fragments(64).unwrap() {
                if let Some(dm) = d.take() {
                    match dm {
                        DefragmentedMessage::Partial(p) => {
                            d = Some(p.extend(&fragment.try_into().unwrap()).unwrap());
                        },
                        DefragmentedMessage::Complete(_) => unreachable!(),
                    }
                } else {
                    d = Some(Message::from_fragments(fragment.try_into().unwrap()));
                }
            }

            match d.unwrap() {
                DefragmentedMessage::Partial(_) => false,
                DefragmentedMessage::Complete(m) => m == message,
            }
        }
    }
}