use crate::{HopeAgent, LearningConfig, LearningEngine};
use serde::{Deserialize, Serialize};
use std::fs;
use std::io::{Read, Write};
use std::path::Path;
#[derive(Debug)]
pub enum PersistenceError {
Io(std::io::Error),
Serialization(String),
Deserialization(String),
InvalidFormat(String),
Compression(String),
}
impl std::fmt::Display for PersistenceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PersistenceError::Io(e) => write!(f, "IO error: {}", e),
PersistenceError::Serialization(e) => write!(f, "Serialization error: {}", e),
PersistenceError::Deserialization(e) => write!(f, "Deserialization error: {}", e),
PersistenceError::InvalidFormat(e) => write!(f, "Invalid format: {}", e),
PersistenceError::Compression(e) => write!(f, "Compression error: {}", e),
}
}
}
impl std::error::Error for PersistenceError {}
impl From<std::io::Error> for PersistenceError {
fn from(e: std::io::Error) -> Self {
PersistenceError::Io(e)
}
}
impl From<serde_json::Error> for PersistenceError {
fn from(e: serde_json::Error) -> Self {
PersistenceError::Serialization(e.to_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PersistenceFormat {
#[default]
Json,
Binary,
MessagePack,
}
#[derive(Debug, Clone)]
pub struct PersistenceOptions {
pub format: PersistenceFormat,
pub pretty: bool,
pub compress: bool,
}
impl Default for PersistenceOptions {
fn default() -> Self {
Self {
format: PersistenceFormat::Json,
pretty: true,
compress: false,
}
}
}
impl PersistenceOptions {
pub fn compact() -> Self {
Self {
format: PersistenceFormat::Binary,
pretty: false,
compress: true,
}
}
pub fn readable() -> Self {
Self {
format: PersistenceFormat::Json,
pretty: true,
compress: false,
}
}
}
pub trait AgentPersistence: Sized {
fn save_to_file(&self, path: &Path) -> Result<(), PersistenceError>;
fn save_to_file_with_options(
&self,
path: &Path,
options: &PersistenceOptions,
) -> Result<(), PersistenceError>;
fn load_from_file(path: &Path) -> Result<Self, PersistenceError>;
fn load_from_file_with_options(
path: &Path,
options: &PersistenceOptions,
) -> Result<Self, PersistenceError>;
fn to_bytes(&self) -> Vec<u8>;
fn to_bytes_with_options(
&self,
options: &PersistenceOptions,
) -> Result<Vec<u8>, PersistenceError>;
fn from_bytes(bytes: &[u8]) -> Result<Self, PersistenceError>;
fn from_bytes_with_options(
bytes: &[u8],
options: &PersistenceOptions,
) -> Result<Self, PersistenceError>;
}
impl AgentPersistence for HopeAgent {
fn save_to_file(&self, path: &Path) -> Result<(), PersistenceError> {
self.save_to_file_with_options(path, &PersistenceOptions::default())
}
fn save_to_file_with_options(
&self,
path: &Path,
options: &PersistenceOptions,
) -> Result<(), PersistenceError> {
let state = self.save_state();
let bytes = serialize_with_options(&state, options)?;
let mut file = fs::File::create(path)?;
file.write_all(&bytes)?;
log::info!("Saved agent state to {:?}", path);
Ok(())
}
fn load_from_file(path: &Path) -> Result<Self, PersistenceError> {
Self::load_from_file_with_options(path, &PersistenceOptions::default())
}
fn load_from_file_with_options(
path: &Path,
options: &PersistenceOptions,
) -> Result<Self, PersistenceError> {
let mut file = fs::File::open(path)?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)?;
let state: crate::hope_agent::SerializedState = deserialize_with_options(&bytes, options)?;
let mut agent = HopeAgent::new(state.config.clone());
agent.load_state(state);
log::info!("Loaded agent state from {:?}", path);
Ok(agent)
}
fn to_bytes(&self) -> Vec<u8> {
self.to_bytes_with_options(&PersistenceOptions::default())
.unwrap_or_default()
}
fn to_bytes_with_options(
&self,
options: &PersistenceOptions,
) -> Result<Vec<u8>, PersistenceError> {
let state = self.save_state();
serialize_with_options(&state, options)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, PersistenceError> {
Self::from_bytes_with_options(bytes, &PersistenceOptions::default())
}
fn from_bytes_with_options(
bytes: &[u8],
options: &PersistenceOptions,
) -> Result<Self, PersistenceError> {
let state: crate::hope_agent::SerializedState = deserialize_with_options(bytes, options)?;
let mut agent = HopeAgent::new(state.config.clone());
agent.load_state(state);
Ok(agent)
}
}
impl AgentPersistence for LearningEngine {
fn save_to_file(&self, path: &Path) -> Result<(), PersistenceError> {
self.save_to_file_with_options(path, &PersistenceOptions::default())
}
fn save_to_file_with_options(
&self,
path: &Path,
options: &PersistenceOptions,
) -> Result<(), PersistenceError> {
let bytes = serialize_with_options(self, options)?;
let mut file = fs::File::create(path)?;
file.write_all(&bytes)?;
log::info!("Saved learning engine to {:?}", path);
Ok(())
}
fn load_from_file(path: &Path) -> Result<Self, PersistenceError> {
Self::load_from_file_with_options(path, &PersistenceOptions::default())
}
fn load_from_file_with_options(
path: &Path,
options: &PersistenceOptions,
) -> Result<Self, PersistenceError> {
let mut file = fs::File::open(path)?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)?;
let engine = deserialize_with_options(&bytes, options)?;
log::info!("Loaded learning engine from {:?}", path);
Ok(engine)
}
fn to_bytes(&self) -> Vec<u8> {
self.to_bytes_with_options(&PersistenceOptions::default())
.unwrap_or_default()
}
fn to_bytes_with_options(
&self,
options: &PersistenceOptions,
) -> Result<Vec<u8>, PersistenceError> {
serialize_with_options(self, options)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, PersistenceError> {
Self::from_bytes_with_options(bytes, &PersistenceOptions::default())
}
fn from_bytes_with_options(
bytes: &[u8],
options: &PersistenceOptions,
) -> Result<Self, PersistenceError> {
deserialize_with_options(bytes, options)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningSnapshot {
pub config: LearningConfig,
pub total_updates: u64,
pub q_values: Vec<(String, String, f64)>,
pub episode_count: u64,
}
impl From<&LearningEngine> for LearningSnapshot {
fn from(engine: &LearningEngine) -> Self {
Self {
config: engine.config().clone(),
total_updates: engine.total_updates(),
q_values: Vec::new(), episode_count: engine.total_episodes(),
}
}
}
fn serialize_with_options<T: Serialize>(
value: &T,
options: &PersistenceOptions,
) -> Result<Vec<u8>, PersistenceError> {
let bytes = match options.format {
PersistenceFormat::Json => {
if options.pretty {
serde_json::to_vec_pretty(value)?
} else {
serde_json::to_vec(value)?
}
}
PersistenceFormat::Binary => {
serde_json::to_vec(value)?
}
PersistenceFormat::MessagePack => {
serde_json::to_vec(value)?
}
};
if options.compress {
compress_bytes(&bytes)
} else {
Ok(bytes)
}
}
fn deserialize_with_options<T: for<'de> Deserialize<'de>>(
bytes: &[u8],
options: &PersistenceOptions,
) -> Result<T, PersistenceError> {
let bytes = if options.compress {
decompress_bytes(bytes)?
} else {
bytes.to_vec()
};
match options.format {
PersistenceFormat::Json | PersistenceFormat::Binary | PersistenceFormat::MessagePack => {
serde_json::from_slice(&bytes)
.map_err(|e| PersistenceError::Deserialization(e.to_string()))
}
}
}
fn compress_bytes(bytes: &[u8]) -> Result<Vec<u8>, PersistenceError> {
let mut result = vec![0x1F, 0x8B]; result.extend_from_slice(bytes);
Ok(result)
}
fn decompress_bytes(bytes: &[u8]) -> Result<Vec<u8>, PersistenceError> {
if bytes.len() >= 2 && bytes[0] == 0x1F && bytes[1] == 0x8B {
Ok(bytes[2..].to_vec())
} else {
Ok(bytes.to_vec())
}
}
pub struct CheckpointManager {
checkpoint_dir: std::path::PathBuf,
max_checkpoints: usize,
checkpoint_interval: u64,
last_checkpoint: u64,
}
impl CheckpointManager {
pub fn new(checkpoint_dir: &Path, max_checkpoints: usize) -> Self {
Self {
checkpoint_dir: checkpoint_dir.to_path_buf(),
max_checkpoints,
checkpoint_interval: 1000,
last_checkpoint: 0,
}
}
pub fn with_interval(mut self, interval: u64) -> Self {
self.checkpoint_interval = interval;
self
}
pub fn should_checkpoint(&self, current_step: u64) -> bool {
current_step - self.last_checkpoint >= self.checkpoint_interval
}
pub fn save_checkpoint(
&mut self,
agent: &HopeAgent,
step: u64,
) -> Result<(), PersistenceError> {
fs::create_dir_all(&self.checkpoint_dir)?;
let checkpoint_path = self
.checkpoint_dir
.join(format!("checkpoint_{}.json", step));
agent.save_to_file(&checkpoint_path)?;
self.last_checkpoint = step;
self.cleanup_old_checkpoints()?;
log::info!("Saved checkpoint at step {}", step);
Ok(())
}
pub fn load_latest_checkpoint(&self) -> Result<HopeAgent, PersistenceError> {
let checkpoints = self.list_checkpoints()?;
if checkpoints.is_empty() {
return Err(PersistenceError::InvalidFormat(
"No checkpoints found".to_string(),
));
}
let latest = checkpoints.last().unwrap();
HopeAgent::load_from_file(latest)
}
fn list_checkpoints(&self) -> Result<Vec<std::path::PathBuf>, PersistenceError> {
if !self.checkpoint_dir.exists() {
return Ok(Vec::new());
}
let mut checkpoints = Vec::new();
for entry in fs::read_dir(&self.checkpoint_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
if name.starts_with("checkpoint_") {
checkpoints.push(path);
}
}
}
}
checkpoints.sort();
Ok(checkpoints)
}
fn cleanup_old_checkpoints(&self) -> Result<(), PersistenceError> {
let mut checkpoints = self.list_checkpoints()?;
while checkpoints.len() > self.max_checkpoints {
if let Some(old_checkpoint) = checkpoints.first() {
fs::remove_file(old_checkpoint)?;
log::debug!("Removed old checkpoint: {:?}", old_checkpoint);
}
checkpoints.remove(0);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{HopeAgent, Observation};
use std::path::PathBuf;
fn temp_path(name: &str) -> PathBuf {
let mut path = std::env::temp_dir();
path.push(format!("hope_agents_test_{}", name));
path
}
#[test]
fn test_save_and_load_hope_agent() {
let mut agent = HopeAgent::with_default_config();
for i in 0..5 {
let obs = Observation::sensor("temp", 20.0 + i as f64);
agent.step(obs);
}
let path = temp_path("agent_save_load.json");
agent.save_to_file(&path).unwrap();
assert!(path.exists());
let loaded_agent = HopeAgent::load_from_file(&path).unwrap();
assert_eq!(
loaded_agent.get_statistics().total_steps,
agent.get_statistics().total_steps
);
let _ = fs::remove_file(&path);
}
#[test]
fn test_save_with_different_options() {
let agent = HopeAgent::with_default_config();
let path = temp_path("agent_compact.bin");
let options = PersistenceOptions::compact();
agent.save_to_file_with_options(&path, &options).unwrap();
assert!(path.exists());
let _loaded = HopeAgent::load_from_file_with_options(&path, &options).unwrap();
let _ = fs::remove_file(&path);
}
#[test]
fn test_to_bytes_and_from_bytes() {
let mut agent = HopeAgent::with_default_config();
let obs = Observation::sensor("temp", 25.0);
agent.step(obs);
let bytes = agent.to_bytes();
assert!(!bytes.is_empty());
let loaded_agent = HopeAgent::from_bytes(&bytes).unwrap();
assert_eq!(
loaded_agent.get_statistics().total_steps,
agent.get_statistics().total_steps
);
}
#[test]
fn test_learning_engine_persistence() {
let engine = LearningEngine::new(LearningConfig::default());
let path = temp_path("learning_engine.json");
engine.save_to_file(&path).unwrap();
assert!(path.exists());
let _loaded_engine = LearningEngine::load_from_file(&path).unwrap();
let _ = fs::remove_file(&path);
}
#[test]
fn test_checkpoint_manager() {
let checkpoint_dir = temp_path("checkpoints");
let mut manager = CheckpointManager::new(&checkpoint_dir, 3).with_interval(10);
let agent = HopeAgent::with_default_config();
assert!(manager.should_checkpoint(10));
assert!(!manager.should_checkpoint(5));
manager.save_checkpoint(&agent, 10).unwrap();
manager.save_checkpoint(&agent, 20).unwrap();
manager.save_checkpoint(&agent, 30).unwrap();
assert!(checkpoint_dir.exists());
let _ = fs::remove_dir_all(&checkpoint_dir);
}
#[test]
fn test_checkpoint_cleanup() {
let checkpoint_dir = temp_path("checkpoints_cleanup");
let mut manager = CheckpointManager::new(&checkpoint_dir, 2).with_interval(1);
let agent = HopeAgent::with_default_config();
manager.save_checkpoint(&agent, 1).unwrap();
manager.save_checkpoint(&agent, 2).unwrap();
manager.save_checkpoint(&agent, 3).unwrap();
manager.save_checkpoint(&agent, 4).unwrap();
let checkpoints = manager.list_checkpoints().unwrap();
assert_eq!(checkpoints.len(), 2);
let _ = fs::remove_dir_all(&checkpoint_dir);
}
#[test]
fn test_roundtrip_with_compression() {
let agent = HopeAgent::with_default_config();
let options = PersistenceOptions {
format: PersistenceFormat::Json,
pretty: false,
compress: true,
};
let bytes = agent.to_bytes_with_options(&options).unwrap();
let loaded = HopeAgent::from_bytes_with_options(&bytes, &options).unwrap();
assert_eq!(
loaded.get_statistics().total_steps,
agent.get_statistics().total_steps
);
}
#[test]
fn test_persistence_error_handling() {
let invalid_path = PathBuf::from("/invalid/path/that/does/not/exist/agent.json");
let result = HopeAgent::load_from_file(&invalid_path);
assert!(result.is_err());
}
}