1use super::FlashError;
2use crate::{
3 Target,
4 architecture::{arm, riscv},
5 core::Architecture,
6};
7use probe_rs_target::{
8 CoreType, Endian, FlashProperties, MemoryRegion, PageInfo, RamRegion, RawFlashAlgorithm,
9 RegionMergeIterator, SectorInfo, TransferEncoding,
10};
11use std::mem::size_of_val;
12
13#[derive(Debug, Default, Clone)]
20pub struct FlashAlgorithm {
21 pub name: String,
23 pub default: bool,
25 pub load_address: u64,
27 pub instructions: Vec<u32>,
29 pub pc_init: Option<u64>,
31 pub pc_uninit: Option<u64>,
33 pub pc_program_page: u64,
35 pub pc_erase_sector: u64,
37 pub pc_erase_all: Option<u64>,
39 pub pc_verify: Option<u64>,
41 pub pc_blank_check: Option<u64>,
43 pub pc_read: Option<u64>,
45 pub pc_flash_size: Option<u64>,
47 pub static_base: u64,
50 pub stack_top: u64,
52 pub stack_size: u64,
54 pub stack_overflow_check: bool,
56 pub page_buffers: Vec<u64>,
60
61 pub rtt_control_block: Option<u64>,
65
66 pub rtt_poll_interval: u64,
68
69 pub flash_properties: FlashProperties,
71
72 pub transfer_encoding: TransferEncoding,
74}
75
76impl FlashAlgorithm {
77 pub fn sector_info(&self, address: u64) -> Option<SectorInfo> {
83 if !self.flash_properties.address_range.contains(&address) {
84 tracing::trace!("Address {:08x} not contained in this flash device", address);
85 return None;
86 }
87
88 let offset_address = address - self.flash_properties.address_range.start;
89
90 let containing_sector = self
91 .flash_properties
92 .sectors
93 .iter()
94 .rfind(|s| s.address <= offset_address)?;
95
96 let sector_index = (offset_address - containing_sector.address) / containing_sector.size;
97
98 let sector_address = self.flash_properties.address_range.start
99 + containing_sector.address
100 + sector_index * containing_sector.size;
101
102 Some(SectorInfo {
103 base_address: sector_address,
104 size: containing_sector.size,
105 })
106 }
107
108 pub fn page_info(&self, address: u64) -> Option<PageInfo> {
111 if !self.flash_properties.address_range.contains(&address) {
112 return None;
113 }
114
115 Some(PageInfo {
116 base_address: address - (address % self.flash_properties.page_size as u64),
117 size: self.flash_properties.page_size,
118 })
119 }
120
121 pub fn iter_sectors(&self) -> impl Iterator<Item = SectorInfo> + '_ {
123 let props = &self.flash_properties;
124
125 assert!(!props.sectors.is_empty());
126 assert!(props.sectors[0].address == 0);
127
128 let mut addr = props.address_range.start;
129 let mut desc_idx = 0;
130 std::iter::from_fn(move || {
131 if addr >= props.address_range.end {
132 return None;
133 }
134
135 if let Some(next_desc) = props.sectors.get(desc_idx + 1)
137 && props.address_range.start + next_desc.address <= addr
138 {
139 desc_idx += 1;
140 }
141
142 let size = props.sectors[desc_idx].size;
143 let sector = SectorInfo {
144 base_address: addr,
145 size,
146 };
147 addr += size;
148
149 Some(sector)
150 })
151 }
152
153 pub fn iter_pages(&self) -> impl Iterator<Item = PageInfo> + '_ {
155 let props = &self.flash_properties;
156
157 let mut addr = props.address_range.start;
158 std::iter::from_fn(move || {
159 if addr >= props.address_range.end {
160 return None;
161 }
162
163 let page = PageInfo {
164 base_address: addr,
165 size: props.page_size,
166 };
167 addr += props.page_size as u64;
168
169 Some(page)
170 })
171 }
172
173 pub fn is_erased(&self, data: &[u8]) -> bool {
175 for b in data {
176 if *b != self.flash_properties.erased_byte_value {
177 return false;
178 }
179 }
180 true
181 }
182
183 const FLASH_ALGO_STACK_SIZE: u32 = 512;
184
185 const RISCV_FLASH_BLOB_HEADER: [u32; 2] = [riscv::assembly::EBREAK, riscv::assembly::EBREAK];
187
188 const ARM_FLASH_BLOB_HEADER_BKPT_T32_LE: [u32; 1] = [arm::assembly::BKPT_T32];
192 const ARM_FLASH_BLOB_HEADER_BKPT_T32_BE: [u32; 1] = [arm::assembly::BKPT_T32.swap_bytes()];
193 const ARM_FLASH_BLOB_HEADER_BKPT_A32_LE: [u32; 1] = [arm::assembly::BKPT_A32];
194 const ARM_FLASH_BLOB_HEADER_BKPT_A32_BE: [u32; 1] = [arm::assembly::BKPT_A32.swap_bytes()];
195 const ARM_FLASH_BLOB_HEADER_HLT_LE: [u32; 1] = [arm::assembly::HLT];
196 const ARM_FLASH_BLOB_HEADER_HLT_BE: [u32; 1] = [arm::assembly::HLT.swap_bytes()];
197
198 const XTENSA_FLASH_BLOB_HEADER: [u32; 0] = [];
199
200 pub fn get_max_algorithm_header_size() -> u64 {
203 let algos = [
204 Self::algorithm_header(CoreType::Armv6m, Endian::Big),
205 Self::algorithm_header(CoreType::Armv6m, Endian::Little),
206 Self::algorithm_header(CoreType::Armv7a, Endian::Big),
207 Self::algorithm_header(CoreType::Armv7a, Endian::Little),
208 Self::algorithm_header(CoreType::Armv7m, Endian::Big),
209 Self::algorithm_header(CoreType::Armv7m, Endian::Little),
210 Self::algorithm_header(CoreType::Armv7em, Endian::Big),
211 Self::algorithm_header(CoreType::Armv7em, Endian::Little),
212 Self::algorithm_header(CoreType::Armv8a, Endian::Big),
213 Self::algorithm_header(CoreType::Armv8a, Endian::Little),
214 Self::algorithm_header(CoreType::Armv8a, Endian::Big),
215 Self::algorithm_header(CoreType::Armv8a, Endian::Little),
216 Self::algorithm_header(CoreType::Armv8m, Endian::Big),
217 Self::algorithm_header(CoreType::Armv8m, Endian::Little),
218 Self::algorithm_header(CoreType::Riscv, Endian::Little),
219 Self::algorithm_header(CoreType::Xtensa, Endian::Big),
220 Self::algorithm_header(CoreType::Xtensa, Endian::Little),
221 ];
222
223 algos.iter().copied().map(size_of_val).max().unwrap() as u64
224 }
225
226 fn algorithm_header(core_type: CoreType, endian: Endian) -> &'static [u32] {
227 match core_type {
228 CoreType::Armv6m | CoreType::Armv7m | CoreType::Armv7em | CoreType::Armv8m => {
229 match endian {
230 Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_BKPT_T32_LE,
231 Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_BKPT_T32_BE,
232 }
233 }
234 CoreType::Armv7a => match endian {
235 Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_BKPT_A32_LE,
236 Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_BKPT_A32_BE,
237 },
238 CoreType::Armv8a => match endian {
239 Endian::Little => &Self::ARM_FLASH_BLOB_HEADER_HLT_LE,
240 Endian::Big => &Self::ARM_FLASH_BLOB_HEADER_HLT_BE,
241 },
242 CoreType::Riscv => &Self::RISCV_FLASH_BLOB_HEADER,
243 CoreType::Xtensa => &Self::XTENSA_FLASH_BLOB_HEADER,
244 }
245 }
246
247 fn required_stack_alignment(architecture: Architecture) -> u64 {
248 match architecture {
249 Architecture::Arm => 8,
250 Architecture::Riscv => 16,
251 Architecture::Xtensa => 16,
252 }
253 }
254
255 pub fn assemble_from_raw(
257 raw: &RawFlashAlgorithm,
258 ram_region: &RamRegion,
259 target: &Target,
260 ) -> Result<Self, FlashError> {
261 Self::assemble_from_raw_with_data(raw, ram_region, ram_region, target)
262 }
263
264 pub fn assemble_from_raw_with_data(
266 raw: &RawFlashAlgorithm,
267 ram_region: &RamRegion,
268 data_ram_region: &RamRegion,
269 target: &Target,
270 ) -> Result<Self, FlashError> {
271 use std::mem::size_of;
272
273 let assembled_instructions = raw.instructions.chunks_exact(size_of::<u32>());
274
275 let remainder = assembled_instructions.remainder();
276 let last_elem = if !remainder.is_empty() {
277 let word = u32::from_le_bytes(
278 remainder
279 .iter()
280 .cloned()
281 .chain([0u8, 0u8, 0u8])
283 .take(4)
284 .collect::<Vec<u8>>()
285 .try_into()
286 .unwrap(),
287 );
288 Some(word)
289 } else {
290 None
291 };
292
293 let header = Self::algorithm_header(
294 target.default_core().core_type,
295 if raw.big_endian {
296 Endian::Big
297 } else {
298 Endian::Little
299 },
300 );
301
302 let instructions: Vec<u32> = header
303 .iter()
304 .copied()
305 .chain(
306 assembled_instructions.map(|bytes| u32::from_le_bytes(bytes.try_into().unwrap())),
307 )
308 .chain(last_elem)
309 .collect();
310
311 let header_size = size_of_val(header) as u64;
312
313 let addr_load = match raw.load_address {
315 Some(address) => {
316 address
318 .checked_sub(header_size)
319 .ok_or(FlashError::InvalidFlashAlgorithmLoadAddress { address })?
320 }
321
322 None => {
323 ram_region.range.start
325 }
326 };
327
328 if addr_load < ram_region.range.start {
329 return Err(FlashError::InvalidFlashAlgorithmLoadAddress { address: addr_load });
330 }
331
332 let code_start = addr_load + header_size;
342 let code_size_bytes = (instructions.len() * size_of::<u32>()) as u64;
343
344 let stack_align = Self::required_stack_alignment(target.architecture());
345 let code_end = (code_start + code_size_bytes).next_multiple_of(stack_align);
347
348 let buffer_page_size = raw.flash_properties.page_size as u64;
349
350 let stack_size = raw.stack_size.unwrap_or(Self::FLASH_ALGO_STACK_SIZE) as u64;
351 tracing::info!("The flash algorithm will be configured with {stack_size} bytes of stack");
352
353 let data_load_addr = if let Some(data_load_addr) = raw.data_load_address {
354 data_load_addr
355 } else if ram_region == data_ram_region {
356 code_end
358 } else {
359 data_ram_region.range.start
361 };
362
363 if data_ram_region.range.end < data_load_addr {
365 return Err(FlashError::InvalidDataAddress {
366 data_load_addr,
367 data_ram: data_ram_region.range.clone(),
368 });
369 }
370 let mut ram_for_data = data_ram_region.range.end - data_load_addr;
371 if code_end + stack_size > data_load_addr && ram_region == data_ram_region {
372 if stack_size > ram_for_data {
374 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
375 }
376 ram_for_data -= stack_size;
377 }
378
379 let double_buffering = if ram_for_data >= 2 * buffer_page_size {
381 true
384 } else if ram_for_data >= buffer_page_size {
385 false
387 } else {
388 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
392 };
393
394 let stack_bottom =
396 if code_end + stack_size <= data_load_addr || ram_region != data_ram_region {
397 code_end } else {
403 let page_count = if double_buffering { 2 } else { 1 };
406 (data_load_addr + page_count * buffer_page_size).next_multiple_of(stack_align)
407 };
408
409 let stack_top = stack_bottom + stack_size;
411 tracing::info!("Stack top: {:#010x}", stack_top);
412
413 if stack_top > ram_region.range.end {
414 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
415 }
416
417 let page_buffers = if double_buffering {
419 let second_buffer_start = data_load_addr + buffer_page_size;
420 vec![data_load_addr, second_buffer_start]
421 } else {
422 vec![data_load_addr]
423 };
424
425 tracing::debug!("Page buffers: {:#010x?}", page_buffers);
426
427 let name = raw.name.clone();
428
429 Ok(FlashAlgorithm {
430 name,
431 default: raw.default,
432 load_address: addr_load,
433 instructions,
434 pc_init: raw.pc_init.map(|v| code_start + v),
435 pc_uninit: raw.pc_uninit.map(|v| code_start + v),
436 pc_program_page: code_start + raw.pc_program_page,
437 pc_erase_sector: code_start + raw.pc_erase_sector,
438 pc_erase_all: raw.pc_erase_all.map(|v| code_start + v),
439 pc_verify: raw.pc_verify.map(|v| code_start + v),
440 pc_blank_check: raw.pc_blank_check.map(|v| code_start + v),
441 pc_read: raw.pc_read.map(|v| code_start + v),
442 pc_flash_size: raw.pc_flash_size.map(|v| code_start + v),
443 static_base: code_start + raw.data_section_offset,
444 stack_top,
445 stack_size,
446 page_buffers,
447 rtt_control_block: raw.rtt_location,
448 rtt_poll_interval: raw.rtt_poll_interval,
449 flash_properties: raw.flash_properties.clone(),
450 transfer_encoding: raw.transfer_encoding.unwrap_or_default(),
451 stack_overflow_check: raw.stack_overflow_check(),
452 })
453 }
454
455 pub(crate) fn assemble_from_raw_with_core(
457 algo: &RawFlashAlgorithm,
458 core_name: &str,
459 target: &Target,
460 ) -> Result<FlashAlgorithm, FlashError> {
461 let mm = &target.memory_map;
463
464 let ram_regions = mm
465 .iter()
466 .filter_map(MemoryRegion::as_ram_region)
467 .filter(|ram| ram.accessible_by(core_name))
468 .merge_consecutive();
469
470 let ram = ram_regions
471 .clone()
472 .filter(|ram| is_ram_suitable_for_algo(ram, algo.load_address))
473 .max_by_key(|region| region.range.end - region.range.start)
474 .ok_or(FlashError::NoRamDefined {
475 name: target.name.clone(),
476 })?;
477 tracing::info!("Chosen RAM to run the algo: {:x?}", ram);
478
479 let data_ram;
480 let data_ram = if let Some(data_load_address) = algo.data_load_address {
481 data_ram = ram_regions
482 .clone()
483 .find(|ram| is_ram_suitable_for_data(ram, data_load_address))
484 .ok_or(FlashError::NoRamDefined {
485 name: target.name.clone(),
486 })?;
487
488 &data_ram
489 } else {
490 &ram
492 };
493 tracing::info!("Data will be loaded to: {:x?}", data_ram);
494
495 Self::assemble_from_raw_with_data(algo, &ram, data_ram, target)
496 }
497}
498
499fn is_ram_suitable_for_algo(ram: &RamRegion, load_address: Option<u64>) -> bool {
501 if !ram.is_executable() {
502 return false;
503 }
504
505 if let Some(load_addr) = load_address {
512 ram.range.contains(&load_addr)
516 } else {
517 true
518 }
519}
520
521fn is_ram_suitable_for_data(ram: &RamRegion, load_address: u64) -> bool {
523 ram.range.contains(&load_address)
527}
528
529#[cfg(test)]
530mod test {
531 use probe_rs_target::{FlashProperties, SectorDescription, SectorInfo};
532
533 use crate::flashing::FlashAlgorithm;
534
535 #[test]
536 fn flash_sector_single_size() {
537 let config = FlashAlgorithm {
538 flash_properties: FlashProperties {
539 sectors: vec![SectorDescription {
540 size: 0x100,
541 address: 0x0,
542 }],
543 address_range: 0x1000..0x1000 + 0x1000,
544 page_size: 0x10,
545 ..Default::default()
546 },
547 ..Default::default()
548 };
549
550 let expected_first = SectorInfo {
551 base_address: 0x1000,
552 size: 0x100,
553 };
554
555 assert!(config.sector_info(0x1000 - 1).is_none());
556
557 assert_eq!(Some(expected_first), config.sector_info(0x1000));
558 assert_eq!(Some(expected_first), config.sector_info(0x10ff));
559
560 assert_eq!(Some(expected_first), config.sector_info(0x100b));
561 assert_eq!(Some(expected_first), config.sector_info(0x10ea));
562 }
563
564 #[test]
565 fn flash_sector_single_size_weird_sector_size() {
566 let config = FlashAlgorithm {
567 flash_properties: FlashProperties {
568 sectors: vec![SectorDescription {
569 size: 258,
570 address: 0x0,
571 }],
572 address_range: 0x800_0000..0x800_0000 + 258 * 10,
573 page_size: 0x10,
574 ..Default::default()
575 },
576 ..Default::default()
577 };
578
579 let expected_first = SectorInfo {
580 base_address: 0x800_0000,
581 size: 258,
582 };
583
584 assert!(config.sector_info(0x800_0000 - 1).is_none());
585
586 assert_eq!(Some(expected_first), config.sector_info(0x800_0000));
587 assert_eq!(Some(expected_first), config.sector_info(0x800_0000 + 257));
588
589 assert_eq!(Some(expected_first), config.sector_info(0x800_000b));
590 assert_eq!(Some(expected_first), config.sector_info(0x800_00e0));
591 }
592
593 #[test]
594 fn flash_sector_multiple_sizes() {
595 let config = FlashAlgorithm {
596 flash_properties: FlashProperties {
597 sectors: vec![
598 SectorDescription {
599 size: 0x4000,
600 address: 0x0,
601 },
602 SectorDescription {
603 size: 0x1_0000,
604 address: 0x1_0000,
605 },
606 SectorDescription {
607 size: 0x2_0000,
608 address: 0x2_0000,
609 },
610 ],
611 address_range: 0x800_0000..0x800_0000 + 0x10_0000,
612 page_size: 0x10,
613 ..Default::default()
614 },
615 ..Default::default()
616 };
617
618 let expected_a = SectorInfo {
619 base_address: 0x800_4000,
620 size: 0x4000,
621 };
622
623 let expected_b = SectorInfo {
624 base_address: 0x801_0000,
625 size: 0x1_0000,
626 };
627
628 let expected_c = SectorInfo {
629 base_address: 0x80A_0000,
630 size: 0x2_0000,
631 };
632
633 assert_eq!(Some(expected_a), config.sector_info(0x800_4000));
634 assert_eq!(Some(expected_b), config.sector_info(0x801_0000));
635 assert_eq!(Some(expected_c), config.sector_info(0x80A_0000));
636 }
637
638 #[test]
639 fn flash_sector_multiple_sizes_iter() {
640 let config = FlashAlgorithm {
641 flash_properties: FlashProperties {
642 sectors: vec![
643 SectorDescription {
644 size: 0x4000,
645 address: 0x0,
646 },
647 SectorDescription {
648 size: 0x1_0000,
649 address: 0x1_0000,
650 },
651 SectorDescription {
652 size: 0x2_0000,
653 address: 0x2_0000,
654 },
655 ],
656 address_range: 0x800_0000..0x800_0000 + 0x8_0000,
657 page_size: 0x10,
658 ..Default::default()
659 },
660 ..Default::default()
661 };
662
663 let got: Vec<SectorInfo> = config.iter_sectors().collect();
664
665 let expected = &[
666 SectorInfo {
667 base_address: 0x800_0000,
668 size: 0x4000,
669 },
670 SectorInfo {
671 base_address: 0x800_4000,
672 size: 0x4000,
673 },
674 SectorInfo {
675 base_address: 0x800_8000,
676 size: 0x4000,
677 },
678 SectorInfo {
679 base_address: 0x800_c000,
680 size: 0x4000,
681 },
682 SectorInfo {
683 base_address: 0x801_0000,
684 size: 0x1_0000,
685 },
686 SectorInfo {
687 base_address: 0x802_0000,
688 size: 0x2_0000,
689 },
690 SectorInfo {
691 base_address: 0x804_0000,
692 size: 0x2_0000,
693 },
694 SectorInfo {
695 base_address: 0x806_0000,
696 size: 0x2_0000,
697 },
698 ];
699 assert_eq!(&got, expected);
700 }
701}