1use serde::Deserialize;
2use serde_yaml::Value;
3
4#[derive(Debug, Deserialize)]
5struct PromptType {
6 #[serde(rename = "type")]
7 prompt_type: String,
8}
9
10#[derive(Debug, Deserialize)]
11pub struct CompletionExampleColumn {
12 pub name: String,
13 pub values: Vec<String>,
14 pub test: Option<String>,
15}
16
17#[derive(Debug, Deserialize)]
18pub struct ChatExample {
19 pub input: String,
20 pub output: Option<String>,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct Message {
25 pub input: String,
26 pub output: Option<String>,
27}
28
29#[derive(Debug, Deserialize)]
30pub struct Parameter {
31 pub name: String,
32 pub value: Value,
33}
34
35#[derive(Debug, Deserialize)]
36pub struct Completion {
37 #[serde(rename = "type")]
38 pub prompt_type: String,
39 pub vendor: String,
40 pub model: String,
41 pub prompt: String,
42 pub parameters: Option<Vec<Parameter>>,
43 pub examples: Option<Vec<CompletionExampleColumn>>,
44}
45
46pub fn find_parameter(
47 parameters: &Option<Vec<crate::prompt::Parameter>>,
48 name: &str,
49) -> Option<Value> {
50 if let Some(real_parameters) = parameters {
51 real_parameters
52 .iter()
53 .find(|p| p.name == name)
54 .map(|x| x.value.clone())
55 } else {
56 None
57 }
58}
59
60impl Completion {
61 pub fn example_count(&self) -> usize {
62 let mut max_length = 0;
63 if let Some(columns) = &self.examples {
64 for column in columns {
65 if column.values.len() > max_length {
66 max_length = column.values.len()
67 }
68 }
69 }
70 max_length
71 }
72
73 pub fn final_prompt(&self) -> String {
74 let mut prompt = self.prompt.clone();
75 prompt.push_str("\n\n");
76 if let Some(columns) = &self.examples {
77 for i in 0..self.example_count() {
78 for column in columns {
79 let line: String = format!(
80 "{}: {}\n",
81 column.name,
82 column.values.get(i).unwrap_or(&"".to_string())
83 );
84 prompt.push_str(&line);
85 }
86 prompt.push('\n');
87 }
88 for column in columns {
89 let line: String = format!(
90 "{}: {}\n",
91 column.name,
92 column.test.as_ref().unwrap_or(&"".to_string())
93 );
94 prompt.push_str(&line);
95 }
96 }
97 prompt.to_string()
98 }
99
100 pub fn find_parameter_as_i32(&self, name: &str) -> Option<i32> {
101 find_parameter(&self.parameters, name).map(|p| p.as_i64().unwrap() as i32)
102 }
103
104 pub fn find_parameter_as_f32(&self, name: &str) -> Option<f32> {
105 find_parameter(&self.parameters, name).map(|p| p.as_f64().unwrap() as f32)
106 }
107
108 pub fn find_parameter_as_str(&self, name: &str) -> Option<String> {
109 find_parameter(&self.parameters, name).map(|p| p.as_str().unwrap().to_string())
110 }
111
112 pub fn find_parameter_as_bool(&self, name: &str) -> Option<bool> {
113 find_parameter(&self.parameters, name).map(|p| p.as_bool().unwrap())
114 }
115}
116
117#[derive(Debug, Deserialize)]
118pub struct Chat {
119 #[serde(rename = "type")]
120 pub prompt_type: String,
121 pub vendor: String,
122 pub model: String,
123 pub parameters: Option<Vec<Parameter>>,
124 pub examples: Option<Vec<ChatExample>>,
125 pub context: Option<String>,
126 pub messages: Option<Vec<Message>>,
127}
128
129impl Chat {
130 pub fn find_parameter_as_i32(&self, name: &str) -> Option<i32> {
131 find_parameter(&self.parameters, name).map(|p| p.as_i64().unwrap() as i32)
132 }
133
134 pub fn find_parameter_as_f32(&self, name: &str) -> Option<f32> {
135 find_parameter(&self.parameters, name).map(|p| p.as_f64().unwrap() as f32)
136 }
137
138 pub fn find_parameter_as_str(&self, name: &str) -> Option<String> {
139 find_parameter(&self.parameters, name).map(|p| p.as_str().unwrap().to_string())
140 }
141
142 pub fn find_parameter_as_bool(&self, name: &str) -> Option<bool> {
143 find_parameter(&self.parameters, name).map(|p| p.as_bool().unwrap())
144 }
145}
146
147#[derive(Debug, Deserialize)]
148pub enum Prompt {
149 Completion(Completion),
150 Chat(Chat),
151 Unknown,
152}
153
154pub fn deserialize_prompt(yaml: &str) -> Prompt {
155 let prompt_type: PromptType = serde_yaml::from_str(&yaml).unwrap();
156 match prompt_type.prompt_type.as_str() {
157 "completion" => {
158 let completion: Completion = serde_yaml::from_str(&yaml).unwrap();
159 Prompt::Completion(completion)
160 }
161 "chat" => {
162 let chat: Chat = serde_yaml::from_str(&yaml).unwrap();
163 Prompt::Chat(chat)
164 }
165 _ => Prompt::Unknown,
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_deserialize_prompt_completion() {
175 let yaml = r#"
176 type: completion
177 vendor: google
178 model: text-bison
179 prompt: Write a hello world in java
180 parameters:
181 - name: maxOutputTokens
182 value: 256
183 - name: temperature
184 value: 0.4
185 examples:
186 - name: input
187 values:
188 - a
189 - b
190 test: c
191 - name: output
192 values:
193 - x
194 - y
195 "#;
196
197 let prompt = deserialize_prompt(yaml);
198
199 if let Prompt::Completion(completion) = prompt {
200 assert_eq!(completion.vendor, "google");
201 assert_eq!(completion.model, "text-bison");
202 assert_eq!(completion.prompt, "Write a hello world in java");
203
204 if let Some(parameters) = completion.parameters {
205 assert_eq!(parameters.len(), 2);
206 assert_eq!(parameters[0].name, "maxOutputTokens");
207 assert_eq!(parameters[0].value, 256);
208 assert_eq!(parameters[1].name, "temperature");
209 assert_eq!(parameters[1].value, 0.4);
210 }
211
212 if let Some(examples) = completion.examples {
213 assert_eq!(examples.len(), 2);
214 assert_eq!(examples[0].name, "input");
215 assert_eq!(examples[1].name, "output");
216 assert_eq!(examples[0].values, vec!["a", "b"]);
217 assert_eq!(examples[0].test, Some("c".to_string()));
218 assert_eq!(examples[1].values, vec!["x", "y"]);
219 assert_eq!(examples[1].test, None);
220 }
221 } else {
222 panic!("Expected Prompt::Completion, got {:?}", prompt);
223 }
224 }
225
226 #[test]
227 fn test_deserialize_prompt_chat() {
228 let yaml = r#"
229 type: chat
230 vendor: google
231 model: chat-bison
232 parameters:
233 - name: maxOutputTokens
234 value: 256
235 - name: temperature
236 value: 0.4
237 examples:
238 - input: who are u?
239 output: I'm google
240 messages:
241 - input: what's your name?
242 "#;
243
244 let prompt = deserialize_prompt(yaml);
245
246 if let Prompt::Chat(chat) = prompt {
247 assert_eq!(chat.vendor, "google");
248 assert_eq!(chat.model, "chat-bison");
249
250 if let Some(parameters) = chat.parameters {
251 assert_eq!(parameters.len(), 2);
252 assert_eq!(parameters[0].name, "maxOutputTokens");
253 assert_eq!(parameters[0].value, 256);
254 assert_eq!(parameters[1].name, "temperature");
255 assert_eq!(parameters[1].value, 0.4);
256 }
257
258 if let Some(examples) = chat.examples {
259 assert_eq!(examples.len(), 1);
260 assert_eq!(examples[0].input, "who are u?");
261 assert_eq!(examples[0].output, Some("I'm google".to_string()));
262 }
263
264 assert_eq!(chat.context, None);
265
266 if let Some(messages) = chat.messages {
267 assert_eq!(messages.len(), 1);
268 assert_eq!(messages[0].input, "what's your name?");
269 assert_eq!(messages[0].output, None);
270 }
271 } else {
272 panic!("Expected Prompt::Chat, got {:?}", prompt);
273 }
274 }
275
276 #[test]
277 fn test_deserialize_prompt_unknown() {
278 let yaml = r#"
279 type: unknown
280 "#;
281
282 let prompt = deserialize_prompt(yaml);
283
284 if let Prompt::Unknown = prompt {
285 } else {
287 panic!("Expected Prompt::Unkwon, got {:?}", prompt);
288 }
289 }
290
291 #[test]
292 fn test_get_examples_count() {
293 let yaml = r#"
294 type: completion
295 vendor: google
296 model: text-bison
297 prompt: Write a hello world in java
298 parameters:
299 - name: maxOutputTokens
300 value: 256
301 - name: temperature
302 value: 0.4
303 examples:
304 - name: input
305 values:
306 - a
307 - b
308 test: c
309 - name: output
310 values:
311 - x
312 - y
313 "#;
314
315 let prompt = deserialize_prompt(yaml);
316
317 let final_prompt = r#"Write a hello world in java
318
319input: a
320output: x
321
322input: b
323output: y
324
325input: c
326output:
327"#;
328 if let Prompt::Completion(completion) = prompt {
329 assert_eq!(completion.example_count(), 2);
330 assert_eq!(completion.final_prompt(), final_prompt);
331 } else {
332 panic!("Expected Prompt::Unkwon, got {:?}", prompt);
333 }
334 }
335}