1#![allow(unsafe_code)]
2
3use core::{cell::UnsafeCell, marker::PhantomData};
4
5use bitmaps::{Bits, BitsImpl};
6
7use crate::{
8 ShadowError,
9 persist::PersistTrigger,
10 policy::{AccessPolicy, PersistPolicy},
11 shadow::{HostShadow, KernelShadow},
12 table::ShadowTable,
13 types::StagingBuffer,
14};
15
16pub struct NoStage;
18
19pub struct WithStage<SB: StagingBuffer> {
21 pub(crate) sb: SB,
22}
23
24pub struct ShadowStorageBase<const TS: usize, const BS: usize, const BC: usize, AP, PP, PT, PK, SS>
38where
39 AP: AccessPolicy,
40 PP: PersistPolicy<PK>,
41 PT: PersistTrigger<PK>,
42 BitsImpl<BC>: Bits,
43{
44 pub(crate) table: UnsafeCell<ShadowTable<TS, BS, BC>>,
45 pub(crate) access_policy: AP,
46 pub(crate) persist_policy: PP,
47 pub(crate) persist_trigger: UnsafeCell<PT>,
48 pub(crate) stage_state: UnsafeCell<SS>,
49 _phantom: PhantomData<PK>,
50}
51
52pub type ShadowStorage<const TS: usize, const BS: usize, const BC: usize, AP, PP, PT, PK> =
54 ShadowStorageBase<TS, BS, BC, AP, PP, PT, PK, NoStage>;
55
56impl<const TS: usize, const BS: usize, const BC: usize, AP, PP, PT, PK>
57 ShadowStorageBase<TS, BS, BC, AP, PP, PT, PK, NoStage>
58where
59 AP: AccessPolicy,
60 PP: PersistPolicy<PK>,
61 PT: PersistTrigger<PK>,
62 BitsImpl<BC>: Bits,
63{
64 pub fn new(policy: AP, persist: PP, trigger: PT) -> Self {
65 Self {
66 table: UnsafeCell::new(ShadowTable::new()),
67 access_policy: policy,
68 persist_policy: persist,
69 persist_trigger: UnsafeCell::new(trigger),
70 stage_state: UnsafeCell::new(NoStage),
71 _phantom: PhantomData,
72 }
73 }
74
75 pub fn with_staging<SB: StagingBuffer>(
77 self,
78 sb: SB,
79 ) -> ShadowStorageBase<TS, BS, BC, AP, PP, PT, PK, WithStage<SB>> {
80 ShadowStorageBase {
81 table: self.table,
82 access_policy: self.access_policy,
83 persist_policy: self.persist_policy,
84 persist_trigger: self.persist_trigger,
85 stage_state: UnsafeCell::new(WithStage { sb }),
86 _phantom: PhantomData,
87 }
88 }
89}
90
91pub type WriteFn = dyn FnMut(u16, &[u8]) -> Result<(), ShadowError>;
93
94impl<const TS: usize, const BS: usize, const BC: usize, AP, PP, PT, PK, SS>
95 ShadowStorageBase<TS, BS, BC, AP, PP, PT, PK, SS>
96where
97 AP: AccessPolicy,
98 PP: PersistPolicy<PK>,
99 PT: PersistTrigger<PK>,
100 BitsImpl<BC>: Bits,
101{
102 pub fn host_shadow(&self) -> HostShadow<'_, TS, BS, BC, AP, PP, PT, PK, SS> {
103 HostShadow::new(self)
104 }
105
106 pub fn kernel_shadow(&self) -> KernelShadow<'_, TS, BS, BC, AP, PP, PT, PK, SS> {
107 KernelShadow::new(self)
108 }
109
110 pub unsafe fn load_defaults_unchecked(
120 &self,
121 f: impl FnOnce(&mut WriteFn) -> Result<(), ShadowError>,
122 ) -> Result<(), ShadowError> {
123 let table = unsafe { &mut *self.table.get() };
124 let mut write = |addr: u16, data: &[u8]| table.write_range(addr, data);
125 f(&mut write)
126 }
127
128 pub fn load_defaults(
135 &self,
136 f: impl FnOnce(&mut WriteFn) -> Result<(), ShadowError>,
137 ) -> Result<(), ShadowError> {
138 critical_section::with(|_| unsafe { self.load_defaults_unchecked(f) })
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use crate::test_support::test_storage;
145
146 #[test]
147 fn load_defaults_writes_data_without_marking_dirty() {
148 let storage = test_storage();
149
150 storage
151 .load_defaults(|write| {
152 write(0, &[0x11, 0x22, 0x33, 0x44])?;
153 write(32, &[0xAA, 0xBB, 0xCC, 0xDD])?;
154 Ok(())
155 })
156 .unwrap();
157
158 storage.host_shadow().with_view(|view| {
160 let mut buf = [0u8; 4];
161 view.read_range(0, &mut buf).unwrap();
162 assert_eq!(buf, [0x11, 0x22, 0x33, 0x44]);
163
164 view.read_range(32, &mut buf).unwrap();
165 assert_eq!(buf, [0xAA, 0xBB, 0xCC, 0xDD]);
166 });
167
168 storage.kernel_shadow().with_view(|view| {
170 assert!(!view.any_dirty());
171 });
172 }
173
174 #[test]
175 fn load_defaults_multiple_ranges() {
176 let storage = test_storage();
177
178 storage
179 .load_defaults(|write| {
180 for i in 0..4 {
181 let addr = i * 16;
182 write(addr, &[i as u8; 4])?;
183 }
184 Ok(())
185 })
186 .unwrap();
187
188 storage.host_shadow().with_view(|view| {
190 for i in 0..4 {
191 let addr = i * 16;
192 let mut buf = [0u8; 4];
193 view.read_range(addr, &mut buf).unwrap();
194 assert_eq!(buf, [i as u8; 4]);
195 }
196 });
197 }
198
199 #[test]
200 fn load_defaults_error_propagates() {
201 let storage = test_storage();
202
203 let result = storage.load_defaults(|write| {
204 write(0, &[0x11; 4])?;
205 write(100, &[0xAA; 4])
207 });
208
209 assert!(result.is_err());
210 }
211
212 #[test]
213 fn normal_writes_work_after_load_defaults() {
214 let storage = test_storage();
215
216 storage
218 .load_defaults(|write| {
219 write(0, &[0x11, 0x22, 0x33, 0x44])?;
220 Ok(())
221 })
222 .unwrap();
223
224 storage.host_shadow().with_view(|view| {
226 view.write_range(0, &[0xAA, 0xBB, 0xCC, 0xDD]).unwrap();
227 });
228
229 storage.kernel_shadow().with_view(|view| {
231 assert!(view.any_dirty());
232 assert!(view.is_dirty(0, 4).unwrap());
233 });
234 }
235}