use std::{
collections::HashMap,
fmt,
io::{BufWriter, Read, Write},
iter::zip,
thread::sleep,
time::Duration,
};
use log::{debug, info};
use regex::Regex;
use serde::{Deserialize, Serialize};
use serialport::{SerialPort, UsbPortInfo};
use slip_codec::SlipDecoder;
#[cfg(unix)]
use self::reset::UnixTightReset;
use self::{
encoder::SlipEncoder,
reset::{
ClassicReset,
ResetStrategy,
UsbJtagSerialReset,
construct_reset_strategy_sequence,
hard_reset,
reset_after_flash,
soft_reset,
},
};
use crate::{
command::{Command, CommandResponse, CommandResponseValue, CommandType, DEFAULT_MAX_LEN},
error::{ConnectionError, Error, ResultExt, RomError, RomErrorKind},
flasher::stubs::CHIP_DETECT_MAGIC_REG_ADDR,
target::Chip,
};
pub(crate) mod reset;
pub use reset::{ResetAfterOperation, ResetBeforeOperation};
const MAX_CONNECT_ATTEMPTS: usize = 7;
const MAX_SYNC_ATTEMPTS: usize = 5;
const USB_SERIAL_JTAG_PID: u16 = 0x1001;
#[cfg(unix)]
pub type Port = serialport::TTYPort;
#[cfg(windows)]
pub type Port = serialport::COMPort;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
pub struct SecurityInfo {
pub flags: u32,
pub flash_crypt_cnt: u8,
pub key_purposes: [u8; 7],
pub chip_id: Option<u32>,
pub eco_version: Option<u32>,
}
impl SecurityInfo {
fn security_flag_map() -> HashMap<&'static str, u32> {
HashMap::from([
("SECURE_BOOT_EN", 1 << 0),
("SECURE_BOOT_AGGRESSIVE_REVOKE", 1 << 1),
("SECURE_DOWNLOAD_ENABLE", 1 << 2),
("SECURE_BOOT_KEY_REVOKE0", 1 << 3),
("SECURE_BOOT_KEY_REVOKE1", 1 << 4),
("SECURE_BOOT_KEY_REVOKE2", 1 << 5),
("SOFT_DIS_JTAG", 1 << 6),
("HARD_DIS_JTAG", 1 << 7),
("DIS_USB", 1 << 8),
("DIS_DOWNLOAD_DCACHE", 1 << 9),
("DIS_DOWNLOAD_ICACHE", 1 << 10),
])
}
pub(crate) fn security_flag_status(&self, flag_name: &str) -> bool {
if let Some(&flag) = Self::security_flag_map().get(flag_name) {
(self.flags & flag) != 0
} else {
false
}
}
}
impl TryFrom<&[u8]> for SecurityInfo {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let esp32s2 = bytes.len() == 12;
if bytes.len() < 12 {
return Err(Error::InvalidResponse(format!(
"expected response of at least 12 bytes, received {} bytes",
bytes.len()
)));
}
let flags = u32::from_le_bytes(bytes[0..4].try_into()?);
let flash_crypt_cnt = bytes[4];
let key_purposes: [u8; 7] = bytes[5..12].try_into()?;
let (chip_id, eco_version) = if esp32s2 {
(None, None) } else {
if bytes.len() < 20 {
return Err(Error::InvalidResponse(format!(
"expected response of at least 20 bytes, received {} bytes",
bytes.len()
)));
}
let chip_id = u32::from_le_bytes(bytes[12..16].try_into()?);
let eco_version = u32::from_le_bytes(bytes[16..20].try_into()?);
(Some(chip_id), Some(eco_version))
};
Ok(SecurityInfo {
flags,
flash_crypt_cnt,
key_purposes,
chip_id,
eco_version,
})
}
}
impl fmt::Display for SecurityInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let key_purposes_str = self
.key_purposes
.iter()
.map(|b| format!("{b}"))
.collect::<Vec<_>>()
.join(", ");
writeln!(f, "\nSecurity Information:")?;
writeln!(f, "=====================")?;
writeln!(f, "Flags: {:#010x} ({:b})", self.flags, self.flags)?;
writeln!(f, "Key Purposes: [{key_purposes_str}]")?;
if let Some(chip_id) = self.chip_id {
writeln!(f, "Chip ID: {chip_id}")?;
}
if let Some(api_version) = self.eco_version {
writeln!(f, "API Version: {api_version}")?;
}
if self.security_flag_status("SECURE_BOOT_EN") {
writeln!(f, "Secure Boot: Enabled")?;
if self.security_flag_status("SECURE_BOOT_AGGRESSIVE_REVOKE") {
writeln!(f, "Secure Boot Aggressive key revocation: Enabled")?;
}
let revoked_keys: Vec<_> = [
"SECURE_BOOT_KEY_REVOKE0",
"SECURE_BOOT_KEY_REVOKE1",
"SECURE_BOOT_KEY_REVOKE2",
]
.iter()
.enumerate()
.filter(|(_, key)| self.security_flag_status(key))
.map(|(i, _)| format!("Secure Boot Key{i} is Revoked"))
.collect();
if !revoked_keys.is_empty() {
writeln!(
f,
"Secure Boot Key Revocation Status:\n {}",
revoked_keys.join("\n ")
)?;
}
} else {
writeln!(f, "Secure Boot: Disabled")?;
}
if !self.flash_crypt_cnt.count_ones().is_multiple_of(2) {
writeln!(f, "Flash Encryption: Enabled")?;
} else {
writeln!(f, "Flash Encryption: Disabled")?;
}
let crypt_cnt_str = "SPI Boot Crypt Count (SPI_BOOT_CRYPT_CNT)";
writeln!(f, "{}: 0x{:x}", crypt_cnt_str, self.flash_crypt_cnt)?;
if self.security_flag_status("DIS_DOWNLOAD_DCACHE") {
writeln!(f, "Dcache in UART download mode: Disabled")?;
}
if self.security_flag_status("DIS_DOWNLOAD_ICACHE") {
writeln!(f, "Icache in UART download mode: Disabled")?;
}
if self.security_flag_status("HARD_DIS_JTAG") {
writeln!(f, "JTAG: Permanently Disabled")?;
} else if self.security_flag_status("SOFT_DIS_JTAG") {
writeln!(f, "JTAG: Software Access Disabled")?;
}
if self.security_flag_status("DIS_USB") {
writeln!(f, "USB Access: Disabled")?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct Connection {
serial: Port,
port_info: UsbPortInfo,
decoder: SlipDecoder,
after_operation: ResetAfterOperation,
before_operation: ResetBeforeOperation,
pub(crate) secure_download_mode: bool,
pub(crate) baud: u32,
}
impl Connection {
pub fn new(
serial: Port,
port_info: UsbPortInfo,
after_operation: ResetAfterOperation,
before_operation: ResetBeforeOperation,
baud: u32,
) -> Self {
Connection {
serial,
port_info,
decoder: SlipDecoder::new(),
after_operation,
before_operation,
secure_download_mode: false,
baud,
}
}
pub fn begin(&mut self) -> Result<(), Error> {
let port_name = self.serial.name().unwrap_or_default();
let reset_sequence = construct_reset_strategy_sequence(
&port_name,
self.port_info.pid,
self.before_operation,
);
for (_, reset_strategy) in zip(0..MAX_CONNECT_ATTEMPTS, reset_sequence.iter().cycle()) {
match self.connect_attempt(reset_strategy.as_ref()) {
Ok(_) => {
return Ok(());
}
Err(e) => {
debug!("Failed to reset, error {e:#?}, retrying");
}
}
}
Err(Error::Connection(Box::new(
ConnectionError::ConnectionFailed,
)))
}
fn connect_attempt(&mut self, reset_strategy: &dyn ResetStrategy) -> Result<(), Error> {
if self.before_operation == ResetBeforeOperation::NoResetNoSync {
return Ok(());
}
let mut download_mode: bool = false;
let mut boot_mode = String::new();
let mut boot_log_detected = false;
let mut buff: Vec<u8>;
if self.before_operation != ResetBeforeOperation::NoReset {
reset_strategy.reset(&mut self.serial)?;
let available_bytes = self.serial.bytes_to_read()?;
buff = vec![0; available_bytes as usize];
let read_bytes = if available_bytes > 0 {
let read_bytes = self.serial.read(&mut buff)? as u32;
if read_bytes != available_bytes {
return Err(Error::Connection(Box::new(ConnectionError::ReadMismatch(
available_bytes,
read_bytes,
))));
}
read_bytes
} else {
0
};
let read_slice = String::from_utf8_lossy(&buff[..read_bytes as usize]).into_owned();
let pattern =
Regex::new(r"boot:(0x[0-9a-fA-F]+)([\s\S]*waiting for download)?").unwrap();
if let Some(data) = pattern.captures(&read_slice) {
boot_log_detected = true;
boot_mode = data
.get(1)
.map(|m| m.as_str())
.unwrap_or_default()
.to_string();
download_mode = data.get(2).is_some();
debug!("Boot Mode: {boot_mode}");
debug!("Download Mode: {download_mode}");
};
}
for _ in 0..MAX_SYNC_ATTEMPTS {
self.flush()?;
if self.sync().is_ok() {
return Ok(());
}
}
if boot_log_detected {
if download_mode {
return Err(Error::Connection(Box::new(ConnectionError::NoSyncReply)));
} else {
return Err(Error::Connection(Box::new(ConnectionError::WrongBootMode(
boot_mode.to_string(),
))));
}
}
Err(Error::Connection(Box::new(
ConnectionError::ConnectionFailed,
)))
}
pub(crate) fn sync(&mut self) -> Result<(), Error> {
self.with_timeout(CommandType::Sync.timeout(), |connection| {
connection.command(Command::Sync)?;
connection.flush()?;
sleep(Duration::from_millis(10));
for _ in 0..MAX_CONNECT_ATTEMPTS {
match connection.read_response_for_command(CommandType::Sync)? {
Some(response) if response.return_op == CommandType::Sync as u8 => {
if response.status == 1 {
connection.flush().ok();
return Err(Error::RomError(Box::new(RomError::new(
CommandType::Sync,
RomErrorKind::from(response.error),
))));
}
}
_ => {
return Err(Error::RomError(Box::new(RomError::new(
CommandType::Sync,
RomErrorKind::InvalidMessage,
))));
}
}
}
Ok(())
})?;
Ok(())
}
pub fn reset(&mut self) -> Result<(), Error> {
reset_after_flash(&mut self.serial, self.port_info.pid)?;
Ok(())
}
pub fn reset_after(&mut self, is_stub: bool, chip: Chip) -> Result<(), Error> {
let pid = self.usb_pid();
match self.after_operation {
ResetAfterOperation::HardReset => hard_reset(&mut self.serial, pid),
ResetAfterOperation::NoReset => {
info!("Staying in bootloader");
soft_reset(self, true, is_stub)?;
Ok(())
}
ResetAfterOperation::NoResetNoStub => {
info!("Staying in flasher stub");
Ok(())
}
ResetAfterOperation::WatchdogReset => {
info!("Resetting device with watchdog");
match chip {
Chip::Esp32c3 => {
if self.is_using_usb_serial_jtag() {
chip.rtc_wdt_reset(self)?;
}
}
Chip::Esp32p4 => {
if chip.is_using_usb_otg(self)? {
chip.rtc_wdt_reset(self)?;
}
}
Chip::Esp32s2 => {
if chip.is_using_usb_otg(self)? {
if chip.can_rtc_wdt_reset(self)? {
chip.rtc_wdt_reset(self)?;
}
}
}
Chip::Esp32s3 => {
if self.is_using_usb_serial_jtag() || chip.is_using_usb_otg(self)? {
if chip.can_rtc_wdt_reset(self)? {
chip.rtc_wdt_reset(self)?;
}
}
}
_ => {
return Err(Error::UnsupportedFeature {
chip,
feature: "watchdog reset".into(),
});
}
}
Ok(())
}
}
}
pub fn reset_to_flash(&mut self, extra_delay: bool) -> Result<(), Error> {
if self.is_using_usb_serial_jtag() {
UsbJtagSerialReset.reset(&mut self.serial)
} else {
#[cfg(unix)]
if UnixTightReset::new(extra_delay)
.reset(&mut self.serial)
.is_ok()
{
return Ok(());
}
ClassicReset::new(extra_delay).reset(&mut self.serial)
}
}
pub fn set_timeout(&mut self, timeout: Duration) -> Result<(), Error> {
self.serial.set_timeout(timeout)?;
Ok(())
}
pub fn set_baud(&mut self, baud: u32) -> Result<(), Error> {
self.serial.set_baud_rate(baud)?;
self.baud = baud;
Ok(())
}
pub fn baud(&self) -> Result<u32, Error> {
Ok(self.serial.baud_rate()?)
}
pub fn with_timeout<T, F>(&mut self, timeout: Duration, mut f: F) -> Result<T, Error>
where
F: FnMut(&mut Connection) -> Result<T, Error>,
{
let old_timeout = {
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
let old_timeout = serial.timeout();
serial.set_timeout(timeout)?;
old_timeout
};
let result = f(self);
self.serial.set_timeout(old_timeout)?;
result
}
pub fn read_flash_response(&mut self) -> Result<Option<CommandResponse>, Error> {
let mut response = Vec::new();
self.decoder.decode(&mut self.serial, &mut response)?;
if response.is_empty() {
return Ok(None);
}
let value = CommandResponseValue::Vector(response.clone());
let header = CommandResponse {
resp: 1_u8,
return_op: CommandType::ReadFlash as u8,
return_length: response.len() as u16,
value,
error: 0_u8,
status: 0_u8,
};
Ok(Some(header))
}
pub fn read_response_for_command(
&mut self,
ty: CommandType,
) -> Result<Option<CommandResponse>, Error> {
self.read_response_bounded(ty.max_response_len())
.for_command(ty)
}
#[deprecated = "May halt on unexpected input from the port --please use `read_response_for_command` instead. Deprecated in https://github.com/esp-rs/espflash/pull/1007"]
pub fn read_response(&mut self) -> Result<Option<CommandResponse>, Error> {
self.read_response_bounded(DEFAULT_MAX_LEN)
}
fn read_response_bounded(&mut self, max_len: u64) -> Result<Option<CommandResponse>, Error> {
match self.read_bounded(10, max_len)? {
None => Ok(None),
Some(response) => {
let status_len = if response.len() == 10 || response.len() == 26 {
2
} else {
4
};
let value = match response.len() {
10 | 12 => CommandResponseValue::ValueU32(u32::from_le_bytes(
response[4..][..4].try_into()?,
)),
44 => CommandResponseValue::ValueU128(u128::from_str_radix(
std::str::from_utf8(&response[8..][..32])?,
16,
)?),
26 => CommandResponseValue::ValueU128(u128::from_be_bytes(
response[8..][..16].try_into()?,
)),
_ => CommandResponseValue::Vector(response.clone()),
};
let header = CommandResponse {
resp: response[0],
return_op: response[1],
return_length: u16::from_le_bytes(response[2..][..2].try_into()?),
value,
error: response[response.len() - status_len + 1],
status: response[response.len() - status_len],
};
Ok(Some(header))
}
}
}
pub fn write_raw(&mut self, data: u32) -> Result<(), Error> {
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
serial.clear(serialport::ClearBuffer::Input)?;
let mut writer = BufWriter::new(serial);
let mut encoder = SlipEncoder::new(&mut writer)?;
encoder.write_all(&data.to_le_bytes())?;
encoder.finish()?;
writer.flush()?;
Ok(())
}
pub fn write_command(&mut self, command: Command<'_>) -> Result<(), Error> {
debug!("Writing command: {command:02x?}");
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
serial.clear(serialport::ClearBuffer::Input)?;
let mut writer = BufWriter::new(serial);
let mut encoder = SlipEncoder::new(&mut writer)?;
command.write(&mut encoder)?;
encoder.finish()?;
writer.flush()?;
Ok(())
}
pub fn command(&mut self, command: Command<'_>) -> Result<CommandResponseValue, Error> {
let ty = command.command_type();
self.write_command(command).for_command(ty)?;
for _ in 0..100 {
match self.read_response_for_command(ty)? {
Some(response) if response.return_op == ty as u8 => {
return if response.status != 0 {
let _error = self.flush();
Err(Error::RomError(Box::new(RomError::new(
command.command_type(),
RomErrorKind::from(response.error),
))))
} else {
let modified_value = match response.value {
CommandResponseValue::Vector(mut vec) if vec.len() >= 8 => {
vec = vec[8..][..response.return_length as usize].to_vec();
CommandResponseValue::Vector(vec)
}
_ => response.value, };
Ok(modified_value)
};
}
_ => continue,
}
}
Err(Error::Connection(Box::new(
ConnectionError::ConnectionFailed,
)))
}
pub fn read_reg(&mut self, addr: u32) -> Result<u32, Error> {
let resp = self.with_timeout(CommandType::ReadReg.timeout(), |connection| {
connection.command(Command::ReadReg { address: addr })
})?;
resp.try_into()
}
pub fn write_reg(&mut self, addr: u32, value: u32, mask: Option<u32>) -> Result<(), Error> {
self.with_timeout(CommandType::WriteReg.timeout(), |connection| {
connection.command(Command::WriteReg {
address: addr,
value,
mask,
})
})?;
Ok(())
}
pub(crate) fn update_reg(&mut self, addr: u32, mask: u32, new_value: u32) -> Result<(), Error> {
let masked_new_value = new_value.checked_shl(mask.trailing_zeros()).unwrap_or(0) & mask;
let masked_old_value = self.read_reg(addr)? & !mask;
self.write_reg(addr, masked_old_value | masked_new_value, None)
}
pub(crate) fn read(&mut self, len: usize) -> Result<Option<Vec<u8>>, Error> {
self.read_bounded(len, u64::MAX)
}
pub(crate) fn read_bounded(
&mut self,
len: usize,
max_len: u64,
) -> Result<Option<Vec<u8>>, Error> {
let mut tmp = Vec::with_capacity(1024);
let mut serial = (&mut self.serial).take(max_len);
loop {
self.decoder.decode(&mut serial, &mut tmp)?;
if tmp.len() >= len {
return Ok(Some(tmp));
}
}
}
pub fn flush(&mut self) -> Result<(), Error> {
self.serial.flush()?;
Ok(())
}
pub fn into_serial(self) -> Port {
self.serial
}
pub fn usb_pid(&self) -> u16 {
self.port_info.pid
}
pub(crate) fn is_using_usb_serial_jtag(&self) -> bool {
self.port_info.pid == USB_SERIAL_JTAG_PID
}
pub fn after_operation(&self) -> ResetAfterOperation {
self.after_operation
}
pub fn before_operation(&self) -> ResetBeforeOperation {
self.before_operation
}
#[cfg(feature = "serialport")]
pub fn security_info(&mut self, use_stub: bool) -> Result<SecurityInfo, crate::error::Error> {
self.with_timeout(CommandType::GetSecurityInfo.timeout(), |connection| {
let response = connection.command(Command::GetSecurityInfo)?;
if let crate::command::CommandResponseValue::Vector(data) = response {
let end = if use_stub { data.len() } else { data.len() - 4 };
SecurityInfo::try_from(&data[..end])
} else {
Err(Error::InvalidResponse(
"response was not a vector of bytes".into(),
))
}
})
}
#[cfg(feature = "serialport")]
pub fn detect_chip(
&mut self,
use_stub: bool,
) -> Result<crate::target::Chip, crate::error::Error> {
match self.security_info(use_stub) {
Ok(info) if info.chip_id.is_some() => {
let chip_id = info.chip_id.unwrap() as u16;
let chip = Chip::try_from(chip_id)?;
Ok(chip)
}
_ => {
let magic = if use_stub {
self.with_timeout(CommandType::ReadReg.timeout(), |connection| {
connection.command(Command::ReadReg {
address: CHIP_DETECT_MAGIC_REG_ADDR,
})
})?
.try_into()?
} else {
self.read_reg(CHIP_DETECT_MAGIC_REG_ADDR)?
};
debug!("Read chip magic value: 0x{magic:08x}");
Chip::from_magic(magic)
}
}
}
}
impl From<Connection> for Port {
fn from(conn: Connection) -> Self {
conn.into_serial()
}
}
mod encoder {
use std::io::Write;
use serde::Serialize;
const END: u8 = 0xC0;
const ESC: u8 = 0xDB;
const ESC_END: u8 = 0xDC;
const ESC_ESC: u8 = 0xDD;
#[derive(Debug, PartialEq, Eq, Serialize, Hash)]
pub struct SlipEncoder<'a, W: Write> {
writer: &'a mut W,
len: usize,
}
impl<'a, W: Write> SlipEncoder<'a, W> {
pub fn new(writer: &'a mut W) -> std::io::Result<Self> {
let len = writer.write(&[END])?;
Ok(Self { writer, len })
}
pub fn finish(mut self) -> std::io::Result<usize> {
self.len += self.writer.write(&[END])?;
Ok(self.len)
}
}
impl<W: Write> Write for SlipEncoder<'_, W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
for value in buf.iter() {
match *value {
END => {
self.len += self.writer.write(&[ESC, ESC_END])?;
}
ESC => {
self.len += self.writer.write(&[ESC, ESC_ESC])?;
}
_ => {
self.len += self.writer.write(&[*value])?;
}
}
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}
}