use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::record::{ActionRecord, FromRecord, Record};
use crate::types::{GroupId, TaskId};
use crate::util::{epoch_millis, epoch_millis_for_ordering};
pub trait EpisodeTrait: Send + Sync {
fn id(&self) -> &EpisodeId;
fn learn_model_name(&self) -> &str;
fn task_id(&self) -> Option<TaskId>;
fn group_id(&self) -> Option<GroupId>;
fn outcome(&self) -> &Outcome;
fn is_success(&self) -> bool {
self.outcome().is_success()
}
fn scenario_name(&self) -> Option<&str>;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EpisodeId {
pub timestamp_ms: u64,
pub counter: u32,
}
impl EpisodeId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
Self {
timestamp_ms: epoch_millis_for_ordering(),
counter: COUNTER.fetch_add(1, Ordering::Relaxed),
}
}
pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
Self {
timestamp_ms,
counter,
}
}
}
impl Default for EpisodeId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for EpisodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
#[derive(Default)]
pub enum Outcome {
Success {
score: f64,
},
Failure {
reason: String,
},
Timeout {
partial_score: Option<f64>,
},
#[default]
Unknown,
}
impl Outcome {
pub fn success(score: f64) -> Self {
Self::Success { score }
}
pub fn success_binary() -> Self {
Self::Success { score: 1.0 }
}
pub fn failure(reason: impl Into<String>) -> Self {
Self::Failure {
reason: reason.into(),
}
}
pub fn timeout(partial_score: Option<f64>) -> Self {
Self::Timeout { partial_score }
}
pub fn is_success(&self) -> bool {
matches!(self, Self::Success { .. })
}
pub fn is_failure(&self) -> bool {
matches!(self, Self::Failure { .. } | Self::Timeout { .. })
}
pub fn is_unknown(&self) -> bool {
matches!(self, Self::Unknown)
}
pub fn score(&self) -> f64 {
match self {
Self::Success { score } => *score,
Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
_ => 0.0,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EpisodeContext {
pub records: Vec<Record>,
}
impl EpisodeContext {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, record: impl Into<Record>) {
self.records.push(record.into());
}
pub fn with_record(mut self, record: impl Into<Record>) -> Self {
self.records.push(record.into());
self
}
pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
self.records.iter().filter_map(T::from_record)
}
pub fn first<T: FromRecord>(&self) -> Option<&T> {
self.iter::<T>().next()
}
pub fn len(&self) -> usize {
self.records.len()
}
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EpisodeMetadata {
pub strategy_name: Option<String>,
pub scenario_name: Option<String>,
pub created_at: u64,
pub started_at: Option<u64>,
pub ended_at: Option<u64>,
pub tags: HashMap<String, String>,
}
impl EpisodeMetadata {
pub fn new() -> Self {
Self {
created_at: epoch_millis(),
..Default::default()
}
}
pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
self.strategy_name = Some(name.into());
self
}
pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
self.scenario_name = Some(name.into());
self
}
pub fn with_duration(mut self, start: u64, end: u64) -> Self {
self.started_at = Some(start);
self.ended_at = Some(end);
self
}
pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.tags.insert(key.into(), value.into());
self
}
pub fn duration_ms(&self) -> Option<u64> {
match (self.started_at, self.ended_at) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Episode {
pub id: EpisodeId,
pub learn_model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub task_id: Option<TaskId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub group_id: Option<GroupId>,
pub context: EpisodeContext,
pub outcome: Outcome,
pub metadata: EpisodeMetadata,
}
impl Episode {
pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
Self {
id: EpisodeId::new(),
learn_model: learn_model.into(),
task_id: None,
group_id: None,
context: EpisodeContext::default(),
outcome,
metadata: EpisodeMetadata::new(),
}
}
pub fn builder() -> EpisodeBuilder {
EpisodeBuilder::default()
}
pub fn is_success(&self) -> bool {
self.outcome.is_success()
}
pub fn worker_id(&self) -> Option<usize> {
self.context
.iter::<ActionRecord>()
.next()
.map(|a| a.worker_id)
}
pub fn get_task_id(&self) -> Option<TaskId> {
self.task_id.or_else(|| {
self.context
.iter::<ActionRecord>()
.next()
.map(|a| a.task_id)
})
}
pub fn get_group_id(&self) -> Option<GroupId> {
self.group_id.or_else(|| {
self.context
.iter::<ActionRecord>()
.next()
.and_then(|a| a.group_id)
})
}
}
impl EpisodeTrait for Episode {
fn id(&self) -> &EpisodeId {
&self.id
}
fn learn_model_name(&self) -> &str {
&self.learn_model
}
fn task_id(&self) -> Option<TaskId> {
self.get_task_id()
}
fn group_id(&self) -> Option<GroupId> {
self.get_group_id()
}
fn outcome(&self) -> &Outcome {
&self.outcome
}
fn scenario_name(&self) -> Option<&str> {
self.metadata.scenario_name.as_deref()
}
}
#[derive(Debug, Default)]
pub struct EpisodeBuilder {
id: Option<EpisodeId>,
learn_model: Option<String>,
task_id: Option<TaskId>,
group_id: Option<GroupId>,
context: EpisodeContext,
outcome: Option<Outcome>,
metadata: EpisodeMetadata,
}
impl EpisodeBuilder {
pub fn id(mut self, id: EpisodeId) -> Self {
self.id = Some(id);
self
}
pub fn learn_model(mut self, name: impl Into<String>) -> Self {
self.learn_model = Some(name.into());
self
}
pub fn task_id(mut self, task_id: TaskId) -> Self {
self.task_id = Some(task_id);
self
}
pub fn group_id(mut self, group_id: GroupId) -> Self {
self.group_id = Some(group_id);
self
}
pub fn record(mut self, record: impl Into<Record>) -> Self {
self.context.push(record);
self
}
pub fn context(mut self, context: EpisodeContext) -> Self {
self.context = context;
self
}
pub fn outcome(mut self, outcome: Outcome) -> Self {
self.outcome = Some(outcome);
self
}
pub fn scenario(mut self, name: impl Into<String>) -> Self {
self.metadata.scenario_name = Some(name.into());
self
}
pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.tags.insert(key.into(), value.into());
self
}
pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn build(self) -> Episode {
Episode {
id: self.id.unwrap_or_default(),
learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
task_id: self.task_id,
group_id: self.group_id,
context: self.context,
outcome: self.outcome.unwrap_or(Outcome::Unknown),
metadata: self.metadata,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
use crate::learn::record::LlmCallRecord;
use crate::types::WorkerId;
fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
let result = if success {
ActionEventResult::success()
} else {
ActionEventResult::failure("test error")
};
ActionEventBuilder::new(tick, WorkerId(worker_id), action)
.result(result)
.duration(Duration::from_millis(50))
.context(
ActionContext::new()
.with_selection_logic("UCB1")
.with_previous_action("PrevAction"),
)
.build()
}
#[test]
fn test_action_record_from_action_event() {
let event = make_action_event(10, 1, "CheckStatus", true);
let record = ActionRecord::from(&event);
assert_eq!(record.tick, 10);
assert_eq!(record.worker_id, 1);
assert_eq!(record.action, "CheckStatus");
assert!(record.success);
assert_eq!(record.duration_ms, 50);
assert_eq!(record.selection_logic, Some("UCB1".to_string()));
assert_eq!(record.previous_action, Some("PrevAction".to_string()));
}
#[test]
fn test_episode_builder_with_actions() {
let event1 = make_action_event(1, 0, "Grep", true);
let event2 = make_action_event(2, 0, "Read", true);
let event3 = make_action_event(3, 0, "done", true);
let episode = Episode::builder()
.learn_model("worker_task")
.record(ActionRecord::from(&event1))
.record(ActionRecord::from(&event2))
.record(ActionRecord::from(&event3))
.outcome(Outcome::success_binary())
.scenario("troubleshooting")
.build();
assert_eq!(episode.learn_model, "worker_task");
assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
let actions: Vec<&str> = episode
.context
.iter::<ActionRecord>()
.map(|a| a.action.as_str())
.collect();
assert_eq!(actions, vec!["Grep", "Read", "done"]);
assert!(episode.is_success());
assert_eq!(
episode.metadata.scenario_name,
Some("troubleshooting".to_string())
);
}
#[test]
fn test_episode_builder_with_llm_call() {
let llm_record = LlmCallRecord::new("decide", "qwen2.5")
.prompt("What action?")
.response("CheckStatus")
.latency_ms(150)
.worker_id(0);
let episode = Episode::builder()
.learn_model("llm_call")
.record(llm_record.clone())
.outcome(Outcome::success(0.9))
.build();
assert_eq!(episode.learn_model, "llm_call");
assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
let llm_call = episode.context.first::<LlmCallRecord>().unwrap();
assert_eq!(llm_call.prompt, "What action?");
assert_eq!(llm_call.response, "CheckStatus");
}
#[test]
fn test_outcome_variants() {
assert!(Outcome::success(1.0).is_success());
assert!(!Outcome::success(1.0).is_failure());
assert_eq!(Outcome::success(0.8).score(), 0.8);
assert!(!Outcome::failure("test").is_success());
assert!(Outcome::failure("test").is_failure());
assert_eq!(Outcome::failure("test").score(), 0.0);
assert!(!Outcome::timeout(Some(0.5)).is_success());
assert!(Outcome::timeout(Some(0.5)).is_failure());
assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
assert!(!Outcome::Unknown.is_success());
assert!(!Outcome::Unknown.is_failure());
}
#[test]
fn test_episode_context_iter() {
let mut context = EpisodeContext::new();
context.push(ActionRecord::new(1, 0, "A").success(true));
context.push(ActionRecord::new(2, 0, "B").success(true));
context.push(ActionRecord::new(3, 0, "C").success(false));
assert_eq!(context.iter::<ActionRecord>().count(), 3);
let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
assert_eq!(success_count, 2);
let actions: Vec<&str> = context
.iter::<ActionRecord>()
.map(|a| a.action.as_str())
.collect();
assert_eq!(actions, vec!["A", "B", "C"]);
}
#[test]
fn test_episode_serialization() {
let episode = Episode::builder()
.learn_model("worker_task")
.record(ActionRecord::new(1, 0, "CheckStatus").success(true))
.outcome(Outcome::success_binary())
.build();
let json = serde_json::to_string(&episode).unwrap();
assert!(json.contains("\"learn_model\":\"worker_task\""));
assert!(json.contains("\"action\":\"CheckStatus\""));
let restored: Episode = serde_json::from_str(&json).unwrap();
assert_eq!(restored.learn_model, "worker_task");
assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
assert!(restored.is_success());
}
#[test]
fn test_llm_call_record_builder() {
let record = LlmCallRecord::new("decide", "qwen2.5")
.prompt("prompt")
.response("response")
.endpoint("http://localhost:11434")
.lora("adapter1")
.latency_ms(100)
.worker_id(5);
assert_eq!(record.call_type, "decide");
assert_eq!(record.model, "qwen2.5");
assert_eq!(record.prompt, "prompt");
assert_eq!(record.response, "response");
assert_eq!(record.lora, Some("adapter1".to_string()));
assert_eq!(record.worker_id, Some(5));
assert!(record.is_success());
let error_record = LlmCallRecord::new("decide", "model").error("timeout");
assert!(!error_record.is_success());
}
#[test]
fn test_episode_builder_with_id_and_metadata() {
let custom_id = EpisodeId::from_parts(12345, 1);
let mut custom_metadata = EpisodeMetadata::new();
custom_metadata.scenario_name = Some("custom-scenario".to_string());
custom_metadata
.tags
.insert("key".to_string(), "value".to_string());
let episode = Episode::builder()
.id(custom_id.clone())
.learn_model("test")
.metadata(custom_metadata)
.outcome(Outcome::Unknown)
.build();
assert_eq!(episode.id, custom_id);
assert_eq!(
episode.metadata.scenario_name,
Some("custom-scenario".to_string())
);
assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
}
}