use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use super::session_group::{LearningPhase, SessionGroupId};
use super::{EpisodeTransitions, NgramStats, SelectionPerformance};
use crate::online_stats::ActionStats;
pub const SNAPSHOT_VERSION: u32 = 1;
pub type Timestamp = u64;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionId(pub String);
impl SessionId {
pub fn timestamp(&self) -> Option<Timestamp> {
self.0.parse().ok()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SnapshotKey {
Global,
Scenario(String),
Session {
scenario: String,
session_id: SessionId,
},
}
impl std::fmt::Display for SnapshotKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Global => write!(f, "global"),
Self::Scenario(s) => write!(f, "scenario:{}", s),
Self::Session {
scenario,
session_id,
} => {
write!(f, "session:{}:{}", scenario, session_id.0)
}
}
}
}
pub trait SnapshotStorage {
type Error: std::error::Error + Send + Sync + 'static;
fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error>;
fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error>;
fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
}
pub trait TimeSeriesQuery {
type Error: std::error::Error + Send + Sync + 'static;
fn query_range(
&self,
scenario: &str,
from: Timestamp,
to: Timestamp,
) -> Result<Vec<LearningSnapshot>, Self::Error>;
fn query_latest(
&self,
scenario: &str,
limit: usize,
) -> Result<Vec<LearningSnapshot>, Self::Error>;
fn query_since(
&self,
scenario: &str,
since: Timestamp,
) -> Result<Vec<LearningSnapshot>, Self::Error> {
self.query_range(scenario, since, u64::MAX)
}
fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningSnapshot {
pub version: u32,
pub metadata: SnapshotMetadata,
pub episode_transitions: EpisodeTransitions,
pub ngram_stats: NgramStats,
pub selection_performance: SelectionPerformance,
pub contextual_stats: HashMap<(String, String), ActionStats>,
pub action_stats: HashMap<String, ActionStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotMetadata {
pub scenario_name: Option<String>,
pub task_description: Option<String>,
pub created_at: u64,
pub session_count: u32,
pub total_episodes: u32,
pub total_actions: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub phase: Option<LearningPhase>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub group_id: Option<SessionGroupId>,
}
impl Default for SnapshotMetadata {
fn default() -> Self {
Self {
scenario_name: None,
task_description: None,
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_count: 1,
total_episodes: 0,
total_actions: 0,
phase: None,
group_id: None,
}
}
}
impl SnapshotMetadata {
pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
self.scenario_name = Some(name.into());
self
}
pub fn with_task(mut self, desc: impl Into<String>) -> Self {
self.task_description = Some(desc.into());
self
}
pub fn with_phase(mut self, phase: LearningPhase) -> Self {
self.phase = Some(phase);
self
}
pub fn with_group_id(mut self, group_id: SessionGroupId) -> Self {
self.group_id = Some(group_id);
self
}
}
impl LearningSnapshot {
pub fn empty() -> Self {
Self {
version: SNAPSHOT_VERSION,
metadata: SnapshotMetadata::default(),
episode_transitions: EpisodeTransitions::default(),
ngram_stats: NgramStats::default(),
selection_performance: SelectionPerformance::default(),
contextual_stats: HashMap::new(),
action_stats: HashMap::new(),
}
}
pub fn with_metadata(mut self, metadata: SnapshotMetadata) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum MergeStrategy {
#[default]
Additive,
TimeDecay {
half_life_sessions: u32,
},
SuccessWeighted,
}
#[derive(Clone)]
pub struct FileSystemStorage {
base_dir: PathBuf,
}
impl FileSystemStorage {
pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
let base_dir = base_dir.as_ref().to_path_buf();
fs::create_dir_all(&base_dir)?;
Ok(Self { base_dir })
}
pub fn base_dir(&self) -> &Path {
&self.base_dir
}
fn key_to_path(&self, key: &SnapshotKey) -> PathBuf {
match key {
SnapshotKey::Global => self.base_dir.join("global_stats.json"),
SnapshotKey::Scenario(scenario) => self
.base_dir
.join("scenarios")
.join(scenario)
.join("stats.json"),
SnapshotKey::Session {
scenario,
session_id,
} => self
.base_dir
.join("scenarios")
.join(scenario)
.join("sessions")
.join(&session_id.0)
.join("stats.json"),
}
}
fn sessions_dir(&self, scenario: &str) -> PathBuf {
self.base_dir
.join("scenarios")
.join(scenario)
.join("sessions")
}
}
impl SnapshotStorage for FileSystemStorage {
type Error = std::io::Error;
fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error> {
let path = self.key_to_path(key);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
snapshot.save_json(&path)
}
fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error> {
let path = self.key_to_path(key);
if !path.exists() {
return Ok(None);
}
LearningSnapshot::load_json(&path).map(Some)
}
fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
let path = self.key_to_path(key);
if path.exists() {
fs::remove_file(&path)?;
Ok(true)
} else {
Ok(false)
}
}
fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
Ok(self.key_to_path(key).exists())
}
}
impl TimeSeriesQuery for FileSystemStorage {
type Error = std::io::Error;
fn query_range(
&self,
scenario: &str,
from: Timestamp,
to: Timestamp,
) -> Result<Vec<LearningSnapshot>, Self::Error> {
let sessions = self.list_sessions(scenario)?;
let mut results = Vec::new();
for session_id in sessions {
if let Some(ts) = session_id.timestamp() {
if ts >= from && ts <= to {
let key = SnapshotKey::Session {
scenario: scenario.to_string(),
session_id,
};
if let Some(snapshot) = self.load(&key)? {
results.push(snapshot);
}
}
}
}
Ok(results)
}
fn query_latest(
&self,
scenario: &str,
limit: usize,
) -> Result<Vec<LearningSnapshot>, Self::Error> {
let mut sessions = self.list_sessions(scenario)?;
sessions.sort_by(|a, b| b.0.cmp(&a.0));
sessions.truncate(limit);
let mut results = Vec::new();
for session_id in sessions {
let key = SnapshotKey::Session {
scenario: scenario.to_string(),
session_id,
};
if let Some(snapshot) = self.load(&key)? {
results.push(snapshot);
}
}
Ok(results)
}
fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error> {
let sessions_dir = self.sessions_dir(scenario);
if !sessions_dir.exists() {
return Ok(Vec::new());
}
let mut sessions = Vec::new();
for entry in fs::read_dir(sessions_dir)? {
let entry = entry?;
if entry.file_type()?.is_dir() {
if let Some(name) = entry.file_name().to_str() {
sessions.push(SessionId(name.to_string()));
}
}
}
sessions.sort_by(|a, b| a.0.cmp(&b.0));
Ok(sessions)
}
}
#[derive(Clone)]
pub struct LearningStore {
storage: FileSystemStorage,
}
impl LearningStore {
pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
let storage = FileSystemStorage::new(base_dir)?;
Ok(Self { storage })
}
pub fn default_path() -> PathBuf {
dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("swarm-engine")
.join("learning")
}
pub fn storage(&self) -> &FileSystemStorage {
&self.storage
}
pub fn load_global(&self) -> std::io::Result<LearningSnapshot> {
self.storage.load(&SnapshotKey::Global)?.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "global stats not found")
})
}
pub fn save_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
self.storage.save(&SnapshotKey::Global, snapshot)
}
pub fn load_scenario(&self, scenario: &str) -> std::io::Result<LearningSnapshot> {
self.storage
.load(&SnapshotKey::Scenario(scenario.to_string()))?
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "scenario stats not found")
})
}
pub fn save_scenario(
&self,
scenario: &str,
snapshot: &LearningSnapshot,
) -> std::io::Result<()> {
self.storage
.save(&SnapshotKey::Scenario(scenario.to_string()), snapshot)
}
pub fn save_session(
&self,
scenario: &str,
snapshot: &LearningSnapshot,
) -> std::io::Result<SessionId> {
let session_id = self.generate_session_id();
let key = SnapshotKey::Session {
scenario: scenario.to_string(),
session_id: session_id.clone(),
};
self.storage.save(&key, snapshot)?;
let meta_path = self.storage.key_to_path(&key).with_file_name("meta.json");
let meta_json = serde_json::to_string_pretty(&snapshot.metadata)?;
fs::write(meta_path, meta_json)?;
self.merge_into_scenario(scenario, snapshot)?;
self.merge_into_global(snapshot)?;
Ok(session_id)
}
pub fn list_sessions(&self, scenario: &str) -> std::io::Result<Vec<SessionId>> {
self.storage.list_sessions(scenario)
}
pub fn load_session(
&self,
scenario: &str,
session_id: &SessionId,
) -> std::io::Result<LearningSnapshot> {
let key = SnapshotKey::Session {
scenario: scenario.to_string(),
session_id: session_id.clone(),
};
self.storage
.load(&key)?
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "session not found"))
}
pub fn query_range(
&self,
scenario: &str,
from: Timestamp,
to: Timestamp,
) -> std::io::Result<Vec<LearningSnapshot>> {
self.storage.query_range(scenario, from, to)
}
pub fn query_latest(
&self,
scenario: &str,
limit: usize,
) -> std::io::Result<Vec<LearningSnapshot>> {
self.storage.query_latest(scenario, limit)
}
pub fn merge(
&self,
snapshots: &[LearningSnapshot],
strategy: MergeStrategy,
) -> LearningSnapshot {
merge_snapshots(snapshots, strategy)
}
fn generate_session_id(&self) -> SessionId {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
SessionId(format!("{:010}", timestamp))
}
fn merge_into_scenario(
&self,
scenario: &str,
snapshot: &LearningSnapshot,
) -> std::io::Result<()> {
let existing = self
.storage
.load(&SnapshotKey::Scenario(scenario.to_string()))?;
let merged = match existing {
Some(existing) => {
merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
}
None => snapshot.clone(),
};
self.save_scenario(scenario, &merged)
}
fn merge_into_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
let existing = self.storage.load(&SnapshotKey::Global)?;
let merged = match existing {
Some(existing) => {
merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
}
None => snapshot.clone(),
};
self.save_global(&merged)
}
pub fn load_offline_model(
&self,
scenario: &str,
) -> std::io::Result<super::offline::OfflineModel> {
let path = self.offline_model_path(scenario);
if !path.exists() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"offline model not found",
));
}
let json = fs::read_to_string(&path)?;
serde_json::from_str(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
pub fn save_offline_model(
&self,
scenario: &str,
model: &super::offline::OfflineModel,
) -> std::io::Result<()> {
let path = self.offline_model_path(scenario);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(model)?;
fs::write(path, json)
}
pub fn run_offline_learning(
&self,
scenario: &str,
session_limit: usize,
) -> std::io::Result<super::offline::OfflineModel> {
let snapshots = self.query_latest(scenario, session_limit)?;
if snapshots.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"no sessions found for offline learning",
));
}
let analyzer = super::offline::OfflineAnalyzer::new(&snapshots);
let model = analyzer.analyze();
self.save_offline_model(scenario, &model)?;
Ok(model)
}
fn offline_model_path(&self, scenario: &str) -> PathBuf {
self.storage
.base_dir()
.join("scenarios")
.join(scenario)
.join("offline_model.json")
}
}
pub fn merge_snapshots(
snapshots: &[LearningSnapshot],
strategy: MergeStrategy,
) -> LearningSnapshot {
if snapshots.is_empty() {
return LearningSnapshot::empty();
}
if snapshots.len() == 1 {
return snapshots[0].clone();
}
let mut result = LearningSnapshot::empty();
let weights: Vec<f64> = match strategy {
MergeStrategy::Additive => vec![1.0; snapshots.len()],
MergeStrategy::TimeDecay { half_life_sessions } => {
let half_life = half_life_sessions as f64;
snapshots
.iter()
.enumerate()
.map(|(i, _)| {
let age = (snapshots.len() - 1 - i) as f64;
0.5_f64.powf(age / half_life)
})
.collect()
}
MergeStrategy::SuccessWeighted => snapshots
.iter()
.map(|s| {
let total = s.metadata.total_episodes as f64;
let success = s.episode_transitions.success_episodes as f64;
if total == 0.0 {
1.0
} else {
1.0 + success / total
}
})
.collect(),
};
result.metadata = SnapshotMetadata {
scenario_name: snapshots
.last()
.and_then(|s| s.metadata.scenario_name.clone()),
task_description: snapshots
.last()
.and_then(|s| s.metadata.task_description.clone()),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_count: snapshots.iter().map(|s| s.metadata.session_count).sum(),
total_episodes: snapshots.iter().map(|s| s.metadata.total_episodes).sum(),
total_actions: snapshots.iter().map(|s| s.metadata.total_actions).sum(),
phase: None,
group_id: None,
};
for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
for (key, &count) in &snapshot.episode_transitions.success_transitions {
let weighted_count = (count as f64 * weight).round() as u32;
*result
.episode_transitions
.success_transitions
.entry(key.clone())
.or_default() += weighted_count;
}
for (key, &count) in &snapshot.episode_transitions.failure_transitions {
let weighted_count = (count as f64 * weight).round() as u32;
*result
.episode_transitions
.failure_transitions
.entry(key.clone())
.or_default() += weighted_count;
}
result.episode_transitions.success_episodes +=
(snapshot.episode_transitions.success_episodes as f64 * weight).round() as u32;
result.episode_transitions.failure_episodes +=
(snapshot.episode_transitions.failure_episodes as f64 * weight).round() as u32;
}
for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
let entry = result
.ngram_stats
.trigrams
.entry(key.clone())
.or_insert((0, 0));
entry.0 += (success as f64 * weight).round() as u32;
entry.1 += (failure as f64 * weight).round() as u32;
}
for (key, &(success, failure)) in &snapshot.ngram_stats.quadgrams {
let entry = result
.ngram_stats
.quadgrams
.entry(key.clone())
.or_insert((0, 0));
entry.0 += (success as f64 * weight).round() as u32;
entry.1 += (failure as f64 * weight).round() as u32;
}
}
for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
for (key, stats) in &snapshot.contextual_stats {
let entry = result.contextual_stats.entry(key.clone()).or_default();
entry.visits += (stats.visits as f64 * weight).round() as u32;
entry.successes += (stats.successes as f64 * weight).round() as u32;
entry.failures += (stats.failures as f64 * weight).round() as u32;
}
}
for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
for (key, stats) in &snapshot.action_stats {
let entry = result.action_stats.entry(key.clone()).or_default();
entry.visits += (stats.visits as f64 * weight).round() as u32;
entry.successes += (stats.successes as f64 * weight).round() as u32;
entry.failures += (stats.failures as f64 * weight).round() as u32;
}
}
for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
for (strat, stats) in &snapshot.selection_performance.strategy_stats {
let entry = result
.selection_performance
.strategy_stats
.entry(strat.clone())
.or_default();
entry.visits += (stats.visits as f64 * weight).round() as u32;
entry.successes += (stats.successes as f64 * weight).round() as u32;
entry.failures += (stats.failures as f64 * weight).round() as u32;
entry.episodes_success += (stats.episodes_success as f64 * weight).round() as u32;
entry.episodes_failure += (stats.episodes_failure as f64 * weight).round() as u32;
}
}
result
}
impl LearningSnapshot {
pub fn to_json(&self) -> serde_json::Result<String> {
serde_json::to_string_pretty(self)
}
pub fn from_json(json: &str) -> serde_json::Result<Self> {
serde_json::from_str(json)
}
pub fn save_json(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
let json = self.to_json()?;
fs::write(path, json)
}
pub fn load_json(path: impl AsRef<Path>) -> std::io::Result<Self> {
let json = fs::read_to_string(path)?;
Self::from_json(&json).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_snapshot_serialization() {
let snapshot = LearningSnapshot::empty()
.with_metadata(SnapshotMetadata::default().with_scenario("test"));
let json = snapshot.to_json().unwrap();
let loaded = LearningSnapshot::from_json(&json).unwrap();
assert_eq!(loaded.version, SNAPSHOT_VERSION);
assert_eq!(loaded.metadata.scenario_name, Some("test".to_string()));
}
#[test]
fn test_learning_store_save_load() {
let dir = tempdir().unwrap();
let store = LearningStore::new(dir.path()).unwrap();
let snapshot = LearningSnapshot::empty()
.with_metadata(SnapshotMetadata::default().with_scenario("troubleshooting"));
store.save_scenario("troubleshooting", &snapshot).unwrap();
let loaded = store.load_scenario("troubleshooting").unwrap();
assert_eq!(
loaded.metadata.scenario_name,
Some("troubleshooting".to_string())
);
}
#[test]
fn test_merge_additive() {
let dir = tempdir().unwrap();
let store = LearningStore::new(dir.path()).unwrap();
let mut s1 = LearningSnapshot::empty();
s1.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 5);
s1.metadata.total_episodes = 10;
let mut s2 = LearningSnapshot::empty();
s2.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 3);
s2.metadata.total_episodes = 5;
let merged = store.merge(&[s1, s2], MergeStrategy::Additive);
assert_eq!(
merged
.episode_transitions
.success_transitions
.get(&("A".to_string(), "B".to_string())),
Some(&8)
);
assert_eq!(merged.metadata.total_episodes, 15);
}
#[test]
fn test_merge_time_decay() {
let dir = tempdir().unwrap();
let store = LearningStore::new(dir.path()).unwrap();
let mut s1 = LearningSnapshot::empty();
s1.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 100);
let mut s2 = LearningSnapshot::empty();
s2.episode_transitions
.success_transitions
.insert(("A".to_string(), "B".to_string()), 100);
let merged = store.merge(
&[s1, s2],
MergeStrategy::TimeDecay {
half_life_sessions: 1,
},
);
let count = merged
.episode_transitions
.success_transitions
.get(&("A".to_string(), "B".to_string()))
.unwrap();
assert_eq!(*count, 150);
}
#[test]
fn test_session_management() {
let dir = tempdir().unwrap();
let store = LearningStore::new(dir.path()).unwrap();
let metadata = SnapshotMetadata::default().with_scenario("test");
let snapshot = LearningSnapshot {
version: SNAPSHOT_VERSION,
metadata,
action_stats: Default::default(),
episode_transitions: Default::default(),
ngram_stats: Default::default(),
selection_performance: Default::default(),
contextual_stats: Default::default(),
};
let session_id = store.save_session("test", &snapshot).unwrap();
assert!(!session_id.0.is_empty());
let sessions = store.list_sessions("test").unwrap();
assert_eq!(sessions.len(), 1);
assert_eq!(sessions[0], session_id);
}
}