1use core::convert::TryFrom;
16use core::convert::TryInto;
17use core::sync::atomic::Ordering;
18use ctaphid_dispatch::{app::Command, Requester};
22use heapless_bytes::Bytes;
23use ref_swap::OptionRefSwap;
24use trussed_core::InterruptFlag;
25use usb_device::{
27 bus::UsbBus,
28 endpoint::{EndpointAddress, EndpointIn, EndpointOut},
29 UsbError,
30 };
32
33use crate::{
34 constants::{
35 MESSAGE_SIZE,
37 PACKET_SIZE,
39 },
40 types::KeepaliveStatus,
41};
42
43enum AuthenticatorError {
46 ChannelBusy,
47 InvalidChannel,
48 InvalidCommand,
49 InvalidLength,
50 InvalidSeq,
51 Timeout,
52}
53
54impl From<AuthenticatorError> for u8 {
55 fn from(error: AuthenticatorError) -> Self {
56 match error {
57 AuthenticatorError::InvalidCommand => 0x01,
58 AuthenticatorError::InvalidLength => 0x03,
59 AuthenticatorError::InvalidSeq => 0x04,
60 AuthenticatorError::Timeout => 0x05,
61 AuthenticatorError::ChannelBusy => 0x06,
62 AuthenticatorError::InvalidChannel => 0x0B,
63 }
64 }
65}
66
67#[derive(Copy, Clone, Debug, Eq, PartialEq)]
69pub struct Request {
70 channel: u32,
71 command: Command,
72 length: u16,
73 timestamp: u32,
74}
75
76#[derive(Copy, Clone, Debug, Eq, PartialEq)]
78pub struct Response {
79 channel: u32,
80 command: Command,
81 length: u16,
82}
83
84impl Response {
85 pub fn from_request_and_size(request: Request, size: usize) -> Self {
86 Self {
87 channel: request.channel,
88 command: request.command,
89 length: size as u16,
90 }
91 }
92
93 pub fn error_from_request(request: Request) -> Self {
94 Self::error_on_channel(request.channel)
95 }
96
97 pub fn error_on_channel(channel: u32) -> Self {
98 Self {
99 channel,
100 command: ctaphid_dispatch::app::Command::Error,
101 length: 1,
102 }
103 }
104}
105
106#[derive(Copy, Clone, Debug, Eq, PartialEq)]
107pub struct MessageState {
108 next_sequence: u8,
110 transmitted: usize,
112}
113
114impl Default for MessageState {
115 fn default() -> Self {
116 Self {
117 next_sequence: 0,
118 transmitted: PACKET_SIZE - 7,
119 }
120 }
121}
122
123impl MessageState {
124 pub fn absorb_packet(&mut self) {
126 self.next_sequence += 1;
127 self.transmitted += PACKET_SIZE - 5;
128 }
129}
130
131#[derive(Clone, Debug, Eq, PartialEq)]
132#[allow(unused)]
133pub enum State {
134 Idle,
135
136 Receiving((Request, MessageState)),
138
139 WaitingOnAuthenticator(Request),
146
147 WaitingToSend(Response),
148
149 Sending((Response, MessageState)),
150}
151
152pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus> {
153 read_endpoint: EndpointOut<'alloc, Bus>,
154 write_endpoint: EndpointIn<'alloc, Bus>,
155 state: State,
156
157 interchange: Requester<'pipe>,
158 interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>,
159
160 buffer: [u8; MESSAGE_SIZE],
162
163 last_channel: u32,
166
167 pub(crate) implements: u8,
169
170 pub(crate) last_milliseconds: u32,
172
173 started_processing: bool,
175
176 needs_keepalive: bool,
177
178 pub(crate) version: crate::Version,
179}
180
181impl<'alloc, 'pipe, Bus: UsbBus> Pipe<'alloc, 'pipe, '_, Bus> {
182 pub(crate) fn new(
183 read_endpoint: EndpointOut<'alloc, Bus>,
184 write_endpoint: EndpointIn<'alloc, Bus>,
185 interchange: Requester<'pipe>,
186 initial_milliseconds: u32,
187 ) -> Self {
188 Self {
189 read_endpoint,
190 write_endpoint,
191 state: State::Idle,
192 interchange,
193 buffer: [0u8; MESSAGE_SIZE],
194 last_channel: 0,
195 interrupt: None,
196 implements: 0x80,
198 last_milliseconds: initial_milliseconds,
199 started_processing: false,
200 needs_keepalive: false,
201 version: Default::default(),
202 }
203 }
204}
205
206impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> {
207 pub(crate) fn with_interrupt(
212 read_endpoint: EndpointOut<'alloc, Bus>,
213 write_endpoint: EndpointIn<'alloc, Bus>,
214 interchange: Requester<'pipe>,
215 interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>,
216 initial_milliseconds: u32,
217 ) -> Self {
218 Self {
219 read_endpoint,
220 write_endpoint,
221 state: State::Idle,
222 interchange,
223 buffer: [0u8; MESSAGE_SIZE],
224 last_channel: 0,
225 interrupt,
226 implements: 0x80,
228 last_milliseconds: initial_milliseconds,
229 started_processing: false,
230 needs_keepalive: false,
231 version: Default::default(),
232 }
233 }
234
235 pub(crate) fn set_version(&mut self, version: crate::Version) {
236 self.version = version;
237 }
238
239 pub fn read_address(&self) -> EndpointAddress {
240 self.read_endpoint.address()
241 }
242
243 pub fn write_address(&self) -> EndpointAddress {
244 self.write_endpoint.address()
245 }
246
247 pub(crate) fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> {
249 &self.read_endpoint
250 }
251
252 pub(crate) fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> {
254 &self.write_endpoint
255 }
256
257 fn cancel_ongoing_activity(&mut self) {
258 if matches!(self.state, State::WaitingOnAuthenticator(_)) {
259 info_now!("Interrupting request");
260 if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) {
261 info_now!("Loaded some interrupter");
262 i.interrupt();
263 }
264 }
265 }
266
267 pub(crate) fn read_and_handle_packet(&mut self) {
272 let mut packet = [0u8; PACKET_SIZE];
274 match self.read_endpoint.read(&mut packet) {
275 Ok(PACKET_SIZE) => {}
276 Ok(_size) => {
277 info!("error unexpected size {}", _size);
284 return;
285 }
286 Err(_error) => {
291 info!("error no {}", _error as i32);
292 return;
293 }
294 };
295 info!(">> ");
296 info!("{}", hex_str!(&packet[..16]));
297
298 let channel = u32::from_be_bytes(packet[..4].try_into().unwrap());
300 let is_initialization = (packet[4] >> 7) != 0;
303 if is_initialization {
306 info!("init");
308
309 let command_number = packet[4] & !0x80;
310 let command = match Command::try_from(command_number) {
313 Ok(command) => command,
314 Err(_) => {
316 info!("Received invalid command.");
317 self.start_sending_error_on_channel(
318 channel,
319 AuthenticatorError::InvalidCommand,
320 );
321 return;
322 }
323 };
324
325 let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap());
327
328 let timestamp = self.last_milliseconds;
329 let current_request = Request {
330 channel,
331 command,
332 length,
333 timestamp,
334 };
335
336 if !(self.state == State::Idle) {
337 let request = match self.state {
338 State::WaitingOnAuthenticator(request) => request,
339 State::Receiving((request, _message_state)) => request,
340 _ => {
341 info_now!("Ignoring transaction as we're already transmitting.");
342 return;
343 }
344 };
345 if packet[4] == 0x86 {
346 info_now!("Resyncing!");
347 self.cancel_ongoing_activity();
348 } else {
349 if channel == request.channel {
350 if command == Command::Cancel {
351 info_now!("Cancelling");
352 self.cancel_ongoing_activity();
353 } else {
354 info_now!("Expected seq, {:?}", request.command);
355 self.start_sending_error(request, AuthenticatorError::InvalidSeq);
356 }
357 } else {
358 info_now!("busy.");
359 self.send_error_now(current_request, AuthenticatorError::ChannelBusy);
360 }
361
362 return;
363 }
364 }
365
366 if length > MESSAGE_SIZE as u16 {
367 info!("Error message too big.");
368 self.send_error_now(current_request, AuthenticatorError::InvalidLength);
369 return;
370 }
371
372 if length > PACKET_SIZE as u16 - 7 {
373 self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]);
376 self.state = State::Receiving((current_request, { MessageState::default() }));
377 } else {
379 self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]);
381 self.dispatch_request(current_request);
382 }
383 } else {
384 match self.state {
386 State::Receiving((request, mut message_state)) => {
387 let sequence = packet[4];
388 if sequence != message_state.next_sequence {
390 info!("Error invalid cont pkt");
394 self.start_sending_error(request, AuthenticatorError::InvalidSeq);
395 return;
396 }
397 if channel != request.channel {
398 info!("Ignore invalid channel");
402 return;
403 }
404
405 let payload_length = request.length as usize;
406 if message_state.transmitted + (PACKET_SIZE - 5) < payload_length {
407 self.buffer[message_state.transmitted..][..PACKET_SIZE - 5]
411 .copy_from_slice(&packet[5..]);
412 message_state.absorb_packet();
413 self.state = State::Receiving((request, message_state));
414 } else {
416 let missing = request.length as usize - message_state.transmitted;
417 self.buffer[message_state.transmitted..payload_length]
418 .copy_from_slice(&packet[5..][..missing]);
419 self.dispatch_request(request);
420 }
421 }
422 _ => {
423 info!("Ignore unexpected cont pkt");
425 }
426 }
427 }
428 }
429
430 pub fn check_timeout(&mut self, milliseconds: u32) {
431 let last = self.last_milliseconds;
434 self.last_milliseconds = milliseconds;
435 if let State::Receiving((request, _message_state)) = &mut self.state {
436 if (milliseconds - last) > 200 {
437 debug!(
441 "lapse in hid check.. {} {} {}",
442 request.timestamp, milliseconds, last
443 );
444 request.timestamp = milliseconds;
445 }
446 else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550)
448 || (milliseconds < request.timestamp && milliseconds > 550)
449 {
450 debug!(
451 "Channel timeout. {}, {}, {}",
452 request.timestamp, milliseconds, last
453 );
454 let req = *request;
455 self.start_sending_error(req, AuthenticatorError::Timeout);
456 }
457 }
458 }
459
460 fn dispatch_request(&mut self, request: Request) {
461 info!("Got request: {:?}", request.command);
462 match request.command {
463 Command::Init => {}
464 _ => {
465 if request.channel == 0xffffffff {
466 self.start_sending_error(request, AuthenticatorError::InvalidChannel);
467 return;
468 }
469 }
470 }
471 match request.command {
473 Command::Init => {
474 match request.channel {
477 0 => {
478 self.start_sending_error(request, AuthenticatorError::InvalidChannel);
480 }
481
482 cid => {
484 if request.length != 8 {
485 info!("Invalid length for init. ignore.");
487 } else {
488 self.last_channel += 1;
489 let _nonce = &self.buffer[..8];
492 let response = Response {
493 channel: cid,
494 command: request.command,
495 length: 17,
496 };
497
498 self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes());
499 self.buffer[12] = 2;
501 self.buffer[13] = self.version.major;
503 self.buffer[14] = self.version.minor;
505 self.buffer[15] = self.version.build;
507 self.buffer[16] = self.implements;
513 self.start_sending(response);
514 }
515 }
516 }
517 }
518
519 Command::Ping => {
520 let response = Response::from_request_and_size(request, request.length as usize);
521 self.start_sending(response);
522 }
523
524 Command::Cancel => {
525 info!("CTAPHID_CANCEL");
526 self.cancel_ongoing_activity();
527 }
528
529 _ => {
530 self.needs_keepalive = request.command == Command::Cbor;
531 if self.interchange.state() == interchange::State::Responded {
532 info!("dumping stale response");
533 self.interchange.take_response();
534 }
535 match self.interchange.request((
536 request.command,
537 Bytes::from_slice(&self.buffer[..request.length as usize]).unwrap(),
538 )) {
539 Ok(_) => {
540 self.state = State::WaitingOnAuthenticator(request);
541 self.started_processing = true;
542 }
543 Err(_) => {
544 info_now!("STATE: {:?}", self.interchange.state());
546 info!("can't handle more than one authenticator request at a time.");
547 self.send_error_now(request, AuthenticatorError::ChannelBusy);
548 }
549 }
550 }
551 }
552 }
553
554 pub fn did_start_processing(&mut self) -> bool {
555 if self.started_processing {
556 self.started_processing = false;
557 true
558 } else {
559 false
560 }
561 }
562
563 pub fn send_keepalive(&mut self, is_waiting_for_user_presence: bool) -> bool {
564 if let State::WaitingOnAuthenticator(request) = &self.state {
565 if !self.needs_keepalive {
566 info!("cmd does not need keepalive messages");
568 false
569 } else {
570 info!("keepalive");
571
572 let mut packet = [0u8; PACKET_SIZE];
573
574 packet[..4].copy_from_slice(&request.channel.to_be_bytes());
575 packet[4] = 0x80 | 0x3B;
576 packet[5..7].copy_from_slice(&1u16.to_be_bytes());
577
578 if is_waiting_for_user_presence {
579 packet[7] = KeepaliveStatus::UpNeeded as u8;
580 } else {
581 packet[7] = KeepaliveStatus::Processing as u8;
582 }
583
584 self.write_endpoint.write(&packet).ok();
585
586 true
587 }
588 } else {
589 info!("keepalive done");
590 false
591 }
592 }
593
594 #[inline(never)]
595 pub fn handle_response(&mut self) {
596 if let State::WaitingOnAuthenticator(request) = self.state {
597 if let Ok(response) = self.interchange.response() {
598 match &response.0 {
599 Err(ctaphid_dispatch::app::Error::InvalidCommand) => {
600 info!("Got waiting reply from authenticator??");
601 self.start_sending_error(request, AuthenticatorError::InvalidCommand);
602 }
603 Err(ctaphid_dispatch::app::Error::InvalidLength) => {
604 info!("Error, payload needed app command.");
605 self.start_sending_error(request, AuthenticatorError::InvalidLength);
606 }
607 Err(ctaphid_dispatch::app::Error::NoResponse) => {
608 info!("Got waiting noresponse from authenticator??");
609 }
610
611 Ok(message) => {
612 if message.len() > self.buffer.len() {
613 error!(
614 "Message is longer than buffer ({} > {})",
615 message.len(),
616 self.buffer.len(),
617 );
618 self.start_sending_error(request, AuthenticatorError::InvalidLength);
619 } else {
620 info!(
621 "Got {} bytes response from authenticator, starting send",
622 message.len()
623 );
624 let response = Response::from_request_and_size(request, message.len());
625 self.buffer[..message.len()].copy_from_slice(message);
626 self.start_sending(response);
627 }
628 }
629 }
630 }
631 }
632 }
633
634 fn start_sending(&mut self, response: Response) {
635 self.state = State::WaitingToSend(response);
636 self.maybe_write_packet();
637 }
638
639 fn start_sending_error(&mut self, request: Request, error: AuthenticatorError) {
640 self.start_sending_error_on_channel(request.channel, error);
641 }
642
643 fn start_sending_error_on_channel(&mut self, channel: u32, error: AuthenticatorError) {
644 self.buffer[0] = error.into();
645 let response = Response::error_on_channel(channel);
646 self.start_sending(response);
647 }
648
649 fn send_error_now(&mut self, request: Request, error: AuthenticatorError) {
650 let last_state = core::mem::replace(&mut self.state, State::Idle);
651 let last_first_byte = self.buffer[0];
652
653 self.buffer[0] = error as u8;
654 let response = Response::error_from_request(request);
655 self.start_sending(response);
656 self.maybe_write_packet();
657
658 self.state = last_state;
659 self.buffer[0] = last_first_byte;
660 }
661
662 #[inline(never)]
664 pub(crate) fn maybe_write_packet(&mut self) {
665 match self.state {
666 State::WaitingToSend(response) => {
667 let mut packet = [0u8; PACKET_SIZE];
669 packet[..4].copy_from_slice(&response.channel.to_be_bytes());
670 packet[4] = response.command.into_u8() | 0x80;
672 packet[5..7].copy_from_slice(&response.length.to_be_bytes());
673
674 let fits_in_one_packet = 7 + response.length as usize <= PACKET_SIZE;
675 if fits_in_one_packet {
676 packet[7..][..response.length as usize]
677 .copy_from_slice(&self.buffer[..response.length as usize]);
678 self.state = State::Idle;
679 } else {
680 packet[7..].copy_from_slice(&self.buffer[..PACKET_SIZE - 7]);
681 }
682
683 let result = self.write_endpoint.write(&packet);
687
688 match result {
689 Err(UsbError::WouldBlock) => {
690 info!("hid usb WouldBlock");
693 }
694 Err(_) => {
695 panic!("unexpected error writing packet!");
697 }
698 Ok(PACKET_SIZE) => {
699 if fits_in_one_packet {
701 self.state = State::Idle;
702 } else {
705 self.state = State::Sending((response, MessageState::default()));
706 }
711 }
712 Ok(_) => {
713 panic!("unexpected size writing packet!");
715 }
716 };
717 }
718
719 State::Sending((response, mut message_state)) => {
720 let mut packet = [0u8; PACKET_SIZE];
722 packet[..4].copy_from_slice(&response.channel.to_be_bytes());
723 packet[4] = message_state.next_sequence;
724
725 let sent = message_state.transmitted;
726 let remaining = response.length as usize - sent;
727 let last_packet = 5 + remaining <= PACKET_SIZE;
728 if last_packet {
729 packet[5..][..remaining]
730 .copy_from_slice(&self.buffer[message_state.transmitted..][..remaining]);
731 } else {
732 packet[5..].copy_from_slice(
733 &self.buffer[message_state.transmitted..][..PACKET_SIZE - 5],
734 );
735 }
736
737 let result = self.write_endpoint.write(&packet);
741
742 match result {
743 Err(UsbError::WouldBlock) => {
744 }
749 Err(_) => {
750 panic!("unexpected error writing packet!");
752 }
753 Ok(PACKET_SIZE) => {
754 if last_packet {
756 self.state = State::Idle;
757 } else {
759 message_state.absorb_packet();
760 self.state = State::Sending((response, message_state));
764 }
765 }
766 Ok(_) => {
767 debug!("short write");
768 panic!("unexpected size writing packet!");
769 }
770 };
771 }
772
773 _ => {}
775 }
776 }
777}