use std::{
io::{BufWriter, Read, Write},
iter::zip,
thread::sleep,
time::Duration,
};
use log::{debug, info};
use regex::Regex;
use serialport::{SerialPort, UsbPortInfo};
use slip_codec::SlipDecoder;
#[cfg(unix)]
use self::reset::UnixTightReset;
use self::{
encoder::SlipEncoder,
reset::{
construct_reset_strategy_sequence, hard_reset, reset_after_flash, ClassicReset,
ResetAfterOperation, ResetBeforeOperation, ResetStrategy, UsbJtagSerialReset,
},
};
use crate::{
command::{Command, CommandType},
connection::reset::soft_reset,
error::{ConnectionError, Error, ResultExt, RomError, RomErrorKind},
};
pub mod reset;
const MAX_CONNECT_ATTEMPTS: usize = 7;
const MAX_SYNC_ATTEMPTS: usize = 5;
pub(crate) const USB_SERIAL_JTAG_PID: u16 = 0x1001;
#[cfg(unix)]
pub type Port = serialport::TTYPort;
#[cfg(windows)]
pub type Port = serialport::COMPort;
#[derive(Debug, Clone)]
pub enum CommandResponseValue {
ValueU32(u32),
ValueU128(u128),
Vector(Vec<u8>),
}
impl TryInto<u32> for CommandResponseValue {
type Error = crate::error::Error;
fn try_into(self) -> Result<u32, Self::Error> {
match self {
CommandResponseValue::ValueU32(value) => Ok(value),
CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
}
}
}
impl TryInto<u128> for CommandResponseValue {
type Error = crate::error::Error;
fn try_into(self) -> Result<u128, Self::Error> {
match self {
CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
CommandResponseValue::ValueU128(value) => Ok(value),
CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
}
}
}
impl TryInto<Vec<u8>> for CommandResponseValue {
type Error = crate::error::Error;
fn try_into(self) -> Result<Vec<u8>, Self::Error> {
match self {
CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
CommandResponseValue::Vector(value) => Ok(value),
}
}
}
#[derive(Debug, Clone)]
pub struct CommandResponse {
pub resp: u8,
pub return_op: u8,
pub return_length: u16,
pub value: CommandResponseValue,
pub error: u8,
pub status: u8,
}
pub struct Connection {
serial: Port,
port_info: UsbPortInfo,
decoder: SlipDecoder,
after_operation: ResetAfterOperation,
before_operation: ResetBeforeOperation,
}
impl Connection {
pub fn new(
serial: Port,
port_info: UsbPortInfo,
after_operation: ResetAfterOperation,
before_operation: ResetBeforeOperation,
) -> Self {
Connection {
serial,
port_info,
decoder: SlipDecoder::new(),
after_operation,
before_operation,
}
}
pub fn begin(&mut self) -> Result<(), Error> {
let port_name = self.serial.name().unwrap_or_default();
let reset_sequence = construct_reset_strategy_sequence(
&port_name,
self.port_info.pid,
self.before_operation,
);
for (_, reset_strategy) in zip(0..MAX_CONNECT_ATTEMPTS, reset_sequence.iter().cycle()) {
match self.connect_attempt(reset_strategy) {
Ok(_) => {
return Ok(());
}
Err(e) => {
debug!("Failed to reset, error {:#?}, retrying", e);
}
}
}
Err(Error::Connection(ConnectionError::ConnectionFailed))
}
#[allow(clippy::borrowed_box)]
fn connect_attempt(&mut self, reset_strategy: &Box<dyn ResetStrategy>) -> Result<(), Error> {
if self.before_operation == ResetBeforeOperation::NoResetNoSync {
return Ok(());
}
let mut download_mode: bool = false;
let mut boot_mode = String::new();
let mut boot_log_detected = false;
let mut buff: Vec<u8>;
if self.before_operation != ResetBeforeOperation::NoReset {
reset_strategy.reset(&mut self.serial)?;
let available_bytes = self.serial.bytes_to_read()?;
buff = vec![0; available_bytes as usize];
let read_bytes = self.serial.read(&mut buff)? as u32;
if read_bytes != available_bytes {
return Err(Error::Connection(ConnectionError::ReadMissmatch(
available_bytes,
read_bytes,
)));
}
let read_slice = String::from_utf8_lossy(&buff[..read_bytes as usize]).into_owned();
let pattern = Regex::new(r"boot:(0x[0-9a-fA-F]+)(.*waiting for download)?").unwrap();
if let Some(data) = pattern.captures(&read_slice) {
boot_log_detected = true;
boot_mode = data
.get(1)
.map(|m| m.as_str())
.unwrap_or_default()
.to_string();
download_mode = data.get(2).is_some();
debug!("Boot Mode: {}", boot_mode);
debug!("Download Mode: {}", download_mode);
};
}
for _ in 0..MAX_SYNC_ATTEMPTS {
self.flush()?;
if self.sync().is_ok() {
return Ok(());
}
}
if boot_log_detected {
if download_mode {
return Err(Error::Connection(ConnectionError::NoSyncReply));
} else {
return Err(Error::Connection(ConnectionError::WrongBootMode(
boot_mode.to_string(),
)));
}
}
Err(Error::Connection(ConnectionError::ConnectionFailed))
}
pub(crate) fn sync(&mut self) -> Result<(), Error> {
self.with_timeout(CommandType::Sync.timeout(), |connection| {
connection.command(Command::Sync)?;
connection.flush()?;
sleep(Duration::from_millis(10));
for _ in 0..MAX_CONNECT_ATTEMPTS {
match connection.read_response()? {
Some(response) if response.return_op == CommandType::Sync as u8 => {
if response.status == 1 {
connection.flush().ok();
return Err(Error::RomError(RomError::new(
CommandType::Sync,
RomErrorKind::from(response.error),
)));
}
}
_ => {
return Err(Error::RomError(RomError::new(
CommandType::Sync,
RomErrorKind::InvalidMessage,
)))
}
}
}
Ok(())
})?;
Ok(())
}
pub fn reset(&mut self) -> Result<(), Error> {
reset_after_flash(&mut self.serial, self.port_info.pid)?;
Ok(())
}
pub fn reset_after(&mut self, is_stub: bool) -> Result<(), Error> {
let pid = self.get_usb_pid()?;
match self.after_operation {
ResetAfterOperation::HardReset => hard_reset(&mut self.serial, pid),
ResetAfterOperation::NoReset => {
info!("Staying in bootloader");
soft_reset(self, true, is_stub)?;
Ok(())
}
ResetAfterOperation::NoResetNoStub => {
info!("Staying in flasher stub");
Ok(())
}
}
}
pub fn reset_to_flash(&mut self, extra_delay: bool) -> Result<(), Error> {
if self.port_info.pid == USB_SERIAL_JTAG_PID {
UsbJtagSerialReset.reset(&mut self.serial)
} else {
#[cfg(unix)]
if UnixTightReset::new(extra_delay)
.reset(&mut self.serial)
.is_ok()
{
return Ok(());
}
ClassicReset::new(extra_delay).reset(&mut self.serial)
}
}
pub fn set_timeout(&mut self, timeout: Duration) -> Result<(), Error> {
self.serial.set_timeout(timeout)?;
Ok(())
}
pub fn set_baud(&mut self, speed: u32) -> Result<(), Error> {
self.serial.set_baud_rate(speed)?;
Ok(())
}
pub fn get_baud(&self) -> Result<u32, Error> {
Ok(self.serial.baud_rate()?)
}
pub fn with_timeout<T, F>(&mut self, timeout: Duration, mut f: F) -> Result<T, Error>
where
F: FnMut(&mut Connection) -> Result<T, Error>,
{
let old_timeout = {
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
let old_timeout = serial.timeout();
serial.set_timeout(timeout)?;
old_timeout
};
let result = f(self);
self.serial.set_timeout(old_timeout)?;
result
}
pub fn read_response(&mut self) -> Result<Option<CommandResponse>, Error> {
match self.read(10)? {
None => Ok(None),
Some(response) => {
let status_len = if response.len() == 10 || response.len() == 26 {
2
} else {
4
};
let value = match response.len() {
10 | 12 => CommandResponseValue::ValueU32(u32::from_le_bytes(
response[4..][..4].try_into().unwrap(),
)),
44 => {
CommandResponseValue::ValueU128(
u128::from_str_radix(
std::str::from_utf8(&response[8..][..32]).unwrap(),
16,
)
.unwrap(),
)
}
26 => {
CommandResponseValue::ValueU128(u128::from_be_bytes(
response[8..][..16].try_into().unwrap(),
))
}
_ => CommandResponseValue::Vector(response.clone()),
};
let header = CommandResponse {
resp: response[0],
return_op: response[1],
return_length: u16::from_le_bytes(response[2..][..2].try_into().unwrap()),
value,
error: response[response.len() - status_len],
status: response[response.len() - status_len + 1],
};
Ok(Some(header))
}
}
}
pub fn write_raw(&mut self, data: u32) -> Result<(), Error> {
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
serial.clear(serialport::ClearBuffer::Input)?;
let mut writer = BufWriter::new(serial);
let mut encoder = SlipEncoder::new(&mut writer)?;
encoder.write_all(&data.to_le_bytes())?;
encoder.finish()?;
writer.flush()?;
Ok(())
}
pub fn write_command(&mut self, command: Command) -> Result<(), Error> {
debug!("Writing command: {:?}", command);
let mut binding = Box::new(&mut self.serial);
let serial = binding.as_mut();
serial.clear(serialport::ClearBuffer::Input)?;
let mut writer = BufWriter::new(serial);
let mut encoder = SlipEncoder::new(&mut writer)?;
command.write(&mut encoder)?;
encoder.finish()?;
writer.flush()?;
Ok(())
}
pub fn command(&mut self, command: Command) -> Result<CommandResponseValue, Error> {
let ty = command.command_type();
self.write_command(command).for_command(ty)?;
for _ in 0..100 {
match self.read_response().for_command(ty)? {
Some(response) if response.return_op == ty as u8 => {
return if response.error != 0 {
let _error = self.flush();
Err(Error::RomError(RomError::new(
command.command_type(),
RomErrorKind::from(response.error),
)))
} else {
Ok(response.value)
}
}
_ => {
continue;
}
}
}
Err(Error::Connection(ConnectionError::ConnectionFailed))
}
pub fn read_reg(&mut self, reg: u32) -> Result<u32, Error> {
self.with_timeout(CommandType::ReadReg.timeout(), |connection| {
connection.command(Command::ReadReg { address: reg })
})
.map(|v| v.try_into().unwrap())
}
pub fn write_reg(&mut self, addr: u32, value: u32, mask: Option<u32>) -> Result<(), Error> {
self.with_timeout(CommandType::WriteReg.timeout(), |connection| {
connection.command(Command::WriteReg {
address: addr,
value,
mask,
})
})?;
Ok(())
}
pub(crate) fn read(&mut self, len: usize) -> Result<Option<Vec<u8>>, Error> {
let mut tmp = Vec::with_capacity(1024);
loop {
self.decoder.decode(&mut self.serial, &mut tmp)?;
if tmp.len() >= len {
return Ok(Some(tmp));
}
}
}
pub fn flush(&mut self) -> Result<(), Error> {
self.serial.flush()?;
Ok(())
}
pub fn into_serial(self) -> Port {
self.serial
}
pub fn get_usb_pid(&self) -> Result<u16, Error> {
Ok(self.port_info.pid)
}
}
mod encoder {
use std::io::Write;
const END: u8 = 0xC0;
const ESC: u8 = 0xDB;
const ESC_END: u8 = 0xDC;
const ESC_ESC: u8 = 0xDD;
pub struct SlipEncoder<'a, W: Write> {
writer: &'a mut W,
len: usize,
}
impl<'a, W: Write> SlipEncoder<'a, W> {
pub fn new(writer: &'a mut W) -> std::io::Result<Self> {
let len = writer.write(&[END])?;
Ok(Self { writer, len })
}
pub fn finish(mut self) -> std::io::Result<usize> {
self.len += self.writer.write(&[END])?;
Ok(self.len)
}
}
impl<'a, W: Write> Write for SlipEncoder<'a, W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
for value in buf.iter() {
match *value {
END => {
self.len += self.writer.write(&[ESC, ESC_END])?;
}
ESC => {
self.len += self.writer.write(&[ESC, ESC_ESC])?;
}
_ => {
self.len += self.writer.write(&[*value])?;
}
}
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}
}