use super::{LearnError, LearnModel};
use crate::learn::episode::{EpisodeContext, Outcome};
use crate::learn::record::DependencyGraphRecord;
use crate::learn::training::TrainingData;
pub struct DependencyGraphLearnModel {
pub system_prompt: String,
}
impl Default for DependencyGraphLearnModel {
fn default() -> Self {
Self {
system_prompt: Self::default_system_prompt().to_string(),
}
}
}
impl DependencyGraphLearnModel {
pub fn new() -> Self {
Self::default()
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn default_system_prompt() -> &'static str {
r#"You are an expert at analyzing action dependencies for task execution.
Given a list of available actions with their descriptions, determine the correct execution order.
Actions are categorized as:
- NodeExpand (discover): Actions that discover new targets (e.g., Search, List)
- NodeStateChange (not_discover): Actions that modify state (e.g., Read, Submit)
Output the sorted order for each category based on:
1. Dependency relationships (e.g., "requires X first" in descriptions)
2. Logical execution flow (start actions first, terminal actions last)
Format your response as:
discover_order: [action1, action2, ...]
not_discover_order: [action1, action2, ...]"#
}
pub fn extractor() -> impl Fn(&crate::learn::episode::Episode) -> Option<(String, String)> {
|episode| {
episode
.context
.first::<DependencyGraphRecord>()
.map(|record| (record.prompt.clone(), record.response.clone()))
}
}
}
impl LearnModel for DependencyGraphLearnModel {
fn name(&self) -> &str {
"dependency_graph"
}
fn objective(&self) -> &str {
"Learn correct DependencyGraph inference from action descriptions"
}
fn build_episodes(
&self,
records: &[crate::learn::record::Record],
) -> Vec<crate::learn::episode::Episode> {
use crate::learn::episode::Episode;
use crate::learn::record::RecordStream;
let stream = RecordStream::new(records);
let mut episodes = Vec::new();
for record in stream.dependency_graphs() {
let context = EpisodeContext::new().with_record(record.clone());
let outcome = self.evaluate(&context);
let episode = Episode::builder()
.learn_model("dependency_graph")
.context(context)
.outcome(outcome)
.build();
episodes.push(episode);
}
episodes
}
fn evaluate(&self, context: &EpisodeContext) -> Outcome {
let records: Vec<_> = context.iter::<DependencyGraphRecord>().collect();
if records.is_empty() {
return Outcome::Unknown;
}
let all_success = records.iter().all(|r| r.is_success());
if all_success {
Outcome::success_binary()
} else {
Outcome::failure("DependencyGraph inference failed")
}
}
fn convert(
&self,
episode: &crate::learn::episode::Episode,
) -> Result<TrainingData, LearnError> {
let record = episode
.context
.first::<DependencyGraphRecord>()
.ok_or_else(|| LearnError::MissingData("DependencyGraphRecord".into()))?;
let scenario = episode
.metadata
.scenario_name
.as_deref()
.unwrap_or("unknown");
Ok(
TrainingData::sft(&self.system_prompt, &record.prompt, &record.response)
.with_scenario(scenario)
.with_episode_id(episode.id.to_string()),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_system_prompt() {
let model = DependencyGraphLearnModel::new();
assert!(model.system_prompt.contains("action dependencies"));
}
#[test]
fn test_extractor() {
use crate::learn::episode::{Episode, EpisodeContext};
let record = DependencyGraphRecord::new("test-model")
.prompt("test prompt")
.response("test response")
.discover_order(vec!["A".into()]);
let context = EpisodeContext::new().with_record(record);
let episode = Episode::builder()
.learn_model("dependency_graph")
.context(context)
.build();
let extractor = DependencyGraphLearnModel::extractor();
let result = extractor(&episode);
assert!(result.is_some());
let (prompt, response) = result.unwrap();
assert_eq!(prompt, "test prompt");
assert_eq!(response, "test response");
}
#[test]
fn test_build_episodes_success() {
use crate::learn::record::Record;
let model = DependencyGraphLearnModel::new();
let success_record = DependencyGraphRecord::new("test-model")
.prompt("What is the dependency order?")
.response("discover_order: [A, B]")
.discover_order(vec!["A".into(), "B".into()]);
let records = vec![Record::from(success_record)];
let episodes = model.build_episodes(&records);
assert_eq!(episodes.len(), 1);
assert!(episodes[0].outcome.is_success());
}
#[test]
fn test_build_episodes_failure() {
use crate::learn::record::Record;
let model = DependencyGraphLearnModel::new();
let failure_record = DependencyGraphRecord::new("test-model")
.prompt("What is the dependency order?")
.response("parse error")
.error("Failed to parse response");
let records = vec![Record::from(failure_record)];
let episodes = model.build_episodes(&records);
assert_eq!(episodes.len(), 1);
assert!(episodes[0].outcome.is_failure());
}
#[test]
fn test_build_episodes_mixed() {
use crate::learn::record::Record;
let model = DependencyGraphLearnModel::new();
let success = DependencyGraphRecord::new("model").discover_order(vec!["A".into()]);
let failure = DependencyGraphRecord::new("model").error("error");
let records = vec![Record::from(success), Record::from(failure)];
let episodes = model.build_episodes(&records);
assert_eq!(episodes.len(), 2);
let success_count = episodes.iter().filter(|e| e.outcome.is_success()).count();
let failure_count = episodes.iter().filter(|e| e.outcome.is_failure()).count();
assert_eq!(success_count, 1);
assert_eq!(failure_count, 1);
}
#[test]
fn test_build_episodes_ignores_other_records() {
use crate::learn::record::{ActionRecord, Record};
let model = DependencyGraphLearnModel::new();
let action = ActionRecord::new(1, 0, "CheckStatus").success(true);
let dep_graph = DependencyGraphRecord::new("model").discover_order(vec!["A".into()]);
let records = vec![Record::from(action), Record::from(dep_graph)];
let episodes = model.build_episodes(&records);
assert_eq!(episodes.len(), 1);
}
}