1use super::FlashError;
2use crate::{
3 Target,
4 architecture::{arm, riscv},
5 core::Architecture,
6};
7use probe_rs_target::{
8 FlashProperties, MemoryRegion, PageInfo, RamRegion, RawFlashAlgorithm, RegionMergeIterator,
9 SectorInfo, TransferEncoding,
10};
11use std::mem::size_of_val;
12
13#[derive(Debug, Default, Clone)]
20pub struct FlashAlgorithm {
21 pub name: String,
23 pub default: bool,
25 pub load_address: u64,
27 pub instructions: Vec<u32>,
29 pub pc_init: Option<u64>,
31 pub pc_uninit: Option<u64>,
33 pub pc_program_page: u64,
35 pub pc_erase_sector: u64,
37 pub pc_erase_all: Option<u64>,
39 pub pc_verify: Option<u64>,
41 pub pc_read: Option<u64>,
43 pub pc_blank_check: Option<u64>,
45 pub static_base: u64,
48 pub stack_top: u64,
50 pub stack_size: u64,
52 pub stack_overflow_check: bool,
54 pub page_buffers: Vec<u64>,
58
59 pub rtt_control_block: Option<u64>,
63
64 pub flash_properties: FlashProperties,
66
67 pub transfer_encoding: TransferEncoding,
69}
70
71impl FlashAlgorithm {
72 pub fn sector_info(&self, address: u64) -> Option<SectorInfo> {
78 if !self.flash_properties.address_range.contains(&address) {
79 tracing::trace!("Address {:08x} not contained in this flash device", address);
80 return None;
81 }
82
83 let offset_address = address - self.flash_properties.address_range.start;
84
85 let containing_sector = self
86 .flash_properties
87 .sectors
88 .iter()
89 .rfind(|s| s.address <= offset_address)?;
90
91 let sector_index = (offset_address - containing_sector.address) / containing_sector.size;
92
93 let sector_address = self.flash_properties.address_range.start
94 + containing_sector.address
95 + sector_index * containing_sector.size;
96
97 Some(SectorInfo {
98 base_address: sector_address,
99 size: containing_sector.size,
100 })
101 }
102
103 pub fn page_info(&self, address: u64) -> Option<PageInfo> {
106 if !self.flash_properties.address_range.contains(&address) {
107 return None;
108 }
109
110 Some(PageInfo {
111 base_address: address - (address % self.flash_properties.page_size as u64),
112 size: self.flash_properties.page_size,
113 })
114 }
115
116 pub fn iter_sectors(&self) -> impl Iterator<Item = SectorInfo> + '_ {
118 let props = &self.flash_properties;
119
120 assert!(!props.sectors.is_empty());
121 assert!(props.sectors[0].address == 0);
122
123 let mut addr = props.address_range.start;
124 let mut desc_idx = 0;
125 std::iter::from_fn(move || {
126 if addr >= props.address_range.end {
127 return None;
128 }
129
130 if let Some(next_desc) = props.sectors.get(desc_idx + 1) {
132 if props.address_range.start + next_desc.address <= addr {
133 desc_idx += 1;
134 }
135 }
136
137 let size = props.sectors[desc_idx].size;
138 let sector = SectorInfo {
139 base_address: addr,
140 size,
141 };
142 addr += size;
143
144 Some(sector)
145 })
146 }
147
148 pub fn iter_pages(&self) -> impl Iterator<Item = PageInfo> + '_ {
150 let props = &self.flash_properties;
151
152 let mut addr = props.address_range.start;
153 std::iter::from_fn(move || {
154 if addr >= props.address_range.end {
155 return None;
156 }
157
158 let page = PageInfo {
159 base_address: addr,
160 size: props.page_size,
161 };
162 addr += props.page_size as u64;
163
164 Some(page)
165 })
166 }
167
168 pub fn is_erased(&self, data: &[u8]) -> bool {
170 for b in data {
171 if *b != self.flash_properties.erased_byte_value {
172 return false;
173 }
174 }
175 true
176 }
177
178 const FLASH_ALGO_STACK_SIZE: u32 = 512;
179
180 const RISCV_FLASH_BLOB_HEADER: [u32; 2] = [riscv::assembly::EBREAK, riscv::assembly::EBREAK];
182
183 const ARM_FLASH_BLOB_HEADER: [u32; 1] = [arm::assembly::BRKPT];
184
185 const XTENSA_FLASH_BLOB_HEADER: [u32; 0] = [];
186
187 pub fn get_max_algorithm_header_size() -> u64 {
190 let algos = [
191 Self::algorithm_header(Architecture::Arm),
192 Self::algorithm_header(Architecture::Riscv),
193 Self::algorithm_header(Architecture::Xtensa),
194 ];
195
196 algos.iter().copied().map(size_of_val).max().unwrap() as u64
197 }
198
199 fn algorithm_header(architecture: Architecture) -> &'static [u32] {
200 match architecture {
201 Architecture::Arm => &Self::ARM_FLASH_BLOB_HEADER,
202 Architecture::Riscv => &Self::RISCV_FLASH_BLOB_HEADER,
203 Architecture::Xtensa => &Self::XTENSA_FLASH_BLOB_HEADER,
204 }
205 }
206
207 fn required_stack_alignment(architecture: Architecture) -> u64 {
208 match architecture {
209 Architecture::Arm => 8,
210 Architecture::Riscv => 16,
211 Architecture::Xtensa => 16,
212 }
213 }
214
215 pub fn assemble_from_raw(
217 raw: &RawFlashAlgorithm,
218 ram_region: &RamRegion,
219 target: &Target,
220 ) -> Result<Self, FlashError> {
221 Self::assemble_from_raw_with_data(raw, ram_region, ram_region, target)
222 }
223
224 pub fn assemble_from_raw_with_data(
226 raw: &RawFlashAlgorithm,
227 ram_region: &RamRegion,
228 data_ram_region: &RamRegion,
229 target: &Target,
230 ) -> Result<Self, FlashError> {
231 use std::mem::size_of;
232
233 let assembled_instructions = raw.instructions.chunks_exact(size_of::<u32>());
234
235 let remainder = assembled_instructions.remainder();
236 let last_elem = if !remainder.is_empty() {
237 let word = u32::from_le_bytes(
238 remainder
239 .iter()
240 .cloned()
241 .chain([0u8, 0u8, 0u8])
243 .take(4)
244 .collect::<Vec<u8>>()
245 .try_into()
246 .unwrap(),
247 );
248 Some(word)
249 } else {
250 None
251 };
252
253 let header = Self::algorithm_header(target.architecture());
254 let instructions: Vec<u32> = header
255 .iter()
256 .copied()
257 .chain(
258 assembled_instructions.map(|bytes| u32::from_le_bytes(bytes.try_into().unwrap())),
259 )
260 .chain(last_elem)
261 .collect();
262
263 let header_size = size_of_val(header) as u64;
264
265 let addr_load = match raw.load_address {
267 Some(address) => {
268 address
270 .checked_sub(header_size)
271 .ok_or(FlashError::InvalidFlashAlgorithmLoadAddress { address })?
272 }
273
274 None => {
275 ram_region.range.start
277 }
278 };
279
280 if addr_load < ram_region.range.start {
281 return Err(FlashError::InvalidFlashAlgorithmLoadAddress { address: addr_load });
282 }
283
284 let code_start = addr_load + header_size;
294 let code_size_bytes = (instructions.len() * size_of::<u32>()) as u64;
295
296 let stack_align = Self::required_stack_alignment(target.architecture());
297 let code_end = (code_start + code_size_bytes).next_multiple_of(stack_align);
299
300 let buffer_page_size = raw.flash_properties.page_size as u64;
301
302 let stack_size = raw.stack_size.unwrap_or(Self::FLASH_ALGO_STACK_SIZE) as u64;
303 tracing::info!("The flash algorithm will be configured with {stack_size} bytes of stack");
304
305 let data_load_addr = if let Some(data_load_addr) = raw.data_load_address {
306 data_load_addr
307 } else if ram_region == data_ram_region {
308 code_end
310 } else {
311 data_ram_region.range.start
313 };
314
315 if data_ram_region.range.end < data_load_addr {
317 return Err(FlashError::InvalidDataAddress {
318 data_load_addr,
319 data_ram: data_ram_region.range.clone(),
320 });
321 }
322 let mut ram_for_data = data_ram_region.range.end - data_load_addr;
323 if code_end + stack_size > data_load_addr && ram_region == data_ram_region {
324 if stack_size > ram_for_data {
326 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
327 }
328 ram_for_data -= stack_size;
329 }
330
331 let double_buffering = if ram_for_data >= 2 * buffer_page_size {
333 true
336 } else if ram_for_data >= buffer_page_size {
337 false
339 } else {
340 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
344 };
345
346 let stack_bottom =
348 if code_end + stack_size <= data_load_addr || ram_region != data_ram_region {
349 code_end } else {
355 let page_count = if double_buffering { 2 } else { 1 };
358 (data_load_addr + page_count * buffer_page_size).next_multiple_of(stack_align)
359 };
360
361 let stack_top = stack_bottom + stack_size;
363 tracing::info!("Stack top: {:#010x}", stack_top);
364
365 if stack_top > ram_region.range.end {
366 return Err(FlashError::InvalidFlashAlgorithmStackSize { size: stack_size });
367 }
368
369 let page_buffers = if double_buffering {
371 let second_buffer_start = data_load_addr + buffer_page_size;
372 vec![data_load_addr, second_buffer_start]
373 } else {
374 vec![data_load_addr]
375 };
376
377 tracing::debug!("Page buffers: {:#010x?}", page_buffers);
378
379 let name = raw.name.clone();
380
381 Ok(FlashAlgorithm {
382 name,
383 default: raw.default,
384 load_address: addr_load,
385 instructions,
386 pc_init: raw.pc_init.map(|v| code_start + v),
387 pc_uninit: raw.pc_uninit.map(|v| code_start + v),
388 pc_program_page: code_start + raw.pc_program_page,
389 pc_erase_sector: code_start + raw.pc_erase_sector,
390 pc_erase_all: raw.pc_erase_all.map(|v| code_start + v),
391 pc_verify: raw.pc_verify.map(|v| code_start + v),
392 pc_read: raw.pc_read.map(|v| code_start + v),
393 pc_blank_check: raw.pc_blank_check.map(|v| code_start + v),
394 static_base: code_start + raw.data_section_offset,
395 stack_top,
396 stack_size,
397 page_buffers,
398 rtt_control_block: raw.rtt_location,
399 flash_properties: raw.flash_properties.clone(),
400 transfer_encoding: raw.transfer_encoding.unwrap_or_default(),
401 stack_overflow_check: raw.stack_overflow_check(),
402 })
403 }
404
405 pub(crate) fn assemble_from_raw_with_core(
407 algo: &RawFlashAlgorithm,
408 core_name: &str,
409 target: &Target,
410 ) -> Result<FlashAlgorithm, FlashError> {
411 let mm = &target.memory_map;
413
414 let ram_regions = mm
415 .iter()
416 .filter_map(MemoryRegion::as_ram_region)
417 .filter(|ram| ram.accessible_by(core_name))
418 .merge_consecutive();
419
420 let ram = ram_regions
421 .clone()
422 .filter(|ram| is_ram_suitable_for_algo(ram, algo.load_address))
423 .max_by_key(|region| region.range.end - region.range.start)
424 .ok_or(FlashError::NoRamDefined {
425 name: target.name.clone(),
426 })?;
427 tracing::info!("Chosen RAM to run the algo: {:x?}", ram);
428
429 let data_ram;
430 let data_ram = if let Some(data_load_address) = algo.data_load_address {
431 data_ram = ram_regions
432 .clone()
433 .find(|ram| is_ram_suitable_for_data(ram, data_load_address))
434 .ok_or(FlashError::NoRamDefined {
435 name: target.name.clone(),
436 })?;
437
438 &data_ram
439 } else {
440 &ram
442 };
443 tracing::info!("Data will be loaded to: {:x?}", data_ram);
444
445 Self::assemble_from_raw_with_data(algo, &ram, data_ram, target)
446 }
447}
448
449fn is_ram_suitable_for_algo(ram: &RamRegion, load_address: Option<u64>) -> bool {
451 if !ram.is_executable() {
452 return false;
453 }
454
455 if let Some(load_addr) = load_address {
462 ram.range.contains(&load_addr)
466 } else {
467 true
468 }
469}
470
471fn is_ram_suitable_for_data(ram: &RamRegion, load_address: u64) -> bool {
473 ram.range.contains(&load_address)
477}
478
479#[cfg(test)]
480mod test {
481 use probe_rs_target::{FlashProperties, SectorDescription, SectorInfo};
482
483 use crate::flashing::FlashAlgorithm;
484
485 #[test]
486 fn flash_sector_single_size() {
487 let config = FlashAlgorithm {
488 flash_properties: FlashProperties {
489 sectors: vec![SectorDescription {
490 size: 0x100,
491 address: 0x0,
492 }],
493 address_range: 0x1000..0x1000 + 0x1000,
494 page_size: 0x10,
495 ..Default::default()
496 },
497 ..Default::default()
498 };
499
500 let expected_first = SectorInfo {
501 base_address: 0x1000,
502 size: 0x100,
503 };
504
505 assert!(config.sector_info(0x1000 - 1).is_none());
506
507 assert_eq!(Some(expected_first), config.sector_info(0x1000));
508 assert_eq!(Some(expected_first), config.sector_info(0x10ff));
509
510 assert_eq!(Some(expected_first), config.sector_info(0x100b));
511 assert_eq!(Some(expected_first), config.sector_info(0x10ea));
512 }
513
514 #[test]
515 fn flash_sector_single_size_weird_sector_size() {
516 let config = FlashAlgorithm {
517 flash_properties: FlashProperties {
518 sectors: vec![SectorDescription {
519 size: 258,
520 address: 0x0,
521 }],
522 address_range: 0x800_0000..0x800_0000 + 258 * 10,
523 page_size: 0x10,
524 ..Default::default()
525 },
526 ..Default::default()
527 };
528
529 let expected_first = SectorInfo {
530 base_address: 0x800_0000,
531 size: 258,
532 };
533
534 assert!(config.sector_info(0x800_0000 - 1).is_none());
535
536 assert_eq!(Some(expected_first), config.sector_info(0x800_0000));
537 assert_eq!(Some(expected_first), config.sector_info(0x800_0000 + 257));
538
539 assert_eq!(Some(expected_first), config.sector_info(0x800_000b));
540 assert_eq!(Some(expected_first), config.sector_info(0x800_00e0));
541 }
542
543 #[test]
544 fn flash_sector_multiple_sizes() {
545 let config = FlashAlgorithm {
546 flash_properties: FlashProperties {
547 sectors: vec![
548 SectorDescription {
549 size: 0x4000,
550 address: 0x0,
551 },
552 SectorDescription {
553 size: 0x1_0000,
554 address: 0x1_0000,
555 },
556 SectorDescription {
557 size: 0x2_0000,
558 address: 0x2_0000,
559 },
560 ],
561 address_range: 0x800_0000..0x800_0000 + 0x10_0000,
562 page_size: 0x10,
563 ..Default::default()
564 },
565 ..Default::default()
566 };
567
568 let expected_a = SectorInfo {
569 base_address: 0x800_4000,
570 size: 0x4000,
571 };
572
573 let expected_b = SectorInfo {
574 base_address: 0x801_0000,
575 size: 0x1_0000,
576 };
577
578 let expected_c = SectorInfo {
579 base_address: 0x80A_0000,
580 size: 0x2_0000,
581 };
582
583 assert_eq!(Some(expected_a), config.sector_info(0x800_4000));
584 assert_eq!(Some(expected_b), config.sector_info(0x801_0000));
585 assert_eq!(Some(expected_c), config.sector_info(0x80A_0000));
586 }
587
588 #[test]
589 fn flash_sector_multiple_sizes_iter() {
590 let config = FlashAlgorithm {
591 flash_properties: FlashProperties {
592 sectors: vec![
593 SectorDescription {
594 size: 0x4000,
595 address: 0x0,
596 },
597 SectorDescription {
598 size: 0x1_0000,
599 address: 0x1_0000,
600 },
601 SectorDescription {
602 size: 0x2_0000,
603 address: 0x2_0000,
604 },
605 ],
606 address_range: 0x800_0000..0x800_0000 + 0x8_0000,
607 page_size: 0x10,
608 ..Default::default()
609 },
610 ..Default::default()
611 };
612
613 let got: Vec<SectorInfo> = config.iter_sectors().collect();
614
615 let expected = &[
616 SectorInfo {
617 base_address: 0x800_0000,
618 size: 0x4000,
619 },
620 SectorInfo {
621 base_address: 0x800_4000,
622 size: 0x4000,
623 },
624 SectorInfo {
625 base_address: 0x800_8000,
626 size: 0x4000,
627 },
628 SectorInfo {
629 base_address: 0x800_c000,
630 size: 0x4000,
631 },
632 SectorInfo {
633 base_address: 0x801_0000,
634 size: 0x1_0000,
635 },
636 SectorInfo {
637 base_address: 0x802_0000,
638 size: 0x2_0000,
639 },
640 SectorInfo {
641 base_address: 0x804_0000,
642 size: 0x2_0000,
643 },
644 SectorInfo {
645 base_address: 0x806_0000,
646 size: 0x2_0000,
647 },
648 ];
649 assert_eq!(&got, expected);
650 }
651}