use std::fs::OpenOptions;
use std::io::{BufRead, BufReader};
use arrrg::CommandLine;
use rand::prelude::*;
use policyai::data::{ConflictField, InjectableAction};
use policyai::{Field, OnConflict, PolicyType};
#[derive(Clone, Default, Debug, arrrg_derive::CommandLine)]
struct Options {
#[arrrg(required, "The decidable semantic injections.")]
decidables: String,
#[arrrg(required, "The actions.")]
actions: String,
#[arrrg(required, "This many texts will be selected to have policies applied.")]
samples: usize,
#[arrrg(required, "The policy type definition.")]
policy: String,
#[arrrg(required, "This many policies will be selected per text.")]
policies: usize,
#[arrrg(required, "This many policies will be enforced to match per text.")]
matching: usize,
#[arrrg(
optional,
"Rate of test cases that should contain conflicts (0.0 to 1.0)."
)]
conflict_rate: Option<f64>,
}
impl Eq for Options {}
impl PartialEq for Options {
fn eq(&self, other: &Self) -> bool {
self.decidables == other.decidables
&& self.actions == other.actions
&& self.samples == other.samples
&& self.policy == other.policy
&& self.policies == other.policies
&& self.matching == other.matching
&& match (self.conflict_rate, other.conflict_rate) {
(None, None) => true,
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
_ => false,
}
}
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
let (options, free) =
Options::from_command_line_relaxed("USAGE: policyai-generate-test-data [OPTIONS]");
if !free.is_empty() {
eprintln!("command takes no positional arguments");
std::process::exit(13);
}
let semantic_injections_file = BufReader::new(
OpenOptions::new()
.read(true)
.open(&options.decidables)
.unwrap(),
);
let mut semantic_injections = vec![];
for line in semantic_injections_file.lines() {
let line = line?;
let injection: policyai::data::DecidableSemanticInjection = serde_json::from_str(&line)?;
semantic_injections.push(injection);
}
let actions_file = BufReader::new(
OpenOptions::new()
.read(true)
.open(&options.actions)
.unwrap(),
);
let mut actions = vec![];
for line in actions_file.lines() {
let line = line?;
let action: policyai::data::InjectableAction = serde_json::from_str(&line)?;
actions.push(action);
}
let policy_type =
PolicyType::parse(&std::fs::read_to_string(&options.policy).unwrap()).unwrap();
let conflict_rate = options.conflict_rate.unwrap_or(0.0);
let mut rng = rand::rng();
for _ in 0..options.samples {
let injection = semantic_injections.choose(&mut rng).unwrap();
assert!(injection.positives.len() >= options.matching);
assert!(injection.negatives.len() >= options.policies - options.matching);
let should_generate_conflicts = rng.random_bool(conflict_rate);
if should_generate_conflicts {
let agreement_fields: Vec<&Field> = policy_type
.fields
.iter()
.filter(|field| {
matches!(
field,
Field::Bool {
on_conflict: OnConflict::Agreement,
..
} | Field::String {
on_conflict: OnConflict::Agreement,
..
} | Field::StringEnum {
on_conflict: OnConflict::Agreement,
..
} | Field::Number {
on_conflict: OnConflict::Agreement,
..
}
)
})
.collect();
if agreement_fields.is_empty() {
eprintln!("no fields require agreement; cannot generate conflicts");
std::process::exit(13);
} else {
generate_conflict_test_case(
&mut rng,
injection,
&actions,
&policy_type,
&options,
&agreement_fields,
);
}
} else {
generate_normal_test_case(&mut rng, injection, &actions, &policy_type, &options);
}
}
Ok(())
}
fn generate_normal_test_case(
rng: &mut impl Rng,
injection: &policyai::data::DecidableSemanticInjection,
actions: &[InjectableAction],
policy_type: &PolicyType,
options: &Options,
) {
let mut policies = vec![];
let mut expected = serde_json::json! {{}};
while policies.len() < options.policies {
let (prompt, inject, action) = if policies.len() < options.matching {
let Some(prompt) = injection.positives.choose(rng) else {
continue;
};
let prompt = prompt.clone();
fn is_compatible_action(output: &serde_json::Value, action: &InjectableAction) -> bool {
let serde_json::Value::Object(obj) = &action.action else {
return false;
};
for (k, v) in obj.iter() {
let see = output.get(k);
if see.is_some() && see != Some(v) {
return false;
}
}
true
}
fn pick_compatible_action<'a>(
rng: &mut impl Rng,
output: &serde_json::Value,
actions: &'a [InjectableAction],
) -> &'a InjectableAction {
loop {
let Some(action) = actions.choose(rng) else {
continue;
};
if is_compatible_action(output, action) {
return action;
}
}
}
let action = pick_compatible_action(rng, &expected, actions);
let serde_json::Value::Object(obj) = &action.action else {
continue;
};
for (k, v) in obj {
expected[k] = v.clone();
}
(prompt, &action.inject, &action.action)
} else {
let Some(prompt) = injection.negatives.choose(rng) else {
continue;
};
let prompt = prompt.clone();
let Some(action) = actions.choose(rng) else {
continue;
};
(prompt, &action.inject, &action.action)
};
let policy = policyai::Policy {
r#type: policy_type.clone(),
prompt: format!("<match>{prompt}</match><action>{inject}</action>"),
action: action.clone(),
};
policies.push(policy);
}
if let Some(map) = policy_type.default_value().as_object() {
if let Some(as_obj) = expected.as_object_mut() {
for (k, v) in map.iter() {
as_obj.entry(k.clone()).or_insert_with(|| v.clone());
}
}
}
policies.shuffle(rng);
println!(
"{}",
serde_json::to_string(&policyai::data::TestDataPoint {
text: injection.text.clone(),
policies,
expected: Some(expected),
conflicts: None,
})
.unwrap()
);
}
fn generate_conflict_test_case(
rng: &mut impl Rng,
injection: &policyai::data::DecidableSemanticInjection,
actions: &[InjectableAction],
policy_type: &PolicyType,
options: &Options,
agreement_fields: &[&Field],
) {
let mut policies = vec![];
let mut conflicts = vec![];
let mut expected = serde_json::json! {{}};
for field in policy_type.fields.iter() {
let name = field.name();
let default = field.default_value();
if !default.is_null() {
expected[name] = default;
}
}
let num_conflict_fields = rng.random_range(1..=agreement_fields.len());
let selected_fields: Vec<_> = agreement_fields
.choose_multiple(rng, num_conflict_fields)
.cloned()
.collect();
for field in selected_fields {
let field_name = field.name();
conflicts.push(ConflictField {
conflict_type: "agreement".to_string(),
field_name: field_name.to_string(),
});
let num_values = rng.random_range(2..=4.min(options.matching));
let mut seen_values = std::collections::HashSet::<String>::new();
for _ in 0..num_values {
if policies.len() >= options.policies {
break;
}
let action = loop {
let candidate = actions.choose(rng).unwrap();
if let serde_json::Value::Object(obj) = &candidate.action {
if let Some(val) = obj.get(field_name) {
let val_str = serde_json::to_string(val).unwrap();
if !seen_values.contains(&val_str) {
seen_values.insert(val_str);
break candidate;
}
}
}
};
let prompt = injection.positives.choose(rng).unwrap();
let prompt = prompt.clone();
let inject = &action.inject;
let policy = policyai::Policy {
r#type: policy_type.clone(),
prompt: format!("<match>{prompt}</match><action>{inject}</action>"),
action: action.action.clone(),
};
policies.push(policy);
}
}
while policies.len() < options.policies {
let prompt = injection.negatives.choose(rng).unwrap();
let mut prompt = prompt.clone();
let action = actions.choose(rng).unwrap();
prompt += " ";
prompt += &action.inject;
let policy = policyai::Policy {
r#type: policy_type.clone(),
prompt,
action: action.action.clone(),
};
policies.push(policy);
}
policies.shuffle(rng);
println!(
"{}",
serde_json::to_string(&policyai::data::TestDataPoint {
text: injection.text.clone(),
policies,
expected: Some(expected),
conflicts: Some(conflicts),
})
.unwrap()
);
}