#[cfg(test)]
mod tests;
use crate::{
db::{
Db,
commit::{
CommitMarker, CommitRowOp, begin_commit_with_migration_state,
clear_migration_state_bytes, finish_commit, load_migration_state_bytes,
},
},
error::InternalError,
traits::CanisterKind,
};
const MAX_MIGRATION_STATE_BYTES: usize = 64 * 1024;
const MIGRATION_STATE_MAGIC: [u8; 2] = *b"MS";
const MIGRATION_STATE_VERSION_CURRENT: u8 = 1;
const MIGRATION_STATE_NONE_ROW_KEY_TAG: u8 = 0;
const MIGRATION_STATE_SOME_ROW_KEY_TAG: u8 = 1;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct MigrationCursor {
next_step: usize,
}
impl MigrationCursor {
#[must_use]
pub const fn start() -> Self {
Self { next_step: 0 }
}
#[must_use]
pub const fn next_step(self) -> usize {
self.next_step
}
const fn from_step(step_index: usize) -> Self {
Self {
next_step: step_index,
}
}
const fn advance(self) -> Self {
Self {
next_step: self.next_step.saturating_add(1),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct PersistedMigrationState {
migration_id: String,
migration_version: u64,
step_index: u64,
last_applied_row_key: Option<Vec<u8>>,
}
#[derive(Clone, Debug)]
pub struct MigrationRowOp {
pub entity_path: String,
pub key: Vec<u8>,
pub before: Option<Vec<u8>>,
pub after: Option<Vec<u8>>,
pub schema_fingerprint: [u8; 16],
}
impl MigrationRowOp {
#[must_use]
pub fn new(
entity_path: impl Into<String>,
key: Vec<u8>,
before: Option<Vec<u8>>,
after: Option<Vec<u8>>,
schema_fingerprint: [u8; 16],
) -> Self {
Self {
entity_path: entity_path.into(),
key,
before,
after,
schema_fingerprint,
}
}
}
impl TryFrom<MigrationRowOp> for CommitRowOp {
type Error = InternalError;
fn try_from(op: MigrationRowOp) -> Result<Self, Self::Error> {
Self::try_new_bytes(
op.entity_path,
op.key.as_slice(),
op.before,
op.after,
op.schema_fingerprint,
)
}
}
#[derive(Clone, Debug)]
pub struct MigrationStep {
name: String,
row_ops: Vec<CommitRowOp>,
}
impl MigrationStep {
pub fn from_row_ops(
name: impl Into<String>,
row_ops: Vec<MigrationRowOp>,
) -> Result<Self, InternalError> {
let commit_row_ops = row_ops
.into_iter()
.map(CommitRowOp::try_from)
.collect::<Result<Vec<_>, _>>()?;
Self::new(name, commit_row_ops)
}
pub(in crate::db) fn new(
name: impl Into<String>,
row_ops: Vec<CommitRowOp>,
) -> Result<Self, InternalError> {
let name = name.into();
validate_non_empty_label(name.as_str(), "migration step name")?;
if row_ops.is_empty() {
return Err(InternalError::migration_step_row_ops_required(&name));
}
Ok(Self { name, row_ops })
}
#[must_use]
pub const fn name(&self) -> &str {
self.name.as_str()
}
#[must_use]
pub const fn row_op_count(&self) -> usize {
self.row_ops.len()
}
}
#[derive(Clone, Debug)]
pub struct MigrationPlan {
id: String,
version: u64,
steps: Vec<MigrationStep>,
}
impl MigrationPlan {
pub fn new(
id: impl Into<String>,
version: u64,
steps: Vec<MigrationStep>,
) -> Result<Self, InternalError> {
let id = id.into();
validate_non_empty_label(id.as_str(), "migration plan id")?;
if version == 0 {
return Err(InternalError::migration_plan_version_required(&id));
}
if steps.is_empty() {
return Err(InternalError::migration_plan_steps_required(&id));
}
Ok(Self { id, version, steps })
}
#[must_use]
pub const fn id(&self) -> &str {
self.id.as_str()
}
#[must_use]
pub const fn version(&self) -> u64 {
self.version
}
#[must_use]
pub const fn len(&self) -> usize {
self.steps.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.steps.is_empty()
}
fn step_at(&self, index: usize) -> Result<&MigrationStep, InternalError> {
self.steps.get(index).ok_or_else(|| {
InternalError::migration_cursor_out_of_bounds(
self.id(),
self.version(),
index,
self.len(),
)
})
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MigrationRunState {
Complete,
NeedsResume,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct MigrationRunOutcome {
cursor: MigrationCursor,
applied_steps: usize,
applied_row_ops: usize,
state: MigrationRunState,
}
impl MigrationRunOutcome {
const fn new(
cursor: MigrationCursor,
applied_steps: usize,
applied_row_ops: usize,
state: MigrationRunState,
) -> Self {
Self {
cursor,
applied_steps,
applied_row_ops,
state,
}
}
#[must_use]
pub const fn cursor(self) -> MigrationCursor {
self.cursor
}
#[must_use]
pub const fn applied_steps(self) -> usize {
self.applied_steps
}
#[must_use]
pub const fn applied_row_ops(self) -> usize {
self.applied_row_ops
}
#[must_use]
pub const fn state(self) -> MigrationRunState {
self.state
}
}
pub(in crate::db) fn execute_migration_plan<C: CanisterKind>(
db: &Db<C>,
plan: &MigrationPlan,
max_steps: usize,
) -> Result<MigrationRunOutcome, InternalError> {
if max_steps == 0 {
return Err(InternalError::migration_execution_requires_max_steps(
plan.id(),
));
}
db.ensure_recovered_state()?;
let mut next_cursor = load_durable_cursor_for_plan(plan)?;
let mut applied_steps = 0usize;
let mut applied_row_ops = 0usize;
while applied_steps < max_steps {
if next_cursor.next_step() >= plan.len() {
break;
}
let step_index = next_cursor.next_step();
let step = plan.step_at(step_index)?;
let next_cursor_after_step = next_cursor.advance();
let next_state_bytes =
encode_durable_cursor_state(plan, next_cursor_after_step, step.row_ops.last())?;
execute_migration_step(db, plan, step_index, step, next_state_bytes)?;
applied_steps = applied_steps.saturating_add(1);
applied_row_ops = applied_row_ops.saturating_add(step.row_op_count());
next_cursor = next_cursor_after_step;
}
let state = if next_cursor.next_step() == plan.len() {
clear_migration_state_bytes()?;
MigrationRunState::Complete
} else {
MigrationRunState::NeedsResume
};
Ok(MigrationRunOutcome::new(
next_cursor,
applied_steps,
applied_row_ops,
state,
))
}
fn load_durable_cursor_for_plan(plan: &MigrationPlan) -> Result<MigrationCursor, InternalError> {
let Some(bytes) = load_migration_state_bytes()? else {
return Ok(MigrationCursor::start());
};
let state = decode_persisted_migration_state(&bytes)?;
if state.migration_id != plan.id() || state.migration_version != plan.version() {
return Err(InternalError::migration_in_progress_conflict(
plan.id(),
plan.version(),
&state.migration_id,
state.migration_version,
));
}
let step_index = usize::try_from(state.step_index).map_err(|_| {
InternalError::migration_persisted_step_index_invalid_usize(
plan.id(),
plan.version(),
state.step_index,
)
})?;
if step_index > plan.len() {
return Err(InternalError::migration_persisted_step_index_out_of_bounds(
plan.id(),
plan.version(),
step_index,
plan.len(),
));
}
if step_index == plan.len() {
clear_migration_state_bytes()?;
}
Ok(MigrationCursor::from_step(step_index))
}
fn encode_durable_cursor_state(
plan: &MigrationPlan,
cursor: MigrationCursor,
last_applied_row_op: Option<&CommitRowOp>,
) -> Result<Vec<u8>, InternalError> {
let step_index = u64::try_from(cursor.next_step()).map_err(|_| {
InternalError::migration_next_step_index_u64_required(plan.id(), plan.version())
})?;
let state = PersistedMigrationState {
migration_id: plan.id().to_string(),
migration_version: plan.version(),
step_index,
last_applied_row_key: last_applied_row_op.map(|row_op| row_op.key.as_bytes().to_vec()),
};
encode_persisted_migration_state(&state)
}
fn decode_persisted_migration_state(
bytes: &[u8],
) -> Result<PersistedMigrationState, InternalError> {
if bytes.len() > MAX_MIGRATION_STATE_BYTES {
return Err(InternalError::serialize_corruption(format!(
"migration state decode failed: payload size {} exceeds limit {MAX_MIGRATION_STATE_BYTES}",
bytes.len(),
)));
}
let mut cursor = bytes;
decode_migration_state_magic(&mut cursor)?;
let format_version = decode_migration_state_u8(&mut cursor, "format version")?;
validate_migration_state_format_version(format_version)?;
let migration_id = decode_migration_state_string(&mut cursor, "migration_id")?;
let migration_version = decode_migration_state_u64(&mut cursor, "migration_version")?;
let step_index = decode_migration_state_u64(&mut cursor, "step_index")?;
let last_applied_row_key =
decode_migration_state_optional_bytes(&mut cursor, "last_applied_row_key")?;
if !cursor.is_empty() {
return Err(InternalError::serialize_corruption(
"migration state decode failed: trailing bytes",
));
}
Ok(PersistedMigrationState {
migration_id,
migration_version,
step_index,
last_applied_row_key,
})
}
fn encode_persisted_migration_state(
state: &PersistedMigrationState,
) -> Result<Vec<u8>, InternalError> {
let row_key_len = state.last_applied_row_key.as_ref().map_or(0usize, Vec::len);
let encoded_len = MIGRATION_STATE_MAGIC
.len()
.saturating_add(1)
.saturating_add(4)
.saturating_add(state.migration_id.len())
.saturating_add(8)
.saturating_add(8)
.saturating_add(1)
.saturating_add(if state.last_applied_row_key.is_some() {
4usize.saturating_add(row_key_len)
} else {
0
});
if encoded_len > MAX_MIGRATION_STATE_BYTES {
return Err(InternalError::migration_state_serialize_failed(format!(
"payload size {encoded_len} exceeds limit {MAX_MIGRATION_STATE_BYTES}",
)));
}
let migration_id_len = u32::try_from(state.migration_id.len()).map_err(|_| {
InternalError::migration_state_serialize_failed("migration_id exceeds u32 length")
})?;
let row_key_len_u32 = u32::try_from(row_key_len).map_err(|_| {
InternalError::migration_state_serialize_failed("last_applied_row_key exceeds u32 length")
})?;
let mut encoded = Vec::with_capacity(encoded_len);
encoded.extend_from_slice(&MIGRATION_STATE_MAGIC);
encoded.push(MIGRATION_STATE_VERSION_CURRENT);
encoded.extend_from_slice(&migration_id_len.to_be_bytes());
encoded.extend_from_slice(state.migration_id.as_bytes());
encoded.extend_from_slice(&state.migration_version.to_be_bytes());
encoded.extend_from_slice(&state.step_index.to_be_bytes());
match state.last_applied_row_key.as_ref() {
Some(row_key) => {
encoded.push(MIGRATION_STATE_SOME_ROW_KEY_TAG);
encoded.extend_from_slice(&row_key_len_u32.to_be_bytes());
encoded.extend_from_slice(row_key);
}
None => encoded.push(MIGRATION_STATE_NONE_ROW_KEY_TAG),
}
Ok(encoded)
}
fn decode_migration_state_magic(bytes: &mut &[u8]) -> Result<(), InternalError> {
let magic = take_migration_state_bytes(bytes, MIGRATION_STATE_MAGIC.len(), "magic")?;
if magic != MIGRATION_STATE_MAGIC {
return Err(InternalError::serialize_corruption(
"migration state decode failed: invalid magic",
));
}
Ok(())
}
fn decode_migration_state_u8(bytes: &mut &[u8], label: &'static str) -> Result<u8, InternalError> {
Ok(take_migration_state_bytes(bytes, 1, label)?[0])
}
fn decode_migration_state_u64(
bytes: &mut &[u8],
label: &'static str,
) -> Result<u64, InternalError> {
let raw = take_migration_state_bytes(bytes, 8, label)?;
Ok(u64::from_be_bytes([
raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
]))
}
fn decode_migration_state_string(
bytes: &mut &[u8],
label: &'static str,
) -> Result<String, InternalError> {
let raw = decode_migration_state_length_prefixed_bytes(bytes, label)?;
String::from_utf8(raw.to_vec()).map_err(|_| {
InternalError::serialize_corruption(format!(
"migration state decode failed: {label} is not valid UTF-8",
))
})
}
fn decode_migration_state_optional_bytes(
bytes: &mut &[u8],
label: &'static str,
) -> Result<Option<Vec<u8>>, InternalError> {
let tag = decode_migration_state_u8(bytes, label)?;
match tag {
MIGRATION_STATE_NONE_ROW_KEY_TAG => Ok(None),
MIGRATION_STATE_SOME_ROW_KEY_TAG => Ok(Some(
decode_migration_state_length_prefixed_bytes(bytes, label)?.to_vec(),
)),
_ => Err(InternalError::serialize_corruption(format!(
"migration state decode failed: invalid {label} tag {tag}",
))),
}
}
fn decode_migration_state_length_prefixed_bytes<'a>(
bytes: &mut &'a [u8],
label: &'static str,
) -> Result<&'a [u8], InternalError> {
let raw_len = take_migration_state_bytes(bytes, 4, label)?;
let len = usize::try_from(u32::from_be_bytes([
raw_len[0], raw_len[1], raw_len[2], raw_len[3],
]))
.map_err(|_| {
InternalError::serialize_corruption(format!(
"migration state decode failed: {label} length out of range",
))
})?;
take_migration_state_bytes(bytes, len, label)
}
fn take_migration_state_bytes<'a>(
bytes: &mut &'a [u8],
len: usize,
label: &'static str,
) -> Result<&'a [u8], InternalError> {
if bytes.len() < len {
return Err(InternalError::serialize_corruption(format!(
"migration state decode failed: truncated {label}",
)));
}
let (head, tail) = bytes.split_at(len);
*bytes = tail;
Ok(head)
}
fn validate_migration_state_format_version(format_version: u8) -> Result<(), InternalError> {
if format_version == MIGRATION_STATE_VERSION_CURRENT {
return Ok(());
}
Err(InternalError::serialize_incompatible_persisted_format(
format!(
"migration state format version {format_version} is unsupported by runtime version {MIGRATION_STATE_VERSION_CURRENT}",
),
))
}
fn execute_migration_step<C: CanisterKind>(
db: &Db<C>,
plan: &MigrationPlan,
step_index: usize,
step: &MigrationStep,
next_state_bytes: Vec<u8>,
) -> Result<(), InternalError> {
let marker = CommitMarker::new(step.row_ops.clone())
.map_err(|err| annotate_step_error(plan, step_index, step.name(), err))?;
let commit = begin_commit_with_migration_state(marker, next_state_bytes)
.map_err(|err| annotate_step_error(plan, step_index, step.name(), err))?;
finish_commit(commit, |_| apply_marker_row_ops(db, &step.row_ops))
.map_err(|err| annotate_step_error(plan, step_index, step.name(), err))?;
Ok(())
}
fn apply_marker_row_ops<C: CanisterKind>(
db: &Db<C>,
row_ops: &[CommitRowOp],
) -> Result<(), InternalError> {
let mut prepared = Vec::with_capacity(row_ops.len());
for row_op in row_ops {
prepared.push(db.prepare_row_commit_op(row_op)?);
}
for prepared_op in prepared {
prepared_op.apply();
}
Ok(())
}
fn annotate_step_error(
plan: &MigrationPlan,
step_index: usize,
step_name: &str,
err: InternalError,
) -> InternalError {
let source_message = err.message().to_string();
err.with_message(format!(
"migration '{}' step {} ('{}') failed: {}",
plan.id(),
step_index,
step_name,
source_message,
))
}
fn validate_non_empty_label(value: &str, label: &str) -> Result<(), InternalError> {
if value.trim().is_empty() {
return Err(InternalError::migration_label_empty(label));
}
Ok(())
}