use probe_rs_target::{FlashProperties, PageInfo, RamRegion, RawFlashAlgorithm, SectorInfo};
use super::FlashError;
use crate::core::Architecture;
use crate::{architecture::riscv, Target};
use std::convert::TryInto;
#[derive(Debug, Default, Clone)]
pub struct FlashAlgorithm {
pub name: String,
pub default: bool,
pub load_address: u64,
pub instructions: Vec<u32>,
pub pc_init: Option<u64>,
pub pc_uninit: Option<u64>,
pub pc_program_page: u64,
pub pc_erase_sector: u64,
pub pc_erase_all: Option<u64>,
pub static_base: u64,
pub begin_stack: u64,
pub begin_data: u64,
pub page_buffers: Vec<u64>,
pub rtt_control_block: Option<u64>,
pub flash_properties: FlashProperties,
}
impl FlashAlgorithm {
pub fn sector_info(&self, address: u64) -> Option<SectorInfo> {
if !self.flash_properties.address_range.contains(&address) {
tracing::trace!("Address {:08x} not contained in this flash device", address);
return None;
}
let offset_address = address - self.flash_properties.address_range.start;
let containing_sector = self
.flash_properties
.sectors
.iter()
.rfind(|s| s.address <= offset_address)?;
let sector_index = (offset_address - containing_sector.address) / containing_sector.size;
let sector_address = self.flash_properties.address_range.start
+ containing_sector.address
+ sector_index * containing_sector.size;
Some(SectorInfo {
base_address: sector_address,
size: containing_sector.size,
})
}
pub fn page_info(&self, address: u64) -> Option<PageInfo> {
if !self.flash_properties.address_range.contains(&address) {
return None;
}
Some(PageInfo {
base_address: address - (address % self.flash_properties.page_size as u64),
size: self.flash_properties.page_size,
})
}
pub fn iter_sectors(&self) -> impl Iterator<Item = SectorInfo> + '_ {
let props = &self.flash_properties;
assert!(!props.sectors.is_empty());
assert!(props.sectors[0].address == 0);
let mut addr = props.address_range.start;
let mut desc_idx = 0;
std::iter::from_fn(move || {
if addr >= props.address_range.end {
return None;
}
if let Some(next_desc) = props.sectors.get(desc_idx + 1) {
if props.address_range.start + next_desc.address <= addr {
desc_idx += 1;
}
}
let size = props.sectors[desc_idx].size;
let sector = SectorInfo {
base_address: addr,
size,
};
addr += size;
Some(sector)
})
}
pub fn iter_pages(&self) -> impl Iterator<Item = PageInfo> + '_ {
let props = &self.flash_properties;
let mut addr = props.address_range.start;
std::iter::from_fn(move || {
if addr >= props.address_range.end {
return None;
}
let page = PageInfo {
base_address: addr,
size: props.page_size,
};
addr += props.page_size as u64;
Some(page)
})
}
pub fn is_erased(&self, data: &[u8]) -> bool {
for b in data {
if *b != self.flash_properties.erased_byte_value {
return false;
}
}
true
}
const FLASH_ALGO_STACK_SIZE: u32 = 512;
const FLASH_ALGO_STACK_DECREMENT: u32 = 64;
const RISCV_FLASH_BLOB_HEADER: [u32; 2] = [riscv::assembly::EBREAK, riscv::assembly::EBREAK];
const ARM_FLASH_BLOB_HEADER: [u32; 8] = [
0xE00A_BE00,
0x062D_780D,
0x2408_4068,
0xD300_0040,
0x1E64_4058,
0x1C49_D1FA,
0x2A00_1E52,
0x0477_0D1F,
];
fn get_algorithm_header(architecture: Architecture) -> &'static [u32] {
match architecture {
Architecture::Arm => &Self::ARM_FLASH_BLOB_HEADER,
Architecture::Riscv => &Self::RISCV_FLASH_BLOB_HEADER,
}
}
pub fn assemble_from_raw(
raw: &RawFlashAlgorithm,
ram_region: &RamRegion,
target: &Target,
) -> Result<Self, FlashError> {
use std::mem::size_of;
if raw.flash_properties.page_size % 4 != 0 {
return Err(FlashError::InvalidPageSize {
size: raw.flash_properties.page_size,
});
}
let assembled_instructions = raw.instructions.chunks_exact(size_of::<u32>());
if !assembled_instructions.remainder().is_empty() {
return Err(FlashError::InvalidFlashAlgorithmLength {
name: raw.name.to_string(),
algorithm_source: Some(target.source().clone()),
});
}
let header = Self::get_algorithm_header(target.architecture());
let instructions: Vec<u32> = header
.iter()
.copied()
.chain(
assembled_instructions.map(|bytes| u32::from_le_bytes(bytes.try_into().unwrap())),
)
.collect();
let mut offset = 0;
let mut addr_stack = 0;
let mut addr_load = 0;
let mut addr_data = 0;
let mut code_start = 0;
let stack_size = {
let stack_size = raw.stack_size.unwrap_or(Self::FLASH_ALGO_STACK_SIZE);
if stack_size < Self::FLASH_ALGO_STACK_DECREMENT {
tracing::warn!(
"Stack size of {} bytes is too small; overriding to {} bytes",
stack_size,
Self::FLASH_ALGO_STACK_DECREMENT
);
Self::FLASH_ALGO_STACK_DECREMENT
} else {
stack_size
}
};
tracing::debug!("The flash algorithm will be configured with {stack_size} bytes of stack");
for i in 0..stack_size / Self::FLASH_ALGO_STACK_DECREMENT {
addr_load = raw
.load_address
.map(|a| {
a.checked_sub((header.len() * size_of::<u32>()) as u64) .ok_or(FlashError::InvalidFlashAlgorithmLoadAddress { address: addr_load })
})
.unwrap_or(Ok(ram_region.range.start))?;
if addr_load < ram_region.range.start {
return Err(FlashError::InvalidFlashAlgorithmLoadAddress { address: addr_load });
}
offset += (header.len() * size_of::<u32>()) as u64;
code_start = addr_load + offset;
offset += (instructions.len() * size_of::<u32>()) as u64;
addr_stack = addr_load
+ offset
+ (stack_size
.checked_sub(Self::FLASH_ALGO_STACK_DECREMENT * i)
.expect("Overflow never happens; decrement multiples are always less than stack size."))
as u64;
addr_data = addr_stack;
offset += raw.flash_properties.page_size as u64;
if offset <= ram_region.range.end - addr_load {
break;
}
}
let addr_data2 = addr_data + raw.flash_properties.page_size as u64;
offset += raw.flash_properties.page_size as u64;
let page_buffers = if offset <= ram_region.range.end - addr_load {
vec![addr_data, addr_data2]
} else {
vec![addr_data]
};
let name = raw.name.clone();
Ok(FlashAlgorithm {
name,
default: raw.default,
load_address: addr_load,
instructions,
pc_init: raw.pc_init.map(|v| code_start + v),
pc_uninit: raw.pc_uninit.map(|v| code_start + v),
pc_program_page: code_start + raw.pc_program_page,
pc_erase_sector: code_start + raw.pc_erase_sector,
pc_erase_all: raw.pc_erase_all.map(|v| code_start + v),
static_base: code_start + raw.data_section_offset,
begin_stack: addr_stack,
begin_data: page_buffers[0],
page_buffers: page_buffers.clone(),
rtt_control_block: raw.rtt_location,
flash_properties: raw.flash_properties.clone(),
})
}
}
#[cfg(test)]
mod test {
use probe_rs_target::{FlashProperties, SectorDescription, SectorInfo};
use crate::flashing::FlashAlgorithm;
#[test]
fn flash_sector_single_size() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![SectorDescription {
size: 0x100,
address: 0x0,
}],
address_range: 0x1000..0x1000 + 0x1000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_first = SectorInfo {
base_address: 0x1000,
size: 0x100,
};
assert!(config.sector_info(0x1000 - 1).is_none());
assert_eq!(Some(expected_first), config.sector_info(0x1000));
assert_eq!(Some(expected_first), config.sector_info(0x10ff));
assert_eq!(Some(expected_first), config.sector_info(0x100b));
assert_eq!(Some(expected_first), config.sector_info(0x10ea));
}
#[test]
fn flash_sector_single_size_weird_sector_size() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![SectorDescription {
size: 258,
address: 0x0,
}],
address_range: 0x800_0000..0x800_0000 + 258 * 10,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_first = SectorInfo {
base_address: 0x800_0000,
size: 258,
};
assert!(config.sector_info(0x800_0000 - 1).is_none());
assert_eq!(Some(expected_first), config.sector_info(0x800_0000));
assert_eq!(Some(expected_first), config.sector_info(0x800_0000 + 257));
assert_eq!(Some(expected_first), config.sector_info(0x800_000b));
assert_eq!(Some(expected_first), config.sector_info(0x800_00e0));
}
#[test]
fn flash_sector_multiple_sizes() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![
SectorDescription {
size: 0x4000,
address: 0x0,
},
SectorDescription {
size: 0x1_0000,
address: 0x1_0000,
},
SectorDescription {
size: 0x2_0000,
address: 0x2_0000,
},
],
address_range: 0x800_0000..0x800_0000 + 0x10_0000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let expected_a = SectorInfo {
base_address: 0x800_4000,
size: 0x4000,
};
let expected_b = SectorInfo {
base_address: 0x801_0000,
size: 0x1_0000,
};
let expected_c = SectorInfo {
base_address: 0x80A_0000,
size: 0x2_0000,
};
assert_eq!(Some(expected_a), config.sector_info(0x800_4000));
assert_eq!(Some(expected_b), config.sector_info(0x801_0000));
assert_eq!(Some(expected_c), config.sector_info(0x80A_0000));
}
#[test]
fn flash_sector_multiple_sizes_iter() {
let config = FlashAlgorithm {
flash_properties: FlashProperties {
sectors: vec![
SectorDescription {
size: 0x4000,
address: 0x0,
},
SectorDescription {
size: 0x1_0000,
address: 0x1_0000,
},
SectorDescription {
size: 0x2_0000,
address: 0x2_0000,
},
],
address_range: 0x800_0000..0x800_0000 + 0x8_0000,
page_size: 0x10,
..Default::default()
},
..Default::default()
};
let got: Vec<SectorInfo> = config.iter_sectors().collect();
let expected = &[
SectorInfo {
base_address: 0x800_0000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_4000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_8000,
size: 0x4000,
},
SectorInfo {
base_address: 0x800_c000,
size: 0x4000,
},
SectorInfo {
base_address: 0x801_0000,
size: 0x1_0000,
},
SectorInfo {
base_address: 0x802_0000,
size: 0x2_0000,
},
SectorInfo {
base_address: 0x804_0000,
size: 0x2_0000,
},
SectorInfo {
base_address: 0x806_0000,
size: 0x2_0000,
},
];
assert_eq!(&got, expected);
}
}