1use std::{
8 io::{BufWriter, Read, Write},
9 iter::zip,
10 thread::sleep,
11 time::Duration,
12};
13
14use log::{debug, info};
15use regex::Regex;
16use serialport::{SerialPort, UsbPortInfo};
17use slip_codec::SlipDecoder;
18
19#[cfg(unix)]
20use self::reset::UnixTightReset;
21use self::{
22 encoder::SlipEncoder,
23 reset::{
24 construct_reset_strategy_sequence, hard_reset, reset_after_flash, ClassicReset,
25 ResetAfterOperation, ResetBeforeOperation, ResetStrategy, UsbJtagSerialReset,
26 },
27};
28use crate::{
29 command::{Command, CommandType},
30 connection::reset::soft_reset,
31 error::{ConnectionError, Error, ResultExt, RomError, RomErrorKind},
32};
33
34pub mod reset;
35
36const MAX_CONNECT_ATTEMPTS: usize = 7;
37const MAX_SYNC_ATTEMPTS: usize = 5;
38pub(crate) const USB_SERIAL_JTAG_PID: u16 = 0x1001;
39
40#[cfg(unix)]
41pub type Port = serialport::TTYPort;
42#[cfg(windows)]
43pub type Port = serialport::COMPort;
44
45#[derive(Debug, Clone)]
46pub enum CommandResponseValue {
47 ValueU32(u32),
48 ValueU128(u128),
49 Vector(Vec<u8>),
50}
51
52impl TryInto<u32> for CommandResponseValue {
53 type Error = crate::error::Error;
54
55 fn try_into(self) -> Result<u32, Self::Error> {
56 match self {
57 CommandResponseValue::ValueU32(value) => Ok(value),
58 CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
59 CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
60 }
61 }
62}
63
64impl TryInto<u128> for CommandResponseValue {
65 type Error = crate::error::Error;
66
67 fn try_into(self) -> Result<u128, Self::Error> {
68 match self {
69 CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
70 CommandResponseValue::ValueU128(value) => Ok(value),
71 CommandResponseValue::Vector(_) => Err(crate::error::Error::InternalError),
72 }
73 }
74}
75
76impl TryInto<Vec<u8>> for CommandResponseValue {
77 type Error = crate::error::Error;
78
79 fn try_into(self) -> Result<Vec<u8>, Self::Error> {
80 match self {
81 CommandResponseValue::ValueU32(_) => Err(crate::error::Error::InternalError),
82 CommandResponseValue::ValueU128(_) => Err(crate::error::Error::InternalError),
83 CommandResponseValue::Vector(value) => Ok(value),
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct CommandResponse {
91 pub resp: u8,
92 pub return_op: u8,
93 pub return_length: u16,
94 pub value: CommandResponseValue,
95 pub error: u8,
96 pub status: u8,
97}
98
99pub struct Connection {
101 serial: Port,
102 port_info: UsbPortInfo,
103 decoder: SlipDecoder,
104 after_operation: ResetAfterOperation,
105 before_operation: ResetBeforeOperation,
106}
107
108impl Connection {
109 pub fn new(
110 serial: Port,
111 port_info: UsbPortInfo,
112 after_operation: ResetAfterOperation,
113 before_operation: ResetBeforeOperation,
114 ) -> Self {
115 Connection {
116 serial,
117 port_info,
118 decoder: SlipDecoder::new(),
119 after_operation,
120 before_operation,
121 }
122 }
123
124 pub fn begin(&mut self) -> Result<(), Error> {
126 let port_name = self.serial.name().unwrap_or_default();
127 let reset_sequence = construct_reset_strategy_sequence(
128 &port_name,
129 self.port_info.pid,
130 self.before_operation,
131 );
132
133 for (_, reset_strategy) in zip(0..MAX_CONNECT_ATTEMPTS, reset_sequence.iter().cycle()) {
134 match self.connect_attempt(reset_strategy) {
135 Ok(_) => {
136 return Ok(());
137 }
138 Err(e) => {
139 debug!("Failed to reset, error {:#?}, retrying", e);
140 }
141 }
142 }
143
144 Err(Error::Connection(ConnectionError::ConnectionFailed))
145 }
146
147 #[allow(clippy::borrowed_box)]
149 fn connect_attempt(&mut self, reset_strategy: &Box<dyn ResetStrategy>) -> Result<(), Error> {
150 if self.before_operation == ResetBeforeOperation::NoResetNoSync {
153 return Ok(());
154 }
155 let mut download_mode: bool = false;
156 let mut boot_mode = String::new();
157 let mut boot_log_detected = false;
158 let mut buff: Vec<u8>;
159 if self.before_operation != ResetBeforeOperation::NoReset {
160 reset_strategy.reset(&mut self.serial)?;
162
163 let available_bytes = self.serial.bytes_to_read()?;
164 buff = vec![0; available_bytes as usize];
165 let read_bytes = self.serial.read(&mut buff)? as u32;
166
167 if read_bytes != available_bytes {
168 return Err(Error::Connection(ConnectionError::ReadMissmatch(
169 available_bytes,
170 read_bytes,
171 )));
172 }
173
174 let read_slice = String::from_utf8_lossy(&buff[..read_bytes as usize]).into_owned();
175
176 let pattern =
177 Regex::new(r"boot:(0x[0-9a-fA-F]+)([\s\S]*waiting for download)?").unwrap();
178
179 if let Some(data) = pattern.captures(&read_slice) {
181 boot_log_detected = true;
182 boot_mode = data
184 .get(1)
185 .map(|m| m.as_str())
186 .unwrap_or_default()
187 .to_string();
188 download_mode = data.get(2).is_some();
189
190 debug!("Boot Mode: {}", boot_mode);
192 debug!("Download Mode: {}", download_mode);
193 };
194 }
195
196 for _ in 0..MAX_SYNC_ATTEMPTS {
197 self.flush()?;
198
199 if self.sync().is_ok() {
200 return Ok(());
201 }
202 }
203
204 if boot_log_detected {
205 if download_mode {
206 return Err(Error::Connection(ConnectionError::NoSyncReply));
207 } else {
208 return Err(Error::Connection(ConnectionError::WrongBootMode(
209 boot_mode.to_string(),
210 )));
211 }
212 }
213
214 Err(Error::Connection(ConnectionError::ConnectionFailed))
215 }
216
217 pub(crate) fn sync(&mut self) -> Result<(), Error> {
219 self.with_timeout(CommandType::Sync.timeout(), |connection| {
220 connection.command(Command::Sync)?;
221 connection.flush()?;
222
223 sleep(Duration::from_millis(10));
224
225 for _ in 0..MAX_CONNECT_ATTEMPTS {
226 match connection.read_response()? {
227 Some(response) if response.return_op == CommandType::Sync as u8 => {
228 if response.status == 1 {
229 connection.flush().ok();
230 return Err(Error::RomError(RomError::new(
231 CommandType::Sync,
232 RomErrorKind::from(response.error),
233 )));
234 }
235 }
236 _ => {
237 return Err(Error::RomError(RomError::new(
238 CommandType::Sync,
239 RomErrorKind::InvalidMessage,
240 )))
241 }
242 }
243 }
244
245 Ok(())
246 })?;
247
248 Ok(())
249 }
250
251 pub fn reset(&mut self) -> Result<(), Error> {
253 reset_after_flash(&mut self.serial, self.port_info.pid)?;
254
255 Ok(())
256 }
257
258 pub fn reset_after(&mut self, is_stub: bool) -> Result<(), Error> {
260 let pid = self.get_usb_pid()?;
261
262 match self.after_operation {
263 ResetAfterOperation::HardReset => hard_reset(&mut self.serial, pid),
264 ResetAfterOperation::NoReset => {
265 info!("Staying in bootloader");
266 soft_reset(self, true, is_stub)?;
267
268 Ok(())
269 }
270 ResetAfterOperation::NoResetNoStub => {
271 info!("Staying in flasher stub");
272 Ok(())
273 }
274 }
275 }
276
277 pub fn reset_to_flash(&mut self, extra_delay: bool) -> Result<(), Error> {
279 if self.port_info.pid == USB_SERIAL_JTAG_PID {
280 UsbJtagSerialReset.reset(&mut self.serial)
281 } else {
282 #[cfg(unix)]
283 if UnixTightReset::new(extra_delay)
284 .reset(&mut self.serial)
285 .is_ok()
286 {
287 return Ok(());
288 }
289
290 ClassicReset::new(extra_delay).reset(&mut self.serial)
291 }
292 }
293
294 pub fn set_timeout(&mut self, timeout: Duration) -> Result<(), Error> {
296 self.serial.set_timeout(timeout)?;
297 Ok(())
298 }
299
300 pub fn set_baud(&mut self, speed: u32) -> Result<(), Error> {
302 self.serial.set_baud_rate(speed)?;
303
304 Ok(())
305 }
306
307 pub fn get_baud(&self) -> Result<u32, Error> {
309 Ok(self.serial.baud_rate()?)
310 }
311
312 pub fn with_timeout<T, F>(&mut self, timeout: Duration, mut f: F) -> Result<T, Error>
314 where
315 F: FnMut(&mut Connection) -> Result<T, Error>,
316 {
317 let old_timeout = {
318 let mut binding = Box::new(&mut self.serial);
319 let serial = binding.as_mut();
320 let old_timeout = serial.timeout();
321 serial.set_timeout(timeout)?;
322 old_timeout
323 };
324
325 let result = f(self);
326
327 self.serial.set_timeout(old_timeout)?;
328
329 result
330 }
331
332 pub fn read_response(&mut self) -> Result<Option<CommandResponse>, Error> {
334 match self.read(10)? {
335 None => Ok(None),
336 Some(response) => {
337 let status_len = if response.len() == 10 || response.len() == 26 {
348 2
349 } else {
350 4
351 };
352
353 let value = match response.len() {
354 10 | 12 => CommandResponseValue::ValueU32(u32::from_le_bytes(
355 response[4..][..4].try_into().unwrap(),
356 )),
357 44 => {
358 CommandResponseValue::ValueU128(
360 u128::from_str_radix(
361 std::str::from_utf8(&response[8..][..32]).unwrap(),
362 16,
363 )
364 .unwrap(),
365 )
366 }
367 26 => {
368 CommandResponseValue::ValueU128(u128::from_be_bytes(
370 response[8..][..16].try_into().unwrap(),
371 ))
372 }
373 _ => CommandResponseValue::Vector(response.clone()),
374 };
375
376 let header = CommandResponse {
377 resp: response[0],
378 return_op: response[1],
379 return_length: u16::from_le_bytes(response[2..][..2].try_into().unwrap()),
380 value,
381 error: response[response.len() - status_len],
382 status: response[response.len() - status_len + 1],
383 };
384
385 Ok(Some(header))
386 }
387 }
388 }
389
390 pub fn write_raw(&mut self, data: u32) -> Result<(), Error> {
392 let mut binding = Box::new(&mut self.serial);
393 let serial = binding.as_mut();
394 serial.clear(serialport::ClearBuffer::Input)?;
395 let mut writer = BufWriter::new(serial);
396 let mut encoder = SlipEncoder::new(&mut writer)?;
397 encoder.write_all(&data.to_le_bytes())?;
398 encoder.finish()?;
399 writer.flush()?;
400 Ok(())
401 }
402
403 pub fn write_command(&mut self, command: Command) -> Result<(), Error> {
405 debug!("Writing command: {:?}", command);
406 let mut binding = Box::new(&mut self.serial);
407 let serial = binding.as_mut();
408
409 serial.clear(serialport::ClearBuffer::Input)?;
410 let mut writer = BufWriter::new(serial);
411 let mut encoder = SlipEncoder::new(&mut writer)?;
412 command.write(&mut encoder)?;
413 encoder.finish()?;
414 writer.flush()?;
415 Ok(())
416 }
417
418 pub fn command(&mut self, command: Command) -> Result<CommandResponseValue, Error> {
420 let ty = command.command_type();
421 self.write_command(command).for_command(ty)?;
422
423 for _ in 0..100 {
424 match self.read_response().for_command(ty)? {
425 Some(response) if response.return_op == ty as u8 => {
426 return if response.error != 0 {
427 let _error = self.flush();
428 Err(Error::RomError(RomError::new(
429 command.command_type(),
430 RomErrorKind::from(response.error),
431 )))
432 } else {
433 Ok(response.value)
434 }
435 }
436 _ => {
437 continue;
438 }
439 }
440 }
441 Err(Error::Connection(ConnectionError::ConnectionFailed))
442 }
443
444 pub fn read_reg(&mut self, reg: u32) -> Result<u32, Error> {
446 self.with_timeout(CommandType::ReadReg.timeout(), |connection| {
447 connection.command(Command::ReadReg { address: reg })
448 })
449 .map(|v| v.try_into().unwrap())
450 }
451
452 pub fn write_reg(&mut self, addr: u32, value: u32, mask: Option<u32>) -> Result<(), Error> {
454 self.with_timeout(CommandType::WriteReg.timeout(), |connection| {
455 connection.command(Command::WriteReg {
456 address: addr,
457 value,
458 mask,
459 })
460 })?;
461
462 Ok(())
463 }
464
465 pub(crate) fn read(&mut self, len: usize) -> Result<Option<Vec<u8>>, Error> {
466 let mut tmp = Vec::with_capacity(1024);
467 loop {
468 self.decoder.decode(&mut self.serial, &mut tmp)?;
469 if tmp.len() >= len {
470 return Ok(Some(tmp));
471 }
472 }
473 }
474
475 pub fn flush(&mut self) -> Result<(), Error> {
477 self.serial.flush()?;
478 Ok(())
479 }
480
481 pub fn into_serial(self) -> Port {
483 self.serial
484 }
485
486 pub fn get_usb_pid(&self) -> Result<u16, Error> {
488 Ok(self.port_info.pid)
489 }
490}
491
492mod encoder {
493 use std::io::Write;
494
495 const END: u8 = 0xC0;
496 const ESC: u8 = 0xDB;
497 const ESC_END: u8 = 0xDC;
498 const ESC_ESC: u8 = 0xDD;
499
500 pub struct SlipEncoder<'a, W: Write> {
501 writer: &'a mut W,
502 len: usize,
503 }
504
505 impl<'a, W: Write> SlipEncoder<'a, W> {
506 pub fn new(writer: &'a mut W) -> std::io::Result<Self> {
508 let len = writer.write(&[END])?;
509 Ok(Self { writer, len })
510 }
511
512 pub fn finish(mut self) -> std::io::Result<usize> {
513 self.len += self.writer.write(&[END])?;
514 Ok(self.len)
515 }
516 }
517
518 impl<W: Write> Write for SlipEncoder<'_, W> {
519 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
523 for value in buf.iter() {
524 match *value {
525 END => {
526 self.len += self.writer.write(&[ESC, ESC_END])?;
527 }
528 ESC => {
529 self.len += self.writer.write(&[ESC, ESC_ESC])?;
530 }
531 _ => {
532 self.len += self.writer.write(&[*value])?;
533 }
534 }
535 }
536
537 Ok(buf.len())
538 }
539
540 fn flush(&mut self) -> std::io::Result<()> {
541 self.writer.flush()
542 }
543 }
544}