ctaphid 0.2.0

Rust implementation of the CTAPHID protocol
Documentation
// Copyright (C) 2021 Robin Krahl <robin.krahl@ireas.org>
// SPDX-License-Identifier: Apache-2.0 or MIT

use std::{cell, convert::TryFrom, time::Duration};

use ctaphid_types::{
    Channel, Command, DefragmentedMessage, DeviceError, InitializationPacket, Message, Packet,
};

use crate::error::{RequestError, ResponseError};
use crate::hid::Device;

struct PacketBuffer(Vec<u8>);

impl PacketBuffer {
    pub fn new(packet_size: usize) -> Self {
        Self(vec![0; packet_size + 1])
    }

    pub fn packet_size(&self) -> usize {
        self.0.len() - 1
    }

    fn send_packet<T: AsRef<[u8]>, D: Device>(
        &mut self,
        device: &D,
        packet: &Packet<T>,
    ) -> Result<(), RequestError> {
        self.0[0] = 0;
        let n = packet
            .serialize(&mut self.0[1..])
            .map_err(RequestError::PacketSerializationFailed)?;
        self.0[n + 1..].fill(0);
        device.send(&self.0)?;
        Ok(())
    }

    fn receive_packet<D: Device>(
        &mut self,
        device: &D,
        channel: Channel,
        timeout: Option<Duration>,
    ) -> Result<Option<Packet<&[u8]>>, ResponseError> {
        let data = device.receive(&mut self.0, timeout)?;
        let packet = Packet::try_from(data).map_err(ResponseError::PacketParsingFailed)?;
        // We would like to loop until we have a packet on the correct channel but the borrow
        // checker does not allow that, so the caller of this method has to loop.
        if packet.channel() == channel {
            Ok(Some(packet))
        } else {
            Ok(None)
        }
    }
}

pub struct MessageBuffer(cell::RefCell<PacketBuffer>);

impl MessageBuffer {
    pub fn new(packet_size: usize) -> Self {
        Self(PacketBuffer::new(packet_size).into())
    }

    pub fn send_message<T: AsRef<[u8]>, D: Device>(
        &self,
        device: &D,
        message: Message<T>,
    ) -> Result<(), RequestError> {
        let mut buffer = self.0.borrow_mut();
        let packets = message
            .fragments(buffer.packet_size())
            .map_err(RequestError::MessageFragmentationFailed)?;
        for packet in packets {
            buffer.send_packet(device, &packet)?;
        }
        Ok(())
    }

    pub fn receive_message<D: Device>(
        &self,
        device: &D,
        channel: Channel,
        command: Command,
        timeout: Option<Duration>,
    ) -> Result<Message<Vec<u8>>, ResponseError> {
        loop {
            let message = self.receive_message_raw(device, channel, timeout)?;
            match message.command {
                Command::Error => {
                    if !message.data.is_empty() {
                        let error = message.data[0];
                        break Err(ResponseError::CommandFailed(DeviceError::from(error)));
                    } else {
                        break Err(ResponseError::MissingErrorCode);
                    }
                }
                Command::KeepAlive => match command {
                    Command::Cbor | Command::Message => {
                        log::debug!("Received KEEPALIVE message");
                        continue;
                    }
                    _ => {
                        break Err(ResponseError::UnexpectedKeepAlive(command));
                    }
                },
                message_command if message_command != command => {
                    break Err(ResponseError::UnexpectedCommand {
                        expected: command,
                        actual: message_command,
                    });
                }
                _ => break Ok(message),
            }
        }
    }

    fn receive_message_raw<D: Device>(
        &self,
        device: &D,
        channel: Channel,
        timeout: Option<Duration>,
    ) -> Result<Message<Vec<u8>>, ResponseError> {
        let mut buffer = self.0.borrow_mut();
        let packet = loop {
            if let Some(packet) = buffer.receive_packet(device, channel, timeout)? {
                break packet;
            }
        };
        let init =
            InitializationPacket::try_from(packet).map_err(ResponseError::PacketParsingFailed)?;
        let mut message = Message::from_fragments(init);
        loop {
            match message {
                DefragmentedMessage::Complete(message) => return Ok(message),
                DefragmentedMessage::Partial(partial) => {
                    let packet = loop {
                        if let Some(packet) = buffer.receive_packet(device, channel, timeout)? {
                            break packet;
                        }
                    };
                    message = partial
                        .try_extend(&packet)
                        .map_err(ResponseError::MessageDefragmentationFailed)?;
                }
            }
        }
    }
}