use super::Error as BootloaderError;
use crate::bootloader::{command, property};
use core::convert::{TryFrom, TryInto};
use hidapi::{HidDevice, HidResult};
pub struct Protocol {
device: HidDevice,
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("receiver aborted data phase")]
AbortDataPhase,
#[error("expected data response packet")]
ExpectedDataPacket,
#[error("expected (non-data) response packet")]
ExpectedResponsePacket,
#[error("error from underlying hidapi")]
HidApi(#[from] hidapi::HidError),
#[error("invalid HID report ID ({0})")]
InvalidReportId(u8),
#[error("unknown response tag ({0})")]
UnknownResponseTag(u8),
#[error("unspecified protocol error")]
Unspecified,
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct ResponsePacket {
pub tag: command::ResponseTag,
pub has_data: bool,
pub status: Option<BootloaderError>,
pub parameters: Vec<u32>,
}
pub enum ReceivedPacket {
Response(ResponsePacket),
Data(Vec<u8>),
}
impl TryFrom<ReceivedPacket> for ResponsePacket {
type Error = Error;
fn try_from(packet: ReceivedPacket) -> Result<Self> {
if let ReceivedPacket::Response(packet) = packet {
Ok(packet)
} else {
Err(Error::ExpectedResponsePacket)
}
}
}
impl TryFrom<ReceivedPacket> for Vec<u8> {
type Error = Error;
fn try_from(packet: ReceivedPacket) -> Result<Self> {
if let ReceivedPacket::Data(data) = packet {
Ok(data)
} else {
Err(Error::ExpectedDataPacket)
}
}
}
pub const READ_TIMEOUT: i32 = 2000;
impl Protocol {
pub fn property(
&self,
property: property::Property,
) -> core::result::Result<Vec<u32>, crate::bootloader::Error> {
let response = self
.call(&command::Command::GetProperty(property))
.expect("success");
if let command::Response::GetProperty(values) = response {
Ok(values)
} else {
todo!();
}
}
pub fn call(&self, command: &command::Command) -> Result<command::Response> {
self.call_progress(command, None)
}
pub fn call_progress<'a>(
&self,
command: &command::Command,
progress: Option<&'a dyn Fn(usize)>,
) -> Result<command::Response> {
let command_packet = command.hid_packet();
self.write(command_packet.as_slice())?;
trace!("--> {}", hex_str!(&command_packet));
let initial_response = self.read_packet()?;
match (command.clone(), command.tag(), command.data_phase()) {
(command, _tag, command::DataPhase::None) => {
let packet = ResponsePacket::try_from(initial_response)?;
assert!(!packet.has_data);
if let Some(status) = packet.status {
panic!("{:?}", status);
}
use command::Command::*;
match command {
Reset
| EraseFlash {
address: _,
length: _,
}
| EraseFlashAll
| ConfigureMemory { .. }
| Keystore(command::KeystoreOperation::Enroll)
| Keystore(command::KeystoreOperation::GenerateKey { key: _, len: _ })
| Keystore(command::KeystoreOperation::WriteNonVolatile)
| Keystore(command::KeystoreOperation::ReadNonVolatile) => {
assert_eq!(packet.tag, command::ResponseTag::Generic);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(
packet.parameters[0].to_le_bytes()[..2],
command.header()[..2]
);
Ok(command::Response::Generic)
}
GetProperty(_property) => {
assert_eq!(packet.tag, command::ResponseTag::GetProperty);
assert!(!packet.parameters.is_empty());
Ok(command::Response::GetProperty(packet.parameters))
}
_ => todo!(),
}
}
(command, _tag, command::DataPhase::CommandData(data)) => {
let packet = ResponsePacket::try_from(initial_response)?;
assert!(packet.status.is_none());
match command.clone() {
command::Command::Keystore(command::KeystoreOperation::SetKey {
key: _,
data: _,
}) => {
for chunk in data.chunks(32) {
let mut data_packet = vec![
command::ReportId::CommandData as u8,
0,
chunk.len() as u8,
0,
];
data_packet.extend_from_slice(chunk);
data_packet.resize(4 + 32, 0);
trace!("--> {}", hex_str!(&data_packet, 4));
self.write(data_packet.as_slice())?;
}
let packet = ResponsePacket::try_from(self.read_packet()?)?;
assert!(!packet.has_data);
if let Some(status) = packet.status {
panic!("unexpected status {:?}", &status);
}
assert_eq!(packet.tag, command::ResponseTag::Generic);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);
Ok(command::Response::Generic)
}
command::Command::WriteMemory {
address: _,
data: _,
}
| command::Command::WriteMemoryWords { .. } => {
for chunk in data.chunks(32) {
let mut data_packet = vec![
command::ReportId::CommandData as u8,
0,
chunk.len() as u8,
0,
];
data_packet.extend_from_slice(chunk);
data_packet.resize(4 + 32, 0);
trace!("--> {}", hex_str!(&data_packet, 4));
self.write(data_packet.as_slice())?;
}
let packet = ResponsePacket::try_from(self.read_packet()?)?;
assert!(!packet.has_data);
if let Some(status) = packet.status {
panic!("unexpected status {:?}", &status);
}
assert_eq!(packet.tag, command::ResponseTag::Generic);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);
Ok(command::Response::Generic)
}
command::Command::ReceiveSbFile { data } => {
let mut position: usize = 0;
for chunk in data.chunks(32) {
position += 32;
let _ = progress.map(|progress| progress(position));
let mut data_packet = vec![
command::ReportId::CommandData as u8,
0,
chunk.len() as u8,
0,
];
data_packet.extend_from_slice(chunk);
data_packet.resize(4 + 32, 0);
trace!("--> {}", hex_str!(&data_packet, 4));
self.write(data_packet.as_slice())?;
}
let packet = ResponsePacket::try_from(match self.read_packet() {
Err(Error::AbortDataPhase) => {
println!("aborting");
self.read_packet().unwrap()
}
x => x?,
})?;
assert!(!packet.has_data);
if let Some(status) = packet.status {
panic!("unexpected status {:?}", &status);
}
assert_eq!(packet.tag, command::ResponseTag::Generic);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);
Ok(command::Response::Generic)
}
_ => todo!(),
}
}
(command::Command::Keystore(command::KeystoreOperation::ReadKeystore), _, _) => {
let _packet = ResponsePacket::try_from(initial_response)?;
let mut data = Vec::new();
let length = 3 * 512;
while data.len() < length {
let partial_data: Vec<u8> = self.read_packet()?.try_into()?;
assert!(data.len() + partial_data.len() <= length);
data.extend_from_slice(&partial_data);
}
let packet = ResponsePacket::try_from(self.read_packet()?)?;
assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);
debug!("read {} in total", data.len());
Ok(command::Response::Data(data))
}
(command::Command::ReadMemory { address: _, length }, _, _) => {
let packet = ResponsePacket::try_from(initial_response)?;
assert!(packet.has_data);
assert!(packet.status.is_none());
assert_eq!(packet.tag, command::ResponseTag::ReadMemory);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(packet.parameters[0] as usize, length);
let mut data = Vec::new();
while data.len() < length {
let partial_data: Vec<u8> = self.read_packet()?.try_into()?;
assert!(data.len() + partial_data.len() <= length);
data.extend_from_slice(&partial_data);
}
let packet = ResponsePacket::try_from(self.read_packet()?)?;
assert!(!packet.has_data);
assert!(packet.status.is_none());
assert_eq!(packet.tag, command::ResponseTag::Generic);
assert_eq!(packet.parameters.len(), 1);
assert_eq!(
packet.parameters[0].to_le_bytes()[..2],
command.header()[..2]
);
Ok(command::Response::ReadMemory(data))
}
_ => todo!(),
}
}
pub fn read_packet(&self) -> Result<ReceivedPacket> {
let mut data = Vec::new();
data.resize(256, 0);
let read = self.device.read_timeout(&mut data, READ_TIMEOUT)?;
data.resize(read, 0);
let report_id = command::ReportId::try_from(data[0]).map_err(Error::InvalidReportId)?;
let expected_packet_len = u16::from_le_bytes(data[2..4].try_into().unwrap()) as usize;
data.resize(4 + expected_packet_len, 0);
trace!("--> {} ({}B)", hex_str!(&data, 4), data.len());
let response_packet = data.split_off(4);
Ok(match report_id {
command::ReportId::Response => {
if response_packet.is_empty() {
return Err(Error::AbortDataPhase);
}
let tag = command::ResponseTag::try_from(response_packet[0])
.map_err(Error::UnknownResponseTag)?;
let has_data = (response_packet[1] & 1) != 0;
let expected_param_count = response_packet[3] as usize;
let mut parameters: Vec<u32> = response_packet[4..]
.chunks(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(expected_param_count, parameters.len());
let status_code = parameters.remove(0);
let status = match status_code {
0 => None,
code => Some(BootloaderError::from(code)),
};
ReceivedPacket::Response(ResponsePacket {
tag,
has_data,
status,
parameters,
})
}
command::ReportId::ResponseData => ReceivedPacket::Data(response_packet),
_ => todo!(),
})
}
pub fn write(&self, data: &[u8]) -> Result<()> {
let sent = self.device.write(data)?;
let all = data.len();
if sent >= all {
Ok(())
} else {
Err(hidapi::HidError::IncompleteSendError { sent, all }.into())
}
}
pub fn read_timeout(&self, timeout: usize) -> HidResult<Vec<u8>> {
let mut data = Vec::new();
data.resize(256, 0);
let read = self.device.read_timeout(&mut data, timeout as i32)?;
data.resize(read, 0);
Ok(data)
}
}
impl Protocol {
pub fn new(device: HidDevice) -> Self {
Self { device }
}
}
impl std::fmt::Debug for Protocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device")
.field("manufacturer", &self.device.get_manufacturer_string())
.field("product", &self.device.get_product_string())
.field("serial number", &self.device.get_serial_number_string())
.finish()
}
}