use crate::gba::cartridge::save_type::SaveType;
use serde::{Deserialize, Serialize};
pub const FLASH_BANK_SIZE: usize = 64 * 1024;
pub const FLASH_SECTOR_SIZE: usize = 4 * 1024;
const CMD_ADDR_1: usize = 0x5555;
const CMD_ADDR_2: usize = 0x2AAA;
const ID_64K: (u8, u8) = (0x1F, 0x3D);
const ID_128K: (u8, u8) = (0x62, 0x13);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum FlashState {
Ready,
Cmd1,
Cmd2,
EraseCmd1,
EraseCmd2,
AwaitWriteData,
AwaitBankSelect,
AwaitErase,
}
#[derive(Debug, Clone)]
pub struct Flash {
data: Vec<u8>,
bank: usize,
id_mode: bool,
id: (u8, u8),
state: FlashState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlashStateSnapshot {
data: Vec<u8>,
bank: usize,
id_mode: bool,
id: (u8, u8),
state: FlashState,
}
impl Flash {
pub fn new_64k() -> Self {
Self {
data: vec![0xFF; FLASH_BANK_SIZE],
bank: 0,
id_mode: false,
id: ID_64K,
state: FlashState::Ready,
}
}
pub fn new_128k() -> Self {
Self {
data: vec![0xFF; FLASH_BANK_SIZE * 2],
bank: 0,
id_mode: false,
id: ID_128K,
state: FlashState::Ready,
}
}
pub(crate) fn new(save_type: SaveType) -> Self {
match save_type {
SaveType::Flash64K => Self::new_64k(),
SaveType::Flash128K => Self::new_128k(),
other => panic!("Flash::new called with non-Flash save type: {other:?}"),
}
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn has_two_banks(&self) -> bool {
self.data.len() > FLASH_BANK_SIZE
}
pub fn read(&self, offset: usize) -> u8 {
let masked = offset & (FLASH_BANK_SIZE - 1);
if self.id_mode {
match masked {
0x0000 => return self.id.0,
0x0001 => return self.id.1,
_ => {}
}
}
self.data[self.bank * FLASH_BANK_SIZE + masked]
}
pub fn write(&mut self, offset: usize, value: u8) {
let masked = offset & (FLASH_BANK_SIZE - 1);
match self.state {
FlashState::Ready => {
if masked == CMD_ADDR_1 && value == 0xAA {
self.state = FlashState::Cmd1;
}
}
FlashState::Cmd1 => {
if masked == CMD_ADDR_2 && value == 0x55 {
self.state = FlashState::Cmd2;
} else {
self.state = FlashState::Ready;
}
}
FlashState::Cmd2 => {
if masked == CMD_ADDR_1 {
match value {
0x90 => {
self.id_mode = true;
self.state = FlashState::Ready;
}
0xF0 => {
self.id_mode = false;
self.state = FlashState::Ready;
}
0x80 => self.state = FlashState::EraseCmd1,
0xA0 => self.state = FlashState::AwaitWriteData,
0xB0 if self.has_two_banks() => {
self.state = FlashState::AwaitBankSelect;
}
_ => self.state = FlashState::Ready,
}
} else {
self.state = FlashState::Ready;
}
}
FlashState::EraseCmd1 => {
if masked == CMD_ADDR_1 && value == 0xAA {
self.state = FlashState::EraseCmd2;
} else {
self.state = FlashState::Ready;
}
}
FlashState::EraseCmd2 => {
if masked == CMD_ADDR_2 && value == 0x55 {
self.state = FlashState::AwaitErase;
} else {
self.state = FlashState::Ready;
}
}
FlashState::AwaitErase => match (masked, value) {
(CMD_ADDR_1, 0x10) => {
for byte in self.data.iter_mut() {
*byte = 0xFF;
}
self.state = FlashState::Ready;
}
(addr, 0x30) if addr & 0x0FFF == 0 => {
let base = self.bank * FLASH_BANK_SIZE + addr;
for byte in &mut self.data[base..base + FLASH_SECTOR_SIZE] {
*byte = 0xFF;
}
self.state = FlashState::Ready;
}
_ => self.state = FlashState::Ready,
},
FlashState::AwaitWriteData => {
let idx = self.bank * FLASH_BANK_SIZE + masked;
self.data[idx] &= value;
self.state = FlashState::Ready;
}
FlashState::AwaitBankSelect => {
if masked == 0x0000 {
self.bank = (value as usize) & 0x1;
}
self.state = FlashState::Ready;
}
}
}
pub fn snapshot(&self) -> &[u8] {
&self.data
}
pub fn capture_state(&self) -> FlashStateSnapshot {
FlashStateSnapshot {
data: self.data.clone(),
bank: self.bank,
id_mode: self.id_mode,
id: self.id,
state: self.state,
}
}
pub fn restore_state(&mut self, state: &FlashStateSnapshot) -> Result<(), String> {
if state.data.len() != FLASH_BANK_SIZE && state.data.len() != FLASH_BANK_SIZE * 2 {
return Err(format!(
"Flash save-state length mismatch: expected {FLASH_BANK_SIZE} or {}, got {}",
FLASH_BANK_SIZE * 2,
state.data.len()
));
}
let bank_count = state.data.len() / FLASH_BANK_SIZE;
if state.bank >= bank_count {
return Err(format!(
"Flash save-state bank out of range: bank {} for {bank_count} banks",
state.bank
));
}
if self.data.len() != state.data.len() {
return Err(format!(
"Flash save-state variant mismatch: live={} bytes, state={} bytes",
self.data.len(),
state.data.len()
));
}
self.data.clone_from(&state.data);
self.bank = state.bank;
self.id_mode = state.id_mode;
self.id = state.id;
self.state = state.state;
Ok(())
}
pub fn restore(&mut self, data: &[u8]) {
let n = data.len().min(self.data.len());
self.data[..n].copy_from_slice(&data[..n]);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn magic_prefix(flash: &mut Flash) {
flash.write(0x5555, 0xAA);
flash.write(0x2AAA, 0x55);
}
#[test]
fn fresh_flash_reads_as_0xff() {
let flash = Flash::new(SaveType::Flash64K);
assert_eq!(flash.read(0), 0xFF);
assert_eq!(flash.read(0xFFFF), 0xFF);
}
#[test]
fn id_mode_returns_chip_ids_for_64k() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0x90); assert_eq!(flash.read(0x0000), ID_64K.0);
assert_eq!(flash.read(0x0001), ID_64K.1);
magic_prefix(&mut flash);
flash.write(0x5555, 0xF0);
assert_eq!(flash.read(0x0000), 0xFF);
}
#[test]
fn id_mode_returns_chip_ids_for_128k() {
let mut flash = Flash::new(SaveType::Flash128K);
magic_prefix(&mut flash);
flash.write(0x5555, 0x90);
assert_eq!(flash.read(0x0000), ID_128K.0);
assert_eq!(flash.read(0x0001), ID_128K.1);
}
#[test]
fn write_byte_command_programs_single_byte() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0); flash.write(0x1234, 0x42);
assert_eq!(flash.read(0x1234), 0x42);
assert_eq!(flash.read(0x1233), 0xFF);
assert_eq!(flash.read(0x1235), 0xFF);
}
#[test]
fn programming_only_clears_bits() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x100, 0xF0);
assert_eq!(flash.read(0x100), 0xF0);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x100, 0x33);
assert_eq!(flash.read(0x100), 0x30);
}
#[test]
fn sector_erase_clears_only_target_sector() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x1000, 0x11);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x2000, 0x22);
magic_prefix(&mut flash);
flash.write(0x5555, 0x80);
magic_prefix(&mut flash);
flash.write(0x1000, 0x30);
assert_eq!(flash.read(0x1000), 0xFF, "target sector must be erased");
assert_eq!(
flash.read(0x1FFF),
0xFF,
"target sector tail must be erased"
);
assert_eq!(flash.read(0x2000), 0x22, "other sector must be untouched");
}
#[test]
fn chip_erase_clears_all_data() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x0042, 0x77);
magic_prefix(&mut flash);
flash.write(0x5555, 0x80);
magic_prefix(&mut flash);
flash.write(0x5555, 0x10);
assert_eq!(flash.read(0x0042), 0xFF);
}
#[test]
fn bank_switch_isolates_two_banks_for_128k() {
let mut flash = Flash::new(SaveType::Flash128K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x0000, 0x11);
magic_prefix(&mut flash);
flash.write(0x5555, 0xB0);
flash.write(0x0000, 0x01);
assert_eq!(flash.read(0x0000), 0xFF, "bank 1 must read as erased");
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x0000, 0x22);
assert_eq!(flash.read(0x0000), 0x22);
magic_prefix(&mut flash);
flash.write(0x5555, 0xB0);
flash.write(0x0000, 0x00);
assert_eq!(flash.read(0x0000), 0x11);
}
#[test]
fn bank_switch_ignored_on_64k_chip() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xB0);
flash.write(0x0000, 0x01);
assert_eq!(flash.read(0x0000), 0xFF);
}
#[test]
fn snapshot_restore_round_trip() {
let mut a = Flash::new(SaveType::Flash64K);
magic_prefix(&mut a);
a.write(0x5555, 0xA0);
a.write(0x0500, 0xCD);
let snap = a.snapshot().to_vec();
let mut b = Flash::new(SaveType::Flash64K);
b.restore(&snap);
assert_eq!(b.read(0x0500), 0xCD);
assert_eq!(b.snapshot().len(), 64 * 1024);
}
#[test]
fn invalid_command_resets_state_machine() {
let mut flash = Flash::new(SaveType::Flash64K);
flash.write(0x5555, 0xAA);
flash.write(0x0000, 0x00); flash.write(0x1234, 0x99);
assert_eq!(flash.read(0x1234), 0xFF);
}
#[test]
fn save_state_restores_id_mode() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0x90);
let state = flash.capture_state();
magic_prefix(&mut flash);
flash.write(0x5555, 0xF0);
flash.restore_state(&state).expect("restore Flash state");
assert_eq!(flash.read(0x0000), ID_64K.0);
assert_eq!(flash.read(0x0001), ID_64K.1);
}
#[test]
fn save_state_restores_await_write_data_command_state() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
let state = flash.capture_state();
let mut restored = Flash::new(SaveType::Flash64K);
restored.restore_state(&state).expect("restore Flash state");
restored.write(0x2468, 0x5A);
assert_eq!(restored.read(0x2468), 0x5A);
}
#[test]
fn save_state_restores_128k_bank_and_data() {
let mut flash = Flash::new(SaveType::Flash128K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x0000, 0x11);
magic_prefix(&mut flash);
flash.write(0x5555, 0xB0);
flash.write(0x0000, 0x01);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x0000, 0x22);
let state = flash.capture_state();
let mut restored = Flash::new(SaveType::Flash128K);
restored.restore_state(&state).expect("restore Flash state");
assert_eq!(restored.read(0x0000), 0x22);
magic_prefix(&mut restored);
restored.write(0x5555, 0xB0);
restored.write(0x0000, 0x00);
assert_eq!(restored.read(0x0000), 0x11);
}
#[test]
fn save_state_roundtrips_through_json() {
let mut flash = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash);
flash.write(0x5555, 0xA0);
flash.write(0x3210, 0x66);
let bytes = serde_json::to_vec(&flash.capture_state()).expect("serialize Flash state");
let decoded: FlashStateSnapshot =
serde_json::from_slice(&bytes).expect("deserialize Flash state");
let mut restored = Flash::new(SaveType::Flash64K);
restored
.restore_state(&decoded)
.expect("restore Flash state");
assert_eq!(restored.read(0x3210), 0x66);
}
#[test]
fn save_state_rejects_flash_variant_mismatch_without_mutating() {
let mut flash64 = Flash::new(SaveType::Flash64K);
magic_prefix(&mut flash64);
flash64.write(0x5555, 0xA0);
flash64.write(0x3210, 0x66);
let state = flash64.capture_state();
let mut flash128 = Flash::new(SaveType::Flash128K);
magic_prefix(&mut flash128);
flash128.write(0x5555, 0xB0);
flash128.write(0x0000, 0x01);
magic_prefix(&mut flash128);
flash128.write(0x5555, 0xA0);
flash128.write(0x3210, 0x99);
let result = flash128.restore_state(&state);
assert!(result.is_err());
assert_eq!(flash128.size(), FLASH_BANK_SIZE * 2);
assert_eq!(flash128.read(0x3210), 0x99);
}
}