use crate::error::{Error, Result};
use crate::logging::trace;
use crate::transport::serial::TIMEOUT;
use crate::{
proto,
transport::{
TransportRaw,
serial::{
FLIPPER_BAUD,
helpers::{drain_until, drain_until_str},
},
},
};
use crate::proto::CommandStatus;
use prost::Message;
use serialport::SerialPort;
#[derive(Debug)]
pub struct SerialRpcTransport {
command_index: u32,
port: Box<dyn SerialPort>,
}
pub trait CommandIndex {
fn increment_command_index(&mut self, by: u32) -> u32;
fn command_index(&mut self) -> u32;
}
impl CommandIndex for SerialRpcTransport {
fn increment_command_index(&mut self, by: u32) -> u32 {
self.command_index += by;
self.command_index
}
fn command_index(&mut self) -> u32 {
self.command_index
}
}
impl SerialRpcTransport {
#[cfg_attr(feature = "tracing", tracing::instrument)]
pub fn new<S: AsRef<str> + std::fmt::Debug>(port: S) -> Result<Self> {
let mut port = serialport::new(port.as_ref(), FLIPPER_BAUD)
.timeout(TIMEOUT)
.open()?;
trace!("draining(prompt)");
drain_until_str(&mut port, ">: ", TIMEOUT)?;
trace!("start_rpc_session");
port.write_all("start_rpc_session\r".as_bytes())?;
port.flush()?;
trace!("draining(start_rpc_session, \\n)");
drain_until(&mut port, b'\n', TIMEOUT)?;
Ok(Self {
command_index: 0,
port,
})
}
#[cfg_attr(feature = "tracing", tracing::instrument)]
pub fn from_port(port: Box<dyn SerialPort>) -> Result<Self> {
Ok(Self {
command_index: 0,
port,
})
}
}
impl proto::Main {
pub fn with_command_id(mut self, command_id: u32) -> Self {
self.command_id = command_id;
self
}
pub fn with_has_next(mut self, has_next: bool) -> Self {
self.has_next = has_next;
self
}
}
impl TransportRaw<proto::Main> for SerialRpcTransport {
type Err = Error;
#[cfg_attr(feature = "tracing", tracing::instrument)]
fn send_raw(&mut self, value: proto::Main) -> std::result::Result<(), Self::Err> {
let encoded = value.encode_length_delimited_to_vec();
self.port.write_all(&encoded)?;
self.port.flush()?;
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument)]
#[cfg(feature = "transport-serial-optimized")]
fn receive_raw(&mut self) -> std::result::Result<proto::Main, Self::Err> {
use prost::bytes::Buf;
self.port.flush()?;
#[cfg(feature = "transport-serial-optimized-large-stack-limit")]
const STACK_LIMIT: usize = 10 + 512;
#[cfg(not(feature = "transport-serial-optimized-large-stack-limit"))]
const STACK_LIMIT: usize = 10 + 128;
let mut buf = [0u8; STACK_LIMIT];
let mut read = 0;
let mut available_bytes = buf.len();
trace!("reading varint");
while read < available_bytes {
match self.port.read(&mut buf[read..]) {
Ok(0) => break, Ok(n) => {
available_bytes = self.port.bytes_to_read()? as usize;
read += n
}
Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => break,
Err(e) => return Err(e.into()),
}
}
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"no data read, failed to parse varint",
)
.into());
}
let total_data_length = prost::decode_length_delimiter(&buf[..read])?;
trace!(total_data_length, "decoded response length");
let varint_length = prost::length_delimiter_len(total_data_length);
trace!(varint_length, "varint length");
let partial_data = &buf[varint_length..read];
let read_all_data = total_data_length <= partial_data.len();
trace!("decoding response");
let main = if read_all_data {
trace!("L3 decode");
proto::Main::decode(partial_data)?
} else {
let remaining_length = total_data_length - partial_data.len();
if remaining_length <= STACK_LIMIT {
trace!("L2 decode");
let mut stack_buf = [0u8; STACK_LIMIT];
self.port.read_exact(&mut stack_buf[..remaining_length])?;
let chained = partial_data.chain(&stack_buf[..remaining_length]);
proto::Main::decode(chained)?
} else {
use crate::logging::warn;
trace!(
"L1 decode - WARN: Increase STACK_LIMIT, current: {STACK_LIMIT}, need: {remaining_length}"
);
#[cfg(feature = "transport-serial-optimized-large-stack-limit")]
warn!(remaining_length, "extremely large response");
#[cfg(not(feature = "transport-serial-optimized-large-stack-limit"))]
warn!(
remaining_length,
"large response; consider enabling the 'transport-serial-optimized-large-stack-limit' feature"
);
let mut remaining_data = vec![0u8; remaining_length];
self.port.read_exact(&mut remaining_data)?;
let chained = partial_data.chain(remaining_data.as_slice());
proto::Main::decode(chained)?
}
};
decode_command_status(main.command_status)?.into_result(main)
}
#[cfg(not(feature = "transport-serial-optimized"))]
#[cfg_attr(feature = "tracing", tracing::instrument)]
#[deprecated(
note = "Use the serial-optimized-varint-reading instead. This function is very slow. Only use when optimized method is broken. Please submit a PR/Issue to GH if it is broken.",
since = "0.4.0"
)]
fn receive_raw(&mut self) -> std::result::Result<proto::Main, Self::Err> {
use crate::proto::CommandStatus;
self.port.flush()?;
let mut buf = [0u8; 10];
let mut index = 0;
while index < 10 {
self.port.read_exact(&mut buf[index..=index])?;
if buf[index] & 0x80 == 0 {
break;
}
index += 1;
}
let len = prost::decode_length_delimiter(buf.as_slice())?;
let mut msg_buf = vec![0u8; len];
self.port.read_exact(&mut msg_buf)?;
let main = proto::Main::decode(msg_buf.as_slice())?;
decode_command_status(main.command_status)?.into_result(main)
}
}
fn decode_command_status(raw: i32) -> Result<CommandStatus> {
CommandStatus::try_from(raw).map_err(|_| Error::InvalidCommandStatus(raw))
}