espflash/connection/
mod.rs

1//! Establish a connection with a target device
2//!
3//! The [Connection] struct abstracts over the serial connection and
4//! sending/decoding of commands, and provides higher-level operations with the
5//! device.
6
7use std::{
8    io::{BufWriter, Read, Write},
9    iter::zip,
10    thread::sleep,
11    time::Duration,
12};
13
14use log::{debug, info};
15use regex::Regex;
16use serialport::{SerialPort, UsbPortInfo};
17use slip_codec::SlipDecoder;
18
19#[cfg(unix)]
20use self::reset::UnixTightReset;
21use self::{
22    encoder::SlipEncoder,
23    reset::{
24        construct_reset_strategy_sequence, hard_reset, reset_after_flash, ClassicReset,
25        ResetAfterOperation, ResetBeforeOperation, ResetStrategy, UsbJtagSerialReset,
26    },
27};
28use crate::{
29    command::{Command, CommandType},
30    connection::reset::soft_reset,
31    error::{ConnectionError, Error, ResultExt, RomError, RomErrorKind},
32};
33
34pub mod reset;
35
36const MAX_CONNECT_ATTEMPTS: usize = 7;
37const MAX_SYNC_ATTEMPTS: usize = 5;
38pub(crate) const USB_SERIAL_JTAG_PID: u16 = 0x1001;
39
40#[cfg(unix)]
41pub type Port = serialport::TTYPort;
42#[cfg(windows)]
43pub type Port = serialport::COMPort;
44
45#[derive(Debug, Clone)]
46pub enum CommandResponseValue {
47    ValueU32(u32),
48    ValueU128(u128),
49    Vector(Vec<u8>),
50}
51
52impl TryInto<u32> for CommandResponseValue {
53    type Error = crate::error::Error;
54
55    fn try_into(self) -> Result<u32, Self::Error> {
56        match self {
57            CommandResponseValue::ValueU32(value) => Ok(value),
58            CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
59            CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
60        }
61    }
62}
63
64impl TryInto<u128> for CommandResponseValue {
65    type Error = crate::error::Error;
66
67    fn try_into(self) -> Result<u128, Self::Error> {
68        match self {
69            CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
70            CommandResponseValue::ValueU128(value) => Ok(value),
71            CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
72        }
73    }
74}
75
76impl TryInto<Vec<u8>> for CommandResponseValue {
77    type Error = crate::error::Error;
78
79    fn try_into(self) -> Result<Vec<u8>, Self::Error> {
80        match self {
81            CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
82            CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
83            CommandResponseValue::Vector(value) => Ok(value),
84        }
85    }
86}
87
88/// A response from a target device following a command
89#[derive(Debug, Clone)]
90pub struct CommandResponse {
91    pub resp: u8,
92    pub return_op: u8,
93    pub return_length: u16,
94    pub value: CommandResponseValue,
95    pub error: u8,
96    pub status: u8,
97}
98
99/// An established connection with a target device
100pub struct Connection {
101    serial: Port,
102    port_info: UsbPortInfo,
103    decoder: SlipDecoder,
104    after_operation: ResetAfterOperation,
105    before_operation: ResetBeforeOperation,
106}
107
108impl Connection {
109    pub fn new(
110        serial: Port,
111        port_info: UsbPortInfo,
112        after_operation: ResetAfterOperation,
113        before_operation: ResetBeforeOperation,
114    ) -> Self {
115        Connection {
116            serial,
117            port_info,
118            decoder: SlipDecoder::new(),
119            after_operation,
120            before_operation,
121        }
122    }
123
124    /// Initialize a connection with a device
125    pub fn begin(&mut self) -> Result<(), Error> {
126        let port_name = self.serial.name().unwrap_or_default();
127        let reset_sequence = construct_reset_strategy_sequence(
128            &port_name,
129            self.port_info.pid,
130            self.before_operation,
131        );
132
133        for (_, reset_strategy) in zip(0..MAX_CONNECT_ATTEMPTS, reset_sequence.iter().cycle()) {
134            match self.connect_attempt(reset_strategy) {
135                Ok(_) => {
136                    return Ok(());
137                }
138                Err(e) => {
139                    debug!("Failed to reset, error {:#?}, retrying", e);
140                }
141            }
142        }
143
144        Err(Error::Connection(ConnectionError::ConnectionFailed))
145    }
146
147    /// Try to connect to a device
148    #[allow(clippy::borrowed_box)]
149    fn connect_attempt(&mut self, reset_strategy: &Box<dyn ResetStrategy>) -> Result<(), Error> {
150        // If we're doing no_sync, we're likely communicating as a pass through
151        // with an intermediate device to the ESP32
152        if self.before_operation == ResetBeforeOperation::NoResetNoSync {
153            return Ok(());
154        }
155        let mut download_mode: bool = false;
156        let mut boot_mode = String::new();
157        let mut boot_log_detected = false;
158        let mut buff: Vec<u8>;
159        if self.before_operation != ResetBeforeOperation::NoReset {
160            // Reset the chip to bootloader (download mode)
161            reset_strategy.reset(&mut self.serial)?;
162
163            let available_bytes = self.serial.bytes_to_read()?;
164            buff = vec![0; available_bytes as usize];
165            let read_bytes = self.serial.read(&mut buff)? as u32;
166
167            if read_bytes != available_bytes {
168                return Err(Error::Connection(ConnectionError::ReadMissmatch(
169                    available_bytes,
170                    read_bytes,
171                )));
172            }
173
174            let read_slice = String::from_utf8_lossy(&buff[..read_bytes as usize]).into_owned();
175
176            let pattern =
177                Regex::new(r"boot:(0x[0-9a-fA-F]+)([\s\S]*waiting for download)?").unwrap();
178
179            // Search for the pattern in the read data
180            if let Some(data) = pattern.captures(&read_slice) {
181                boot_log_detected = true;
182                // Boot log detected
183                boot_mode = data
184                    .get(1)
185                    .map(|m| m.as_str())
186                    .unwrap_or_default()
187                    .to_string();
188                download_mode = data.get(2).is_some();
189
190                // Further processing or printing the results
191                debug!("Boot Mode: {}", boot_mode);
192                debug!("Download Mode: {}", download_mode);
193            };
194        }
195
196        for _ in 0..MAX_SYNC_ATTEMPTS {
197            self.flush()?;
198
199            if self.sync().is_ok() {
200                return Ok(());
201            }
202        }
203
204        if boot_log_detected {
205            if download_mode {
206                return Err(Error::Connection(ConnectionError::NoSyncReply));
207            } else {
208                return Err(Error::Connection(ConnectionError::WrongBootMode(
209                    boot_mode.to_string(),
210                )));
211            }
212        }
213
214        Err(Error::Connection(ConnectionError::ConnectionFailed))
215    }
216
217    /// Try to sync with the device for a given timeout
218    pub(crate) fn sync(&mut self) -> Result<(), Error> {
219        self.with_timeout(CommandType::Sync.timeout(), |connection| {
220            connection.command(Command::Sync)?;
221            connection.flush()?;
222
223            sleep(Duration::from_millis(10));
224
225            for _ in 0..MAX_CONNECT_ATTEMPTS {
226                match connection.read_response()? {
227                    Some(response) if response.return_op == CommandType::Sync as u8 => {
228                        if response.status == 1 {
229                            connection.flush().ok();
230                            return Err(Error::RomError(RomError::new(
231                                CommandType::Sync,
232                                RomErrorKind::from(response.error),
233                            )));
234                        }
235                    }
236                    _ => {
237                        return Err(Error::RomError(RomError::new(
238                            CommandType::Sync,
239                            RomErrorKind::InvalidMessage,
240                        )))
241                    }
242                }
243            }
244
245            Ok(())
246        })?;
247
248        Ok(())
249    }
250
251    // Reset the device
252    pub fn reset(&mut self) -> Result<(), Error> {
253        reset_after_flash(&mut self.serial, self.port_info.pid)?;
254
255        Ok(())
256    }
257
258    // Reset the device taking into account the reset after argument
259    pub fn reset_after(&mut self, is_stub: bool) -> Result<(), Error> {
260        let pid = self.get_usb_pid()?;
261
262        match self.after_operation {
263            ResetAfterOperation::HardReset => hard_reset(&mut self.serial, pid),
264            ResetAfterOperation::NoReset => {
265                info!("Staying in bootloader");
266                soft_reset(self, true, is_stub)?;
267
268                Ok(())
269            }
270            ResetAfterOperation::NoResetNoStub => {
271                info!("Staying in flasher stub");
272                Ok(())
273            }
274        }
275    }
276
277    // Reset the device to flash mode
278    pub fn reset_to_flash(&mut self, extra_delay: bool) -> Result<(), Error> {
279        if self.port_info.pid == USB_SERIAL_JTAG_PID {
280            UsbJtagSerialReset.reset(&mut self.serial)
281        } else {
282            #[cfg(unix)]
283            if UnixTightReset::new(extra_delay)
284                .reset(&mut self.serial)
285                .is_ok()
286            {
287                return Ok(());
288            }
289
290            ClassicReset::new(extra_delay).reset(&mut self.serial)
291        }
292    }
293
294    /// Set timeout for the serial port
295    pub fn set_timeout(&mut self, timeout: Duration) -> Result<(), Error> {
296        self.serial.set_timeout(timeout)?;
297        Ok(())
298    }
299
300    /// Set baud rate for the serial port
301    pub fn set_baud(&mut self, speed: u32) -> Result<(), Error> {
302        self.serial.set_baud_rate(speed)?;
303
304        Ok(())
305    }
306
307    /// Get the current baud rate of the serial port
308    pub fn get_baud(&self) -> Result<u32, Error> {
309        Ok(self.serial.baud_rate()?)
310    }
311
312    /// Run a command with a timeout defined by the command type
313    pub fn with_timeout<T, F>(&mut self, timeout: Duration, mut f: F) -> Result<T, Error>
314    where
315        F: FnMut(&mut Connection) -> Result<T, Error>,
316    {
317        let old_timeout = {
318            let mut binding = Box::new(&mut self.serial);
319            let serial = binding.as_mut();
320            let old_timeout = serial.timeout();
321            serial.set_timeout(timeout)?;
322            old_timeout
323        };
324
325        let result = f(self);
326
327        self.serial.set_timeout(old_timeout)?;
328
329        result
330    }
331
332    /// Read the response from a serial port
333    pub fn read_response(&mut self) -> Result<Option<CommandResponse>, Error> {
334        match self.read(10)? {
335            None => Ok(None),
336            Some(response) => {
337                // here is what esptool does: https://github.com/espressif/esptool/blob/master/esptool/loader.py#L458
338                // from esptool: things are a bit weird here, bear with us
339
340                // we rely on the known and expected response sizes which should be fine for now - if that changes we need to pass the command type
341                // we are parsing the response for
342                // for most commands the response length is 10 (for the stub) or 12 (for ROM code)
343                // the MD5 command response is 44 for ROM loader, 26 for the stub
344                // see https://docs.espressif.com/projects/esptool/en/latest/esp32/advanced-topics/serial-protocol.html?highlight=md5#response-packet
345                // see https://docs.espressif.com/projects/esptool/en/latest/esp32/advanced-topics/serial-protocol.html?highlight=md5#status-bytes
346                // see https://docs.espressif.com/projects/esptool/en/latest/esp32/advanced-topics/serial-protocol.html?highlight=md5#verifying-uploaded-data
347                let status_len = if response.len() == 10 || response.len() == 26 {
348                    2
349                } else {
350                    4
351                };
352
353                let value = match response.len() {
354                    10 | 12 => CommandResponseValue::ValueU32(u32::from_le_bytes(
355                        response[4..][..4].try_into().unwrap(),
356                    )),
357                    44 => {
358                        // MD5 is in ASCII
359                        CommandResponseValue::ValueU128(
360                            u128::from_str_radix(
361                                std::str::from_utf8(&response[8..][..32]).unwrap(),
362                                16,
363                            )
364                            .unwrap(),
365                        )
366                    }
367                    26 => {
368                        // MD5 is BE bytes
369                        CommandResponseValue::ValueU128(u128::from_be_bytes(
370                            response[8..][..16].try_into().unwrap(),
371                        ))
372                    }
373                    _ => CommandResponseValue::Vector(response.clone()),
374                };
375
376                let header = CommandResponse {
377                    resp: response[0],
378                    return_op: response[1],
379                    return_length: u16::from_le_bytes(response[2..][..2].try_into().unwrap()),
380                    value,
381                    error: response[response.len() - status_len],
382                    status: response[response.len() - status_len + 1],
383                };
384
385                Ok(Some(header))
386            }
387        }
388    }
389
390    /// Write raw data to the serial port
391    pub fn write_raw(&mut self, data: u32) -> Result<(), Error> {
392        let mut binding = Box::new(&mut self.serial);
393        let serial = binding.as_mut();
394        serial.clear(serialport::ClearBuffer::Input)?;
395        let mut writer = BufWriter::new(serial);
396        let mut encoder = SlipEncoder::new(&mut writer)?;
397        encoder.write_all(&data.to_le_bytes())?;
398        encoder.finish()?;
399        writer.flush()?;
400        Ok(())
401    }
402
403    /// Write a command to the serial port
404    pub fn write_command(&mut self, command: Command) -> Result<(), Error> {
405        debug!("Writing command: {:?}", command);
406        let mut binding = Box::new(&mut self.serial);
407        let serial = binding.as_mut();
408
409        serial.clear(serialport::ClearBuffer::Input)?;
410        let mut writer = BufWriter::new(serial);
411        let mut encoder = SlipEncoder::new(&mut writer)?;
412        command.write(&mut encoder)?;
413        encoder.finish()?;
414        writer.flush()?;
415        Ok(())
416    }
417
418    ///  Write a command and reads the response
419    pub fn command(&mut self, command: Command) -> Result<CommandResponseValue, Error> {
420        let ty = command.command_type();
421        self.write_command(command).for_command(ty)?;
422
423        for _ in 0..100 {
424            match self.read_response().for_command(ty)? {
425                Some(response) if response.return_op == ty as u8 => {
426                    return if response.error != 0 {
427                        let _error = self.flush();
428                        Err(Error::RomError(RomError::new(
429                            command.command_type(),
430                            RomErrorKind::from(response.error),
431                        )))
432                    } else {
433                        Ok(response.value)
434                    }
435                }
436                _ => {
437                    continue;
438                }
439            }
440        }
441        Err(Error::Connection(ConnectionError::ConnectionFailed))
442    }
443
444    /// Read a register command with a timeout
445    pub fn read_reg(&mut self, reg: u32) -> Result<u32, Error> {
446        self.with_timeout(CommandType::ReadReg.timeout(), |connection| {
447            connection.command(Command::ReadReg { address: reg })
448        })
449        .map(|v| v.try_into().unwrap())
450    }
451
452    /// Write a register command with a timeout
453    pub fn write_reg(&mut self, addr: u32, value: u32, mask: Option<u32>) -> Result<(), Error> {
454        self.with_timeout(CommandType::WriteReg.timeout(), |connection| {
455            connection.command(Command::WriteReg {
456                address: addr,
457                value,
458                mask,
459            })
460        })?;
461
462        Ok(())
463    }
464
465    pub(crate) fn read(&mut self, len: usize) -> Result<Option<Vec<u8>>, Error> {
466        let mut tmp = Vec::with_capacity(1024);
467        loop {
468            self.decoder.decode(&mut self.serial, &mut tmp)?;
469            if tmp.len() >= len {
470                return Ok(Some(tmp));
471            }
472        }
473    }
474
475    /// Flush the serial port
476    pub fn flush(&mut self) -> Result<(), Error> {
477        self.serial.flush()?;
478        Ok(())
479    }
480
481    /// Turn a serial port into a [Port]
482    pub fn into_serial(self) -> Port {
483        self.serial
484    }
485
486    /// Get the USB PID of the serial port
487    pub fn get_usb_pid(&self) -> Result<u16, Error> {
488        Ok(self.port_info.pid)
489    }
490}
491
492mod encoder {
493    use std::io::Write;
494
495    const END: u8 = 0xC0;
496    const ESC: u8 = 0xDB;
497    const ESC_END: u8 = 0xDC;
498    const ESC_ESC: u8 = 0xDD;
499
500    pub struct SlipEncoder<'a, W: Write> {
501        writer: &'a mut W,
502        len: usize,
503    }
504
505    impl<'a, W: Write> SlipEncoder<'a, W> {
506        /// Creates a new encoder context
507        pub fn new(writer: &'a mut W) -> std::io::Result<Self> {
508            let len = writer.write(&[END])?;
509            Ok(Self { writer, len })
510        }
511
512        pub fn finish(mut self) -> std::io::Result<usize> {
513            self.len += self.writer.write(&[END])?;
514            Ok(self.len)
515        }
516    }
517
518    impl<W: Write> Write for SlipEncoder<'_, W> {
519        /// Writes the given buffer replacing the END and ESC bytes
520        ///
521        /// See https://docs.espressif.com/projects/esptool/en/latest/esp32c3/advanced-topics/serial-protocol.html#low-level-protocol
522        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
523            for value in buf.iter() {
524                match *value {
525                    END => {
526                        self.len += self.writer.write(&[ESC, ESC_END])?;
527                    }
528                    ESC => {
529                        self.len += self.writer.write(&[ESC, ESC_ESC])?;
530                    }
531                    _ => {
532                        self.len += self.writer.write(&[*value])?;
533                    }
534                }
535            }
536
537            Ok(buf.len())
538        }
539
540        fn flush(&mut self) -> std::io::Result<()> {
541            self.writer.flush()
542        }
543    }
544}