use crate::account::AccountView;
use crate::error::ProgramError;
use crate::layout::{HopperHeader, LayoutContract};
use crate::zerocopy::AccountLayout;
#[derive(Clone, Copy)]
pub struct MigrationEdge {
pub from_epoch: u32,
pub to_epoch: u32,
pub migrator: fn(body: &mut [u8]) -> Result<(), ProgramError>,
}
impl MigrationEdge {
pub const fn is_forward(&self) -> bool {
self.to_epoch > self.from_epoch
}
}
pub trait LayoutMigration {
const MIGRATIONS: &'static [MigrationEdge];
}
#[inline]
pub fn apply_pending_migrations<T>(
account: &AccountView,
current_epoch: u32,
) -> Result<u32, ProgramError>
where
T: AccountLayout + LayoutContract + LayoutMigration,
{
let target_epoch = <T as AccountLayout>::SCHEMA_EPOCH;
if current_epoch == target_epoch {
return Ok(0);
}
if current_epoch > target_epoch {
return Err(ProgramError::InvalidAccountData);
}
let edges = <T as LayoutMigration>::MIGRATIONS;
let mut applied = 0u32;
let mut epoch = current_epoch;
let mut data = account.try_borrow_mut()?;
let header_len = core::mem::size_of::<HopperHeader>();
if data.len() < header_len {
return Err(ProgramError::AccountDataTooSmall);
}
while epoch < target_epoch {
let edge = find_edge(edges, epoch)?;
let (header_bytes, body_bytes) = data.split_at_mut(header_len);
(edge.migrator)(body_bytes)?;
let new_epoch_bytes = edge.to_epoch.to_le_bytes();
header_bytes[12..16].copy_from_slice(&new_epoch_bytes);
epoch = edge.to_epoch;
applied += 1;
}
Ok(applied)
}
#[inline]
fn find_edge(edges: &[MigrationEdge], epoch: u32) -> Result<&MigrationEdge, ProgramError> {
for edge in edges {
if edge.from_epoch == epoch {
if !edge.is_forward() {
return Err(ProgramError::InvalidAccountData);
}
return Ok(edge);
}
}
Err(ProgramError::InvalidAccountData)
}
#[cfg(test)]
mod tests {
use super::*;
fn identity(_body: &mut [u8]) -> Result<(), ProgramError> {
Ok(())
}
#[test]
fn migration_edge_is_forward_detects_non_monotonic() {
let forward = MigrationEdge {
from_epoch: 1,
to_epoch: 2,
migrator: identity,
};
let backward = MigrationEdge {
from_epoch: 3,
to_epoch: 2,
migrator: identity,
};
let same = MigrationEdge {
from_epoch: 2,
to_epoch: 2,
migrator: identity,
};
assert!(forward.is_forward());
assert!(!backward.is_forward());
assert!(!same.is_forward());
}
#[test]
fn find_edge_returns_matching_edge() {
let edges = [
MigrationEdge {
from_epoch: 1,
to_epoch: 2,
migrator: identity,
},
MigrationEdge {
from_epoch: 2,
to_epoch: 3,
migrator: identity,
},
];
let e1 = find_edge(&edges, 1).expect("edge exists");
assert_eq!(e1.to_epoch, 2);
let e2 = find_edge(&edges, 2).expect("edge exists");
assert_eq!(e2.to_epoch, 3);
}
#[test]
fn find_edge_errs_on_missing_epoch() {
let edges = [MigrationEdge {
from_epoch: 1,
to_epoch: 2,
migrator: identity,
}];
assert!(find_edge(&edges, 5).is_err());
}
#[test]
fn find_edge_rejects_non_forward_edge() {
let edges = [MigrationEdge {
from_epoch: 3,
to_epoch: 2,
migrator: identity,
}];
assert!(find_edge(&edges, 3).is_err());
}
}