1use core::{
2 any::Any,
3 ops::{Deref, DerefMut},
4 sync::atomic::{AtomicI64, Ordering},
5};
6
7use alloc::{
8 boxed::Box,
9 sync::{Arc, Weak},
10};
11
12use crate::{Descriptor, Pid, driver::Class, get_pid};
13
14pub struct DeviceOwner {
15 lock: Arc<LockInner>,
16}
17
18impl DeviceOwner {
19 pub fn new<T: Class>(descriptor: Descriptor, device: T) -> Self {
20 Self {
21 lock: Arc::new(LockInner::new(descriptor, Box::into_raw(Box::new(device)))),
22 }
23 }
24
25 pub fn weak<T: Class>(&self) -> Result<Device<T>, GetDeviceError> {
26 Device::new(&self.lock)
27 }
28
29 pub fn is<T: Class>(&self) -> bool {
30 unsafe { &*self.lock.ptr }.is::<T>()
31 }
32}
33
34impl Drop for LockInner {
35 fn drop(&mut self) {
36 unsafe {
37 let ptr = self.ptr;
38 let _ = Box::from_raw(ptr);
39 }
40 }
41}
42
43struct LockInner {
44 borrowed: AtomicI64,
45 ptr: *mut dyn Any,
46 descriptor: Descriptor,
47}
48
49unsafe impl Send for LockInner {}
50unsafe impl Sync for LockInner {}
51
52impl LockInner {
53 fn new(descriptor: Descriptor, ptr: *mut dyn Any) -> Self {
54 Self {
55 borrowed: AtomicI64::new(-1),
56 ptr,
57 descriptor,
58 }
59 }
60
61 pub fn try_lock(self: &Arc<Self>, pid: Pid) -> Result<(), GetDeviceError> {
62 let mut pid = pid;
63 if pid.is_not_set() {
64 pid = Pid::INVALID.into();
65 }
66
67 let id: usize = pid.into();
68
69 match self.borrowed.compare_exchange(
70 Pid::NOT_SET as _,
71 id as _,
72 Ordering::Acquire,
73 Ordering::Relaxed,
74 ) {
75 Ok(_) => Ok(()),
76 Err(old) => {
77 if old as usize == Pid::INVALID {
78 Err(GetDeviceError::UsedByUnknown)
79 } else {
80 let pid: Pid = (old as usize).into();
81 Err(GetDeviceError::UsedByOthers(pid))
82 }
83 }
84 }
85 }
86
87 pub fn lock(self: &Arc<Self>) -> Result<(), GetDeviceError> {
88 let pid = get_pid();
89 loop {
90 match self.try_lock(pid) {
91 Ok(guard) => return Ok(guard),
92 Err(GetDeviceError::UsedByOthers(_)) | Err(GetDeviceError::UsedByUnknown) => {
93 continue;
94 }
95 Err(e) => return Err(e),
96 }
97 }
98 }
99}
100
101pub struct DeviceGuard<T> {
102 lock: Arc<LockInner>,
103 ptr: *mut T,
104}
105
106unsafe impl<T> Send for DeviceGuard<T> {}
107
108impl<T> Drop for DeviceGuard<T> {
109 fn drop(&mut self) {
110 self.lock
111 .borrowed
112 .store(Pid::NOT_SET as _, Ordering::Release);
113 }
114}
115
116impl<T> Deref for DeviceGuard<T> {
117 type Target = T;
118
119 fn deref(&self) -> &Self::Target {
120 unsafe { &*self.ptr }
121 }
122}
123
124impl<T> DerefMut for DeviceGuard<T> {
125 fn deref_mut(&mut self) -> &mut Self::Target {
126 unsafe { &mut *self.ptr }
127 }
128}
129
130impl<T> DeviceGuard<T> {
131 pub fn descriptor(&self) -> &Descriptor {
132 &self.lock.descriptor
133 }
134}
135
136#[derive(Clone)]
137pub struct Device<T> {
138 lock: Weak<LockInner>,
139 descriptor: Descriptor,
140 ptr: *mut T,
141}
142
143unsafe impl<T> Send for Device<T> {}
144unsafe impl<T> Sync for Device<T> {}
145
146impl<T: Any> Device<T> {
147 fn new(lock: &Arc<LockInner>) -> Result<Self, GetDeviceError> {
148 let ptr = match unsafe { &*lock.ptr }.downcast_ref::<T>() {
149 Some(v) => v as *const T as *mut T,
150 None => return Err(GetDeviceError::TypeNotMatch),
151 };
152
153 Ok(Self {
154 lock: Arc::downgrade(lock),
155 descriptor: lock.descriptor.clone(),
156 ptr,
157 })
158 }
159
160 pub fn lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
161 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
162 lock.lock()?;
163
164 Ok(DeviceGuard {
165 lock,
166 ptr: self.ptr,
167 })
168 }
169 pub fn try_lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
170 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
171 lock.try_lock(get_pid())?;
172
173 Ok(DeviceGuard {
174 lock,
175 ptr: self.ptr,
176 })
177 }
178
179 pub fn descriptor(&self) -> &Descriptor {
180 &self.descriptor
181 }
182
183 pub unsafe fn force_use(&self) -> *mut T {
188 self.ptr
189 }
190}
191
192impl<T: Class> Device<T> {
193 pub fn downcast<T2: 'static>(&self) -> Result<Device<T2>, GetDeviceError> {
194 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
195
196 let t2_any = unsafe { &mut *self.ptr }
197 .raw_any_mut()
198 .ok_or(GetDeviceError::TypeNotMatch)?;
199
200 let t2_type = t2_any
201 .downcast_mut::<T2>()
202 .ok_or(GetDeviceError::TypeNotMatch)?;
203
204 Ok(Device {
205 lock: Arc::downgrade(&lock),
206 descriptor: self.descriptor.clone(),
207 ptr: t2_type as *mut T2,
208 })
209 }
210}
211
212#[derive(thiserror::Error, Debug, Clone, Copy)]
213pub enum GetDeviceError {
214 #[error("Used by pid: {0:?}")]
215 UsedByOthers(Pid),
216 #[error("Used by unknown pid")]
217 UsedByUnknown,
218 #[error("Device type not match")]
219 TypeNotMatch,
220 #[error("Device released")]
221 DeviceReleased,
222 #[error("Device not found")]
223 NotFound,
224}