use byteorder::{ByteOrder, ReadBytesExt, WriteBytesExt, LE};
use hidapi::{HidDevice, HidError};
use log::{info, trace};
use num_enum::TryFromPrimitive;
use std::convert::TryFrom;
use std::io::{Read, Write};
use std::thread::sleep;
use std::time::Duration;
use thiserror::Error;
const XFER_HEADER_SIZE: usize = 5;
const XFER_DATA_SIZE: usize = 1017;
pub fn download(device: &HidDevice, file: &mut impl Read) -> Result<(), Error> {
let mut report = vec![];
let mut block_num = 0u16;
let mut prev_delay = Duration::from_millis(0);
loop {
report.clear();
report.resize(1 + XFER_HEADER_SIZE, 0u8);
let data_size = file.take(XFER_DATA_SIZE as _).read_to_end(&mut report)?;
let mut cursor = std::io::Cursor::new(&mut report);
cursor.write_u8(DfuReportId::UploadDownload as _).unwrap();
cursor.write_u8(DfuRequest::DFU_DNLOAD as _).unwrap();
cursor.write_u16::<LE>(block_num).unwrap();
cursor.write_u16::<LE>(data_size as u16).unwrap();
assert!(cursor.position() == (1 + XFER_HEADER_SIZE) as _);
device
.send_feature_report(&report)
.map_err(|e| Error::DeviceIoError {
source: e,
action: "sending firmware data chunk",
})?;
if data_size == 0 {
info!(
"Waiting {:?}, as requested by device, for firmware to manifest",
prev_delay
);
sleep(prev_delay);
}
let status = DfuStatusResult::read_from_device(device)?;
status.ensure_ok()?;
prev_delay = Duration::from_millis(status.poll_timeout as _);
trace!(
"Successfully downloaded block {:#06x} ({} bytes)",
block_num,
data_size
);
if data_size == 0 {
status.ensure_state(DfuState::dfuIDLE)?;
break;
} else {
status.ensure_state(DfuState::dfuDNLOAD_IDLE)?;
}
block_num = match block_num.checked_add(1) {
Some(i) => i,
None => return Err(ProtocolError::FileTooLarge.into()),
};
}
Ok(())
}
pub fn upload(device: &HidDevice, file: &mut impl Write) -> Result<(), Error> {
let mut report = [0u8; 1 + XFER_HEADER_SIZE + XFER_DATA_SIZE];
loop {
report.fill(0u8);
report[0] = DfuReportId::UploadDownload as u8;
let report_size = map_gfr(
device.get_feature_report(&mut report),
1 + XFER_HEADER_SIZE,
"reading firmware data chunk",
)?;
let status = DfuStatusResult::read_from_device(device)?;
status.ensure_ok()?;
let data_size = LE::read_u16(&report[1..3]) as usize;
let data_start = 1 + XFER_HEADER_SIZE;
if report_size < data_start + data_size {
return Err(ProtocolError::ReportTooShort {
expected: data_start + data_size,
actual: report_size,
}
.into());
}
trace!("Successfully uploaded block ({} bytes)", data_size);
file.write_all(&report[data_start..data_start + data_size])?;
if data_size != XFER_DATA_SIZE {
status.ensure_state(DfuState::dfuIDLE)?;
break;
} else {
status.ensure_state(DfuState::dfuUPLOAD_IDLE)?;
}
}
Ok(())
}
#[non_exhaustive]
pub enum InfoField {
DeviceModel,
SerialNumber,
CurrentFirmware,
}
pub fn read_info_field(device: &HidDevice, field: InfoField) -> Result<String, Error> {
const INFO_REPORT_ID: u8 = 2;
const INFO_REPORT_LEN: usize = 126;
use InfoField::*;
let mut request_report = [0u8; 1 + 2 + 1];
request_report[0] = INFO_REPORT_ID;
request_report[1..3].copy_from_slice(match field {
DeviceModel => b"pl",
SerialNumber => b"sn",
CurrentFirmware => b"vr",
});
device
.send_feature_report(&request_report)
.map_err(|e| Error::DeviceIoError {
source: e,
action: "requesting info field",
})?;
let mut response_report = [0u8; 1 + INFO_REPORT_LEN];
response_report[0] = INFO_REPORT_ID;
map_gfr(
device.get_feature_report(&mut response_report),
1,
"reading info field",
)?;
let result = response_report[1..].split(|&x| x == 0).next().unwrap();
Ok(std::str::from_utf8(result)
.map_err(|e| Error::ProtocolError(e.into()))?
.to_owned())
}
pub fn enter_dfu(device: &HidDevice) -> Result<(), Error> {
const ENTER_DFU_REPORT_ID: u8 = 1;
device
.send_feature_report(&[ENTER_DFU_REPORT_ID, 0xb0, 0x07]) .map_err(|e| Error::DeviceIoError {
source: e,
action: "entering DFU mode",
})
}
pub fn leave_dfu(device: &HidDevice) -> Result<(), Error> {
device
.send_feature_report(&[DfuReportId::StateCmd as u8, DfuRequest::BOSE_EXIT_DFU as u8])
.map_err(|e| Error::DeviceIoError {
source: e,
action: "leaving DFU mode",
})
}
pub fn ensure_idle(device: &HidDevice) -> Result<(), Error> {
use DfuState::*;
let status = DfuStatusResult::read_from_device(device)?;
match status.state {
dfuIDLE => return Ok(()),
dfuDNLOAD_SYNC | dfuDNLOAD_IDLE | dfuMANIFEST_SYNC | dfuUPLOAD_IDLE => {
info!(
"Device not idle, state = {:?}; sending DFU_ABORT",
status.state
);
device
.send_feature_report(&[DfuReportId::StateCmd as u8, DfuRequest::DFU_ABORT as u8])
.map_err(|e| Error::DeviceIoError {
source: e,
action: "sending DFU_ABORT",
})?;
}
dfuERROR => {
info!(
"Device in error state, status = {:?} ({}); sending DFU_CLRSTATUS",
status.status,
status.status.error_str()
);
device
.send_feature_report(&[
DfuReportId::StateCmd as u8,
DfuRequest::DFU_CLRSTATUS as u8,
])
.map_err(|e| Error::DeviceIoError {
source: e,
action: "sending DFU_CLRSTATUS",
})?;
}
_ => return Err(ProtocolError::BadInitialState(status.state).into()),
};
let status = DfuStatusResult::read_from_device(device)?;
status.ensure_ok()?;
status.ensure_state(dfuIDLE).map_err(Into::into)
}
#[repr(u8)]
enum DfuReportId {
UploadDownload = 1,
GetStatus = 2,
StateCmd = 3,
}
#[repr(u8)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, TryFromPrimitive)]
#[allow(non_camel_case_types)] pub enum DfuStatus {
OK = 0x00,
errTARGET = 0x01,
errFILE = 0x02,
errWRITE = 0x03,
errERASE = 0x04,
errCHECK_ERASED = 0x05,
errPROG = 0x06,
errVERIFY = 0x07,
errADDRESS = 0x08,
errNOTDONE = 0x09,
errFIRMWARE = 0x0a,
errVENDOR = 0x0b,
errUSBR = 0x0c,
errPOR = 0x0d,
errUNKNOWN = 0x0e,
errSTALLEDPKT = 0x0f,
}
impl DfuStatus {
pub fn error_str(&self) -> &'static str {
use DfuStatus::*;
match self {
OK => "No error condition is present.",
errTARGET => "File is not targeted for use by this device.",
errFILE => "File is for this device but fails some vendor-specific verification test.",
errWRITE => "Device is unable to write memory.",
errERASE => "Memory erase function failed.",
errCHECK_ERASED => "Memory erase check failed.",
errPROG => "Program memory function failed.",
errVERIFY => "Programmed memory failed verification.",
errADDRESS => "Cannot program memory due to received address that is out of range.",
errNOTDONE => "Received DFU_DNLOAD with wLength = 0, but device does not think it has all of the data yet.",
errFIRMWARE => "Device's firmware is corrupt. It cannot return to run-time (non-DFU) operations.",
errVENDOR => "iString indicates a vendor-specific error.",
errUSBR => "Device detected unexpected USB reset signaling.",
errPOR => "Device detected unexpected power on reset.",
errUNKNOWN => "Something went wrong, but the device does not know what it was.",
errSTALLEDPKT => "Device stalled an unexpected request.",
}
}
}
#[repr(u8)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, TryFromPrimitive)]
#[allow(non_camel_case_types)] pub enum DfuState {
appIDLE = 0,
appDETACH = 1,
dfuIDLE = 2,
dfuDNLOAD_SYNC = 3,
dfuDNBUSY = 4,
dfuDNLOAD_IDLE = 5,
dfuMANIFEST_SYNC = 6,
dfuMANIFEST = 7,
dfuMANIFEST_WAIT_RESET = 8,
dfuUPLOAD_IDLE = 9,
dfuERROR = 10,
}
impl DfuState {
#[allow(dead_code)]
fn read_from_device(device: &HidDevice) -> Result<Self, Error> {
let mut report = [0u8; 1 + 1]; report[0] = DfuReportId::StateCmd as u8;
map_gfr(
device.get_feature_report(&mut report),
report.len(),
"querying state",
)?;
Self::try_from(report[1]).map_err(|e| ProtocolError::UnknownState(e.number).into())
}
fn ensure(self, expected: Self) -> Result<(), ProtocolError> {
if self != expected {
Err(ProtocolError::UnexpectedState {
expected,
actual: self,
})
} else {
Ok(())
}
}
}
#[repr(u8)]
#[allow(non_camel_case_types)] #[allow(dead_code)] enum DfuRequest {
DFU_DETACH = 0,
DFU_DNLOAD = 1,
DFU_UPLOAD = 2,
DFU_GETSTATUS = 3,
DFU_CLRSTATUS = 4,
DFU_GETSTATE = 5,
DFU_ABORT = 6,
BOSE_EXIT_DFU = 0xff, }
#[derive(Copy, Clone, Debug)]
struct DfuStatusResult {
pub status: DfuStatus,
pub state: DfuState,
pub poll_timeout: u32,
}
impl DfuStatusResult {
fn read_from_device(device: &HidDevice) -> Result<Self, Error> {
let mut report = [0u8; 1 + 6]; report[0] = DfuReportId::GetStatus as u8;
map_gfr(
device.get_feature_report(&mut report),
report.len(),
"querying status",
)?;
let mut cursor = std::io::Cursor::new(report);
cursor.set_position(1);
let status = DfuStatus::try_from(cursor.read_u8().unwrap())
.map_err(|e| ProtocolError::UnknownState(e.number))?;
let poll_timeout = cursor.read_u24::<LE>().unwrap();
let state = DfuState::try_from(cursor.read_u8().unwrap())
.map_err(|e| ProtocolError::UnknownStatus(e.number))?;
Ok(Self {
status,
poll_timeout,
state,
})
}
fn ensure_ok(&self) -> Result<(), ProtocolError> {
if self.status != DfuStatus::OK {
Err(ProtocolError::ErrorStatus(self.status))
} else {
Ok(())
}
}
fn ensure_state(&self, expected: DfuState) -> Result<(), ProtocolError> {
self.state.ensure(expected)
}
}
fn map_gfr(
r: Result<usize, HidError>,
min_size: usize,
action: &'static str,
) -> Result<usize, Error> {
match r {
Err(e) => Err(Error::DeviceIoError { source: e, action }),
Ok(s) if s < min_size => Err(ProtocolError::ReportTooShort {
expected: min_size,
actual: s,
}
.into()),
Ok(s) => Ok(s),
}
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("DFU protocol error")]
ProtocolError(#[from] ProtocolError),
#[error("USB transaction error while {action}")]
DeviceIoError {
source: HidError,
action: &'static str,
},
#[error("file I/O error")]
FileIoError(#[from] std::io::Error),
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ProtocolError {
#[error("device reported state ({0}) that is not in the DFU spec")]
UnknownState(u8),
#[error("device reported status ({0}) that is not in the DFU spec")]
UnknownStatus(u8),
#[error("device reported an error: {0:?} ({})", .0.error_str())]
ErrorStatus(DfuStatus),
#[error("device entered unexpected state: expected {expected:?}, got {actual:?}")]
UnexpectedState {
expected: DfuState,
actual: DfuState,
},
#[error("don't know how to safely leave initial state {0:?}; please re-enter DFU mode")]
BadInitialState(DfuState),
#[error("file too large: overflowed 16-bit block number while sending")]
FileTooLarge,
#[error("device returned invalid UTF-8 string")]
InvalidString(#[from] std::str::Utf8Error),
#[error("feature report from device was {actual} bytes, expected at least {expected}")]
ReportTooShort { expected: usize, actual: usize },
}