use std::path::Path;
use std::time::Instant;
use anyhow::{Context, Result, anyhow};
use colored::Colorize;
use probe_rs::{
MemoryInterface, Permissions, Session, SessionConfig,
config::Registry,
flashing::{
DownloadOptions, FlashLoader, FlashProgress, ProgressEvent, ProgressOperation, erase,
erase_all, run_blank_check,
},
probe::{DebugProbeSelector, WireProtocol, list::Lister},
};
use probe_rs_target::RawFlashAlgorithm;
use xshell::{Shell, cmd};
use crate::commands::elf::cmd_elf;
#[expect(clippy::too_many_arguments)]
pub fn cmd_test(
target_artifact: &Path,
template_path: &Path,
definition_export_path: &Path,
test_start_sector_address: Option<u64>,
chip: Option<String>,
name: Option<String>,
probe: Option<DebugProbeSelector>,
speed: Option<u32>,
protocol: Option<WireProtocol>,
) -> Result<()> {
ensure_is_file(target_artifact)?;
ensure_is_file(template_path)?;
anyhow::ensure!(
!definition_export_path.is_dir(),
"'{}' is a directory. Please specify a file name.",
definition_export_path.display()
);
println!("Generating the YAML file in `{definition_export_path:?}`");
std::fs::copy(template_path, definition_export_path).with_context(|| {
format!(
"Failed to copy template file from '{}' to '{}'",
template_path.display(),
definition_export_path.display()
)
})?;
cmd_elf(
target_artifact,
true,
Some(definition_export_path),
true,
name,
)?;
if let Err(error) = generate_debug_info(target_artifact) {
println!("Generating debug artifacts failed because:");
println!("{error}");
}
let mut registry = Registry::new();
let yaml = std::fs::read_to_string(definition_export_path)?;
let family_name = registry.add_target_family_from_yaml(&yaml)?;
let targets = registry
.get_targets_by_family_name(&family_name)
.with_context(|| format!("Failed to get targets of {family_name}"))?;
let target_name = match targets.len() {
0 => return Err(anyhow!("No targets found for family {family_name}")),
1 => &targets[0],
count if chip.is_none() => {
return Err(anyhow!(
"{count} targets found for family {family_name}: {targets:#?}. Specify the desired target with --chip."
));
}
_ => {
let chip = chip.as_ref().unwrap();
targets
.iter()
.find(|target| *target == chip)
.with_context(|| format!("No target found for chip {chip}"))?
}
};
let permissions = Permissions::new().allow_erase_all();
let session_config = SessionConfig {
permissions,
speed,
protocol,
};
let lister = Lister::new();
let available_probes = lister.list(probe.as_ref());
if available_probes.len() > 1 {
return Err(anyhow!(
"Multiple probes were found -- please specify a probe with `--probe`"
));
}
let Some(probe) = available_probes.first() else {
return Err(anyhow!("No probes were found"));
};
let mut probe = probe.open()?;
if let Some(speed) = session_config.speed {
probe.set_speed(speed)?;
}
if let Some(protocol) = session_config.protocol {
probe.select_protocol(protocol)?;
}
let mut session =
probe.attach_with_registry(target_name, session_config.permissions, ®istry)?;
let mut progress = progress_callbacks();
let flash_algorithm = if let Some(test_start_sector_address) = test_start_sector_address {
let predicate = |x: &&RawFlashAlgorithm| {
x.flash_properties.address_range.start <= test_start_sector_address
&& test_start_sector_address < x.flash_properties.address_range.end
};
let error_message = anyhow!("No flash algorithm matching specified address can be found");
session
.target()
.flash_algorithms
.iter()
.find(predicate)
.ok_or(error_message)?
} else {
&session.target().flash_algorithms[0]
};
let flash_properties = &flash_algorithm.flash_properties;
let start_address = flash_properties.address_range.start;
let end_address = flash_properties.address_range.end;
let data_size = flash_properties.page_size;
let sector_size = flash_properties.sectors[0].size;
let test_start_sector_address = test_start_sector_address.unwrap_or(start_address);
if test_start_sector_address < start_address
|| test_start_sector_address > start_address + end_address - sector_size * 2
|| !test_start_sector_address.is_multiple_of(sector_size)
{
return Err(anyhow!(
"test_start_sector_address must be sector aligned address pointing flash range"
));
}
let test = "Test".green();
println!("{test}: Erasing sectorwise and writing two pages ...");
run_flash_erase(
&mut session,
EraseType::EraseRange(
test_start_sector_address,
test_start_sector_address + sector_size * 2,
),
)?;
println!("{test}: Erase done");
run_blank_check(
&mut session,
&mut progress,
test_start_sector_address,
test_start_sector_address + sector_size * 2,
true,
)?;
println!("{test}: Writing two pages ...");
let mut loader = session.target().flash_loader();
let data = (0..data_size).map(|n| (n % 256) as u8).collect::<Vec<_>>();
loader.add_data(test_start_sector_address + 1, &data)?;
run_flash_download(&mut session, loader, true)?;
println!("{test}: Write done");
let mut readback = vec![0; data_size as usize];
session
.core(0)?
.read(test_start_sector_address + 1, &mut readback)?;
assert_eq!(readback, data);
println!("{test}: Erasing the entire chip and writing two pages ...");
run_flash_erase(&mut session, EraseType::EraseAll)?;
println!("{test}: Erase done");
run_blank_check(
&mut session,
&mut progress,
test_start_sector_address,
test_start_sector_address + sector_size * 2,
true,
)?;
println!("{test}: Writing two pages ...");
let mut loader = session.target().flash_loader();
let data = (0..data_size).map(|n| (n % 256) as u8).collect::<Vec<_>>();
loader.add_data(test_start_sector_address + 1, &data)?;
run_flash_download(&mut session, loader, true)?;
println!("{test}: Write done");
let mut readback = vec![0; data_size as usize];
session
.core(0)?
.read_8(test_start_sector_address + 1, &mut readback)?;
assert_eq!(readback, data);
println!("{test}: Erasing sectorwise and writing two pages double buffered ...");
run_flash_erase(
&mut session,
EraseType::EraseRange(
test_start_sector_address,
test_start_sector_address + sector_size * 2,
),
)?;
println!("{test}: Erase done");
run_blank_check(
&mut session,
&mut progress,
test_start_sector_address,
test_start_sector_address + sector_size * 2,
true,
)?;
println!("{test}: Writing two pages ...");
let mut loader = session.target().flash_loader();
let data = (0..data_size).map(|n| (n % 256) as u8).collect::<Vec<_>>();
loader.add_data(test_start_sector_address + 1, &data)?;
run_flash_download(&mut session, loader, false)?;
println!("{test}: Write done");
let mut readback = vec![0; data_size as usize];
session
.core(0)?
.read_8(test_start_sector_address + 1, &mut readback)?;
assert_eq!(readback, data);
Ok(())
}
fn progress_callbacks() -> FlashProgress<'static> {
FlashProgress::new({
let mut t = Instant::now();
move |event| match event {
ProgressEvent::Started(ProgressOperation::Program) => t = Instant::now(),
ProgressEvent::Started(ProgressOperation::Erase) => t = Instant::now(),
ProgressEvent::Failed(ProgressOperation::Erase) => {
println!("Failed erasing in {:?}", t.elapsed());
}
ProgressEvent::Finished(ProgressOperation::Erase) => {
println!("Finished erasing in {:?}", t.elapsed());
}
ProgressEvent::Failed(ProgressOperation::Program) => {
println!("Failed programming in {:?}", t.elapsed());
}
ProgressEvent::Finished(ProgressOperation::Program) => {
println!("Finished programming in {:?}", t.elapsed());
}
ProgressEvent::DiagnosticMessage { message } => {
let prefix = "Message".yellow();
if message.ends_with('\n') {
print!("{prefix}: {message}");
} else {
println!("{prefix}: {message}");
}
}
_ => (),
}
})
}
fn ensure_is_file(file_path: &Path) -> Result<()> {
anyhow::ensure!(
file_path.is_file(),
"'{}' does not seem to be a valid file.",
file_path.display()
);
Ok(())
}
pub fn run_flash_download(
session: &mut Session,
loader: FlashLoader,
disable_double_buffering: bool,
) -> Result<()> {
let mut download_option = DownloadOptions::default();
download_option.keep_unwritten_bytes = false;
download_option.disable_double_buffering = disable_double_buffering;
download_option.progress = progress_callbacks();
download_option.skip_erase = true;
loader.commit(session, download_option)?;
Ok(())
}
pub enum EraseType {
EraseAll,
EraseRange(u64, u64),
}
pub fn run_flash_erase(session: &mut Session, erase_type: EraseType) -> Result<()> {
let mut progress = progress_callbacks();
if let EraseType::EraseRange(start, end) = erase_type {
erase(session, &mut progress, start, end, true)?;
} else {
erase_all(session, &mut progress, true)?;
}
Ok(())
}
fn generate_debug_info(target_artifact: &Path) -> Result<()> {
let sh = Shell::new()?;
std::fs::write(
"target/disassembly.s",
cmd!(sh, "rust-objdump --disassemble {target_artifact}")
.output()?
.stdout,
)?;
std::fs::write(
"target/dump.txt",
cmd!(sh, "rust-objdump -x {target_artifact}")
.output()?
.stdout,
)?;
std::fs::write(
"target/nm.txt",
cmd!(sh, "rust-nm {target_artifact} -n").output()?.stdout,
)?;
Ok(())
}