use std::collections::HashMap;
use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::{ActionRecord, Record};
use super::super::training::TrainingData;
use super::{system_events, LearnError, LearnModel};
use crate::types::TaskId;
pub struct WorkerDecisionSequenceLearn {
system_prompt: String,
min_actions: usize,
available_actions: Vec<String>,
system_events: Vec<String>,
}
impl WorkerDecisionSequenceLearn {
pub fn new() -> Self {
Self {
system_prompt: "You are an intelligent agent that diagnoses and resolves system issues. \
Given a context and available actions, determine the optimal action sequence.".to_string(),
min_actions: 3,
available_actions: vec![
"CheckStatus".to_string(),
"ReadLogs".to_string(),
"AnalyzeMetrics".to_string(),
"Diagnose".to_string(),
"Restart".to_string(),
],
system_events: system_events::DEFAULT_SYSTEM_EVENTS
.iter()
.map(|s| s.to_string())
.collect(),
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_min_actions(mut self, min: usize) -> Self {
self.min_actions = min;
self
}
pub fn with_available_actions(mut self, actions: Vec<String>) -> Self {
self.available_actions = actions;
self
}
pub fn with_system_event(mut self, event: impl Into<String>) -> Self {
self.system_events.push(event.into());
self
}
fn is_system_event(&self, action: &str) -> bool {
self.system_events.iter().any(|e| e == action)
}
}
impl Default for WorkerDecisionSequenceLearn {
fn default() -> Self {
Self::new()
}
}
impl LearnModel for WorkerDecisionSequenceLearn {
fn name(&self) -> &str {
"worker_decision_sequence"
}
fn objective(&self) -> &str {
"Learn successful action sequences for problem resolution"
}
fn evaluate(&self, context: &EpisodeContext) -> Outcome {
if context.is_empty() {
return Outcome::failure("Empty context: no actions to evaluate");
}
let successful_actions: Vec<_> = context
.iter::<ActionRecord>()
.filter(|a| a.success && !self.is_system_event(&a.action))
.collect();
if successful_actions.len() >= self.min_actions {
Outcome::success(1.0)
} else {
Outcome::failure(format!(
"Insufficient successful actions: {} < {}",
successful_actions.len(),
self.min_actions
))
}
}
fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
let mut task_actions: HashMap<TaskId, Vec<&ActionRecord>> = HashMap::new();
for record in records.iter().filter_map(Record::as_action) {
if record.success && !self.is_system_event(&record.action) {
task_actions.entry(record.task_id).or_default().push(record);
}
}
let mut episodes = Vec::new();
for (task_id, successful_actions) in task_actions {
if successful_actions.len() < self.min_actions {
continue;
}
let group_id = successful_actions.first().and_then(|r| r.group_id);
let mut context = EpisodeContext::new();
for action in &successful_actions {
context.push((*action).clone());
}
let outcome = self.evaluate(&context);
let mut builder = Episode::builder()
.learn_model("worker_decision_sequence")
.task_id(task_id)
.context(context)
.outcome(outcome);
if let Some(gid) = group_id {
builder = builder.group_id(gid);
}
episodes.push(builder.build());
}
episodes
}
fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
if !episode.outcome.is_success() {
return Err(LearnError::InvalidEpisode(
"Episode is not successful".into(),
));
}
let actions: Vec<&str> = episode
.context
.iter::<ActionRecord>()
.map(|a| a.action.as_str())
.collect();
if actions.len() < self.min_actions {
return Err(LearnError::InvalidEpisode(format!(
"Too few actions: {} < {}",
actions.len(),
self.min_actions
)));
}
let available = self.available_actions.join(", ");
let prompt = format!(
"Current context: default\n\
Available actions: {}\n\n\
What is the best sequence of actions to resolve this issue?",
available
);
let action_sequence = actions.join(" -> ");
let response = format!(
"Based on the context, the optimal action sequence is: {}",
action_sequence
);
Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
.with_episode_id(episode.id.to_string())
.with_outcome_score(episode.outcome.score()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
use crate::types::WorkerId;
use std::time::Duration;
fn make_action_with_task(
tick: u64,
worker_id: usize,
action: &str,
success: bool,
task_id: TaskId,
) -> ActionEvent {
let result = if success {
ActionEventResult::success()
} else {
ActionEventResult::failure("error")
};
ActionEventBuilder::new(tick, WorkerId(worker_id), action)
.task_id(task_id)
.result(result)
.duration(Duration::from_millis(10))
.context(ActionContext::new())
.build()
}
fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
actions.iter().map(Record::from).collect()
}
#[test]
fn test_worker_decision_sequence_learn_build_episodes() {
let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
let task_id = TaskId::new();
let actions = vec![
make_action_with_task(1, 0, "CheckStatus", true, task_id),
make_action_with_task(2, 0, "ReadLogs", true, task_id),
make_action_with_task(3, 0, "Restart", true, task_id),
make_action_with_task(4, 0, "tick_end", true, task_id), make_action_with_task(5, 0, "done", true, task_id), ];
let records = make_records(&actions);
let episodes = learn.build_episodes(&records);
assert_eq!(episodes.len(), 1);
assert!(episodes[0].outcome.is_success());
assert_eq!(episodes[0].task_id, Some(task_id));
let action_count = episodes[0].context.iter::<ActionRecord>().count();
assert_eq!(action_count, 3);
}
#[test]
fn test_worker_decision_sequence_learn_filters_failed_actions() {
let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
let task_id = TaskId::new();
let actions = vec![
make_action_with_task(1, 0, "CheckStatus", true, task_id),
make_action_with_task(2, 0, "ReadLogs", false, task_id), make_action_with_task(3, 0, "Restart", true, task_id),
];
let records = make_records(&actions);
let episodes = learn.build_episodes(&records);
assert_eq!(episodes.len(), 0);
}
#[test]
fn test_worker_decision_sequence_learn_convert() {
let learn = WorkerDecisionSequenceLearn::new().with_available_actions(vec![
"A".to_string(),
"B".to_string(),
"C".to_string(),
]);
let episode = Episode::builder()
.learn_model("worker_decision_sequence")
.record(ActionRecord::new(1, 0, "A").success(true))
.record(ActionRecord::new(2, 0, "B").success(true))
.record(ActionRecord::new(3, 0, "C").success(true))
.outcome(Outcome::success(1.0))
.build();
let result = learn.convert(&episode);
assert!(result.is_ok());
let training_data = result.unwrap();
assert!(training_data.prompt.contains("Available actions: A, B, C"));
assert!(training_data.chosen.contains("A -> B -> C"));
}
}