use color_print::cstr;
use indicatif::{ProgressBar, ProgressStyle};
use log::{info, trace};
use packets::{
Packet, PacketParse,
command::{CmdResponse, CommandHeader, CommandPacket},
data_phase::DataPhasePacket,
};
use protocols::Protocol;
use tags::{
ToAddress,
command::{CommandTag, CommandToParams, KeyProvOperation, TrustProvOperation},
command_flag::CommandFlag,
command_response::CmdResponseTag,
property::{PropertyTag, PropertyTagDiscriminants},
status::StatusCode,
};
use crate::CommunicationError;
mod formatters;
pub mod memory;
pub mod packets;
pub mod protocols;
pub mod tags;
#[derive(Clone, Debug)]
pub struct GetPropertyResponse {
pub status: StatusCode,
pub response_words: Box<[u32]>,
pub property: PropertyTag,
}
#[derive(Clone, Debug)]
pub struct ReadMemoryResponse {
pub status: StatusCode,
pub response_words: Box<[u32]>,
pub bytes: Box<[u8]>,
}
#[derive(Clone, Debug)]
pub enum KeyProvisioningResponse {
Status(StatusCode),
KeyStore {
status: StatusCode,
response_words: Box<[u32]>,
bytes: Box<[u8]>,
},
}
trait InvalidData<T> {
fn or_invalid(self) -> Result<T, CommunicationError>;
}
impl<T, E> InvalidData<T> for Result<T, E> {
fn or_invalid(self) -> Result<T, CommunicationError> {
self.or(Err(CommunicationError::InvalidData))
}
}
pub struct McuBoot<T>
where
T: Protocol,
{
device: T,
pub progress_bar: bool,
pub mask_read_data_phase: bool,
}
pub type ResultComm<T> = Result<T, CommunicationError>;
pub type ResultStatus = ResultComm<StatusCode>;
impl<T> McuBoot<T>
where
T: Protocol,
{
#[must_use]
pub fn new(device: T) -> Self {
info!(
"Initialized MCU Boot with device identifier: {}",
device.get_identifier()
);
McuBoot {
device,
progress_bar: false,
mask_read_data_phase: false,
}
}
pub fn get_property(
&mut self,
tag: PropertyTagDiscriminants,
memory_index: u32,
) -> ResultComm<GetPropertyResponse> {
let command = CommandPacket::new_none_flag(CommandTag::GetProperty { tag, memory_index });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
if let CmdResponseTag::GetProperty(val) = response.tag {
Ok(GetPropertyResponse {
status: response.status,
property: PropertyTag::from_code(tag, &val),
response_words: val,
})
} else {
Err(CommunicationError::InvalidPacketReceived)
}
}
pub fn set_property(&mut self, tag: PropertyTagDiscriminants, value: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::SetProperty { tag, value });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn reset(&mut self) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::Reset);
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn call(&mut self, start_address: u32, argument: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::Call {
start_address,
argument,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn execute(&mut self, start_address: u32, argument: u32, stackpointer: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::Execute {
start_address,
argument,
stackpointer,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn fill_memory(&mut self, start_address: u32, byte_count: u32, pattern: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::FillMemory {
start_address,
byte_count,
pattern,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn write_memory(&mut self, start_address: u32, memory_id: u32, bytes: &[u8]) -> ResultStatus {
let command = CommandPacket::new_data_phase(CommandTag::WriteMemory {
start_address,
memory_id,
bytes,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn flash_erase_all(&mut self, memory_id: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::FlashEraseAll { memory_id });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn flash_erase_region(&mut self, start_address: u32, byte_count: u32, memory_id: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::FlashEraseRegion {
start_address,
byte_count,
memory_id,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn flash_erase_all_unsecure(&mut self) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::FlashEraseAllUnsecure);
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn read_memory(
&mut self,
start_address: u32,
byte_count: u32,
memory_id: u32,
) -> ResultComm<ReadMemoryResponse> {
let command = CommandPacket::new_none_flag(CommandTag::ReadMemory {
start_address,
byte_count,
memory_id,
});
self.send_command(&command)?;
let response = self.read_command()?;
let status = &response.status;
if !(status.is_success() || status.is_memory_blank_page_read_disallowed()) {
return Err((*status).into());
}
if let CmdResponseTag::ReadMemory(bytes) = response.tag {
Ok(ReadMemoryResponse {
status: response.status,
response_words: Box::new([bytes.len() as u32]),
bytes,
})
} else {
Err(CommunicationError::InvalidPacketReceived)
}
}
pub fn configure_memory(&mut self, memory_id: u32, address: u32) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::ConfigureMemory { memory_id, address });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn receive_sb_file(&mut self, bytes: &[u8]) -> ResultStatus {
let command = CommandPacket::new_data_phase(CommandTag::ReceiveSBFile { bytes });
match self.send_command(&command) {
Ok(()) | Err(CommunicationError::Aborted) => {
let response = self.read_cmd_response()?;
Ok(response.status)
}
Err(err) => Err(err),
}
}
pub fn trust_provisioning(&mut self, operation: &TrustProvOperation) -> ResultComm<(StatusCode, Box<[u32]>)> {
let command = CommandPacket::new_none_flag(CommandTag::TrustProvisioning(operation));
self.send_command(&command)?;
let response = self.read_cmd_response()?;
match response.tag {
CmdResponseTag::TrustProvisioning(data) => Ok((response.status, data)),
_ => Err(CommunicationError::InvalidPacketReceived),
}
}
pub fn key_provisioning(
&mut self,
operation: &KeyProvOperation,
) -> Result<KeyProvisioningResponse, CommunicationError> {
let command = CommandPacket::new_none_flag(CommandTag::KeyProvisioning(operation));
if let KeyProvOperation::ReadKeyStore { .. } = operation {
self.send_command(&command)?;
let response = self.read_cmd_response()?;
match response.tag {
CmdResponseTag::KeyProvisioning(data, data_phase) => {
Ok(KeyProvisioningResponse::KeyStore {
status: response.status,
response_words: data,
bytes: data_phase.unwrap_or_default(),
})
}
_ => Err(CommunicationError::InvalidPacketReceived),
}
} else {
self.mask_read_data_phase = true;
self.send_command(&command)?;
self.mask_read_data_phase = false;
let response = self.read_cmd_response()?;
Ok(KeyProvisioningResponse::Status(response.status))
}
}
pub fn flash_read_once(&mut self, index: u32, count: u32) -> ResultComm<u32> {
let command = CommandPacket::new_none_flag(CommandTag::FlashReadOnce { index, count });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
match response.tag {
CmdResponseTag::FlashReadOnce(value) => Ok(value),
_ => Err(CommunicationError::InvalidPacketReceived),
}
}
pub fn flash_program_once(&mut self, index: u32, count: u32, data: u32, verify: bool) -> ResultStatus {
let command = CommandPacket::new_none_flag(CommandTag::FlashProgramOnce { index, count, data });
self.send_command(&command)?;
let response = self.read_cmd_response()?;
if verify && response.status.is_success() {
match self.flash_read_once(index & ((1 << 24) - 1), count) {
Ok(read_value) => {
if read_value & data == data {
Ok(response.status)
} else {
Ok(StatusCode::OtpVerifyFail)
}
}
Err(e) => Err(e),
}
} else {
Ok(response.status)
}
}
pub fn fuse_read(&mut self, start_address: u32, byte_count: u32, memory_id: u32) -> ResultComm<ReadMemoryResponse> {
let command = CommandPacket::new_none_flag(CommandTag::FuseRead {
start_address,
byte_count,
memory_id,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
let status = &response.status;
if !status.is_success() {
return Err((*status).into());
}
match response.tag {
CmdResponseTag::ReadMemory(bytes) => Ok(ReadMemoryResponse {
status: response.status,
response_words: Box::new([bytes.len() as u32]),
bytes,
}),
_ => Err(CommunicationError::InvalidPacketReceived),
}
}
pub fn fuse_program(&mut self, start_address: u32, memory_id: u32, bytes: &[u8]) -> ResultStatus {
let command = CommandPacket::new_data_phase(CommandTag::FuseProgram {
start_address,
memory_id,
bytes,
});
self.send_command(&command)?;
let response = self.read_cmd_response()?;
Ok(response.status)
}
pub fn load_image(&mut self, bytes: &[u8]) -> ResultStatus {
let command = CommandPacket::new_data_phase(CommandTag::NoCommand { bytes });
self.send_command(&command)?;
Ok(StatusCode::Success)
}
fn read_cmd_response(&mut self) -> ResultComm<CmdResponse> {
let response = self.read_command()?;
info!("{}: {response:02X?}", cstr!("<bold>Received"));
if response.status.is_success() {
Ok(response)
} else {
Err(response.status.into())
}
}
fn send_command(&mut self, command: &CommandPacket) -> ResultComm<()> {
let tag = &command.tag;
let (params, data_phase) = tag.to_params();
let packet = command.header.construct_frame(¶ms, tag.code());
info!("{}: {command:02X?}", cstr!("<bold>Sending"));
if let Some(data) = data_phase {
info!("Sending data phase: {data:02X?}");
let max_packet_size: u32 = {
let response = self.get_property(PropertyTagDiscriminants::MaxPacketSize, 0)?;
match response.property {
PropertyTag::MaxPacketSize(size) => size,
_ => return Err(CommunicationError::InvalidData),
}
};
if !matches!(tag, CommandTag::NoCommand { .. }) {
self.device.write_packet_raw(&packet)?;
self.read_cmd_response()?;
}
{
let progress_bar = self.create_progress_bar(data.len() as u64, "Sending data");
for bytes in data.chunks(
max_packet_size
.try_into()
.expect("pointer size of this platform is too small"),
) {
self.device.write_packet_concrete(DataPhasePacket::parse(bytes)?)?;
if let Some(bar) = progress_bar.as_ref() {
bar.inc(max_packet_size.into());
}
}
}
} else {
self.device.write_packet_raw(&packet)?;
}
Ok(())
}
fn read_command(&mut self) -> ResultComm<CmdResponse> {
trace!("Starting to read command");
let data = self.device.read_packet_raw(CmdResponse::get_code())?;
let params_slice = &data[8..];
if params_slice.len() % 4 != 0 && params_slice.len() != 4 * data[3] as usize {
return Err(CommunicationError::InvalidData);
}
let header = CommandHeader {
flag: CommandFlag::try_from(data[1]).or(Err(CommunicationError::InvalidData))?,
reserved: data[2],
};
let status = parse_status(data[4..8].try_into().or_invalid()?)?;
if self.mask_read_data_phase {
return Ok(CmdResponse {
header,
status,
tag: CmdResponseTag::from_code(data[0], params_slice, None).ok_or(CommunicationError::InvalidData)?,
});
}
match header.flag {
CommandFlag::NoData => Ok(CmdResponse {
header,
status,
tag: CmdResponseTag::from_code(data[0], params_slice, None).ok_or(CommunicationError::InvalidData)?,
}),
CommandFlag::HasDataPhase => {
let length = u32::from_le_bytes(params_slice[0..4].try_into().or_invalid()?);
trace!("Data phase length: {length}");
let mut data_phase = Vec::new();
{
let progress_bar = self.create_progress_bar(length.into(), "Receiving data");
while data_phase.len() != length as usize {
trace!("Reading data phase packet");
data_phase.extend(match self.device.read_packet_concrete::<DataPhasePacket>() {
Ok(data) => {
if let Some(bar) = progress_bar.as_ref() {
bar.inc(data.data.len() as u64);
}
data.data
}
Err(CommunicationError::Aborted) => break,
Err(err) => return Err(err),
});
}
}
trace!("Reading final response");
let final_response = self.device.read_packet_raw(CmdResponse::get_code())?;
let status = parse_status(final_response[4..8].try_into().or_invalid()?)?;
Ok(CmdResponse {
header: CommandHeader {
flag: CommandFlag::NoData,
reserved: data[2],
},
status,
tag: CmdResponseTag::from_code(data[0], params_slice, Some(&data_phase))
.ok_or(CommunicationError::InvalidData)?,
})
}
}
}
fn create_progress_bar(&self, len: u64, prefix: &'static str) -> Option<ProgressBar> {
if self.progress_bar {
let bar = ProgressBar::new(len);
bar.set_style(
ProgressStyle::with_template("{prefix} [{bar:40}] {binary_bytes:>}/{binary_total_bytes}")
.unwrap()
.progress_chars("##-"),
);
bar.set_prefix(prefix);
Some(bar)
} else {
None
}
}
}
fn parse_status(data: [u8; 4]) -> ResultComm<StatusCode> {
let discriminant = u32::from_le_bytes(data);
StatusCode::try_from(discriminant).or(Err(CommunicationError::UnexpectedStatus(
StatusCode::UnknownStatusCode,
discriminant,
)))
}
#[cfg(test)]
mod tests {
use crate::mboot::{
McuBoot,
protocols::{ProtocolOpen, uart::UARTProtocol},
tags::property::{PropertyTag, PropertyTagDiscriminants},
};
const DEVICE: &str = "COM3";
fn get_boot() -> McuBoot<UARTProtocol> {
McuBoot::new(UARTProtocol::open(DEVICE).unwrap())
}
#[test]
#[ignore = "Requires hardware connection to board"]
fn test_board_get_version() {
let mut boot = get_boot();
let version = boot.get_property(PropertyTagDiscriminants::CurrentVersion, 0).unwrap();
if let PropertyTag::CurrentVersion(ver) = version.property {
assert_eq!(ver.mark, 'K');
assert_eq!(ver.major, 3);
assert_eq!(ver.minor, 1);
assert_eq!(ver.fixation, 1);
} else {
panic!()
}
}
}