1use core::{
2 any::Any,
3 ops::{Deref, DerefMut},
4 sync::atomic::{AtomicI64, Ordering},
5};
6
7use alloc::{
8 boxed::Box,
9 sync::{Arc, Weak},
10};
11use rdif_base::DriverGeneric;
12
13use crate::{Descriptor, Pid, get_pid};
14
15pub struct DeviceOwner {
16 lock: Arc<LockInner>,
17}
18
19impl DeviceOwner {
20 pub fn new<T: DriverGeneric>(descriptor: Descriptor, device: T) -> Self {
21 Self {
22 lock: Arc::new(LockInner::new(descriptor, Box::into_raw(Box::new(device)))),
23 }
24 }
25
26 pub fn weak<T: DriverGeneric>(&self) -> Result<Device<T>, GetDeviceError> {
27 Device::new(&self.lock)
28 }
29
30 pub fn is<T: DriverGeneric>(&self) -> bool {
31 unsafe { &*self.lock.ptr }.is::<T>()
32 }
33}
34
35impl Drop for LockInner {
36 fn drop(&mut self) {
37 unsafe {
38 let ptr = self.ptr;
39 let _ = Box::from_raw(ptr);
40 }
41 }
42}
43
44struct LockInner {
45 borrowed: AtomicI64,
46 ptr: *mut dyn Any,
47 descriptor: Descriptor,
48}
49
50unsafe impl Send for LockInner {}
51unsafe impl Sync for LockInner {}
52
53impl LockInner {
54 fn new(descriptor: Descriptor, ptr: *mut dyn Any) -> Self {
55 Self {
56 borrowed: AtomicI64::new(-1),
57 ptr,
58 descriptor,
59 }
60 }
61
62 pub fn try_lock(self: &Arc<Self>, pid: Pid) -> Result<(), GetDeviceError> {
63 let mut pid = pid;
64 if pid.is_not_set() {
65 pid = Pid::INVALID.into();
66 }
67
68 let id: usize = pid.into();
69
70 match self.borrowed.compare_exchange(
71 Pid::NOT_SET as _,
72 id as _,
73 Ordering::Acquire,
74 Ordering::Relaxed,
75 ) {
76 Ok(_) => Ok(()),
77 Err(old) => {
78 if old as usize == Pid::INVALID {
79 Err(GetDeviceError::UsedByUnknown)
80 } else {
81 let pid: Pid = (old as usize).into();
82 Err(GetDeviceError::UsedByOthers(pid))
83 }
84 }
85 }
86 }
87
88 pub fn lock(self: &Arc<Self>) -> Result<(), GetDeviceError> {
89 let pid = get_pid();
90 loop {
91 match self.try_lock(pid) {
92 Ok(guard) => return Ok(guard),
93 Err(GetDeviceError::UsedByOthers(_)) | Err(GetDeviceError::UsedByUnknown) => {
94 continue;
95 }
96 Err(e) => return Err(e),
97 }
98 }
99 }
100}
101
102pub struct DeviceGuard<T> {
103 lock: Arc<LockInner>,
104 ptr: *mut T,
105}
106
107unsafe impl<T> Send for DeviceGuard<T> {}
108
109impl<T> Drop for DeviceGuard<T> {
110 fn drop(&mut self) {
111 self.lock
112 .borrowed
113 .store(Pid::NOT_SET as _, Ordering::Release);
114 }
115}
116
117impl<T> Deref for DeviceGuard<T> {
118 type Target = T;
119
120 fn deref(&self) -> &Self::Target {
121 unsafe { &*self.ptr }
122 }
123}
124
125impl<T> DerefMut for DeviceGuard<T> {
126 fn deref_mut(&mut self) -> &mut Self::Target {
127 unsafe { &mut *self.ptr }
128 }
129}
130
131impl<T> DeviceGuard<T> {
132 pub fn descriptor(&self) -> &Descriptor {
133 &self.lock.descriptor
134 }
135}
136
137#[derive(Clone)]
138pub struct Device<T> {
139 lock: Weak<LockInner>,
140 descriptor: Descriptor,
141 ptr: *mut T,
142}
143
144unsafe impl<T> Send for Device<T> {}
145unsafe impl<T> Sync for Device<T> {}
146
147impl<T: Any> Device<T> {
148 fn new(lock: &Arc<LockInner>) -> Result<Self, GetDeviceError> {
149 let ptr = match unsafe { &*lock.ptr }.downcast_ref::<T>() {
150 Some(v) => v as *const T as *mut T,
151 None => return Err(GetDeviceError::TypeNotMatch),
152 };
153
154 Ok(Self {
155 lock: Arc::downgrade(lock),
156 descriptor: lock.descriptor.clone(),
157 ptr,
158 })
159 }
160
161 pub fn lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
162 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
163 lock.lock()?;
164
165 Ok(DeviceGuard {
166 lock,
167 ptr: self.ptr,
168 })
169 }
170 pub fn try_lock(&self) -> Result<DeviceGuard<T>, GetDeviceError> {
171 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
172 lock.try_lock(get_pid())?;
173
174 Ok(DeviceGuard {
175 lock,
176 ptr: self.ptr,
177 })
178 }
179
180 pub fn descriptor(&self) -> &Descriptor {
181 &self.descriptor
182 }
183
184 pub fn type_name(&self) -> &'static str {
185 core::any::type_name::<T>()
186 }
187
188 pub unsafe fn force_use(&self) -> *mut T {
193 self.ptr
194 }
195}
196
197impl<T: DriverGeneric> Device<T> {
198 pub fn downcast<T2: 'static>(&self) -> Result<Device<T2>, GetDeviceError> {
199 let lock = self.lock.upgrade().ok_or(GetDeviceError::DeviceReleased)?;
200
201 let t2_any = unsafe { &mut *self.ptr }
202 .raw_any_mut()
203 .ok_or(GetDeviceError::TypeNotMatch)?;
204
205 let t2_type = t2_any
206 .downcast_mut::<T2>()
207 .ok_or(GetDeviceError::TypeNotMatch)?;
208
209 Ok(Device {
210 lock: Arc::downgrade(&lock),
211 descriptor: self.descriptor.clone(),
212 ptr: t2_type as *mut T2,
213 })
214 }
215}
216
217#[derive(thiserror::Error, Debug, Clone, Copy)]
218pub enum GetDeviceError {
219 #[error("Used by pid: {0:?}")]
220 UsedByOthers(Pid),
221 #[error("Used by unknown pid")]
222 UsedByUnknown,
223 #[error("Device type not match")]
224 TypeNotMatch,
225 #[error("Device released")]
226 DeviceReleased,
227 #[error("Device not found")]
228 NotFound,
229}