use std::collections::HashMap;
use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::{ActionRecord, Record, RecordStream};
use super::super::training::TrainingData;
use super::{LearnError, LearnModel};
use crate::types::TaskId;
pub struct WorkerTaskLearn {
system_prompt: String,
min_actions: usize,
}
impl WorkerTaskLearn {
pub fn new() -> Self {
Self {
system_prompt:
"You are an intelligent agent that diagnoses and resolves system issues."
.to_string(),
min_actions: 2,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_min_actions(mut self, min: usize) -> Self {
self.min_actions = min;
self
}
}
impl Default for WorkerTaskLearn {
fn default() -> Self {
Self::new()
}
}
impl LearnModel for WorkerTaskLearn {
fn name(&self) -> &str {
"worker_task"
}
fn objective(&self) -> &str {
"Learn complete worker task sequences from start to done"
}
fn evaluate(&self, context: &EpisodeContext) -> Outcome {
if context.is_empty() {
return Outcome::failure("Empty context: no actions to evaluate");
}
let last_action = context.iter::<ActionRecord>().last();
match last_action {
Some(action) if action.is_terminal() => {
if action.success {
Outcome::success_binary()
} else {
Outcome::failure("Task failed")
}
}
_ => Outcome::Unknown,
}
}
fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
let stream = RecordStream::new(records);
let mut task_actions: HashMap<TaskId, Vec<&ActionRecord>> = HashMap::new();
for record in stream.actions() {
task_actions.entry(record.task_id).or_default().push(record);
}
let mut episodes = Vec::new();
for (task_id, task_records) in task_actions {
let mut current_sequence: Vec<&ActionRecord> = Vec::new();
let group_id = task_records.first().and_then(|r| r.group_id);
for record in task_records {
current_sequence.push(record);
if record.is_terminal() {
if current_sequence.len() >= self.min_actions {
let mut context = EpisodeContext::new();
for r in ¤t_sequence {
context.push((*r).clone());
}
let outcome = self.evaluate(&context);
let mut builder = Episode::builder()
.learn_model("worker_task")
.task_id(task_id)
.context(context)
.outcome(outcome);
if let Some(gid) = group_id {
builder = builder.group_id(gid);
}
episodes.push(builder.build());
}
current_sequence.clear();
}
}
}
episodes
}
fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
if !episode.outcome.is_success() {
return Err(LearnError::InvalidEpisode(
"Episode is not successful".into(),
));
}
let action_count = episode.context.iter::<ActionRecord>().count();
if action_count < self.min_actions {
return Err(LearnError::InvalidEpisode(format!(
"Too few actions: {} < {}",
action_count, self.min_actions
)));
}
let actions: Vec<&str> = episode
.context
.iter::<ActionRecord>()
.filter(|a| !a.is_terminal())
.map(|a| a.action.as_str())
.collect();
let prompt = format!(
"Diagnose and resolve the issue.\nAvailable actions: {}",
actions.join(", ")
);
let response = format!("Execute the following sequence: {}", actions.join(" -> "));
Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
.with_episode_id(episode.id.to_string())
.with_outcome_score(episode.outcome.score()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
use crate::types::WorkerId;
use std::time::Duration;
fn make_action_with_task(
tick: u64,
worker_id: usize,
action: &str,
success: bool,
task_id: TaskId,
) -> ActionEvent {
let result = if success {
ActionEventResult::success()
} else {
ActionEventResult::failure("error")
};
ActionEventBuilder::new(tick, WorkerId(worker_id), action)
.task_id(task_id)
.result(result)
.duration(Duration::from_millis(10))
.context(ActionContext::new())
.build()
}
fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
actions.iter().map(Record::from).collect()
}
#[test]
fn test_worker_task_learn_build_episodes() {
let learn = WorkerTaskLearn::new().with_min_actions(2);
let task1_id = TaskId::new();
let task2_id = TaskId::new();
let actions = vec![
make_action_with_task(1, 0, "CheckStatus", true, task1_id),
make_action_with_task(2, 0, "ReadLogs", true, task1_id),
make_action_with_task(3, 0, "done", true, task1_id),
make_action_with_task(4, 1, "Grep", true, task2_id),
make_action_with_task(5, 1, "done", false, task2_id),
];
let records = make_records(&actions);
let episodes = learn.build_episodes(&records);
assert_eq!(episodes.len(), 2);
let task1_ep = episodes.iter().find(|ep| ep.task_id == Some(task1_id));
assert!(task1_ep.is_some());
assert!(task1_ep.unwrap().outcome.is_success());
let task2_ep = episodes.iter().find(|ep| ep.task_id == Some(task2_id));
assert!(task2_ep.is_some());
assert!(task2_ep.unwrap().outcome.is_failure());
}
#[test]
fn test_worker_task_learn_convert_success_only() {
let learn = WorkerTaskLearn::new();
let failed_ep = Episode::builder()
.learn_model("worker_task")
.outcome(Outcome::failure("test"))
.build();
assert!(learn.convert(&failed_ep).is_err());
let success_ep = Episode::builder()
.learn_model("worker_task")
.record(ActionRecord::new(1, 0, "Check").success(true))
.record(ActionRecord::new(2, 0, "Fix").success(true))
.record(ActionRecord::new(3, 0, "done").success(true))
.outcome(Outcome::success_binary())
.build();
assert!(learn.convert(&success_ep).is_ok());
}
}