tower_llm/validation/
gen.rs1use async_openai::types::*;
4use proptest::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct GeneratorConfig {
8 pub min_messages: usize,
9 pub max_messages: usize,
10 pub must_have_system: bool,
11 pub must_have_tool_calls: bool,
12 pub min_tool_calls: usize,
13 pub allow_developer: bool,
14 pub enforce_tool_order: bool,
15}
16
17impl Default for GeneratorConfig {
18 fn default() -> Self {
19 Self {
20 min_messages: 2,
21 max_messages: 12,
22 must_have_system: true,
23 must_have_tool_calls: false,
24 min_tool_calls: 0,
25 allow_developer: false,
26 enforce_tool_order: true,
27 }
28 }
29}
30
31pub fn valid_conversation(
33 cfg: GeneratorConfig,
34) -> impl Strategy<Value = Vec<ChatCompletionRequestMessage>> {
35 let turn = (any::<bool>(), any::<bool>()); let turns = proptest::collection::vec(turn, 1..=3);
38 let sys_flag = proptest::strategy::Just(cfg.must_have_system);
39 (sys_flag, turns).prop_map(move |(must_sys, turns)| {
40 let mut msgs: Vec<ChatCompletionRequestMessage> = Vec::new();
41 if must_sys {
42 let sys = ChatCompletionRequestSystemMessageArgs::default()
43 .content("sys")
44 .build()
45 .unwrap();
46 msgs.push(sys.into());
47 }
48 let usr = ChatCompletionRequestUserMessageArgs::default()
50 .content("hi")
51 .build()
52 .unwrap();
53 msgs.push(usr.into());
54
55 let mut tool_id_counter = 1usize;
56 let last_turn = turns.len().saturating_sub(1);
57 for (idx, (with_tools, with_text)) in turns.into_iter().enumerate() {
58 if with_tools || cfg.must_have_tool_calls {
60 let min_calls = cfg.min_tool_calls.max(1);
61 let num_calls = std::cmp::max(min_calls, 1);
62 let mut calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
63 for _ in 0..num_calls {
64 let id = format!("c{}", tool_id_counter);
65 tool_id_counter += 1;
66 calls.push(ChatCompletionMessageToolCall {
67 id: id.clone(),
68 r#type: ChatCompletionToolType::Function,
69 function: FunctionCall {
70 name: "tool".into(),
71 arguments: "{}".into(),
72 },
73 });
74 }
75 let asst = ChatCompletionRequestAssistantMessageArgs::default()
76 .content("")
77 .tool_calls(calls.clone())
78 .build()
79 .unwrap();
80 msgs.push(asst.into());
81 for tc in calls.into_iter() {
83 let t = ChatCompletionRequestToolMessageArgs::default()
84 .tool_call_id(tc.id)
85 .content("{}")
86 .build()
87 .unwrap();
88 msgs.push(t.into());
89 }
90 } else {
91 let content = if with_text { "ok" } else { "" };
92 let asst = ChatCompletionRequestAssistantMessageArgs::default()
93 .content(content)
94 .build()
95 .unwrap();
96 msgs.push(asst.into());
97 }
98 if idx != last_turn {
100 let u = ChatCompletionRequestUserMessageArgs::default()
101 .content("next")
102 .build()
103 .unwrap();
104 msgs.push(u.into());
105 }
106 }
107 msgs
108 })
109}