use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use super::snapshot::SessionId;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionGroupId(pub String);
impl SessionGroupId {
pub fn new() -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
Self(format!("g{}", timestamp))
}
pub fn from_raw(s: impl Into<String>) -> Self {
Self(s.into())
}
}
impl Default for SessionGroupId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionGroupId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LearningPhase {
Bootstrap,
Release,
Validate,
}
impl std::fmt::Display for LearningPhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Bootstrap => write!(f, "bootstrap"),
Self::Release => write!(f, "release"),
Self::Validate => write!(f, "validate"),
}
}
}
impl std::str::FromStr for LearningPhase {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bootstrap" => Ok(Self::Bootstrap),
"release" => Ok(Self::Release),
"validate" => Ok(Self::Validate),
_ => Err(format!("Unknown phase: {}", s)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionGroupMetadata {
pub scenario: String,
pub created_at: u64,
pub completed_at: Option<u64>,
pub target_runs: usize,
pub success_count: usize,
pub failure_count: usize,
pub variant: Option<String>,
}
impl SessionGroupMetadata {
pub fn new(scenario: impl Into<String>, target_runs: usize) -> Self {
Self {
scenario: scenario.into(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
completed_at: None,
target_runs,
success_count: 0,
failure_count: 0,
variant: None,
}
}
pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
self.variant = Some(variant.into());
self
}
pub fn record_success(&mut self) {
self.success_count += 1;
}
pub fn record_failure(&mut self) {
self.failure_count += 1;
}
pub fn mark_completed(&mut self) {
self.completed_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
);
}
pub fn success_rate(&self) -> f64 {
let total = self.success_count + self.failure_count;
if total == 0 {
0.0
} else {
self.success_count as f64 / total as f64
}
}
pub fn completed_runs(&self) -> usize {
self.success_count + self.failure_count
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionGroup {
pub id: SessionGroupId,
pub phase: LearningPhase,
pub session_ids: Vec<SessionId>,
pub metadata: SessionGroupMetadata,
}
impl SessionGroup {
pub fn new(phase: LearningPhase, scenario: impl Into<String>, target_runs: usize) -> Self {
let scenario = scenario.into();
Self {
id: SessionGroupId::new(),
phase,
session_ids: Vec::new(),
metadata: SessionGroupMetadata::new(&scenario, target_runs),
}
}
pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
self.metadata = self.metadata.with_variant(variant);
self
}
pub fn add_session(&mut self, session_id: SessionId, success: bool) {
self.session_ids.push(session_id);
if success {
self.metadata.record_success();
} else {
self.metadata.record_failure();
}
}
pub fn mark_completed(&mut self) {
self.metadata.mark_completed();
}
pub fn success_rate(&self) -> f64 {
self.metadata.success_rate()
}
pub fn is_target_reached(&self) -> bool {
self.metadata.completed_runs() >= self.metadata.target_runs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_group_id_generation() {
let id1 = SessionGroupId::new();
let id2 = SessionGroupId::new();
assert!(id1.0.starts_with('g'));
assert!(id2.0.starts_with('g'));
}
#[test]
fn test_learning_phase_display() {
assert_eq!(LearningPhase::Bootstrap.to_string(), "bootstrap");
assert_eq!(LearningPhase::Release.to_string(), "release");
assert_eq!(LearningPhase::Validate.to_string(), "validate");
}
#[test]
fn test_learning_phase_parse() {
assert_eq!(
"bootstrap".parse::<LearningPhase>().unwrap(),
LearningPhase::Bootstrap
);
assert_eq!(
"RELEASE".parse::<LearningPhase>().unwrap(),
LearningPhase::Release
);
assert!("unknown".parse::<LearningPhase>().is_err());
}
#[test]
fn test_session_group_success_rate() {
let mut group = SessionGroup::new(LearningPhase::Bootstrap, "test", 10);
assert_eq!(group.success_rate(), 0.0);
group.add_session(SessionId("1".to_string()), true);
group.add_session(SessionId("2".to_string()), true);
group.add_session(SessionId("3".to_string()), true);
group.add_session(SessionId("4".to_string()), false);
group.add_session(SessionId("5".to_string()), false);
assert_eq!(group.success_rate(), 0.6);
assert!(!group.is_target_reached());
for i in 6..=10 {
group.add_session(SessionId(i.to_string()), true);
}
assert!(group.is_target_reached());
}
}