1#![forbid(unsafe_code)]
2
3pub use helpers::EnumResultHelpers;
13use parking_lot::RwLock;
14use pico_common::PicoError;
15use pico_common::{Driver, PicoResult};
16use pico_device::PicoDevice;
17use pico_driver::{kernel_driver, ArcDriver, DriverLoadError, LoadDriverExt, Resolution};
18use rayon::prelude::*;
19use std::{collections::HashMap, sync::Arc};
20use thiserror::Error;
21
22mod helpers;
23
24const PICO_VENDOR_ID: u16 = 0x0CE9;
25
26#[cfg_attr(feature = "serde", derive(serde::Serialize))]
28#[derive(Clone, Debug)]
29pub struct EnumeratedDevice {
30 #[cfg_attr(feature = "serde", serde(skip))]
31 driver: ArcDriver,
32 pub variant: String,
33 pub serial: String,
34}
35
36impl EnumeratedDevice {
37 pub(crate) fn new(driver: ArcDriver, variant: String, serial: String) -> Self {
38 EnumeratedDevice {
39 driver,
40 variant,
41 serial,
42 }
43 }
44
45 pub fn open(&self) -> PicoResult<PicoDevice> {
47 PicoDevice::try_open(&self.driver, Some(&self.serial))
48 }
49}
50
51#[cfg_attr(
52 feature = "serde",
53 derive(serde::Serialize, serde::Deserialize),
54 serde(tag = "type")
55)]
56#[derive(Error, Debug, Clone)]
57pub enum EnumerationError {
58 #[error("Pico driver error: {error}")]
59 DriverError {
60 driver: Driver,
61 #[source]
62 error: PicoError,
63 },
64
65 #[error("The {driver} driver could not find any devices. The Pico Technology kernel driver appears to be missing.")]
66 KernelDriverError { driver: Driver },
67
68 #[error("The {driver} driver could not be found or failed to load")]
69 DriverLoadError { driver: Driver, error: String },
70
71 #[error("Invalid Driver Version: Requires >= {required}, Found: {found}")]
72 VersionError {
73 driver: Driver,
74 found: String,
75 required: String,
76 },
77}
78
79impl EnumerationError {
80 pub fn from(driver: Driver, error: DriverLoadError) -> Self {
81 match error {
82 DriverLoadError::DriverError(error) => EnumerationError::DriverError { driver, error },
83 DriverLoadError::LibloadingError(error) => EnumerationError::DriverLoadError {
84 driver,
85 error: error.to_string(),
86 },
87 DriverLoadError::VersionError { found, required } => EnumerationError::VersionError {
88 driver,
89 found,
90 required,
91 },
92 }
93 }
94}
95
96#[derive(Clone, Default)]
108pub struct DeviceEnumerator {
109 resolution: Resolution,
110 loaded_drivers: Arc<RwLock<HashMap<Driver, ArcDriver>>>,
111}
112
113impl DeviceEnumerator {
114 pub fn new() -> Self {
115 DeviceEnumerator::with_resolution(Default::default())
116 }
117
118 #[tracing::instrument(level = "info")]
120 pub fn with_resolution(resolution: Resolution) -> Self {
121 DeviceEnumerator {
122 resolution,
123 loaded_drivers: Default::default(),
124 }
125 }
126
127 #[tracing::instrument(level = "debug")]
130 pub fn enumerate_raw() -> HashMap<Driver, usize> {
131 usb_enumeration::enumerate(Some(PICO_VENDOR_ID), None)
132 .iter()
133 .map(|d| Driver::from_pid(d.product_id))
134 .flatten()
135 .fold(HashMap::new(), |mut map, x| {
136 map.entry(x).and_modify(|count| *count += 1).or_insert(1);
137 map
138 })
139 }
140
141 #[tracing::instrument(level = "info", skip(self))]
142 pub fn enumerate(&self) -> Vec<Result<EnumeratedDevice, EnumerationError>> {
144 DeviceEnumerator::enumerate_raw()
145 .into_par_iter()
146 .flat_map(|(driver_type, device_count)| {
147 self.enumerate_driver(driver_type, Some(device_count))
148 })
149 .collect()
150 }
151
152 #[tracing::instrument(level = "debug", skip(self))]
154 fn enumerate_driver(
155 &self,
156 driver_type: Driver,
157 device_count: Option<usize>,
158 ) -> Vec<Result<EnumeratedDevice, EnumerationError>> {
159 let device_count = device_count.unwrap_or(1);
160
161 let driver = match self.get_or_load_driver(driver_type) {
162 Ok(driver) => driver,
163 Err(error) => {
164 return vec![Err(EnumerationError::from(driver_type, error)); device_count]
165 }
166 };
167
168 match driver.enumerate_units() {
169 Ok(results) => {
170 if results.is_empty() && kernel_driver::is_missing() {
174 vec![
175 Err(EnumerationError::KernelDriverError {
176 driver: driver_type,
177 });
178 device_count
179 ]
180 } else {
181 results
182 .into_iter()
183 .map(|r| Ok(EnumeratedDevice::new(driver.clone(), r.variant, r.serial)))
184 .collect()
185 }
186 }
187 Err(error) => vec![
188 Err(EnumerationError::DriverError {
189 driver: driver_type,
190 error,
191 });
192 device_count
193 ],
194 }
195 }
196
197 #[tracing::instrument(level = "debug", skip(self))]
198 fn get_or_load_driver(&self, driver_type: Driver) -> Result<ArcDriver, DriverLoadError> {
199 let driver = {
200 let loaded_drivers = self.loaded_drivers.read();
201 loaded_drivers.get(&driver_type).cloned()
202 };
203
204 match driver {
205 Some(driver) => Ok(driver),
206 None => match driver_type.try_load_with_resolution(&self.resolution) {
207 Ok(driver) => {
208 self.loaded_drivers
209 .write()
210 .insert(driver_type, driver.clone());
211
212 Ok(driver)
213 }
214 Err(e) => Err(e),
215 },
216 }
217 }
218}