Skip to main content

usb_gadget/function/
util.rs

1//! Utils for implementing USB gadget functions.
2
3use std::{
4    collections::HashMap,
5    ffi::{OsStr, OsString},
6    fmt, fs,
7    io::{Error, ErrorKind, Result},
8    os::unix::prelude::{OsStrExt, OsStringExt},
9    path::{Component, Path, PathBuf},
10    sync::{Arc, Mutex, MutexGuard, Once, OnceLock},
11};
12
13use crate::{function::register_remove_handlers, trim_os_str};
14
15/// USB gadget function.
16pub trait Function: fmt::Debug + Send + Sync + 'static {
17    /// Name of the function driver.
18    fn driver(&self) -> OsString;
19
20    /// Function directory.
21    fn dir(&self) -> FunctionDir;
22
23    /// Register the function in configfs at the specified path.
24    fn register(&self) -> Result<()>;
25
26    /// Notifies the function that the USB gadget is about to be removed.
27    fn pre_removal(&self) -> Result<()> {
28        Ok(())
29    }
30
31    /// Notifies the function that the USB gadget has been removed.
32    fn post_removal(&self, _dir: &Path) -> Result<()> {
33        Ok(())
34    }
35}
36
37/// USB function registration state.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
39pub enum State {
40    /// Function is not registered.
41    Unregistered,
42    /// Function is registered but not bound to UDC.
43    Registered,
44    /// Function is registered and bound to UDC.
45    Bound,
46    /// Function was removed and will stay in this state.
47    Removed,
48}
49
50/// Provides access to the status of a USB function.
51#[derive(Clone, Debug)]
52pub struct Status(FunctionDir);
53
54impl Status {
55    /// Registration state.
56    pub fn state(&self) -> State {
57        let inner = self.0.inner.lock().unwrap();
58        match (&inner.dir, inner.dir_was_set, inner.bound) {
59            (None, false, _) => State::Unregistered,
60            (None, true, _) => State::Removed,
61            (Some(_), _, false) => State::Registered,
62            (Some(_), _, true) => State::Bound,
63        }
64    }
65
66    /// Waits for the function to be bound to a UDC.
67    ///
68    /// Returns with a broken pipe error if gadget is removed.
69    #[cfg(feature = "tokio")]
70    pub async fn bound(&self) -> Result<()> {
71        loop {
72            let notifier = self.0.notify.notified();
73            match self.state() {
74                State::Bound => return Ok(()),
75                State::Removed => return Err(Error::new(ErrorKind::BrokenPipe, "gadget was removed")),
76                _ => (),
77            }
78            notifier.await;
79        }
80    }
81
82    /// Waits for the function to be unbound from a UDC.
83    #[cfg(feature = "tokio")]
84    pub async fn unbound(&self) {
85        loop {
86            let notifier = self.0.notify.notified();
87            if self.state() != State::Bound {
88                return;
89            }
90            notifier.await;
91        }
92    }
93
94    /// The USB gadget function directory in configfs, if registered.
95    pub fn path(&self) -> Option<PathBuf> {
96        self.0.inner.lock().unwrap().dir.clone()
97    }
98}
99
100/// USB gadget function directory container.
101///
102/// Stores the directory in configfs of a USB function and provides access methods.
103#[derive(Clone)]
104pub struct FunctionDir {
105    inner: Arc<Mutex<FunctionDirInner>>,
106    #[cfg(feature = "tokio")]
107    notify: Arc<tokio::sync::Notify>,
108}
109
110#[derive(Debug, Default)]
111struct FunctionDirInner {
112    dir: Option<PathBuf>,
113    dir_was_set: bool,
114    bound: bool,
115}
116
117impl fmt::Debug for FunctionDir {
118    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119        f.debug_tuple("FunctionDir").field(&*self.inner.lock().unwrap()).finish()
120    }
121}
122
123impl Default for FunctionDir {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl FunctionDir {
130    /// Creates an empty function directory container.
131    pub fn new() -> Self {
132        Self {
133            inner: Arc::new(Mutex::new(FunctionDirInner::default())),
134            #[cfg(feature = "tokio")]
135            notify: Arc::new(tokio::sync::Notify::new()),
136        }
137    }
138
139    pub(crate) fn set_dir(&self, function_dir: &Path) {
140        let mut inner = self.inner.lock().unwrap();
141        inner.dir = Some(function_dir.to_path_buf());
142        inner.dir_was_set = true;
143
144        #[cfg(feature = "tokio")]
145        self.notify.notify_waiters();
146    }
147
148    pub(crate) fn reset_dir(&self) {
149        self.inner.lock().unwrap().dir = None;
150
151        #[cfg(feature = "tokio")]
152        self.notify.notify_waiters();
153    }
154
155    pub(crate) fn set_bound(&self, bound: bool) {
156        self.inner.lock().unwrap().bound = bound;
157
158        #[cfg(feature = "tokio")]
159        self.notify.notify_waiters();
160    }
161
162    /// Create status accessor.
163    pub fn status(&self) -> Status {
164        Status(self.clone())
165    }
166
167    /// The USB gadget function directory in configfs.
168    pub fn dir(&self) -> Result<PathBuf> {
169        self.inner
170            .lock()
171            .unwrap()
172            .dir
173            .clone()
174            .ok_or_else(|| Error::new(ErrorKind::NotFound, "USB function not registered"))
175    }
176
177    /// Driver name.
178    pub fn driver(&self) -> Result<OsString> {
179        let dir = self.dir()?;
180        let (driver, _instance) = split_function_dir(&dir)
181            .ok_or_else(|| Error::new(ErrorKind::InvalidData, "invalid function directory name"))?;
182        Ok(driver.to_os_string())
183    }
184
185    /// Instance name.
186    pub fn instance(&self) -> Result<OsString> {
187        let dir = self.dir()?;
188        let (_driver, instance) = split_function_dir(&dir)
189            .ok_or_else(|| Error::new(ErrorKind::InvalidData, "invalid function directory name"))?;
190        Ok(instance.to_os_string())
191    }
192
193    /// Path to the specified property.
194    pub fn property_path(&self, name: impl AsRef<Path>) -> Result<PathBuf> {
195        let path = name.as_ref();
196        if path.components().all(|c| matches!(c, Component::Normal(_))) {
197            Ok(self.dir()?.join(path))
198        } else {
199            Err(Error::new(ErrorKind::InvalidInput, "property path must be relative"))
200        }
201    }
202
203    /// Create a subdirectory.
204    pub fn create_dir(&self, name: impl AsRef<Path>) -> Result<()> {
205        let path = self.property_path(name)?;
206        log::debug!("creating directory {}", path.display());
207        fs::create_dir(path)
208    }
209
210    /// Create a subdirectory and its parent directories.
211    pub fn create_dir_all(&self, name: impl AsRef<Path>) -> Result<()> {
212        let path = self.property_path(name)?;
213        log::debug!("creating directories {}", path.display());
214        fs::create_dir_all(path)
215    }
216
217    /// Remove a subdirectory.
218    pub fn remove_dir(&self, name: impl AsRef<Path>) -> Result<()> {
219        let path = self.property_path(name)?;
220        log::debug!("removing directory {}", path.display());
221        fs::remove_dir(path)
222    }
223
224    /// Read a binary property.
225    pub fn read(&self, name: impl AsRef<Path>) -> Result<Vec<u8>> {
226        let path = self.property_path(name)?;
227        let res = fs::read(&path);
228
229        match &res {
230            Ok(value) => {
231                log::debug!("read property {} with value {}", path.display(), String::from_utf8_lossy(value))
232            }
233            Err(err) => log::debug!("reading property {} failed: {}", path.display(), err),
234        }
235
236        res
237    }
238
239    /// Read and trim a string property.
240    pub fn read_string(&self, name: impl AsRef<Path>) -> Result<String> {
241        let mut data = self.read(name)?;
242        while data.last() == Some(&b'\0') || data.last() == Some(&b' ') || data.last() == Some(&b'\n') {
243            data.truncate(data.len() - 1);
244        }
245
246        Ok(String::from_utf8(data).map_err(|err| Error::new(ErrorKind::InvalidData, err))?.trim().to_string())
247    }
248
249    /// Read an trim an OS string property.
250    pub fn read_os_string(&self, name: impl AsRef<Path>) -> Result<OsString> {
251        Ok(trim_os_str(&OsString::from_vec(self.read(name)?)).to_os_string())
252    }
253
254    /// Write a property.
255    pub fn write(&self, name: impl AsRef<Path>, value: impl AsRef<[u8]>) -> Result<()> {
256        let path = self.property_path(name)?;
257        let value = value.as_ref();
258        log::debug!("setting property {} to {}", path.display(), String::from_utf8_lossy(value));
259        fs::write(path, value)
260    }
261
262    /// Create a symbolic link.
263    pub fn symlink(&self, target: impl AsRef<Path>, link: impl AsRef<Path>) -> Result<()> {
264        let target = self.property_path(target)?;
265        let link = self.property_path(link)?;
266        log::debug!("creating symlink {} -> {}", link.display(), target.display());
267        std::os::unix::fs::symlink(target, link)
268    }
269
270    /// Device major and minor numbers.
271    ///
272    /// Reads the `dev` property from configfs, which contains the device
273    /// numbers in the `major:minor` format.
274    ///
275    /// Not all function types expose a `dev` property.
276    pub fn dev_numbers(&self) -> Result<(u32, u32)> {
277        let dev = self.read_string("dev")?;
278        let Some((major, minor)) = dev.split_once(':') else {
279            return Err(Error::new(ErrorKind::InvalidData, "invalid device number format"));
280        };
281        let major = major.parse().map_err(|err| Error::new(ErrorKind::InvalidData, err))?;
282        let minor = minor.parse().map_err(|err| Error::new(ErrorKind::InvalidData, err))?;
283        Ok((major, minor))
284    }
285
286    /// The device path in `/dev` that corresponds to this function.
287    ///
288    /// Not all function types expose a `dev` property.
289    pub fn dev_path(&self) -> Result<PathBuf> {
290        let (major, minor) = self.dev_numbers()?;
291        dev_path(major, minor)
292    }
293}
294
295/// Find the device path in `/dev` for the given major and minor numbers.
296///
297/// Scans `/dev` (non-recursively) for a character device with matching
298/// device numbers.
299pub fn dev_path(major: u32, minor: u32) -> Result<PathBuf> {
300    use std::os::{linux::fs::MetadataExt, unix::fs::FileTypeExt};
301
302    let target = rustix::fs::makedev(major, minor);
303
304    for entry in fs::read_dir("/dev")? {
305        let Ok(entry) = entry else { continue };
306        let Ok(meta) = entry.metadata() else { continue };
307        if meta.file_type().is_char_device() && meta.st_rdev() == target {
308            return Ok(entry.path());
309        }
310    }
311
312    Err(Error::new(ErrorKind::NotFound, format!("no device file found in /dev for {major}:{minor}")))
313}
314
315/// Split configfs function directory path into driver name and instance name.
316pub fn split_function_dir(function_dir: &Path) -> Option<(&OsStr, &OsStr)> {
317    let name = function_dir.file_name()?;
318    let name = name.as_bytes();
319
320    let dot = name.iter().enumerate().find_map(|(i, c)| if *c == b'.' { Some(i) } else { None })?;
321    let driver = &name[..dot];
322    let instance = &name[dot + 1..];
323
324    Some((OsStr::from_bytes(driver), OsStr::from_bytes(instance)))
325}
326
327/// Handler function for removing function instance.
328type RemoveHandler = Arc<dyn Fn(PathBuf) -> Result<()> + Send + Sync>;
329
330/// Registered handlers for removing function instances.
331static REMOVE_HANDLERS: OnceLock<Mutex<HashMap<OsString, RemoveHandler>>> = OnceLock::new();
332
333/// Registered handlers for removing function instances.
334fn remove_handlers() -> MutexGuard<'static, HashMap<OsString, RemoveHandler>> {
335    let handlers = REMOVE_HANDLERS.get_or_init(|| Mutex::new(HashMap::new()));
336    handlers.lock().unwrap()
337}
338
339/// Initializes handlers for removing function instances.
340pub(crate) fn init_remove_handlers() {
341    static ONCE: Once = Once::new();
342    ONCE.call_once(|| {
343        register_remove_handlers();
344    });
345}
346
347/// Register a function remove handler for the specified function driver.
348pub fn register_remove_handler(
349    driver: impl AsRef<OsStr>, handler: impl Fn(PathBuf) -> Result<()> + Send + Sync + 'static,
350) {
351    remove_handlers().insert(driver.as_ref().to_os_string(), Arc::new(handler));
352}
353
354/// Calls the remove handler for the function directory, if any is registered.
355pub(crate) fn call_remove_handler(function_dir: &Path) -> Result<()> {
356    let Some((driver, _)) = split_function_dir(function_dir) else {
357        return Err(Error::new(ErrorKind::InvalidInput, "invalid function directory"));
358    };
359
360    let handler_opt = remove_handlers().get(driver).cloned();
361    match handler_opt {
362        Some(handler) => handler(function_dir.to_path_buf()),
363        None => Ok(()),
364    }
365}
366
367/// Value channel.
368pub(crate) mod value {
369    use std::{
370        error::Error,
371        fmt,
372        fmt::Display,
373        io, mem,
374        sync::{mpsc, Mutex},
375    };
376
377    /// Value was already sent.
378    #[derive(Debug, Clone)]
379    pub struct AlreadySentError;
380
381    impl Display for AlreadySentError {
382        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
383            write!(f, "value already sent")
384        }
385    }
386
387    impl Error for AlreadySentError {}
388
389    /// Sender of value channel.
390    #[derive(Debug)]
391    pub struct Sender<T>(Mutex<Option<mpsc::Sender<T>>>);
392
393    impl<T> Sender<T> {
394        /// Sends a value.
395        ///
396        /// This can only be called once.
397        pub fn send(&self, value: T) -> Result<(), AlreadySentError> {
398            match self.0.lock().unwrap().take() {
399                Some(tx) => {
400                    let _ = tx.send(value);
401                    Ok(())
402                }
403                None => Err(AlreadySentError),
404            }
405        }
406    }
407
408    /// Value channel receive error.
409    #[derive(Debug, Clone)]
410    pub enum RecvError {
411        /// Value was not yet sent.
412        Empty,
413        /// Sender was dropped without sending a value.
414        Disconnected,
415        /// Value was taken.
416        Taken,
417    }
418
419    impl Display for RecvError {
420        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
421            match self {
422                RecvError::Empty => write!(f, "value was not yet sent"),
423                RecvError::Disconnected => write!(f, "no value was sent"),
424                RecvError::Taken => write!(f, "value was taken"),
425            }
426        }
427    }
428
429    impl Error for RecvError {}
430
431    impl From<mpsc::RecvError> for RecvError {
432        fn from(_err: mpsc::RecvError) -> Self {
433            Self::Disconnected
434        }
435    }
436
437    impl From<mpsc::TryRecvError> for RecvError {
438        fn from(err: mpsc::TryRecvError) -> Self {
439            match err {
440                mpsc::TryRecvError::Empty => Self::Empty,
441                mpsc::TryRecvError::Disconnected => Self::Disconnected,
442            }
443        }
444    }
445
446    impl From<RecvError> for io::Error {
447        fn from(err: RecvError) -> Self {
448            match err {
449                RecvError::Empty => io::Error::new(io::ErrorKind::WouldBlock, err),
450                RecvError::Disconnected => io::Error::new(io::ErrorKind::BrokenPipe, err),
451                RecvError::Taken => io::Error::other(err),
452            }
453        }
454    }
455
456    /// Receiver state.
457    #[derive(Default)]
458    enum State<T> {
459        Receiving(Mutex<mpsc::Receiver<T>>),
460        Received(T),
461        #[default]
462        Taken,
463    }
464
465    /// Receiver of value channel.
466    #[derive(Default)]
467    pub struct Receiver<T>(State<T>);
468
469    impl<T: fmt::Debug> fmt::Debug for Receiver<T> {
470        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
471            match &self.0 {
472                State::Receiving(_) => write!(f, "<uninit>"),
473                State::Received(v) => v.fmt(f),
474                State::Taken => write!(f, "<taken>"),
475            }
476        }
477    }
478
479    impl<T> Receiver<T> {
480        /// Get the value, if it has been sent.
481        pub fn get(&mut self) -> Result<&mut T, RecvError> {
482            match &mut self.0 {
483                State::Receiving(rx) => {
484                    let value = rx.get_mut().unwrap().try_recv()?;
485                    self.0 = State::Received(value);
486                }
487                State::Taken => return Err(RecvError::Taken),
488                _ => (),
489            }
490
491            let State::Received(value) = &mut self.0 else { unreachable!() };
492            Ok(value)
493        }
494
495        /// Wait for the value.
496        #[allow(dead_code)]
497        pub fn wait(&mut self) -> Result<&mut T, RecvError> {
498            if let State::Receiving(rx) = &mut self.0 {
499                let value = rx.get_mut().unwrap().recv()?;
500                self.0 = State::Received(value);
501            }
502
503            self.get()
504        }
505
506        /// Take the value, if it has been sent.
507        #[allow(dead_code)]
508        pub fn take(&mut self) -> Result<T, RecvError> {
509            self.get()?;
510
511            let State::Received(value) = mem::take(&mut self.0) else { unreachable!() };
512            Ok(value)
513        }
514    }
515
516    /// Creates a new value channel.
517    pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
518        let (tx, rx) = mpsc::channel();
519        (Sender(Mutex::new(Some(tx))), Receiver(State::Receiving(Mutex::new(rx))))
520    }
521}