use std::io::Write;
use flate2::{
Compression,
write::{ZlibDecoder, ZlibEncoder},
};
use log::debug;
use md5::{Digest, Md5};
use crate::{
Error,
flasher::{FLASH_SECTOR_SIZE, SpiAttachParams},
image_format::Segment,
target::{Chip, WDT_WKEY},
};
#[cfg(feature = "serialport")]
use crate::{
command::{Command, CommandType},
connection::Connection,
target::FlashTarget,
target::ProgressCallbacks,
};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct Esp32Target {
chip: Chip,
spi_attach_params: SpiAttachParams,
use_stub: bool,
verify: bool,
skip: bool,
need_flash_end: bool,
}
impl Esp32Target {
pub fn new(
chip: Chip,
spi_attach_params: SpiAttachParams,
use_stub: bool,
verify: bool,
skip: bool,
) -> Self {
Esp32Target {
chip,
spi_attach_params,
use_stub,
verify,
skip,
need_flash_end: false,
}
}
}
#[cfg(feature = "serialport")]
impl FlashTarget for Esp32Target {
fn begin(&mut self, connection: &mut Connection) -> Result<(), Error> {
connection.with_timeout(CommandType::SpiAttach.timeout(), |connection| {
let command = if self.use_stub {
Command::SpiAttachStub {
spi_params: self.spi_attach_params,
}
} else {
Command::SpiAttach {
spi_params: self.spi_attach_params,
}
};
connection.command(command)
})?;
if connection.is_using_usb_serial_jtag()
&& !connection.secure_download_mode
&& let (Some(wdt_wprotect), Some(wdt_config0)) =
(self.chip.wdt_wprotect(), self.chip.wdt_config0())
{
connection.command(Command::WriteReg {
address: wdt_wprotect,
value: WDT_WKEY,
mask: None,
})?; connection.command(Command::WriteReg {
address: wdt_config0,
value: 0x0,
mask: None,
})?; connection.command(Command::WriteReg {
address: wdt_wprotect,
value: 0x0,
mask: None,
})?; }
Ok(())
}
fn write_segment(
&mut self,
connection: &mut Connection,
segment: Segment<'_>,
progress: &mut dyn ProgressCallbacks,
) -> Result<(), Error> {
let addr = segment.addr;
let mut md5_hasher = Md5::new();
md5_hasher.update(&segment.data);
let checksum_md5 = md5_hasher.finalize();
let use_compression = self.use_stub;
let flash_write_size = if self.use_stub {
self.chip.stub_flash_write_size()
} else {
self.chip.flash_write_size()
};
let erase_count = segment.data.len().div_ceil(FLASH_SECTOR_SIZE);
let erase_size = (erase_count * FLASH_SECTOR_SIZE) as u32;
if self.skip {
let flash_checksum_md5: u128 = connection.with_timeout(
CommandType::FlashMd5.timeout_for_size(segment.data.len() as u32),
|connection| {
connection
.command(Command::FlashMd5 {
offset: addr,
size: segment.data.len() as u32,
})?
.try_into()
},
)?;
if checksum_md5[..] == flash_checksum_md5.to_be_bytes() {
debug!("Segment at address '0x{addr:x}' has not changed, skipping write");
progress.finish(true);
return Ok(());
}
}
let data = if use_compression {
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::best());
encoder.write_all(&segment.data)?;
encoder.finish()?
} else {
segment.data.to_vec()
};
let block_count = data.len().div_ceil(flash_write_size);
let chunks = data.chunks(flash_write_size);
let num_chunks = chunks.len();
progress.init(addr, num_chunks);
if use_compression {
connection.with_timeout(
CommandType::FlashDeflBegin.timeout_for_size(erase_size),
|connection| {
connection.command(Command::FlashDeflBegin {
size: segment.data.len() as u32,
blocks: block_count as u32,
block_size: flash_write_size as u32,
offset: addr,
supports_encryption: self.chip != Chip::Esp32 && !self.use_stub,
})?;
Ok(())
},
)?;
} else {
connection.with_timeout(
CommandType::FlashBegin.timeout_for_size(erase_size),
|connection| {
connection.command(Command::FlashBegin {
size: erase_size,
blocks: block_count as u32,
block_size: flash_write_size as u32,
offset: addr,
supports_encryption: self.chip != Chip::Esp32,
})?;
Ok(())
},
)?;
}
self.need_flash_end = true;
let mut decoder = ZlibDecoder::new(Vec::new());
let mut decoded_size = 0;
for (i, block) in chunks.enumerate() {
if use_compression {
decoder.write_all(block)?;
decoder.flush()?;
let size = decoder.get_ref().len() - decoded_size;
decoded_size = decoder.get_ref().len();
connection.with_timeout(
CommandType::FlashDeflData.timeout_for_size(size as u32),
|connection| {
connection.command(Command::FlashDeflData {
sequence: i as u32,
pad_to: 0,
pad_byte: 0xff,
data: block,
})?;
Ok(())
},
)?;
} else {
connection.with_timeout(
CommandType::FlashData.timeout_for_size(block.len() as u32),
|connection| {
connection.command(Command::FlashData {
sequence: i as u32,
pad_to: flash_write_size,
pad_byte: 0xff,
data: block,
})?;
Ok(())
},
)?;
}
progress.update(i + 1)
}
if self.verify {
progress.verifying();
let flash_checksum_md5: u128 = connection.with_timeout(
CommandType::FlashMd5.timeout_for_size(segment.data.len() as u32),
|connection| {
connection
.command(Command::FlashMd5 {
offset: addr,
size: segment.data.len() as u32,
})?
.try_into()
},
)?;
if checksum_md5[..] != flash_checksum_md5.to_be_bytes() {
return Err(Error::VerifyFailed);
}
debug!("Segment at address '0x{addr:x}' verified successfully");
}
progress.finish(false);
Ok(())
}
fn finish(&mut self, connection: &mut Connection, reboot: bool) -> Result<(), Error> {
if self.need_flash_end {
let flash_end_reboot = connection.secure_download_mode || reboot;
let result = if self.use_stub {
connection.with_timeout(CommandType::FlashDeflEnd.timeout(), |connection| {
connection.command(Command::FlashDeflEnd {
reboot: flash_end_reboot,
})
})
} else {
connection.with_timeout(CommandType::FlashEnd.timeout(), |connection| {
connection.command(Command::FlashEnd {
reboot: flash_end_reboot,
})
})
};
match result {
Ok(_) => {}
Err(Error::RomError(_)) if connection.secure_download_mode => {
}
Err(e) => return Err(e),
}
}
if reboot && !connection.secure_download_mode {
connection.reset_after(self.use_stub, self.chip)?;
}
Ok(())
}
}