use serde::{Deserialize, Serialize};
use serde_valid::Validate;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Dataset {
#[validate(min_length = 1)]
pub name: String,
pub description: Option<String>,
pub version: String,
#[validate(min_items = 1)]
pub test_cases: Vec<TestCase>,
pub defaults: Option<DefaultConfig>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct TestCase {
#[validate(min_length = 1)]
pub id: String,
pub category: Option<String>,
#[validate(min_length = 1)]
pub prompt: String,
pub variables: Option<HashMap<String, String>>,
pub expected: Option<String>,
pub references: Option<Vec<String>>,
pub config: Option<TestConfig>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DefaultConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl Dataset {
pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
version: version.into(),
test_cases: Vec::new(),
defaults: None,
metadata: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn add_test_case(&mut self, test_case: TestCase) {
self.test_cases.push(test_case);
}
pub fn with_defaults(mut self, defaults: DefaultConfig) -> Self {
self.defaults = Some(defaults);
self
}
pub fn filter_by_category(&self, category: &str) -> Vec<&TestCase> {
self.test_cases
.iter()
.filter(|tc| tc.category.as_deref() == Some(category))
.collect()
}
pub fn len(&self) -> usize {
self.test_cases.len()
}
pub fn is_empty(&self) -> bool {
self.test_cases.is_empty()
}
}
impl TestCase {
pub fn new(id: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
id: id.into(),
category: None,
prompt: prompt.into(),
variables: None,
expected: None,
references: None,
config: None,
metadata: None,
}
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_variables(mut self, variables: HashMap<String, String>) -> Self {
self.variables = Some(variables);
self
}
pub fn add_variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.variables
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn with_expected(mut self, expected: impl Into<String>) -> Self {
self.expected = Some(expected.into());
self
}
pub fn with_references(mut self, references: Vec<String>) -> Self {
self.references = Some(references);
self
}
pub fn with_config(mut self, config: TestConfig) -> Self {
self.config = Some(config);
self
}
}
impl DefaultConfig {
pub fn new() -> Self {
Self {
temperature: None,
max_tokens: None,
top_p: None,
stop: None,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
}
impl Default for DefaultConfig {
fn default() -> Self {
Self::new()
}
}
impl TestConfig {
pub fn new() -> Self {
Self {
model: None,
temperature: None,
max_tokens: None,
top_p: None,
stop: None,
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
}
impl Default for TestConfig {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_creation() {
let dataset = Dataset::new("test-dataset", "1.0.0")
.with_description("Test description");
assert_eq!(dataset.name, "test-dataset");
assert_eq!(dataset.version, "1.0.0");
assert_eq!(dataset.description, Some("Test description".to_string()));
assert_eq!(dataset.test_cases.len(), 0);
}
#[test]
fn test_dataset_validation_empty_name() {
let dataset = Dataset {
name: "".to_string(), description: None,
version: "1.0.0".to_string(),
test_cases: vec![TestCase::new("test-1", "prompt")],
defaults: None,
metadata: None,
};
assert!(dataset.validate().is_err());
}
#[test]
fn test_dataset_validation_no_test_cases() {
let dataset = Dataset {
name: "test".to_string(),
description: None,
version: "1.0.0".to_string(),
test_cases: vec![], defaults: None,
metadata: None,
};
assert!(dataset.validate().is_err());
}
#[test]
fn test_dataset_validation_valid() {
let dataset = Dataset {
name: "test".to_string(),
description: None,
version: "1.0.0".to_string(),
test_cases: vec![TestCase::new("test-1", "prompt")],
defaults: None,
metadata: None,
};
assert!(dataset.validate().is_ok());
}
#[test]
fn test_test_case_creation() {
let test = TestCase::new("test-1", "What is Rust?")
.with_category("qa")
.with_expected("Rust is a systems programming language");
assert_eq!(test.id, "test-1");
assert_eq!(test.prompt, "What is Rust?");
assert_eq!(test.category, Some("qa".to_string()));
assert!(test.expected.is_some());
}
#[test]
fn test_test_case_with_variables() {
let test = TestCase::new("test-1", "Explain {{topic}}")
.add_variable("topic", "ownership");
assert_eq!(test.variables.as_ref().unwrap().get("topic").unwrap(), "ownership");
}
#[test]
fn test_filter_by_category() {
let mut dataset = Dataset::new("test", "1.0.0");
dataset.add_test_case(TestCase::new("t1", "prompt1").with_category("coding"));
dataset.add_test_case(TestCase::new("t2", "prompt2").with_category("qa"));
dataset.add_test_case(TestCase::new("t3", "prompt3").with_category("coding"));
let coding_tests = dataset.filter_by_category("coding");
assert_eq!(coding_tests.len(), 2);
}
#[test]
fn test_default_config() {
let config = DefaultConfig::new()
.with_temperature(0.7)
.with_max_tokens(500);
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_tokens, Some(500));
}
#[test]
fn test_test_config() {
let config = TestConfig::new()
.with_model("gpt-4")
.with_temperature(0.0);
assert_eq!(config.model, Some("gpt-4".to_string()));
assert_eq!(config.temperature, Some(0.0));
}
}