use crate::models::ModelId;
use crate::scenario::{Backend, Format, Modality, QaScenario, TraceLevel};
use proptest::prelude::*;
pub fn model_id_strategy() -> impl Strategy<Value = ModelId> {
(
prop::sample::select(vec![
"Qwen",
"meta-llama",
"microsoft",
"google",
"mistralai",
"deepseek-ai",
"TinyLlama",
]),
prop::sample::select(vec![
"Qwen2.5-Coder-1.5B-Instruct",
"Llama-3.2-1B-Instruct",
"Phi-3-mini-4k-instruct",
"gemma-2b-it",
"Mistral-7B-Instruct-v0.3",
"deepseek-coder-1.3b-instruct",
"TinyLlama-1.1B-Chat-v1.0",
]),
)
.prop_map(|(org, name)| ModelId::new(org, name))
}
pub fn modality_strategy() -> impl Strategy<Value = Modality> {
prop_oneof![
Just(Modality::Run),
Just(Modality::Chat),
Just(Modality::Serve),
]
}
pub fn backend_strategy() -> impl Strategy<Value = Backend> {
prop_oneof![Just(Backend::Cpu), Just(Backend::Gpu),]
}
pub fn format_strategy() -> impl Strategy<Value = Format> {
prop_oneof![
Just(Format::Gguf),
Just(Format::SafeTensors),
Just(Format::Apr),
]
}
pub fn trace_level_strategy() -> impl Strategy<Value = TraceLevel> {
prop_oneof![
Just(TraceLevel::None),
Just(TraceLevel::Basic),
Just(TraceLevel::Layer),
Just(TraceLevel::Payload),
]
}
pub fn arithmetic_prompt_strategy() -> impl Strategy<Value = String> {
(
1i32..100,
1i32..100,
prop::sample::select(vec!['+', '-', '*']),
)
.prop_map(|(a, b, op)| format!("What is {a}{op}{b}?"))
}
pub fn code_prompt_strategy() -> impl Strategy<Value = String> {
prop_oneof![
Just("def fibonacci(n):".to_string()),
Just("fn main() {".to_string()),
Just("async function fetch() {".to_string()),
Just("class Person:".to_string()),
Just("impl Iterator for".to_string()),
Just("pub struct Config {".to_string()),
]
}
pub fn edge_case_prompt_strategy() -> impl Strategy<Value = String> {
prop_oneof![
Just(String::new()), Just(" ".to_string()), Just("\n\n\n".to_string()), Just("你好世界".to_string()), Just("مرحبا بالعالم".to_string()), Just("🎉🚀💻".to_string()), Just("a".repeat(10000)), Just("<script>alert('xss')</script>".to_string()), Just("'; DROP TABLE users; --".to_string()), ]
}
pub fn any_prompt_strategy() -> impl Strategy<Value = String> {
prop_oneof![
3 => arithmetic_prompt_strategy(),
2 => code_prompt_strategy(),
1 => edge_case_prompt_strategy(),
2 => "[a-zA-Z0-9 ]{1,100}".prop_map(|s| s),
]
}
pub fn scenario_strategy() -> impl Strategy<Value = QaScenario> {
(
model_id_strategy(),
modality_strategy(),
backend_strategy(),
format_strategy(),
any_prompt_strategy(),
0u64..1000,
)
.prop_map(|(model, modality, backend, format, prompt, seed)| {
QaScenario::new(model, modality, backend, format, prompt, seed)
})
}
pub fn scenario_for_model_strategy(model: ModelId) -> impl Strategy<Value = QaScenario> {
(
modality_strategy(),
backend_strategy(),
format_strategy(),
any_prompt_strategy(),
0u64..1000,
)
.prop_map(move |(modality, backend, format, prompt, seed)| {
QaScenario::new(model.clone(), modality, backend, format, prompt, seed)
})
}
pub fn temperature_strategy() -> impl Strategy<Value = f32> {
prop_oneof![
Just(0.0), Just(0.7), Just(1.0), 0.0f32..2.0, ]
}
pub fn max_tokens_strategy() -> impl Strategy<Value = u32> {
prop_oneof![
Just(1u32), Just(32), Just(128), Just(512), Just(2048), 1u32..4096, ]
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::strategy::ValueTree;
use proptest::test_runner::TestRunner;
#[test]
fn test_model_id_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..100 {
let model = model_id_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(!model.org.is_empty());
assert!(!model.name.is_empty());
}
}
#[test]
fn test_scenario_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..100 {
let scenario = scenario_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(!scenario.id.is_empty());
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 100,
failure_persistence: Some(Box::new(proptest::test_runner::FileFailurePersistence::WithSource("proptest-regressions"))),
..ProptestConfig::default()
})]
#[test]
fn prop_arithmetic_prompts_contain_operator(prompt in arithmetic_prompt_strategy()) {
prop_assert!(
prompt.contains('+') || prompt.contains('-') || prompt.contains('*'),
"Prompt should contain arithmetic operator: {}", prompt
);
}
#[test]
fn prop_scenarios_have_valid_id(scenario in scenario_strategy()) {
prop_assert!(!scenario.id.is_empty());
prop_assert!(scenario.id.contains('_'));
}
#[test]
fn prop_temperature_in_range(temp in temperature_strategy()) {
prop_assert!(temp >= 0.0);
prop_assert!(temp <= 2.0);
}
#[test]
fn prop_max_tokens_positive(tokens in max_tokens_strategy()) {
prop_assert!(tokens >= 1);
prop_assert!(tokens <= 4096);
}
#[test]
fn prop_scenario_command_is_valid(scenario in scenario_strategy()) {
let cmd = scenario.to_command("model.gguf");
prop_assert!(!cmd.is_empty());
prop_assert!(cmd.contains("apr"));
}
}
#[test]
fn test_modality_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let modality = modality_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(matches!(
modality,
Modality::Run | Modality::Chat | Modality::Serve
));
}
}
#[test]
fn test_backend_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let backend = backend_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(matches!(backend, Backend::Cpu | Backend::Gpu));
}
}
#[test]
fn test_format_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let format = format_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(matches!(
format,
Format::Gguf | Format::SafeTensors | Format::Apr
));
}
}
#[test]
fn test_trace_level_strategy_generates_valid() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let level = trace_level_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(matches!(
level,
TraceLevel::None | TraceLevel::Basic | TraceLevel::Layer | TraceLevel::Payload
));
}
}
#[test]
fn test_code_prompt_strategy_generates_code() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let prompt = code_prompt_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(
prompt.contains("def ")
|| prompt.contains("fn ")
|| prompt.contains("async ")
|| prompt.contains("class ")
|| prompt.contains("impl ")
|| prompt.contains("pub ")
);
}
}
#[test]
fn test_edge_case_prompt_strategy_generates_edge_cases() {
let mut runner = TestRunner::default();
let mut seen_empty = false;
let mut seen_unicode = false;
let mut seen_long = false;
for _ in 0..100 {
let prompt = edge_case_prompt_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
if prompt.is_empty() || prompt.trim().is_empty() {
seen_empty = true;
}
if prompt.contains("你好") || prompt.contains("مرحبا") || prompt.contains("🎉")
{
seen_unicode = true;
}
if prompt.len() > 1000 {
seen_long = true;
}
}
assert!(seen_empty || seen_unicode || seen_long);
}
#[test]
fn test_scenario_for_model_strategy() {
let model = ModelId::new("test", "model");
let mut runner = TestRunner::default();
for _ in 0..50 {
let scenario = scenario_for_model_strategy(model.clone())
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert_eq!(scenario.model.org, "test");
assert_eq!(scenario.model.name, "model");
}
}
#[test]
fn test_temperature_strategy_generates_valid_range() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let temp = temperature_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(temp >= 0.0);
assert!(temp <= 2.0);
}
}
#[test]
fn test_max_tokens_strategy_generates_valid_range() {
let mut runner = TestRunner::default();
for _ in 0..50 {
let tokens = max_tokens_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
assert!(tokens >= 1);
assert!(tokens <= 4096);
}
}
#[test]
fn test_any_prompt_strategy_generates_variety() {
let mut runner = TestRunner::default();
let mut has_arithmetic = false;
let mut has_code = false;
let mut has_other = false;
for _ in 0..100 {
let prompt = any_prompt_strategy()
.new_tree(&mut runner)
.expect("Failed to generate")
.current();
if prompt.contains('+') || prompt.contains('-') || prompt.contains('*') {
has_arithmetic = true;
} else if prompt.contains("def ") || prompt.contains("fn ") || prompt.contains("class ")
{
has_code = true;
} else {
has_other = true;
}
}
assert!(has_arithmetic || has_code || has_other);
}
}