use std::collections::{HashMap, HashSet};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use uuid::Uuid;
use rlmesh_proto::env::v1::EpisodeMetadata;
use rlmesh_proto::spaces::v1::MetaMap;
fn unix_nanos_now() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| i64::try_from(d.as_nanos()).unwrap_or(i64::MAX))
.unwrap_or(0)
}
struct Episode {
id: String,
seed: Option<i64>,
env_index: i32,
step_count: i64,
cumulative_reward: f64,
start_time: Instant,
start_timestamp_ns: i64,
}
impl Episode {
fn new(env_index: i32, seed: Option<i64>) -> Self {
let start_time = Instant::now();
let start_timestamp_ns = unix_nanos_now();
Self {
id: Uuid::new_v4().to_string(),
seed,
env_index,
step_count: 0,
cumulative_reward: 0.0,
start_time,
start_timestamp_ns,
}
}
fn record_step(&mut self, reward: f64) {
self.step_count += 1;
self.cumulative_reward += reward;
}
fn complete(
self,
terminated: bool,
truncated: bool,
final_info: Option<MetaMap>,
) -> EpisodeMetadata {
let end_timestamp_ns = unix_nanos_now();
let duration_ms = self.start_time.elapsed().as_millis() as i64;
EpisodeMetadata {
episode_id: self.id,
seed: self.seed,
env_index: self.env_index,
step_count: self.step_count,
cumulative_reward: self.cumulative_reward,
terminated,
truncated,
start_timestamp_ns: self.start_timestamp_ns,
end_timestamp_ns,
duration_ms,
final_info,
}
}
}
const MAX_INTERRUPTED_EPISODES: usize = 100;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LaneState {
Active,
PendingAutoreset,
Idle,
}
pub struct EpisodeTracker {
active: HashMap<i32, Episode>,
pending_autoreset: HashSet<i32>,
interrupted: Vec<EpisodeMetadata>,
interrupted_dropped: u64,
}
impl EpisodeTracker {
pub fn new() -> Self {
Self {
active: HashMap::new(),
pending_autoreset: HashSet::new(),
interrupted: Vec::new(),
interrupted_dropped: 0,
}
}
fn push_interrupted(&mut self, metadata: EpisodeMetadata) {
if self.interrupted.len() >= MAX_INTERRUPTED_EPISODES {
self.interrupted.remove(0);
self.interrupted_dropped = self.interrupted_dropped.saturating_add(1);
}
self.interrupted.push(metadata);
}
fn take_interrupted(&mut self) -> Vec<EpisodeMetadata> {
if self.interrupted_dropped > 0 {
tracing::warn!(
dropped = self.interrupted_dropped,
cap = MAX_INTERRUPTED_EPISODES,
"interrupted-episode buffer overflowed; oldest interrupted episodes were \
dropped before they could be delivered (a client looping Reset without Step \
grows this buffer); their accounting is lost"
);
self.interrupted_dropped = 0;
}
std::mem::take(&mut self.interrupted)
}
pub fn start_episode(&mut self, env_index: i32, seed: Option<i64>) -> String {
let episode = Episode::new(env_index, seed);
let episode_id = episode.id.clone();
self.pending_autoreset.remove(&env_index);
if let Some(old_episode) = self.active.insert(env_index, episode) {
tracing::debug!(
"Episode {} for env {} interrupted by a new episode; completing as truncated",
old_episode.id,
env_index
);
self.push_interrupted(old_episode.complete(false, true, None));
}
episode_id
}
pub fn drain_interrupted(&mut self) -> Vec<EpisodeMetadata> {
self.take_interrupted()
}
pub fn record_step(&mut self, env_index: i32, reward: f64) {
if let Some(episode) = self.active.get_mut(&env_index) {
episode.record_step(reward);
} else {
tracing::warn!(
"Attempted to record step for env {} with no active episode",
env_index
);
}
}
pub fn complete_episode(
&mut self,
env_index: i32,
terminated: bool,
truncated: bool,
final_info: Option<MetaMap>,
) -> Option<EpisodeMetadata> {
let episode = self.active.remove(&env_index)?;
Some(episode.complete(terminated, truncated, final_info))
}
pub fn complete_all(&mut self, reason: &str) -> Vec<EpisodeMetadata> {
tracing::info!(
"Completing all {} active episodes: {}",
self.active.len(),
reason
);
let mut completed = self.take_interrupted();
for (_env_index, episode) in self.active.drain() {
let metadata = episode.complete(false, true, None);
completed.push(metadata);
}
self.pending_autoreset.clear();
completed
}
pub fn active_episode_id(&self, env_index: i32) -> Option<&str> {
self.active
.get(&env_index)
.map(|episode| episode.id.as_str())
}
pub fn lane_state(&self, env_index: i32) -> LaneState {
if self.active.contains_key(&env_index) {
LaneState::Active
} else if self.pending_autoreset.contains(&env_index) {
LaneState::PendingAutoreset
} else {
LaneState::Idle
}
}
pub fn expect_autoreset(&mut self, env_index: i32) {
self.pending_autoreset.insert(env_index);
}
}
impl Default for EpisodeTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn active_count(tracker: &EpisodeTracker) -> usize {
tracker.active.len()
}
#[test]
fn test_episode_lifecycle() {
let mut tracker = EpisodeTracker::new();
let ep_id = tracker.start_episode(0, Some(42));
assert_eq!(active_count(&tracker), 1);
tracker.record_step(0, 1.0);
tracker.record_step(0, 2.5);
let metadata = tracker.complete_episode(0, true, false, None).unwrap();
assert_eq!(metadata.episode_id, ep_id);
assert_eq!(metadata.seed, Some(42));
assert_eq!(metadata.env_index, 0);
assert_eq!(metadata.step_count, 2);
assert_eq!(metadata.cumulative_reward, 3.5);
assert!(metadata.terminated);
assert!(!metadata.truncated);
assert_eq!(active_count(&tracker), 0);
}
#[test]
fn lane_state_tracks_active_pending_and_idle() {
let mut tracker = EpisodeTracker::new();
assert_eq!(tracker.lane_state(0), LaneState::Idle);
tracker.start_episode(0, None);
assert_eq!(tracker.lane_state(0), LaneState::Active);
tracker.complete_episode(0, true, false, None);
assert_eq!(tracker.lane_state(0), LaneState::Idle);
tracker.start_episode(0, None);
tracker.complete_episode(0, true, false, None);
tracker.expect_autoreset(0);
assert_eq!(tracker.lane_state(0), LaneState::PendingAutoreset);
tracker.start_episode(0, None);
assert_eq!(tracker.lane_state(0), LaneState::Active);
}
#[test]
fn active_episode_takes_precedence_over_pending_autoreset() {
let mut tracker = EpisodeTracker::new();
tracker.start_episode(0, None);
tracker.expect_autoreset(0);
assert_eq!(tracker.lane_state(0), LaneState::Active);
}
#[test]
fn complete_all_clears_pending_autoreset() {
let mut tracker = EpisodeTracker::new();
tracker.start_episode(0, None);
tracker.complete_episode(0, true, false, None);
tracker.expect_autoreset(0);
assert_eq!(tracker.lane_state(0), LaneState::PendingAutoreset);
let _ = tracker.complete_all("close");
assert_eq!(tracker.lane_state(0), LaneState::Idle);
}
#[test]
fn interrupted_episode_is_completed_as_truncated_and_drained_once() {
let mut tracker = EpisodeTracker::new();
let first = tracker.start_episode(0, Some(7));
tracker.record_step(0, 1.5);
tracker.record_step(0, 2.5);
let second = tracker.start_episode(0, None);
assert_ne!(first, second);
assert_eq!(active_count(&tracker), 1);
let interrupted = tracker.drain_interrupted();
assert_eq!(interrupted.len(), 1);
assert_eq!(interrupted[0].episode_id, first);
assert_eq!(interrupted[0].step_count, 2);
assert_eq!(interrupted[0].cumulative_reward, 4.0);
assert!(!interrupted[0].terminated);
assert!(interrupted[0].truncated);
assert!(tracker.drain_interrupted().is_empty());
}
#[test]
fn complete_all_includes_undrained_interrupted_episodes() {
let mut tracker = EpisodeTracker::new();
let first = tracker.start_episode(0, Some(1));
tracker.record_step(0, 1.0);
let second = tracker.start_episode(0, None);
let mut all = tracker.complete_all("client close");
all.sort_by(|a, b| a.episode_id.cmp(&b.episode_id));
let mut expected = vec![first, second];
expected.sort();
let mut got: Vec<String> = all.iter().map(|m| m.episode_id.clone()).collect();
got.sort();
assert_eq!(got, expected);
assert!(tracker.drain_interrupted().is_empty());
}
#[test]
fn interrupted_buffer_is_bounded_under_repeated_reset_without_step() {
let mut tracker = EpisodeTracker::new();
let resets = MAX_INTERRUPTED_EPISODES * 5;
for _ in 0..resets {
tracker.start_episode(0, None);
assert!(
tracker.interrupted.len() <= MAX_INTERRUPTED_EPISODES,
"interrupted buffer exceeded its cap"
);
}
assert_eq!(tracker.interrupted.len(), MAX_INTERRUPTED_EPISODES);
assert!(tracker.interrupted_dropped > 0);
let drained = tracker.drain_interrupted();
assert_eq!(drained.len(), MAX_INTERRUPTED_EPISODES);
assert_eq!(tracker.interrupted_dropped, 0);
assert!(tracker.drain_interrupted().is_empty());
}
#[test]
fn test_vectorized_episodes() {
let mut tracker = EpisodeTracker::new();
let ep0 = tracker.start_episode(0, Some(100));
let ep1 = tracker.start_episode(1, Some(200));
let _ep2 = tracker.start_episode(2, Some(300));
assert_eq!(active_count(&tracker), 3);
tracker.record_step(0, 1.0);
tracker.record_step(1, 2.0);
tracker.record_step(2, 3.0);
let meta1 = tracker.complete_episode(1, true, false, None).unwrap();
assert_eq!(meta1.episode_id, ep1);
assert_eq!(active_count(&tracker), 2);
let meta0 = tracker.complete_episode(0, false, true, None).unwrap();
assert_eq!(meta0.episode_id, ep0);
assert!(!meta0.terminated);
assert!(meta0.truncated);
assert_eq!(active_count(&tracker), 1);
tracker.complete_episode(2, true, false, None);
assert_eq!(active_count(&tracker), 0);
}
#[test]
fn test_complete_all() {
let mut tracker = EpisodeTracker::new();
tracker.start_episode(0, Some(1));
tracker.start_episode(1, Some(2));
tracker.start_episode(2, Some(3));
tracker.record_step(0, 1.0);
tracker.record_step(1, 2.0);
let interrupted = tracker.complete_all("test cancellation");
assert_eq!(interrupted.len(), 3);
assert_eq!(active_count(&tracker), 0);
for meta in interrupted {
assert!(!meta.terminated);
assert!(meta.truncated);
}
}
#[test]
fn unseeded_episode_leaves_seed_unset_not_fabricated_zero() {
let mut tracker = EpisodeTracker::new();
tracker.start_episode(0, None);
let meta = tracker.complete_episode(0, true, false, None).unwrap();
assert_eq!(meta.seed, None);
tracker.start_episode(1, Some(0));
let meta = tracker.complete_episode(1, true, false, None).unwrap();
assert_eq!(meta.seed, Some(0));
}
}