use std::{borrow::Cow, str::FromStr, thread::sleep};
use bytemuck::{Pod, Zeroable, __core::time::Duration};
use esp_idf_part::PartitionTable;
use log::{debug, info, warn};
use serialport::UsbPortInfo;
use strum::{Display, EnumIter, EnumVariantNames};
use self::stubs::FlashStub;
use crate::{
command::{Command, CommandType},
connection::Connection,
elf::{ElfFirmwareImage, FirmwareImage, RomSegment},
error::{ConnectionError, Error, FlashDetectError, ResultExt},
image_format::ImageFormatKind,
interface::Interface,
targets::Chip,
};
mod stubs;
pub(crate) const CHECKSUM_INIT: u8 = 0xEF;
pub(crate) const FLASH_SECTOR_SIZE: usize = 0x1000;
pub(crate) const FLASH_WRITE_SIZE: usize = 0x400;
const CHIP_DETECT_MAGIC_REG_ADDR: u32 = 0x40001000;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(3);
const EXPECTED_STUB_HANDSHAKE: &str = "OHAI";
const FLASH_BLOCK_SIZE: usize = 0x100;
const FLASH_SECTORS_PER_BLOCK: usize = FLASH_SECTOR_SIZE / FLASH_BLOCK_SIZE;
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq, Display, EnumVariantNames)]
#[repr(u8)]
pub enum FlashFrequency {
#[strum(serialize = "12M")]
Flash12M,
#[strum(serialize = "15M")]
Flash15M,
#[strum(serialize = "16M")]
Flash16M,
#[strum(serialize = "20M")]
Flash20M,
#[strum(serialize = "24M")]
Flash24M,
#[strum(serialize = "26M")]
Flash26M,
#[strum(serialize = "30M")]
Flash30M,
#[default]
#[strum(serialize = "40M")]
Flash40M,
#[strum(serialize = "48M")]
Flash48M,
#[strum(serialize = "60M")]
Flash60M,
#[strum(serialize = "80M")]
Flash80M,
}
impl FromStr for FlashFrequency {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
use FlashFrequency::*;
match s.to_uppercase().as_str() {
"12M" => Ok(Flash12M),
"15M" => Ok(Flash15M),
"16M" => Ok(Flash16M),
"20M" => Ok(Flash20M),
"24M" => Ok(Flash24M),
"26M" => Ok(Flash26M),
"30M" => Ok(Flash30M),
"40M" => Ok(Flash40M),
"48M" => Ok(Flash48M),
"60M" => Ok(Flash60M),
"80M" => Ok(Flash80M),
_ => Err(Error::InvalidFlashFrequency(s.to_string())),
}
}
}
#[derive(Copy, Clone, Debug, Default, EnumVariantNames)]
#[strum(serialize_all = "lowercase")]
pub enum FlashMode {
Qio,
Qout,
#[default]
Dio,
Dout,
}
impl FromStr for FlashMode {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mode = match s.to_lowercase().as_str() {
"qio" => FlashMode::Qio,
"qout" => FlashMode::Qout,
"dio" => FlashMode::Dio,
"dout" => FlashMode::Dout,
_ => return Err(Error::InvalidFlashMode(s.to_string())),
};
Ok(mode)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Display, EnumVariantNames, EnumIter)]
#[repr(u8)]
pub enum FlashSize {
#[strum(serialize = "256K")]
Flash256Kb = 0x12,
#[strum(serialize = "512K")]
Flash512Kb = 0x13,
#[strum(serialize = "1M")]
Flash1Mb = 0x14,
#[strum(serialize = "2M")]
Flash2Mb = 0x15,
#[default]
#[strum(serialize = "4M")]
Flash4Mb = 0x16,
#[strum(serialize = "8M")]
Flash8Mb = 0x17,
#[strum(serialize = "16M")]
Flash16Mb = 0x18,
#[strum(serialize = "32M")]
Flash32Mb = 0x19,
#[strum(serialize = "64M")]
Flash64Mb = 0x1a,
#[strum(serialize = "128M")]
Flash128Mb = 0x21,
}
impl FlashSize {
fn from(value: u8) -> Result<FlashSize, Error> {
match value {
0x12 => Ok(FlashSize::Flash256Kb),
0x13 => Ok(FlashSize::Flash512Kb),
0x14 => Ok(FlashSize::Flash1Mb),
0x15 => Ok(FlashSize::Flash2Mb),
0x16 => Ok(FlashSize::Flash4Mb),
0x17 => Ok(FlashSize::Flash8Mb),
0x18 => Ok(FlashSize::Flash16Mb),
0x19 => Ok(FlashSize::Flash32Mb),
0x1a => Ok(FlashSize::Flash64Mb),
0x21 => Ok(FlashSize::Flash128Mb),
_ => Err(Error::UnsupportedFlash(FlashDetectError::from(value))),
}
}
pub fn size(self) -> u32 {
match self {
FlashSize::Flash256Kb => 0x0040000,
FlashSize::Flash512Kb => 0x0080000,
FlashSize::Flash1Mb => 0x0100000,
FlashSize::Flash2Mb => 0x0200000,
FlashSize::Flash4Mb => 0x0400000,
FlashSize::Flash8Mb => 0x0800000,
FlashSize::Flash16Mb => 0x1000000,
FlashSize::Flash32Mb => 0x2000000,
FlashSize::Flash64Mb => 0x4000000,
FlashSize::Flash128Mb => 0x8000000,
}
}
}
impl FromStr for FlashSize {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
use strum::{IntoEnumIterator, VariantNames};
let upper = s.to_uppercase();
FlashSize::VARIANTS
.iter()
.copied()
.zip(FlashSize::iter())
.find(|(name, _)| *name == upper)
.map(|(_, variant)| variant)
.ok_or_else(|| Error::InvalidFlashSize(s.to_string()))
}
}
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct SpiAttachParams {
clk: u8,
q: u8,
d: u8,
hd: u8,
cs: u8,
}
impl SpiAttachParams {
pub const fn default() -> Self {
SpiAttachParams {
clk: 0,
q: 0,
d: 0,
hd: 0,
cs: 0,
}
}
pub const fn esp32_pico_d4() -> Self {
SpiAttachParams {
clk: 6,
q: 17,
d: 8,
hd: 11,
cs: 16,
}
}
pub fn encode(self, stub: bool) -> Vec<u8> {
let packed = ((self.hd as u32) << 24)
| ((self.cs as u32) << 18)
| ((self.d as u32) << 12)
| ((self.q as u32) << 6)
| (self.clk as u32);
let mut encoded: Vec<u8> = packed.to_le_bytes().to_vec();
if !stub {
encoded.append(&mut vec![0u8; 4]);
}
encoded
}
}
const TRY_SPI_PARAMS: [SpiAttachParams; 2] =
[SpiAttachParams::default(), SpiAttachParams::esp32_pico_d4()];
#[derive(Zeroable, Pod, Copy, Clone, Debug)]
#[repr(C)]
struct BlockParams {
size: u32,
sequence: u32,
dummy1: u32,
dummy2: u32,
}
#[derive(Zeroable, Pod, Copy, Clone, Debug)]
#[repr(C)]
struct BeginParams {
size: u32,
blocks: u32,
block_size: u32,
offset: u32,
encrypted: u32,
}
#[derive(Zeroable, Pod, Copy, Clone)]
#[repr(C)]
struct EntryParams {
no_entry: u32,
entry: u32,
}
pub struct Flasher {
connection: Connection,
chip: Chip,
flash_size: FlashSize,
spi_params: SpiAttachParams,
use_stub: bool,
}
impl Flasher {
pub fn connect(
serial: Interface,
port_info: UsbPortInfo,
speed: Option<u32>,
use_stub: bool,
) -> Result<Self, Error> {
let mut connection = Connection::new(serial, port_info);
connection.begin()?;
connection.set_timeout(DEFAULT_TIMEOUT)?;
let magic = connection.read_reg(CHIP_DETECT_MAGIC_REG_ADDR)?;
let chip = Chip::from_magic(magic)?;
let mut flasher = Flasher {
connection,
chip,
flash_size: FlashSize::Flash4Mb,
spi_params: SpiAttachParams::default(),
use_stub,
};
if use_stub {
info!("Using flash stub");
flasher.load_stub()?;
}
flasher.spi_autodetect()?;
if let Some(b) = speed {
match flasher.chip {
Chip::Esp8266 => (), _ => {
if b > 115_200 {
warn!("Setting baud rate higher than 115,200 can cause issues");
flasher.change_baud(b)?;
}
}
}
}
Ok(flasher)
}
fn load_stub(&mut self) -> Result<(), Error> {
debug!("Loading flash stub for chip: {:?}", self.chip);
let stub = FlashStub::get(self.chip);
let mut ram_target = self.chip.ram_target(
Some(stub.entry()),
self.chip
.into_target()
.max_ram_block_size(&mut self.connection)?,
);
ram_target.begin(&mut self.connection).flashing()?;
let (text_addr, text) = stub.text();
debug!("Write {} byte stub text", text.len());
ram_target
.write_segment(
&mut self.connection,
RomSegment {
addr: text_addr,
data: Cow::Borrowed(&text),
},
None,
)
.flashing()?;
let (data_addr, data) = stub.data();
debug!("Write {} byte stub data", data.len());
ram_target
.write_segment(
&mut self.connection,
RomSegment {
addr: data_addr,
data: Cow::Borrowed(&data),
},
None,
)
.flashing()?;
debug!("Finish stub write");
ram_target.finish(&mut self.connection, true).flashing()?;
debug!("Stub written!");
match self.connection.read(EXPECTED_STUB_HANDSHAKE.len())? {
Some(resp) if resp == EXPECTED_STUB_HANDSHAKE.as_bytes() => Ok(()),
_ => Err(Error::Connection(ConnectionError::InvalidStubHandshake)),
}?;
let magic = self.connection.read_reg(CHIP_DETECT_MAGIC_REG_ADDR)?;
let chip = Chip::from_magic(magic)?;
debug!("Re-detected chip: {:?}", chip);
Ok(())
}
fn spi_autodetect(&mut self) -> Result<(), Error> {
for spi_params in TRY_SPI_PARAMS.iter().copied() {
debug!("Attempting flash enable with: {:?}", spi_params);
if let Err(_e) = self.enable_flash(spi_params) {
debug!("Flash enable failed");
}
if let Some(flash_size) = self.flash_detect()? {
debug!("Flash detect OK!");
self.flash_size = flash_size;
self.spi_params = spi_params;
return Ok(());
}
debug!("Flash detect failed");
}
debug!("SPI flash autodetection failed");
Err(Error::FlashConnect)
}
fn flash_detect(&mut self) -> Result<Option<FlashSize>, Error> {
const FLASH_RETRY: u8 = 0xFF;
let flash_id = self.spi_command(CommandType::FlashDetect, &[], 24)?;
let size_id = (flash_id >> 16) as u8;
if size_id == FLASH_RETRY {
return Ok(None);
}
let flash_size = match FlashSize::from(size_id) {
Ok(size) => size,
Err(_) => {
warn!(
"Could not detect flash size (FlashID=0x{:02X}, SizeID=0x{:02X}), defaulting to 4MB",
flash_id,
size_id
);
FlashSize::Flash4Mb
}
};
Ok(Some(flash_size))
}
fn enable_flash(&mut self, spi_params: SpiAttachParams) -> Result<(), Error> {
match self.chip {
Chip::Esp8266 => {
self.connection.command(Command::FlashBegin {
supports_encryption: false,
offset: 0,
block_size: FLASH_WRITE_SIZE as u32,
size: 0,
blocks: 0,
})?;
}
_ => {
self.connection
.with_timeout(CommandType::SpiAttach.timeout(), |connection| {
connection.command(if self.use_stub {
Command::SpiAttachStub { spi_params }
} else {
Command::SpiAttach { spi_params }
})
})?;
}
}
Ok(())
}
fn spi_command(
&mut self,
command: CommandType,
data: &[u8],
read_bits: u32,
) -> Result<u32, Error> {
assert!(read_bits < 32);
assert!(data.len() < 64);
let spi_registers = self.chip.into_target().spi_registers();
let old_spi_usr = self.connection.read_reg(spi_registers.usr())?;
let old_spi_usr2 = self.connection.read_reg(spi_registers.usr2())?;
let mut flags = 1 << 31;
if !data.is_empty() {
flags |= 1 << 27;
}
if read_bits > 0 {
flags |= 1 << 28;
}
self.connection
.write_reg(spi_registers.usr(), flags, None)?;
self.connection
.write_reg(spi_registers.usr2(), 7 << 28 | command as u32, None)?;
if let (Some(mosi_data_length), Some(miso_data_length)) =
(spi_registers.mosi_length(), spi_registers.miso_length())
{
if !data.is_empty() {
self.connection
.write_reg(mosi_data_length, data.len() as u32 * 8 - 1, None)?;
}
if read_bits > 0 {
self.connection
.write_reg(miso_data_length, read_bits - 1, None)?;
}
} else {
let mosi_mask = if data.is_empty() {
0
} else {
data.len() as u32 * 8 - 1
};
let miso_mask = if read_bits == 0 { 0 } else { read_bits - 1 };
self.connection.write_reg(
spi_registers.usr1(),
miso_mask << 8 | mosi_mask << 17,
None,
)?;
}
if data.is_empty() {
self.connection.write_reg(spi_registers.w0(), 0, None)?;
} else {
for (i, bytes) in data.chunks(4).enumerate() {
let mut data_bytes = [0; 4];
data_bytes[0..bytes.len()].copy_from_slice(bytes);
let data = u32::from_le_bytes(data_bytes);
self.connection
.write_reg(spi_registers.w0() + i as u32, data, None)?;
}
}
self.connection
.write_reg(spi_registers.cmd(), 1 << 18, None)?;
let mut i = 0;
loop {
sleep(Duration::from_millis(1));
if self.connection.read_reg(spi_registers.usr())? & (1 << 18) == 0 {
break;
}
i += 1;
if i > 10 {
return Err(Error::Connection(ConnectionError::Timeout(command.into())));
}
}
let result = self.connection.read_reg(spi_registers.w0())?;
self.connection
.write_reg(spi_registers.usr(), old_spi_usr, None)?;
self.connection
.write_reg(spi_registers.usr2(), old_spi_usr2, None)?;
Ok(result)
}
pub fn connection(&mut self) -> &mut Connection {
&mut self.connection
}
pub fn chip(&self) -> Chip {
self.chip
}
pub fn board_info(&mut self) -> Result<(), Error> {
let chip = self.chip();
let target = chip.into_target();
let features = target.chip_features(self.connection())?;
let freq = target.crystal_freq(self.connection())?;
let mac = target.mac_address(self.connection())?;
print!("Chip type: {chip}");
if chip != Chip::Esp8266 {
let (major, minor) = target.chip_revision(self.connection())?;
println!(" (revision v{major}.{minor})");
} else {
println!("");
}
println!("Crystal frequency: {freq}MHz");
println!("Flash size: {}", self.flash_size);
println!("Features: {}", features.join(", "));
println!("MAC address: {mac}");
Ok(())
}
pub fn load_elf_to_ram(&mut self, elf_data: &[u8]) -> Result<(), Error> {
let image = ElfFirmwareImage::try_from(elf_data)?;
if image.rom_segments(self.chip).next().is_some() {
return Err(Error::ElfNotRamLoadable);
}
let mut target = self.chip.ram_target(
Some(image.entry()),
self.chip
.into_target()
.max_ram_block_size(&mut self.connection)?,
);
target.begin(&mut self.connection).flashing()?;
for segment in image.ram_segments(self.chip) {
let progress_cb = if cfg!(feature = "cli") {
use crate::cli::{build_progress_bar_callback, progress_bar};
let progress = progress_bar(format!("segment 0x{:X}", segment.addr), None);
let progress_cb = build_progress_bar_callback(progress);
Some(progress_cb)
} else {
None
};
target
.write_segment(&mut self.connection, segment.into(), progress_cb)
.flashing()?;
}
target.finish(&mut self.connection, true).flashing()
}
pub fn load_elf_to_flash_with_format(
&mut self,
elf_data: &[u8],
bootloader: Option<Vec<u8>>,
partition_table: Option<PartitionTable>,
image_format: Option<ImageFormatKind>,
flash_mode: Option<FlashMode>,
flash_size: Option<FlashSize>,
flash_freq: Option<FlashFrequency>,
) -> Result<(), Error> {
let image = ElfFirmwareImage::try_from(elf_data)?;
let mut target = self.chip.flash_target(self.spi_params, self.use_stub);
target.begin(&mut self.connection).flashing()?;
let image = self.chip.into_target().get_flash_image(
&image,
bootloader,
partition_table,
image_format,
Some(
self.chip
.into_target()
.chip_revision(&mut self.connection)?,
),
flash_mode,
flash_size.or(Some(self.flash_size)),
flash_freq,
)?;
#[cfg(feature = "cli")]
crate::cli::display_image_size(image.app_size(), image.part_size());
for segment in image.flash_segments() {
let progress_cb = if cfg!(feature = "cli") {
use crate::cli::{build_progress_bar_callback, progress_bar};
let progress = progress_bar(format!("segment 0x{:X}", segment.addr), None);
let progress_cb = build_progress_bar_callback(progress);
Some(progress_cb)
} else {
None
};
target
.write_segment(&mut self.connection, segment, progress_cb)
.flashing()?;
}
target.finish(&mut self.connection, true).flashing()?;
Ok(())
}
pub fn write_bin_to_flash(
&mut self,
addr: u32,
data: &[u8],
progress_cb: Option<Box<dyn Fn(usize, usize)>>,
) -> Result<(), Error> {
let segment = RomSegment {
addr,
data: Cow::from(data),
};
let mut target = self.chip.flash_target(self.spi_params, self.use_stub);
target.begin(&mut self.connection).flashing()?;
target.write_segment(&mut self.connection, segment, progress_cb)?;
target.finish(&mut self.connection, true).flashing()?;
Ok(())
}
pub fn load_elf_to_flash(
&mut self,
elf_data: &[u8],
bootloader: Option<Vec<u8>>,
partition_table: Option<PartitionTable>,
flash_mode: Option<FlashMode>,
flash_size: Option<FlashSize>,
flash_freq: Option<FlashFrequency>,
) -> Result<(), Error> {
self.load_elf_to_flash_with_format(
elf_data,
bootloader,
partition_table,
None,
flash_mode,
flash_size,
flash_freq,
)
}
pub fn change_baud(&mut self, speed: u32) -> Result<(), Error> {
debug!("Change baud to: {}", speed);
let prior_baud = match self.use_stub {
true => self.connection.get_baud()?,
false => 0,
};
self.connection
.with_timeout(CommandType::ChangeBaud.timeout(), |connection| {
connection.command(Command::ChangeBaud {
new_baud: speed,
prior_baud,
})
})?;
self.connection.set_baud(speed)?;
std::thread::sleep(Duration::from_secs_f32(0.05));
self.connection.flush()?;
Ok(())
}
pub fn get_usb_pid(&self) -> Result<u16, Error> {
self.connection.get_usb_pid()
}
pub fn erase_region(&mut self, offset: u32, size: u32) -> Result<(), Error> {
debug!("Erasing region of 0x{:x}B at 0x{:08x}", size, offset);
self.connection
.with_timeout(CommandType::EraseRegion.timeout(), |connection| {
connection.command(Command::EraseRegion { offset, size })
})?;
std::thread::sleep(Duration::from_secs_f32(0.05));
self.connection.flush()?;
Ok(())
}
pub fn into_interface(self) -> Interface {
self.connection.into_interface()
}
}
pub(crate) fn get_erase_size(offset: usize, size: usize) -> usize {
let sector_count = (size + FLASH_SECTOR_SIZE - 1) / FLASH_SECTOR_SIZE;
let start_sector = offset / FLASH_SECTOR_SIZE;
let head_sectors = usize::min(
FLASH_SECTORS_PER_BLOCK - (start_sector % FLASH_SECTORS_PER_BLOCK),
sector_count,
);
if sector_count < 2 * head_sectors {
(sector_count + 1) / 2 * FLASH_SECTOR_SIZE
} else {
(sector_count - head_sectors) * FLASH_SECTOR_SIZE
}
}
pub(crate) fn checksum(data: &[u8], mut checksum: u8) -> u8 {
for byte in data {
checksum ^= *byte;
}
checksum
}