rdrive/
lock.rs

1use core::{
2    any::Any,
3    ops::{Deref, DerefMut},
4    sync::atomic::{AtomicI64, Ordering},
5};
6
7use alloc::{
8    boxed::Box,
9    sync::{Arc, Weak},
10};
11use rdif_base::DriverGeneric;
12
13use crate::{Descriptor, Pid, get_pid};
14
15pub struct DeviceOwner {
16    lock: Arc<LockInner>,
17}
18
19impl DeviceOwner {
20    pub fn new<T: DriverGeneric>(descriptor: Descriptor, device: T) -> Self {
21        Self {
22            lock: Arc::new(LockInner::new(descriptor, Box::into_raw(Box::new(device)))),
23        }
24    }
25
26    pub fn weak<T: DriverGeneric>(&self) -> Result<Device<T>, GetDeviceError> {
27        Device::new(&self.lock)
28    }
29
30    pub fn is<T: DriverGeneric>(&self) -> bool {
31        unsafe { &*self.lock.ptr }.is::<T>()
32    }
33}
34
35impl Drop for LockInner {
36    fn drop(&mut self) {
37        unsafe {
38            let ptr = self.ptr;
39            let _ = Box::from_raw(ptr);
40        }
41    }
42}
43
44struct LockInner {
45    borrowed: AtomicI64,
46    ptr: *mut dyn Any,
47    descriptor: Descriptor,
48}
49
50unsafe impl Send for LockInner {}
51unsafe impl Sync for LockInner {}
52
53impl LockInner {
54    fn new(descriptor: Descriptor, ptr: *mut dyn Any) -> Self {
55        Self {
56            borrowed: AtomicI64::new(-1),
57            ptr,
58            descriptor,
59        }
60    }
61
62    pub fn try_lock(self: &Arc<Self>, pid: Pid) -> Result<(), GetDeviceError> {
63        let mut pid = pid;
64        if pid.is_not_set() {
65            pid = Pid::INVALID.into();
66        }
67
68        let id: usize = pid.into();
69
70        match self.borrowed.compare_exchange(
71            Pid::NOT_SET as _,
72            id as _,
73            Ordering::Acquire,
74            Ordering::Relaxed,
75        ) {
76            Ok(_) => Ok(()),
77            Err(old) => {
78                if old as usize == Pid::INVALID {
79                    Err(GetDeviceError::UsedByUnknown)
80                } else {
81                    let pid: Pid = (old as usize).into();
82                    Err(GetDeviceError::UsedByOthers(pid))
83                }
84            }
85        }
86    }
87
88    pub fn lock(self: &Arc<Self>) -> Result<(), GetDeviceError> {
89        let pid = get_pid();
90        loop {
91            match self.try_lock(pid) {
92                Ok(guard) => return Ok(guard),
93                Err(GetDeviceError::UsedByOthers(_)) | Err(GetDeviceError::UsedByUnknown) => {
94                    continue;
95                }
96                Err(e) => return Err(e),
97            }
98        }
99    }
100}
101
102pub struct DeviceGuard<T> {
103    lock: Arc<LockInner>,
104    ptr: *mut T,
105}
106
107unsafe impl<T> Send for DeviceGuard<T> {}
108
109impl<T> Drop for DeviceGuard<T> {
110    fn drop(&mut self) {
111        self.lock
112            .borrowed
113            .store(Pid::NOT_SET as _, Ordering::Release);
114    }
115}
116
117impl<T> Deref for DeviceGuard<T> {
118    type Target = T;
119
120    fn deref(&self) -> &Self::Target {
121        unsafe { &*self.ptr }
122    }
123}
124
125impl<T> DerefMut for DeviceGuard<T> {
126    fn deref_mut(&mut self) -> &mut Self::Target {
127        unsafe { &mut *self.ptr }
128    }
129}
130
131impl<T> DeviceGuard<T> {
132    pub fn descriptor(&self) -> &Descriptor {
133        &self.lock.descriptor
134    }
135}
136
137#[derive(Clone)]
138pub struct Device<T> {
139    lock: Weak<LockInner>,
140    descriptor: Descriptor,
141    ptr: *mut T,
142}
143
144unsafe impl<T> Send for Device<T> {}
145unsafe impl<T> Sync for Device<T> {}
146
147impl<T: Any> Device<T> {
148    fn new(lock: &Arc<LockInner>) -> Result<Self, GetDeviceError> {
149        let ptr = match unsafe { &*lock.ptr }.downcast_ref::<T>() {
150            Some(v) => v as *const T as *mut T,
151            None => return Err(GetDeviceError::TypeNotMatch),
152        };
153
154        Ok(Self {
155            lock: Arc::downgrade(lock),
156            descriptor: lock.descriptor.clone(),
157            ptr,
158        })
159    }
160
161    pub fn lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
162        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
163        lock.lock()?;
164
165        Ok(DeviceGuard {
166            lock,
167            ptr: self.ptr,
168        })
169    }
170    pub fn try_lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
171        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
172        lock.try_lock(get_pid())?;
173
174        Ok(DeviceGuard {
175            lock,
176            ptr: self.ptr,
177        })
178    }
179
180    pub fn descriptor(&self) -> &Descriptor {
181        &self.descriptor
182    }
183
184    pub fn type_name(&self) -> &'static str {
185        core::any::type_name::<T>()
186    }
187
188    /// 强制获取设备
189    ///
190    /// # Safety
191    /// 一般用于中断处理中
192    pub unsafe fn force_use(&self) -> *mut T {
193        self.ptr
194    }
195}
196
197impl<T: DriverGeneric> Device<T> {
198    pub fn downcast<T2: 'static>(&self) -> Result<Device<T2>, GetDeviceError> {
199        let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
200
201        let t2_any = unsafe { &mut *self.ptr }
202            .raw_any_mut()
203            .ok_or(GetDeviceError::TypeNotMatch)?;
204
205        let t2_type = t2_any
206            .downcast_mut::<T2>()
207            .ok_or(GetDeviceError::TypeNotMatch)?;
208
209        Ok(Device {
210            lock: Arc::downgrade(&lock),
211            descriptor: self.descriptor.clone(),
212            ptr: t2_type as *mut T2,
213        })
214    }
215}
216
217#[derive(thiserror::Error, Debug, Clone, Copy)]
218pub enum GetDeviceError {
219    #[error("Used by pid: {0:?}")]
220    UsedByOthers(Pid),
221    #[error("Used by unknown pid")]
222    UsedByUnknown,
223    #[error("Device type not match")]
224    TypeNotMatch,
225    #[error("Device released")]
226    DeviceReleased,
227    #[error("Device not found")]
228    NotFound,
229}