use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Domain {
Retail,
Airline,
}
impl Domain {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Retail => "retail",
Self::Airline => "airline",
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Task {
pub id: String,
pub user_scenario: UserScenario,
pub evaluation_criteria: Option<EvaluationCriteria>,
#[serde(flatten)]
_rest: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct UserScenario {
pub instructions: UserInstructions,
#[serde(default)]
pub persona: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum UserInstructions {
Structured(StructuredUserInstructions),
Plain(String),
}
#[derive(Debug, Clone, Deserialize)]
pub struct StructuredUserInstructions {
pub domain: String,
pub reason_for_call: String,
pub task_instructions: String,
#[serde(default)]
pub known_info: Option<String>,
#[serde(default)]
pub unknown_info: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EvaluationCriteria {
#[serde(default)]
pub actions: Vec<Action>,
#[serde(default)]
pub reward_basis: Vec<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Action {
pub action_id: String,
#[serde(default = "default_requestor")]
pub requestor: String,
pub name: String,
#[serde(default)]
pub arguments: serde_json::Map<String, serde_json::Value>,
#[serde(default)]
pub compare_args: Option<Vec<String>>,
#[serde(default)]
pub info: Option<String>,
}
fn default_requestor() -> String {
"assistant".to_owned()
}
#[cfg(test)]
mod tests {
use super::*;
const RETAIL_FIXTURE: &str = r##"[
{
"id": "0",
"user_scenario": {
"instructions": {
"domain": "retail",
"reason_for_call": "I need to cancel an order",
"task_instructions": "Cancel order #W1234567",
"known_info": "Order id: #W1234567"
},
"persona": "Impatient customer"
},
"evaluation_criteria": {
"actions": [
{
"action_id": "a1",
"requestor": "assistant",
"name": "cancel_pending_order",
"arguments": {"order_id": "#W1234567", "reason": "no_longer_needed"},
"compare_args": ["order_id", "reason"]
}
],
"reward_basis": ["ACTION"]
}
},
{
"id": "1",
"user_scenario": {
"instructions": "Plain string instructions for a simple task"
},
"evaluation_criteria": {
"actions": [],
"reward_basis": ["ACTION"]
}
}
]"##;
#[test]
fn parse_structured_instructions() {
let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0].id, "0");
match &tasks[0].user_scenario.instructions {
UserInstructions::Structured(s) => {
assert_eq!(s.domain, "retail");
assert_eq!(s.reason_for_call, "I need to cancel an order");
assert!(s.known_info.is_some());
}
UserInstructions::Plain(_) => panic!("expected structured"),
}
}
#[test]
fn parse_plain_instructions() {
let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
match &tasks[1].user_scenario.instructions {
UserInstructions::Plain(s) => assert!(s.contains("Plain string")),
UserInstructions::Structured(_) => panic!("expected plain"),
}
}
#[test]
fn parse_evaluation_criteria() {
let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
assert_eq!(criteria.actions.len(), 1);
assert_eq!(criteria.actions[0].name, "cancel_pending_order");
assert_eq!(
criteria.actions[0].compare_args,
Some(vec!["order_id".to_owned(), "reason".to_owned()])
);
}
#[test]
fn metadata_roundtrip() {
let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
let value = serde_json::to_value(criteria).unwrap();
let back: EvaluationCriteria = serde_json::from_value(value).unwrap();
assert_eq!(back.actions.len(), 1);
assert_eq!(back.actions[0].name, "cancel_pending_order");
}
#[test]
fn domain_as_str() {
assert_eq!(Domain::Retail.as_str(), "retail");
assert_eq!(Domain::Airline.as_str(), "airline");
}
}