rdrive/
lock.rs

1use core::{
2    any::Any,
3    ops::{Deref, DerefMut},
4    sync::atomic::{AtomicI64, Ordering},
5};
6
7use alloc::{
8    boxed::Box,
9    sync::{Arc, Weak},
10};
11
12use crate::{Descriptor, Pid, driver::Class, get_pid};
13
14pub struct DeviceOwner {
15    lock: Arc<LockInner>,
16}
17
18impl DeviceOwner {
19    pub fn new<T: Class>(descriptor: Descriptor, device: T) -> Self {
20        Self {
21            lock: Arc::new(LockInner::new(descriptor, Box::into_raw(Box::new(device)))),
22        }
23    }
24
25    pub fn weak<T: Class>(&self) -> Result<Device<T>, GetDeviceError> {
26        Device::new(&self.lock)
27    }
28
29    pub fn is<T: Class>(&self) -> bool {
30        unsafe { &*self.lock.ptr }.is::<T>()
31    }
32}
33
34impl Drop for LockInner {
35    fn drop(&mut self) {
36        unsafe {
37            let ptr = self.ptr;
38            let _ = Box::from_raw(ptr);
39        }
40    }
41}
42
43struct LockInner {
44    borrowed: AtomicI64,
45    ptr: *mut dyn Any,
46    descriptor: Descriptor,
47}
48
49unsafe impl Send for LockInner {}
50unsafe impl Sync for LockInner {}
51
52impl LockInner {
53    fn new(descriptor: Descriptor, ptr: *mut dyn Any) -> Self {
54        Self {
55            borrowed: AtomicI64::new(-1),
56            ptr,
57            descriptor,
58        }
59    }
60
61    pub fn try_lock(self: &Arc<Self>, pid: Pid) -> Result<(), GetDeviceError> {
62        let mut pid = pid;
63        if pid.is_not_set() {
64            pid = Pid::INVALID.into();
65        }
66
67        let id: usize = pid.into();
68
69        match self.borrowed.compare_exchange(
70            Pid::NOT_SET as _,
71            id as _,
72            Ordering::Acquire,
73            Ordering::Relaxed,
74        ) {
75            Ok(_) => Ok(()),
76            Err(old) => {
77                if old as usize == Pid::INVALID {
78                    Err(GetDeviceError::UsedByUnknown)
79                } else {
80                    let pid: Pid = (old as usize).into();
81                    Err(GetDeviceError::UsedByOthers(pid))
82                }
83            }
84        }
85    }
86
87    pub fn lock(self: &Arc<Self>) -> Result<(), GetDeviceError> {
88        let pid = get_pid();
89        loop {
90            match self.try_lock(pid) {
91                Ok(guard) => return Ok(guard),
92                Err(GetDeviceError::UsedByOthers(_)) | Err(GetDeviceError::UsedByUnknown) => {
93                    continue;
94                }
95                Err(e) => return Err(e),
96            }
97        }
98    }
99}
100
101pub struct DeviceGuard<T> {
102    lock: Arc<LockInner>,
103    ptr: *mut T,
104}
105
106unsafe impl<T> Send for DeviceGuard<T> {}
107
108impl<T> Drop for DeviceGuard<T> {
109    fn drop(&mut self) {
110        self.lock
111            .borrowed
112            .store(Pid::NOT_SET as _, Ordering::Release);
113    }
114}
115
116impl<T> Deref for DeviceGuard<T> {
117    type Target = T;
118
119    fn deref(&self) -> &Self::Target {
120        unsafe { &*self.ptr }
121    }
122}
123
124impl<T> DerefMut for DeviceGuard<T> {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        unsafe { &mut *self.ptr }
127    }
128}
129
130impl<T> DeviceGuard<T> {
131    pub fn descriptor(&self) -> &Descriptor {
132        &self.lock.descriptor
133    }
134}
135
136#[derive(Clone)]
137pub struct Device<T> {
138    lock: Weak<LockInner>,
139    descriptor: Descriptor,
140    ptr: *mut T,
141}
142
143unsafe impl<T> Send for Device<T> {}
144unsafe impl<T> Sync for Device<T> {}
145
146impl<T: Any> Device<T> {
147    fn new(lock: &Arc<LockInner>) -> Result<Self, GetDeviceError> {
148        let ptr = match unsafe { &*lock.ptr }.downcast_ref::<T>() {
149            Some(v) => v as *const T as *mut T,
150            None => return Err(GetDeviceError::TypeNotMatch),
151        };
152
153        Ok(Self {
154            lock: Arc::downgrade(lock),
155            descriptor: lock.descriptor.clone(),
156            ptr,
157        })
158    }
159
160    pub fn lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
161        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
162        lock.lock()?;
163
164        Ok(DeviceGuard {
165            lock,
166            ptr: self.ptr,
167        })
168    }
169    pub fn try_lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
170        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
171        lock.try_lock(get_pid())?;
172
173        Ok(DeviceGuard {
174            lock,
175            ptr: self.ptr,
176        })
177    }
178
179    pub fn descriptor(&self) -> &Descriptor {
180        &self.descriptor
181    }
182
183    /// 强制获取设备
184    ///
185    /// # Safety
186    /// 一般用于中断处理中
187    pub unsafe fn force_use(&self) -> *mut T {
188        self.ptr
189    }
190}
191
192impl<T: Class> Device<T> {
193    pub fn downcast<T2: 'static>(&self) -> Result<Device<T2>, GetDeviceError> {
194        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
195
196        let t2_any = unsafe { &mut *self.ptr }
197            .raw_any_mut()
198            .ok_or(GetDeviceError::TypeNotMatch)?;
199
200        let t2_type = t2_any
201            .downcast_mut::<T2>()
202            .ok_or(GetDeviceError::TypeNotMatch)?;
203
204        Ok(Device {
205            lock: Arc::downgrade(&lock),
206            descriptor: self.descriptor.clone(),
207            ptr: t2_type as *mut T2,
208        })
209    }
210}
211
212#[derive(thiserror::Error, Debug, Clone, Copy)]
213pub enum GetDeviceError {
214    #[error("Used by pid: {0:?}")]
215    UsedByOthers(Pid),
216    #[error("Used by unknown pid")]
217    UsedByUnknown,
218    #[error("Device type not match")]
219    TypeNotMatch,
220    #[error("Device released")]
221    DeviceReleased,
222    #[error("Device not found")]
223    NotFound,
224}