1use crate::{
2 bits::roundup,
3 elf::{LoadingAction, ProgramMetadata},
4 machine::SupportMachine,
5 memory::{get_page_indices, Memory, FLAG_DIRTY},
6 Error, Register, RISCV_GENERAL_REGISTER_NUMBER, RISCV_PAGESIZE,
7};
8use bytes::Bytes;
9use serde::{Deserialize, Serialize};
10use std::cmp::min;
11use std::collections::HashMap;
12
13const PAGE_SIZE: u64 = RISCV_PAGESIZE as u64;
14
15pub trait DataSource<I: Clone + PartialEq> {
27 fn load_data(&self, id: &I, offset: u64, length: u64) -> Option<(Bytes, u64)>;
28}
29
30#[derive(Clone, Debug)]
31pub struct Snapshot2Context<I: Clone + PartialEq, D: DataSource<I>> {
32 pages: HashMap<u64, (I, u64, u8)>,
34 data_source: D,
35}
36
37impl<I: Clone + PartialEq, D: DataSource<I> + Default> Default for Snapshot2Context<I, D> {
38 fn default() -> Self {
39 Self::new(D::default())
40 }
41}
42
43impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
44 pub fn new(data_source: D) -> Self {
45 Self {
46 pages: HashMap::default(),
47 data_source,
48 }
49 }
50
51 pub fn resume<M: SupportMachine>(
53 &mut self,
54 machine: &mut M,
55 snapshot: &Snapshot2<I>,
56 ) -> Result<(), Error> {
57 if machine.version() != snapshot.version {
58 return Err(Error::InvalidVersion);
59 }
60 self.pages.clear();
62 for (i, v) in snapshot.registers.iter().enumerate() {
63 machine.set_register(i, M::REG::from_u64(*v));
64 }
65 machine.update_pc(M::REG::from_u64(snapshot.pc));
66 machine.commit_pc();
67 machine.set_cycles(snapshot.cycles);
68 machine.set_max_cycles(snapshot.max_cycles);
69 for (address, flag, id, offset, length) in &snapshot.pages_from_source {
70 if address % PAGE_SIZE != 0 {
71 return Err(Error::MemPageUnalignedAccess);
72 }
73 let (data, _) = self.load_data(id, *offset, *length)?;
74 if data.len() as u64 % PAGE_SIZE != 0 {
75 return Err(Error::MemPageUnalignedAccess);
76 }
77 machine.memory_mut().store_bytes(*address, &data)?;
78 for i in 0..(data.len() as u64 / PAGE_SIZE) {
79 let page = address / PAGE_SIZE + i;
80 machine.memory_mut().set_flag(page, *flag)?;
81 }
82 self.track_pages(machine, *address, data.len() as u64, id, *offset)?;
83 }
84 for (address, flag, content) in &snapshot.dirty_pages {
85 if address % PAGE_SIZE != 0 {
86 return Err(Error::MemPageUnalignedAccess);
87 }
88 if content.len() as u64 % PAGE_SIZE != 0 {
89 return Err(Error::MemPageUnalignedAccess);
90 }
91 machine.memory_mut().store_bytes(*address, content)?;
92 for i in 0..(content.len() as u64 / PAGE_SIZE) {
93 let page = address / PAGE_SIZE + i;
94 machine.memory_mut().set_flag(page, *flag)?;
95 }
96 }
97 machine
98 .memory_mut()
99 .set_lr(&M::REG::from_u64(snapshot.load_reservation_address));
100 Ok(())
101 }
102
103 pub fn load_data(&self, id: &I, offset: u64, length: u64) -> Result<(Bytes, u64), Error> {
104 self.data_source
105 .load_data(id, offset, length)
106 .ok_or(Error::SnapshotDataLoadError)
107 }
108
109 pub fn store_bytes<M: SupportMachine>(
114 &mut self,
115 machine: &mut M,
116 addr: u64,
117 id: &I,
118 offset: u64,
119 length: u64,
120 size_addr: u64,
121 ) -> Result<(u64, u64), Error> {
122 let (data, full_length) = self.load_data(id, offset, length)?;
123 machine
124 .memory_mut()
125 .store64(&M::REG::from_u64(size_addr), &M::REG::from_u64(full_length))?;
126 self.untrack_pages(machine, addr, data.len() as u64)?;
127 machine.memory_mut().store_bytes(addr, &data)?;
128 self.track_pages(machine, addr, data.len() as u64, id, offset)?;
129 Ok((data.len() as u64, full_length))
130 }
131
132 pub fn mark_program<M: SupportMachine>(
145 &mut self,
146 machine: &mut M,
147 metadata: &ProgramMetadata,
148 id: &I,
149 offset: u64,
150 ) -> Result<(), Error> {
151 for action in &metadata.actions {
152 self.init_pages(machine, action, id, offset)?;
153 }
154 Ok(())
155 }
156
157 pub fn make_snapshot<M: SupportMachine>(&self, machine: &mut M) -> Result<Snapshot2<I>, Error> {
159 let mut dirty_pages: Vec<(u64, u8, Vec<u8>)> = vec![];
160 for i in 0..machine.memory().memory_pages() as u64 {
161 let flag = machine.memory_mut().fetch_flag(i)?;
162 if flag & FLAG_DIRTY == 0 {
163 continue;
164 }
165 let address = i * PAGE_SIZE;
166 let mut data: Vec<u8> = machine.memory_mut().load_bytes(address, PAGE_SIZE)?.into();
167 if let Some(last) = dirty_pages.last_mut() {
168 if last.0 + last.2.len() as u64 == address && last.1 == flag {
169 last.2.append(&mut data);
170 }
171 }
172 if !data.is_empty() {
173 dirty_pages.push((address, flag, data));
174 }
175 }
176 let mut pages_from_source: Vec<(u64, u8, I, u64, u64)> = vec![];
177 let mut pages: Vec<u64> = self.pages.keys().copied().collect();
178 pages.sort_unstable();
179 for page in pages {
180 if machine.memory_mut().fetch_flag(page)? & FLAG_DIRTY != 0 {
184 continue;
185 }
186 let address = page * PAGE_SIZE;
187 let (id, offset, flag) = &self.pages[&page];
188 let mut appended_to_last = false;
189 if let Some((last_address, last_flag, last_id, last_offset, last_length)) =
190 pages_from_source.last_mut()
191 {
192 if *last_address + *last_length == address
193 && *last_flag == *flag
194 && *last_id == *id
195 && *last_offset + *last_length == *offset
196 {
197 *last_length += PAGE_SIZE;
198 appended_to_last = true;
199 }
200 }
201 if !appended_to_last {
202 pages_from_source.push((address, *flag, id.clone(), *offset, PAGE_SIZE));
203 }
204 }
205 let mut registers = [0u64; RISCV_GENERAL_REGISTER_NUMBER];
206 for (i, v) in machine.registers().iter().enumerate() {
207 registers[i] = v.to_u64();
208 }
209 Ok(Snapshot2 {
210 pages_from_source,
211 dirty_pages,
212 version: machine.version(),
213 registers,
214 pc: machine.pc().to_u64(),
215 cycles: machine.cycles(),
216 max_cycles: machine.max_cycles(),
217 load_reservation_address: machine.memory().lr().to_u64(),
218 })
219 }
220
221 fn init_pages<M: SupportMachine>(
222 &mut self,
223 machine: &mut M,
224 action: &LoadingAction,
225 id: &I,
226 offset: u64,
227 ) -> Result<(), Error> {
228 let start = action.addr + action.offset_from_addr;
229 let length = min(
230 action.source.end - action.source.start,
231 action.size - action.offset_from_addr,
232 );
233 self.track_pages(machine, start, length, id, offset + action.source.start)
234 }
235
236 pub fn track_pages<M: SupportMachine>(
239 &mut self,
240 machine: &mut M,
241 start: u64,
242 mut length: u64,
243 id: &I,
244 mut offset: u64,
245 ) -> Result<(), Error> {
246 let mut aligned_start = roundup(start, PAGE_SIZE);
247 let aligned_bytes = aligned_start - start;
248 if length < aligned_bytes {
249 return Ok(());
250 }
251 offset += aligned_bytes;
252 length -= aligned_bytes;
253 while length >= PAGE_SIZE {
254 let page = aligned_start / PAGE_SIZE;
255 machine.memory_mut().clear_flag(page, FLAG_DIRTY)?;
256 let flag = machine.memory_mut().fetch_flag(page)?;
257 self.pages.insert(page, (id.clone(), offset, flag));
258 aligned_start += PAGE_SIZE;
259 length -= PAGE_SIZE;
260 offset += PAGE_SIZE;
261 }
262 Ok(())
263 }
264
265 pub fn untrack_pages<M: SupportMachine>(
266 &mut self,
267 machine: &mut M,
268 start: u64,
269 length: u64,
270 ) -> Result<(), Error> {
271 if length == 0 {
272 return Ok(());
273 }
274 let page_indices = get_page_indices(start, length)?;
275 for page in page_indices.0..=page_indices.1 {
276 machine.memory_mut().set_flag(page, FLAG_DIRTY)?;
277 self.pages.remove(&page);
278 }
279 Ok(())
280 }
281}
282
283#[derive(Clone, Debug, Deserialize, Serialize)]
284pub struct Snapshot2<I: Clone + PartialEq> {
285 pub pages_from_source: Vec<(u64, u8, I, u64, u64)>,
287 pub dirty_pages: Vec<(u64, u8, Vec<u8>)>,
289 pub version: u32,
290 pub registers: [u64; RISCV_GENERAL_REGISTER_NUMBER],
291 pub pc: u64,
292 pub cycles: u64,
293 pub max_cycles: u64,
294 pub load_reservation_address: u64,
295}