1#![warn(
52 missing_copy_implementations,
53 missing_debug_implementations,
54 missing_docs,
55 non_ascii_idents,
56 trivial_casts,
57 unused,
58 unused_qualifications
59)]
60#![deny(unsafe_code)]
61#![cfg_attr(
63 not(any(feature = "usb-device-v0.2", feature = "usb-device-v0.3",)),
64 allow(unused)
65)]
66
67#[cfg(feature = "usb-device-v0.2")]
68mod impl02;
69#[cfg(feature = "usb-device-v0.3")]
70mod impl03;
71
72use std::{
73 collections::{btree_map::Entry, BTreeMap},
74 sync::atomic::{AtomicBool, Ordering},
75};
76
77use crossbeam::channel::{self, Receiver, Sender};
78
79#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
81pub enum UsbDirection {
82 Out,
84 In,
86}
87
88#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
93pub struct EndpointAddress {
94 direction: UsbDirection,
95 idx: u8,
96}
97
98impl EndpointAddress {
99 pub const fn new(idx: u8, direction: UsbDirection) -> Option<Self> {
103 if idx.leading_zeros() >= 1 {
105 Some(Self { direction, idx })
106 } else {
107 None
108 }
109 }
110
111 pub const fn direction(&self) -> UsbDirection {
113 self.direction
114 }
115
116 pub const fn index(&self) -> u8 {
118 self.idx
119 }
120}
121
122#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
124pub enum EndpointType {
125 Control,
127 Isochronous,
129 Bulk,
131 Interrupt,
133}
134
135#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
137#[non_exhaustive]
138pub struct Endpoint {
139 pub address: EndpointAddress,
141 pub ty: EndpointType,
143}
144
145#[derive(Debug)]
146struct EndpointData {
147 ty: EndpointType,
148 tx: Sender<Vec<u8>>,
149 rx: Receiver<Vec<u8>>,
150 in_started: AtomicBool,
151}
152
153impl EndpointData {
154 fn new(ty: EndpointType) -> Self {
155 let (tx, rx) = channel::bounded(1);
156 Self {
157 ty,
158 tx,
159 rx,
160 in_started: AtomicBool::new(false),
161 }
162 }
163}
164
165#[derive(Debug, Default)]
184pub struct Bus {
185 enabled: bool,
186 endpoints: BTreeMap<EndpointAddress, EndpointData>,
187}
188
189impl Bus {
190 pub fn new() -> Self {
192 Default::default()
193 }
194
195 pub fn endpoints(&self) -> impl Iterator<Item = Endpoint> + '_ {
197 self.endpoints.iter().map(|(&address, endpoint)| Endpoint {
198 address,
199 ty: endpoint.ty,
200 })
201 }
202
203 pub fn endpoint_tx(&self, addr: EndpointAddress) -> Option<Sender<Vec<u8>>> {
208 if addr.direction() == UsbDirection::Out {
209 self.endpoints.get(&addr).map(|ep| ep.tx.clone())
210 } else {
211 None
212 }
213 }
214
215 pub fn endpoint_rx(&self, addr: EndpointAddress) -> Option<Receiver<Vec<u8>>> {
220 if addr.direction() == UsbDirection::In {
221 self.endpoints.get(&addr).map(|ep| ep.rx.clone())
222 } else {
223 None
224 }
225 }
226
227 fn next_address(&self, direction: UsbDirection) -> Option<EndpointAddress> {
228 let mut idx = 1;
229 for key in self.endpoints.keys() {
230 if key.direction() != direction {
231 continue;
232 }
233 if key.index() == 0 {
234 continue;
235 }
236 if key.index() == idx {
237 idx = idx.checked_add(1)?;
238 } else {
239 break;
240 }
241 }
242 EndpointAddress::new(idx, direction)
243 }
244
245 fn impl_alloc_ep(
246 &mut self,
247 ep_dir: UsbDirection,
248 ep_addr: Option<EndpointAddress>,
249 ep_type: EndpointType,
250 ) -> Result<EndpointAddress, Error> {
251 if self.enabled {
252 return Err(Error::InvalidState);
253 }
254
255 let ep_addr = if let Some(ep_addr) = ep_addr {
256 if ep_dir != ep_addr.direction() {
257 return Err(Error::InvalidEndpoint);
258 }
259 ep_addr
260 } else {
261 self.next_address(ep_dir).ok_or(Error::EndpointOverflow)?
262 };
263
264 if let Entry::Vacant(e) = self.endpoints.entry(ep_addr) {
265 e.insert(EndpointData::new(ep_type));
266 Ok(ep_addr)
267 } else {
268 Err(Error::InvalidEndpoint)
269 }
270 }
271
272 fn impl_enable(&mut self) {
273 self.enabled = true;
274 }
275
276 fn impl_write(&self, ep_addr: EndpointAddress, buf: &[u8]) -> Result<usize, Error> {
277 if ep_addr.direction() != UsbDirection::In {
278 return Err(Error::InvalidEndpoint);
279 }
280 let n = buf.len();
281 let endpoint = self.endpoints.get(&ep_addr).ok_or(Error::InvalidEndpoint)?;
282 endpoint
283 .tx
284 .try_send(buf.into())
285 .map_err(|_| Error::WouldBlock)?;
286 endpoint.in_started.store(true, Ordering::Relaxed);
287 Ok(n)
288 }
289
290 fn impl_read(&self, ep_addr: EndpointAddress, buf: &mut [u8]) -> Result<usize, Error> {
291 if ep_addr.direction() != UsbDirection::Out {
292 return Err(Error::InvalidEndpoint);
293 }
294 let packet = self
295 .endpoints
296 .get(&ep_addr)
297 .ok_or(Error::InvalidEndpoint)?
298 .rx
299 .try_recv()
300 .map_err(|_| Error::WouldBlock)?;
301 let n = packet.len();
302 if n <= buf.len() {
303 buf[..n].copy_from_slice(&packet);
304 Ok(n)
305 } else {
306 Err(Error::BufferOverflow)
307 }
308 }
309
310 fn impl_poll(&self) -> PollResult {
311 let mut ep_out = 0;
312 let mut ep_in_complete = 0;
313 for (ep_addr, endpoint) in &self.endpoints {
314 match ep_addr.direction() {
315 UsbDirection::Out => {
316 if !endpoint.rx.is_empty() {
317 ep_out |= 1 << ep_addr.index();
318 }
319 }
320 UsbDirection::In => {
321 if endpoint.tx.is_empty() && endpoint.in_started.swap(false, Ordering::Relaxed)
322 {
323 ep_in_complete |= 1 << ep_addr.index();
324 }
325 }
326 }
327 }
328 if ep_out > 0 || ep_in_complete > 0 {
329 PollResult::Data {
330 ep_out,
331 ep_in_complete,
332 }
333 } else {
334 PollResult::None
335 }
336 }
337}
338
339#[derive(Debug, PartialEq)]
340enum PollResult {
341 Data { ep_out: u16, ep_in_complete: u16 },
342 None,
343}
344
345#[derive(Debug, PartialEq)]
346enum Error {
347 BufferOverflow,
348 EndpointOverflow,
349 InvalidEndpoint,
350 InvalidState,
351 WouldBlock,
352}
353
354#[cfg(test)]
355mod tests {
356 use super::{Bus, Endpoint, EndpointAddress, EndpointType, Error, PollResult, UsbDirection};
357
358 #[test]
359 fn test_address() {
360 for direction in [UsbDirection::Out, UsbDirection::In] {
361 for idx in [0, 1, 10, 127] {
362 let address = EndpointAddress::new(idx, direction);
363 assert_eq!(address, Some(EndpointAddress { direction, idx }));
364 let address = address.unwrap();
365 assert_eq!(address.direction(), direction);
366 assert_eq!(address.index(), idx);
367 }
368 for idx in [128, 129, 230, u8::MAX] {
369 assert_eq!(EndpointAddress::new(idx, direction), None);
370 }
371 }
372 }
373
374 #[test]
375 fn test_address_order() {
376 let out0 = EndpointAddress::new(0, UsbDirection::Out).unwrap();
377 let out1 = EndpointAddress::new(1, UsbDirection::Out).unwrap();
378 let out127 = EndpointAddress::new(127, UsbDirection::Out).unwrap();
379 let in0 = EndpointAddress::new(0, UsbDirection::In).unwrap();
380 let in1 = EndpointAddress::new(1, UsbDirection::In).unwrap();
381 let in127 = EndpointAddress::new(127, UsbDirection::In).unwrap();
382
383 let mut addresses = [in127, in1, in0, out127, out1, out0];
384 addresses.sort();
385 assert_eq!(addresses, [out0, out1, out127, in0, in1, in127,]);
386 }
387
388 fn alloc(
389 bus: &mut Bus,
390 direction: UsbDirection,
391 addr: Option<u8>,
392 ty: EndpointType,
393 ) -> Result<EndpointAddress, Error> {
394 bus.impl_alloc_ep(
395 direction,
396 addr.map(|idx| EndpointAddress::new(idx, direction).unwrap()),
397 ty,
398 )
399 }
400
401 fn alloc_ok(
402 bus: &mut Bus,
403 direction: UsbDirection,
404 addr: Option<u8>,
405 ty: EndpointType,
406 ) -> EndpointAddress {
407 let addr = alloc(bus, direction, addr, ty).unwrap();
408 assert_eq!(addr.direction(), direction);
409 addr
410 }
411
412 fn alloc_err(
413 bus: &mut Bus,
414 direction: UsbDirection,
415 addr: Option<u8>,
416 ty: EndpointType,
417 ) -> Error {
418 alloc(bus, direction, addr, ty).unwrap_err()
419 }
420
421 #[test]
422 fn test_bus_alloc() {
423 let mut bus = Bus::default();
424
425 for direction in [UsbDirection::Out, UsbDirection::In] {
426 let addr = alloc_ok(&mut bus, direction, None, EndpointType::Interrupt);
427 assert_eq!(addr.index(), 1);
428 let addr = alloc_ok(&mut bus, direction, None, EndpointType::Bulk);
429 assert_eq!(addr.index(), 2);
430 let err = alloc_err(&mut bus, direction, Some(1), EndpointType::Interrupt);
431 assert_eq!(err, Error::InvalidEndpoint);
432 let addr = alloc_ok(&mut bus, direction, Some(9), EndpointType::Bulk);
433 assert_eq!(addr.index(), 9);
434 let addr = alloc_ok(&mut bus, direction, None, EndpointType::Bulk);
435 assert_eq!(addr.index(), 3);
436 let addr = alloc_ok(&mut bus, direction, Some(0), EndpointType::Control);
437 assert_eq!(addr.index(), 0);
438 let addr = alloc_ok(&mut bus, direction, None, EndpointType::Bulk);
439 assert_eq!(addr.index(), 4);
440 }
441
442 let expected = [
443 (0, EndpointType::Control),
444 (1, EndpointType::Interrupt),
445 (2, EndpointType::Bulk),
446 (3, EndpointType::Bulk),
447 (4, EndpointType::Bulk),
448 (9, EndpointType::Bulk),
449 ];
450 let endpoints: Vec<_> = bus.endpoints().collect();
451 let mut expected_endpoints = Vec::new();
452 for direction in [UsbDirection::Out, UsbDirection::In] {
453 expected_endpoints.extend(expected.into_iter().map(|(idx, ty)| Endpoint {
454 address: EndpointAddress::new(idx, direction).unwrap(),
455 ty,
456 }));
457 }
458 assert_eq!(endpoints, expected_endpoints);
459 }
460
461 #[test]
462 fn test_bus_alloc_wrong_type() {
463 let mut bus = Bus::default();
464
465 let result = bus.impl_alloc_ep(
466 UsbDirection::In,
467 Some(EndpointAddress::new(0, UsbDirection::Out).unwrap()),
468 EndpointType::Interrupt,
469 );
470 assert_eq!(result, Err(Error::InvalidEndpoint));
471
472 let result = bus.impl_alloc_ep(
473 UsbDirection::Out,
474 Some(EndpointAddress::new(0, UsbDirection::In).unwrap()),
475 EndpointType::Interrupt,
476 );
477 assert_eq!(result, Err(Error::InvalidEndpoint));
478 }
479
480 #[test]
481 fn test_bus_alloc_overflow() {
482 for direction in [UsbDirection::In, UsbDirection::Out] {
483 let mut bus = Bus::default();
484 for i in 1..128 {
485 let addr = alloc_ok(&mut bus, direction, None, EndpointType::Interrupt);
486 assert_eq!(addr.index(), i);
487 }
488 let err = alloc_err(&mut bus, direction, None, EndpointType::Interrupt);
489 assert_eq!(err, Error::EndpointOverflow);
490 }
491 }
492
493 #[test]
494 fn test_bus_write() {
495 let mut bus = Bus::default();
496
497 let ep_in1 = alloc_ok(&mut bus, UsbDirection::In, Some(1), EndpointType::Interrupt);
498 let ep_in2 = alloc_ok(&mut bus, UsbDirection::In, Some(2), EndpointType::Interrupt);
499 let ep_in3 = EndpointAddress::new(3, UsbDirection::In).unwrap();
500 let ep_out1 = alloc_ok(
501 &mut bus,
502 UsbDirection::Out,
503 Some(1),
504 EndpointType::Interrupt,
505 );
506 let ep_out2 = EndpointAddress::new(2, UsbDirection::Out).unwrap();
507
508 let rx1 = bus.endpoint_rx(ep_in1).unwrap();
509 let rx2 = bus.endpoint_rx(ep_in2).unwrap();
510
511 assert!(bus.endpoint_rx(ep_in3).is_none());
512 assert!(bus.endpoint_rx(ep_out1).is_none());
513 assert!(bus.endpoint_rx(ep_out2).is_none());
514
515 assert_eq!(bus.impl_poll(), PollResult::None);
516
517 assert!(rx1.is_empty());
518 assert!(rx2.is_empty());
519
520 let data = b"testdata";
521 let result = bus.impl_write(ep_in2, data);
522 assert_eq!(result, Ok(data.len()));
523
524 assert_eq!(bus.impl_poll(), PollResult::None);
525
526 assert!(rx1.is_empty());
527 assert!(!rx2.is_empty());
528
529 let result = bus.impl_write(ep_in2, data);
530 assert_eq!(result, Err(Error::WouldBlock));
531 assert_eq!(bus.impl_poll(), PollResult::None);
532
533 let received = rx2.recv().unwrap();
534 assert_eq!(&received, data);
535
536 assert_eq!(
537 bus.impl_poll(),
538 PollResult::Data {
539 ep_out: 0,
540 ep_in_complete: 0b100
541 }
542 );
543 assert_eq!(bus.impl_poll(), PollResult::None);
544
545 assert!(rx2.try_recv().is_err());
546 assert_eq!(bus.impl_poll(), PollResult::None);
547 assert!(bus.impl_write(ep_in2, data).is_ok());
548 }
549
550 #[test]
551 fn test_bus_write_multi() {
552 let mut bus = Bus::default();
553
554 let ep_in1 = alloc_ok(&mut bus, UsbDirection::In, Some(1), EndpointType::Interrupt);
555 let ep_in2 = alloc_ok(&mut bus, UsbDirection::In, Some(2), EndpointType::Interrupt);
556
557 let rx1 = bus.endpoint_rx(ep_in1).unwrap();
558 let rx2 = bus.endpoint_rx(ep_in2).unwrap();
559
560 assert_eq!(bus.impl_poll(), PollResult::None);
561
562 assert!(rx1.is_empty());
563 assert!(rx2.is_empty());
564
565 let data1 = b"testdata";
566 let result = bus.impl_write(ep_in1, data1);
567 assert_eq!(result, Ok(data1.len()));
568
569 let data2 = b"some other important data";
570 let result = bus.impl_write(ep_in2, data2);
571 assert_eq!(result, Ok(data2.len()));
572
573 assert_eq!(bus.impl_poll(), PollResult::None);
574
575 assert!(!rx1.is_empty());
576 assert!(!rx2.is_empty());
577
578 assert_eq!(rx1.recv().unwrap(), data1);
579 assert_eq!(rx2.recv().unwrap(), data2);
580
581 assert_eq!(
582 bus.impl_poll(),
583 PollResult::Data {
584 ep_out: 0,
585 ep_in_complete: 0b110
586 }
587 );
588 assert_eq!(bus.impl_poll(), PollResult::None);
589 }
590
591 #[test]
592 fn test_bus_write_out() {
593 let mut bus = Bus::default();
594
595 let ep = alloc_ok(
596 &mut bus,
597 UsbDirection::Out,
598 Some(1),
599 EndpointType::Interrupt,
600 );
601 assert_eq!(bus.impl_write(ep, b"data"), Err(Error::InvalidEndpoint));
602 }
603
604 #[test]
605 fn test_bus_write_unalloc() {
606 let bus = Bus::default();
607
608 let ep = EndpointAddress::new(3, UsbDirection::In).unwrap();
609 assert_eq!(bus.impl_write(ep, b"data"), Err(Error::InvalidEndpoint));
610 }
611
612 #[test]
613 fn test_bus_read() {
614 let mut bus = Bus::default();
615
616 let ep_in1 = alloc_ok(&mut bus, UsbDirection::In, Some(1), EndpointType::Interrupt);
617 let ep_in2 = EndpointAddress::new(2, UsbDirection::In).unwrap();
618 let ep_out1 = alloc_ok(
619 &mut bus,
620 UsbDirection::Out,
621 Some(1),
622 EndpointType::Interrupt,
623 );
624 let ep_out2 = alloc_ok(
625 &mut bus,
626 UsbDirection::Out,
627 Some(2),
628 EndpointType::Interrupt,
629 );
630 let ep_out3 = EndpointAddress::new(3, UsbDirection::Out).unwrap();
631
632 assert!(bus.endpoint_tx(ep_out1).is_some());
633 let tx = bus.endpoint_tx(ep_out2).unwrap();
634
635 assert!(bus.endpoint_tx(ep_out3).is_none());
636 assert!(bus.endpoint_tx(ep_in1).is_none());
637 assert!(bus.endpoint_tx(ep_in2).is_none());
638
639 assert_eq!(bus.impl_poll(), PollResult::None);
640
641 let mut buffer = [0; 1024];
642 assert_eq!(bus.impl_read(ep_out1, &mut buffer), Err(Error::WouldBlock));
643 assert_eq!(bus.impl_read(ep_out2, &mut buffer), Err(Error::WouldBlock));
644
645 let data = b"testdata";
646 tx.send(data.into()).unwrap();
647
648 assert_eq!(
649 bus.impl_poll(),
650 PollResult::Data {
651 ep_out: 0b100,
652 ep_in_complete: 0
653 }
654 );
655 assert_eq!(
656 bus.impl_poll(),
657 PollResult::Data {
658 ep_out: 0b100,
659 ep_in_complete: 0
660 }
661 );
662 assert_eq!(
663 bus.impl_poll(),
664 PollResult::Data {
665 ep_out: 0b100,
666 ep_in_complete: 0
667 }
668 );
669
670 assert!(tx.try_send(data.into()).is_err());
671
672 assert_eq!(bus.impl_read(ep_out1, &mut buffer), Err(Error::WouldBlock));
673 assert_eq!(bus.impl_read(ep_out2, &mut buffer), Ok(data.len()));
674 assert_eq!(data, &buffer[..data.len()]);
675
676 assert_eq!(bus.impl_poll(), PollResult::None);
677
678 assert_eq!(bus.impl_read(ep_out1, &mut buffer), Err(Error::WouldBlock));
679 assert_eq!(bus.impl_read(ep_out2, &mut buffer), Err(Error::WouldBlock));
680 assert!(tx.try_send(data.into()).is_ok());
681 }
682
683 #[test]
684 fn test_bus_read_in() {
685 let mut bus = Bus::default();
686
687 let ep = alloc_ok(&mut bus, UsbDirection::In, Some(1), EndpointType::Interrupt);
688 let mut buffer = [0; 1024];
689 assert_eq!(bus.impl_read(ep, &mut buffer), Err(Error::InvalidEndpoint));
690 }
691
692 #[test]
693 fn test_bus_read_unalloc() {
694 let bus = Bus::default();
695
696 let ep = EndpointAddress::new(3, UsbDirection::Out).unwrap();
697 let mut buffer = [0; 1024];
698 assert_eq!(bus.impl_read(ep, &mut buffer), Err(Error::InvalidEndpoint));
699 }
700
701 #[test]
702 fn test_bus_read_overflow() {
703 let mut bus = Bus::default();
704
705 let ep = alloc_ok(
706 &mut bus,
707 UsbDirection::Out,
708 Some(1),
709 EndpointType::Interrupt,
710 );
711 let tx = bus.endpoint_tx(ep).unwrap();
712 tx.send(vec![0; 128]).unwrap();
713
714 let mut buffer = [0; 1];
715 assert_eq!(bus.impl_read(ep, &mut buffer), Err(Error::BufferOverflow));
716 assert_eq!(bus.impl_read(ep, &mut buffer), Err(Error::WouldBlock));
717 }
718
719 #[test]
720 fn test_bus_read_write() {
721 let mut bus = Bus::default();
722
723 let ep_in = alloc_ok(&mut bus, UsbDirection::In, Some(1), EndpointType::Interrupt);
724 let ep_out = alloc_ok(
725 &mut bus,
726 UsbDirection::Out,
727 Some(1),
728 EndpointType::Interrupt,
729 );
730
731 let rx = bus.endpoint_rx(ep_in).unwrap();
732 let tx = bus.endpoint_tx(ep_out).unwrap();
733
734 assert_eq!(bus.impl_poll(), PollResult::None);
735
736 let data1 = b"testdata";
737 tx.send(data1.into()).unwrap();
738
739 assert_eq!(
740 bus.impl_poll(),
741 PollResult::Data {
742 ep_out: 0b10,
743 ep_in_complete: 0
744 }
745 );
746 assert_eq!(
747 bus.impl_poll(),
748 PollResult::Data {
749 ep_out: 0b10,
750 ep_in_complete: 0
751 }
752 );
753
754 let data2 = b"some other important data";
755 let result = bus.impl_write(ep_in, data2);
756 assert_eq!(result, Ok(data2.len()));
757
758 assert_eq!(
759 bus.impl_poll(),
760 PollResult::Data {
761 ep_out: 0b10,
762 ep_in_complete: 0b00
763 }
764 );
765 assert_eq!(
766 bus.impl_poll(),
767 PollResult::Data {
768 ep_out: 0b10,
769 ep_in_complete: 0
770 }
771 );
772
773 let received = rx.recv().unwrap();
774 assert_eq!(&received, data2);
775
776 assert_eq!(
777 bus.impl_poll(),
778 PollResult::Data {
779 ep_out: 0b10,
780 ep_in_complete: 0b10,
781 }
782 );
783 assert_eq!(
784 bus.impl_poll(),
785 PollResult::Data {
786 ep_out: 0b10,
787 ep_in_complete: 0
788 }
789 );
790
791 let mut buffer = [0; 1024];
792 assert_eq!(bus.impl_read(ep_out, &mut buffer), Ok(data1.len()));
793 assert_eq!(data1, &buffer[..data1.len()]);
794
795 assert_eq!(bus.impl_poll(), PollResult::None);
796 }
797
798 #[test]
799 fn test_bus_enable() {
800 let mut bus = Bus::default();
801 bus.impl_enable();
802
803 let err = alloc_err(&mut bus, UsbDirection::In, None, EndpointType::Interrupt);
804 assert_eq!(err, Error::InvalidState);
805 }
806}