use super::FlashError;
use crate::{
Target,
architecture::{arm, riscv},
core::Architecture,
};
use probe_rs_target::{
CoreType, Endian, FlashProperties, MemoryRegion, PageInfo, RamRegion, RawFlashAlgorithm,
RegionMergeIterator, SectorInfo, TransferEncoding,
};
use std::mem::size_of_val;
#[derive(Debug, Default, Clone)]
pub struct FlashAlgorithm {
pub name: String,
pub default: bool,
pub load_address: u64,
pub instructions: Vec<u32>,
pub pc_init: Option<u64>,
pub pc_uninit: Option<u64>,
pub pc_program_page: u64,
pub pc_erase_sector: u64,
pub pc_erase_all: Option<u64>,
pub pc_verify: Option<u64>,
pub pc_blank_check: Option<u64>,
pub pc_read: Option<u64>,
pub pc_flash_size: Option<u64>,
pub static_base: u64,
pub stack_top: u64,
pub stack_size: u64,
pub stack_overflow_check: bool,
pub page_buffers: Vec<u64>,
pub rtt_control_block: Option<u64>,
pub rtt_poll_interval: u64,
pub flash_properties: FlashProperties,
pub transfer_encoding: TransferEncoding,
}
impl FlashAlgorithm {
pub fn sector_info(&self, address: u64) -> Option<SectorInfo> {
if !self.flash_properties.address_range.contains(&address) {
tracing::trace!("Address {:08x} not contained in this flash device", address);
return None;
}
let offset_address = address - self.flash_properties.address_range.start;
let containing_sector = self
.flash_properties
.sectors
.iter()
.rfind(|s| s.address <= offset_address)?;
let sector_index = (offset_address - containing_sector.address) / containing_sector.size;
let sector_address = self.flash_properties.address_range.start
+ containing_sector.address
+ sector_index * containing_sector.size;
Some(SectorInfo {
base_address: sector_address,
size: containing_sector.size,
})
}
pub fn page_info(&self, address: u64) -> Option<PageInfo> {
if !self.flash_properties.address_range.contains(&address) {
return None;
}
Some(PageInfo {
base_address: address - (address % self.flash_properties.page_size as u64),
size: self.flash_properties.page_size,
})
}
pub fn iter_sectors(&self) -> impl Iterator<Item = SectorInfo> + '_ {
let props = &self.flash_properties;
assert!(!props.sectors.is_empty());
assert!(props.sectors[0].address == 0);
let mut addr = props.address_range.start;
let mut desc_idx = 0;
std::iter::from_fn(move || {
if addr >= props.address_range.end {
return None;
}
if let Some(next_desc) = props.sectors.get(desc_idx + 1)
&& props.address_range.start + next_desc.address <= addr
{
desc_idx += 1;
}
let size = props.sectors[desc_idx].size;
let sector = SectorInfo {
base_address: addr,
size,
};
addr += size;
Some(sector)
})
}
pub fn iter_pages(&self) -> impl Iterator<Item = PageInfo> + '_ {
let props = &self.flash_properties;
let mut addr = props.address_range.start;
std::iter::from_fn(move || {
if addr >= props.address_range.end {
return None;
}
let page = PageInfo {
base_address: addr,
size: props.page_size,
};
addr += props.page_size as u64;
Some(page)
})
}
pub fn is_erased(&self, data: &[u8]) -> bool {
for b in data {
if *b != self.flash_properties.erased_byte_value {
return false;
}
}
true
}
const FLASH_ALGO_STACK_SIZE: u32 = 512;
const RISCV_FLASH_BLOB_HEADER: [u32; 2] = [riscv::assembly::EBREAK, riscv::assembly::EBREAK];
const ARM_FLASH_BLOB_HEADER_BKPT_T32_LE: [u32; 1] = [arm::assembly::BKPT_T32];
const ARM_FLASH_BLOB_HEADER_BKPT_T32_BE: [u32; 1] = [arm::assembly::BKPT_T32.swap_bytes()];
const ARM_FLASH_BLOB_HEADER_BKPT_A32_LE: [u32; 1] = [arm::assembly::BKPT_A32];
const ARM_FLASH_BLOB_HEADER_BKPT_A32_BE: [u32; 1] = [arm::assembly::BKPT_A32.swap_bytes()];
const ARM_FLASH_BLOB_HEADER_HLT_LE: [u32; 1] = [arm::assembly::HLT];
const ARM_FLASH_BLOB_HEADER_HLT_BE: [u32; 1] = [arm::assembly::HLT.swap_bytes()];
const XTENSA_FLASH_BLOB_HEADER: [u32; 0] = [];
pub fn get_max_algorithm_header_size() -> u64 {
let algos = [
Self::algorithm_header(CoreType::Armv6m, Endian::Big),
Self::algorithm_header(CoreType::Armv6m, Endian::Little),
Self::algorithm_header(CoreType::Armv7a, Endian::Big),
Self::algorithm_header(CoreType::Armv7a, Endian::Little),
Self::algorithm_header(CoreType::Armv7m, Endian::Big),
Self::algorithm_header(CoreType::Armv7m, Endian::Little),
Self::algorithm_header(CoreType::Armv7em, Endian::Big),
Self::algorithm_header(CoreType::Armv7em, Endian::Little),
Self::algorithm_header(CoreType::Armv8a, Endian::Big),
Self::algorithm_header(CoreType::Armv8a, Endian::Little),
Self::algorithm_header(CoreType::Armv8a, Endian::Big),
Self::algorithm_header(CoreType::Armv8a, Endian::Little),
Self::algorithm_header(CoreType::Armv8m, Endian::Big),
Self::algorithm_header(CoreType::Armv8m, Endian::Little),
Self::algorithm_header(CoreType::Riscv, Endian::Little),
Self::algorithm_header(CoreType::Xtensa, Endian::Big),
Self::algorithm_header(CoreType::Xtensa, Endian::Little),
];
algos.iter().copied().map(size_of_val).max().unwrap() as u64
}
fn algorithm_header(core_type: CoreType, endian: Endian) -> &'static [u32] {
match core_type {
CoreType::Armv6m | CoreType::Armv7m | CoreType::Armv7em | CoreType::Armv8m => {
match endian {
Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_BKPT_T32_LE,
Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_BKPT_T32_BE,
}
}
CoreType::Armv7a => match endian {
Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_BKPT_A32_LE,
Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_BKPT_A32_BE,
},
CoreType::Armv8a => match endian {
Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_HLT_LE,
Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_HLT_BE,
},
CoreType::Riscv => &Self::RISCV_FLASH_BLOB_HEADER,
CoreType::Xtensa => &Self::XTENSA_FLASH_BLOB_HEADER,
}
}
fn required_stack_alignment(architecture: Architecture) -> u64 {
match architecture {
Architecture::Arm => 8,
Architecture::Riscv => 16,
Architecture::Xtensa => 16,
}
}
pub fn assemble_from_raw(
raw: &RawFlashAlgorithm,
ram_region: &RamRegion,
target: &Target,
) -> Result<Self, FlashError> {
Self::assemble_from_raw_with_data(raw, ram_region, ram_region, target)
}
pub fn assemble_from_raw_with_data(
raw: &RawFlashAlgorithm,
ram_region: &RamRegion,
data_ram_region: &RamRegion,
target: &Target,
) -> Result<Self, FlashError> {
use std::mem::size_of;
let assembled_instructions = raw.instructions.chunks_exact(size_of::<u32>());
let remainder = assembled_instructions.remainder();
let last_elem = if !remainder.is_empty() {
let word = u32::from_le_bytes(
remainder
.iter()
.cloned()
.chain([0u8, 0u8, 0u8])
.take(4)
.collect::<Vec<u8>>()
.try_into()
.unwrap(),
);
Some(word)
} else {
None
};
let header = Self::algorithm_header(
target.default_core().core_type,
if raw.big_endian {
Endian::Big
} else {
Endian::Little
},
);
let instructions: Vec<u32> = header
.iter()
.copied()
.chain(
assembled_instructions.map(|bytes| u32::from_le_bytes(bytes.try_into().unwrap())),
)
.chain(last_elem)
.collect();
let header_size = size_of_val(header) as u64;
let addr_load = match raw.load_address {
Some(address) => {
address
.checked_sub(header_size)
.ok_or(FlashError::InvalidFlashAlgorithmLoadAddress { address })?
}
None => {
ram_region.range.start
}
};
if addr_load < ram_region.range.start {
return Err(FlashError::InvalidFlashAlgorithmLoadAddress { address: addr_load });
}
let code_start = addr_load + header_size;
let code_size_bytes = (instructions.len() * size_of::<u32>()) as u64;
let stack_align = Self::required_stack_alignment(target.architecture());
let code_end = (code_start + code_size_bytes).next_multiple_of(stack_align);
let buffer_page_size = raw.flash_properties.page_size as u64;
let stack_size = raw.stack_size.unwrap_or(Self::FLASH_ALGO_STACK_SIZE) as u64;
tracing::info!("The flash algorithm will be configured with {stack_size} bytes of stack");
let data_load_addr = if let Some(data_load_addr) = raw.data_load_address {
data_load_addr
} else if ram_region == data_ram_region {
code_end
} else {
data_ram_region.range.start
};
if data_ram_region.range.end < data_load_addr {
return Err(FlashError::InvalidDataAddress {
data_load_addr,
data_ram: data_ram_region.range.clone(),
});
}
let mut ram_for_data = data_ram_region.range.end - data_load_addr;
if code_end + stack_size > data_load_addr && ram_region == data_ram_region {
if stack_size > ram_for_data {
return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
}
ram_for_data -= stack_size;
}
let double_buffering = if ram_for_data >= 2 * buffer_page_size {
true
} else if ram_for_data >= buffer_page_size {
false
} else {
return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
};
let stack_bottom =
if code_end + stack_size <= data_load_addr || ram_region != data_ram_region {
code_end } else {
let page_count = if double_buffering { 2 } else { 1 };
(data_load_addr + page_count * buffer_page_size).next_multiple_of(stack_align)
};
let stack_top = stack_bottom + stack_size;
tracing::info!("Stack top: {:#010x}", stack_top);
if stack_top > ram_region.range.end {
return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
}
let page_buffers = if double_buffering {
let second_buffer_start = data_load_addr + buffer_page_size;
vec![data_load_addr, second_buffer_start]
} else {
vec![data_load_addr]
};
tracing::debug!("Page buffers: {:#010x?}", page_buffers);
let name = raw.name.clone();
Ok(FlashAlgorithm {
name,
default: raw.default,
load_address: addr_load,
instructions,
pc_init: raw.pc_init.map(|v| code_start + v),
pc_uninit: raw.pc_uninit.map(|v| code_start + v),
pc_program_page: code_start + raw.pc_program_page,
pc_erase_sector: code_start + raw.pc_erase_sector,
pc_erase_all: raw.pc_erase_all.map(|v| code_start + v),
pc_verify: raw.pc_verify.map(|v| code_start + v),
pc_blank_check: raw.pc_blank_check.map(|v| code_start + v),
pc_read: raw.pc_read.map(|v| code_start + v),
pc_flash_size: raw.pc_flash_size.map(|v| code_start + v),
static_base: code_start + raw.data_section_offset,
stack_top,
stack_size,
page_buffers,
rtt_control_block: raw.rtt_location,
rtt_poll_interval: raw.rtt_poll_interval,
flash_properties: raw.flash_properties.clone(),
transfer_encoding: raw.transfer_encoding.unwrap_or_default(),
stack_overflow_check: raw.stack_overflow_check(),
})
}
pub(crate) fn assemble_from_raw_with_core(
algo: &RawFlashAlgorithm,
core_name: &str,
target: &Target,
) -> Result<FlashAlgorithm, FlashError> {
let mm = &target.memory_map;
let ram_regions = mm
.iter()
.filter_map(MemoryRegion::as_ram_region)
.filter(|ram| ram.accessible_by(core_name))
.merge_consecutive();
let ram = ram_regions
.clone()
.filter(|ram| is_ram_suitable_for_algo(ram, algo.load_address))
.max_by_key(|region| region.range.end - region.range.start)
.ok_or(FlashError::NoRamDefined {
name: target.name.clone(),
})?;
tracing::info!("Chosen RAM to run the algo: {:x?}", ram);
let data_ram;
let data_ram = if let Some(data_load_address) = algo.data_load_address {
data_ram = ram_regions
.clone()
.find(|ram| is_ram_suitable_for_data(ram, data_load_address))
.ok_or(FlashError::NoRamDefined {
name: target.name.clone(),
})?;
&data_ram
} else {
&ram
};
tracing::info!("Data will be loaded to: {:x?}", data_ram);
Self::assemble_from_raw_with_data(algo, &ram, data_ram, target)
}
}
fn is_ram_suitable_for_algo(ram: &RamRegion, load_address: Option<u64>) -> bool {
if !ram.is_executable() {
return false;
}
if let Some(load_addr) = load_address {
ram.range.contains(&load_addr)
} else {
true
}
}
fn is_ram_suitable_for_data(ram: &RamRegion, load_address: u64) -> bool {
ram.range.contains(&load_address)
}
#[cfg(test)]
mod test {
use probe_rs_target::{FlashProperties, SectorDescription, SectorInfo};
use crate::flashing::FlashAlgorithm;
#[test]
fn flash_sector_single_size() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![SectorDescription {
size: 0x100,
address: 0x0,
}],
address_range: 0x1000..0x1000 + 0x1000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_first = SectorInfo {
base_address: 0x1000,
size: 0x100,
};
assert!(config.sector_info(0x1000 - 1).is_none());
assert_eq!(Some(expected_first), config.sector_info(0x1000));
assert_eq!(Some(expected_first), config.sector_info(0x10ff));
assert_eq!(Some(expected_first), config.sector_info(0x100b));
assert_eq!(Some(expected_first), config.sector_info(0x10ea));
}
#[test]
fn flash_sector_single_size_weird_sector_size() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![SectorDescription {
size: 258,
address: 0x0,
}],
address_range: 0x800_0000..0x800_0000 + 258 * 10,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_first = SectorInfo {
base_address: 0x800_0000,
size: 258,
};
assert!(config.sector_info(0x800_0000 - 1).is_none());
assert_eq!(Some(expected_first), config.sector_info(0x800_0000));
assert_eq!(Some(expected_first), config.sector_info(0x800_0000 + 257));
assert_eq!(Some(expected_first), config.sector_info(0x800_000b));
assert_eq!(Some(expected_first), config.sector_info(0x800_00e0));
}
#[test]
fn flash_sector_multiple_sizes() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![
SectorDescription {
size: 0x4000,
address: 0x0,
},
SectorDescription {
size: 0x1_0000,
address: 0x1_0000,
},
SectorDescription {
size: 0x2_0000,
address: 0x2_0000,
},
],
address_range: 0x800_0000..0x800_0000 + 0x10_0000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_a = SectorInfo {
base_address: 0x800_4000,
size: 0x4000,
};
let expected_b = SectorInfo {
base_address: 0x801_0000,
size: 0x1_0000,
};
let expected_c = SectorInfo {
base_address: 0x80A_0000,
size: 0x2_0000,
};
assert_eq!(Some(expected_a), config.sector_info(0x800_4000));
assert_eq!(Some(expected_b), config.sector_info(0x801_0000));
assert_eq!(Some(expected_c), config.sector_info(0x80A_0000));
}
#[test]
fn flash_sector_multiple_sizes_iter() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![
SectorDescription {
size: 0x4000,
address: 0x0,
},
SectorDescription {
size: 0x1_0000,
address: 0x1_0000,
},
SectorDescription {
size: 0x2_0000,
address: 0x2_0000,
},
],
address_range: 0x800_0000..0x800_0000 + 0x8_0000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let got: Vec<SectorInfo> = config.iter_sectors().collect();
let expected = &[
SectorInfo {
base_address: 0x800_0000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_4000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_8000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_c000,
size: 0x4000,
},
SectorInfo {
base_address: 0x801_0000,
size: 0x1_0000,
},
SectorInfo {
base_address: 0x802_0000,
size: 0x2_0000,
},
SectorInfo {
base_address: 0x804_0000,
size: 0x2_0000,
},
SectorInfo {
base_address: 0x806_0000,
size: 0x2_0000,
},
];
assert_eq!(&got, expected);
}
}