1use super::*;
2use bytes::Buf;
3use pretty_hex::PrettyHex;
4
5const REQUEST_TYPE: u8 = 0b00100001;
6const DFU_GETSTATUS: u8 = 3;
7const DFU_CLRSTATUS: u8 = 4;
8
9pub struct GetStatusMessage {
11 pub status: Status,
13 pub poll_timeout: u64,
15 pub state: State,
17 pub index: u8,
19}
20
21#[must_use]
23pub struct GetStatus<T: ChainedCommand<Arg = GetStatusMessage>> {
24 pub(crate) chained_command: T,
25}
26
27impl<T: ChainedCommand<Arg = GetStatusMessage>> GetStatus<T> {
28 pub fn get_status(self, buffer: &'_ mut [u8]) -> (GetStatusRecv<T>, UsbReadControl<'_>) {
30 debug_assert!(buffer.len() >= 6);
31 let next = GetStatusRecv {
32 chained_command: self.chained_command,
33 };
34
35 let control = UsbReadControl::new(REQUEST_TYPE, DFU_GETSTATUS, 0, buffer);
36 (next, control)
37 }
38}
39
40#[must_use]
42pub struct GetStatusRecv<T: ChainedCommand<Arg = GetStatusMessage>> {
43 chained_command: T,
44}
45
46impl<T: ChainedCommand<Arg = GetStatusMessage>> GetStatusRecv<T> {
48 pub fn chain(self, mut bytes: &[u8]) -> Result<T::Into, Error> {
50 log::trace!("Received device status: {}", bytes.hex_dump());
51 if bytes.len() < 6 {
52 return Err(Error::ResponseTooShort {
53 got: bytes.len(),
54 expected: 6,
55 });
56 }
57
58 let status = bytes.get_u8().into();
59 log::trace!("Device status: {:?}", status);
60 let poll_timeout = bytes.get_uint_le(3);
61 log::trace!("Poll timeout: {}", poll_timeout);
62 let state: State = bytes.get_u8().into();
63 let state = state.for_status();
64 log::trace!("Device state: {:?}", state);
65 let i_string = bytes.get_u8();
66 log::trace!("Device i string: {:#x}", i_string);
67
68 Ok(self.chained_command.chain(GetStatusMessage {
69 status,
70 poll_timeout,
71 state,
72 index: i_string,
73 }))
74 }
75}
76
77#[must_use]
79pub struct ClearStatus<T> {
80 pub(crate) chained_command: T,
81}
82
83impl<T> ChainedCommand for ClearStatus<T> {
84 type Arg = get_status::GetStatusMessage;
85 type Into = (T, Option<UsbWriteControl<[u8; 0]>>);
86
87 fn chain(
89 self,
90 get_status::GetStatusMessage {
91 status: _,
92 poll_timeout: _,
93 state,
94 index: _,
95 }: Self::Arg,
96 ) -> (T, Option<UsbWriteControl<[u8; 0]>>) {
97 let next = self.chained_command;
98 if state == State::DfuError {
99 log::trace!("Device is in error state, clearing status...");
100 let control = UsbWriteControl::new(REQUEST_TYPE, DFU_CLRSTATUS, 0, []);
101
102 (next, Some(control))
103 } else {
104 log::trace!("Device is not in error state, skip clearing status");
105 (next, None)
106 }
107 }
108}
109
110#[must_use]
112pub struct WaitState<T> {
113 intermediate: State,
114 state: State,
115 chained_command: T,
116 end: bool,
117 poll_timeout: u64,
118}
119
120#[allow(missing_docs)]
122pub enum Step<T> {
123 Break(T),
124 Wait(GetStatus<WaitState<T>>, u64),
126}
127
128impl<T> WaitState<T> {
129 pub fn new(intermediate: State, state: State, chained_command: T) -> Self {
131 Self {
132 intermediate,
133 state,
134 chained_command,
135 end: false,
136 poll_timeout: 0,
137 }
138 }
139
140 pub fn next(self) -> Step<T> {
142 if self.end {
143 log::trace!("Device state OK");
144 Step::Break(self.chained_command)
145 } else {
146 let poll_timeout = self.poll_timeout;
147 log::trace!(
148 "Waiting for device state: {:?} (poll timeout: {})",
149 self.state,
150 poll_timeout,
151 );
152
153 Step::Wait(
154 GetStatus {
155 chained_command: self,
156 },
157 poll_timeout,
158 )
159 }
160 }
161}
162
163impl<T> ChainedCommand for WaitState<T> {
164 type Arg = GetStatusMessage;
165 type Into = Result<Self, Error>;
166
167 fn chain(
168 self,
169 GetStatusMessage {
170 status: _,
171 poll_timeout,
172 state,
173 index: _,
174 }: Self::Arg,
175 ) -> Self::Into {
176 log::trace!("Device state: {:?}", state);
177 if state == self.state || state == self.intermediate {
178 Ok(WaitState {
179 chained_command: self.chained_command,
180 intermediate: self.intermediate,
181 state: self.state,
182 end: state == self.state,
183 poll_timeout,
184 })
185 } else {
186 Err(Error::InvalidState {
187 got: state,
188 expected: self.intermediate,
189 })
190 }
191 }
192}