prompt_def/
prompt.rs

1use serde::Deserialize;
2use serde_yaml::Value;
3
4#[derive(Debug, Deserialize)]
5struct PromptType {
6    #[serde(rename = "type")]
7    prompt_type: String,
8}
9
10#[derive(Debug, Deserialize)]
11pub struct CompletionExampleColumn {
12    pub name: String,
13    pub values: Vec<String>,
14    pub test: Option<String>,
15}
16
17#[derive(Debug, Deserialize)]
18pub struct ChatExample {
19    pub input: String,
20    pub output: Option<String>,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct Message {
25    pub input: String,
26    pub output: Option<String>,
27}
28
29#[derive(Debug, Deserialize)]
30pub struct Parameter {
31    pub name: String,
32    pub value: Value,
33}
34
35#[derive(Debug, Deserialize)]
36pub struct Completion {
37    #[serde(rename = "type")]
38    pub prompt_type: String,
39    pub vendor: String,
40    pub model: String,
41    pub prompt: String,
42    pub parameters: Option<Vec<Parameter>>,
43    pub examples: Option<Vec<CompletionExampleColumn>>,
44}
45
46pub fn find_parameter(
47    parameters: &Option<Vec<crate::prompt::Parameter>>,
48    name: &str,
49) -> Option<Value> {
50    if let Some(real_parameters) = parameters {
51        real_parameters
52            .iter()
53            .find(|p| p.name == name)
54            .map(|x| x.value.clone())
55    } else {
56        None
57    }
58}
59
60impl Completion {
61    pub fn example_count(&self) -> usize {
62        let mut max_length = 0;
63        if let Some(columns) = &self.examples {
64            for column in columns {
65                if column.values.len() > max_length {
66                    max_length = column.values.len()
67                }
68            }
69        }
70        max_length
71    }
72
73    pub fn final_prompt(&self) -> String {
74        let mut prompt = self.prompt.clone();
75        prompt.push_str("\n\n");
76        if let Some(columns) = &self.examples {
77            for i in 0..self.example_count() {
78                for column in columns {
79                    let line: String = format!(
80                        "{}: {}\n",
81                        column.name,
82                        column.values.get(i).unwrap_or(&"".to_string())
83                    );
84                    prompt.push_str(&line);
85                }
86                prompt.push('\n');
87            }
88            for column in columns {
89                let line: String = format!(
90                    "{}: {}\n",
91                    column.name,
92                    column.test.as_ref().unwrap_or(&"".to_string())
93                );
94                prompt.push_str(&line);
95            }
96        }
97        prompt.to_string()
98    }
99
100    pub fn find_parameter_as_i32(&self, name: &str) -> Option<i32> {
101        find_parameter(&self.parameters, name).map(|p| p.as_i64().unwrap() as i32)
102    }
103
104    pub fn find_parameter_as_f32(&self, name: &str) -> Option<f32> {
105        find_parameter(&self.parameters, name).map(|p| p.as_f64().unwrap() as f32)
106    }
107
108    pub fn find_parameter_as_str(&self, name: &str) -> Option<String> {
109        find_parameter(&self.parameters, name).map(|p| p.as_str().unwrap().to_string())
110    }
111
112    pub fn find_parameter_as_bool(&self, name: &str) -> Option<bool> {
113        find_parameter(&self.parameters, name).map(|p| p.as_bool().unwrap())
114    }
115}
116
117#[derive(Debug, Deserialize)]
118pub struct Chat {
119    #[serde(rename = "type")]
120    pub prompt_type: String,
121    pub vendor: String,
122    pub model: String,
123    pub parameters: Option<Vec<Parameter>>,
124    pub examples: Option<Vec<ChatExample>>,
125    pub context: Option<String>,
126    pub messages: Option<Vec<Message>>,
127}
128
129impl Chat {
130    pub fn find_parameter_as_i32(&self, name: &str) -> Option<i32> {
131        find_parameter(&self.parameters, name).map(|p| p.as_i64().unwrap() as i32)
132    }
133
134    pub fn find_parameter_as_f32(&self, name: &str) -> Option<f32> {
135        find_parameter(&self.parameters, name).map(|p| p.as_f64().unwrap() as f32)
136    }
137
138    pub fn find_parameter_as_str(&self, name: &str) -> Option<String> {
139        find_parameter(&self.parameters, name).map(|p| p.as_str().unwrap().to_string())
140    }
141
142    pub fn find_parameter_as_bool(&self, name: &str) -> Option<bool> {
143        find_parameter(&self.parameters, name).map(|p| p.as_bool().unwrap())
144    }
145}
146
147#[derive(Debug, Deserialize)]
148pub enum Prompt {
149    Completion(Completion),
150    Chat(Chat),
151    Unknown,
152}
153
154pub fn deserialize_prompt(yaml: &str) -> Prompt {
155    let prompt_type: PromptType = serde_yaml::from_str(&yaml).unwrap();
156    match prompt_type.prompt_type.as_str() {
157        "completion" => {
158            let completion: Completion = serde_yaml::from_str(&yaml).unwrap();
159            Prompt::Completion(completion)
160        }
161        "chat" => {
162            let chat: Chat = serde_yaml::from_str(&yaml).unwrap();
163            Prompt::Chat(chat)
164        }
165        _ => Prompt::Unknown,
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_deserialize_prompt_completion() {
175        let yaml = r#"  
176            type: completion  
177            vendor: google
178            model: text-bison
179            prompt: Write a hello world in java
180            parameters:  
181                - name: maxOutputTokens
182                  value: 256
183                - name: temperature
184                  value: 0.4
185            examples:  
186                - name: input
187                  values:
188                    - a
189                    - b
190                  test: c
191                - name: output
192                  values:
193                    - x
194                    - y
195        "#;
196
197        let prompt = deserialize_prompt(yaml);
198
199        if let Prompt::Completion(completion) = prompt {
200            assert_eq!(completion.vendor, "google");
201            assert_eq!(completion.model, "text-bison");
202            assert_eq!(completion.prompt, "Write a hello world in java");
203
204            if let Some(parameters) = completion.parameters {
205                assert_eq!(parameters.len(), 2);
206                assert_eq!(parameters[0].name, "maxOutputTokens");
207                assert_eq!(parameters[0].value, 256);
208                assert_eq!(parameters[1].name, "temperature");
209                assert_eq!(parameters[1].value, 0.4);
210            }
211
212            if let Some(examples) = completion.examples {
213                assert_eq!(examples.len(), 2);
214                assert_eq!(examples[0].name, "input");
215                assert_eq!(examples[1].name, "output");
216                assert_eq!(examples[0].values, vec!["a", "b"]);
217                assert_eq!(examples[0].test, Some("c".to_string()));
218                assert_eq!(examples[1].values, vec!["x", "y"]);
219                assert_eq!(examples[1].test, None);
220            }
221        } else {
222            panic!("Expected Prompt::Completion, got {:?}", prompt);
223        }
224    }
225
226    #[test]
227    fn test_deserialize_prompt_chat() {
228        let yaml = r#"  
229            type: chat  
230            vendor: google
231            model: chat-bison 
232            parameters:  
233                - name: maxOutputTokens
234                  value: 256
235                - name: temperature
236                  value: 0.4
237            examples:  
238                - input: who are u?
239                  output: I'm google
240            messages:  
241                - input: what's your name?
242        "#;
243
244        let prompt = deserialize_prompt(yaml);
245
246        if let Prompt::Chat(chat) = prompt {
247            assert_eq!(chat.vendor, "google");
248            assert_eq!(chat.model, "chat-bison");
249
250            if let Some(parameters) = chat.parameters {
251                assert_eq!(parameters.len(), 2);
252                assert_eq!(parameters[0].name, "maxOutputTokens");
253                assert_eq!(parameters[0].value, 256);
254                assert_eq!(parameters[1].name, "temperature");
255                assert_eq!(parameters[1].value, 0.4);
256            }
257
258            if let Some(examples) = chat.examples {
259                assert_eq!(examples.len(), 1);
260                assert_eq!(examples[0].input, "who are u?");
261                assert_eq!(examples[0].output, Some("I'm google".to_string()));
262            }
263
264            assert_eq!(chat.context, None);
265
266            if let Some(messages) = chat.messages {
267                assert_eq!(messages.len(), 1);
268                assert_eq!(messages[0].input, "what's your name?");
269                assert_eq!(messages[0].output, None);
270            }
271        } else {
272            panic!("Expected Prompt::Chat, got {:?}", prompt);
273        }
274    }
275
276    #[test]
277    fn test_deserialize_prompt_unknown() {
278        let yaml = r#"  
279            type: unknown  
280        "#;
281
282        let prompt = deserialize_prompt(yaml);
283
284        if let Prompt::Unknown = prompt {
285            // Test passed
286        } else {
287            panic!("Expected Prompt::Unkwon, got {:?}", prompt);
288        }
289    }
290
291    #[test]
292    fn test_get_examples_count() {
293        let yaml = r#"  
294            type: completion  
295            vendor: google
296            model: text-bison
297            prompt: Write a hello world in java
298            parameters:  
299                - name: maxOutputTokens
300                  value: 256
301                - name: temperature
302                  value: 0.4
303            examples:  
304                - name: input
305                  values:
306                    - a
307                    - b
308                  test: c
309                - name: output
310                  values:
311                    - x
312                    - y
313        "#;
314
315        let prompt = deserialize_prompt(yaml);
316
317        let final_prompt = r#"Write a hello world in java
318
319input: a
320output: x
321
322input: b
323output: y
324
325input: c
326output: 
327"#;
328        if let Prompt::Completion(completion) = prompt {
329            assert_eq!(completion.example_count(), 2);
330            assert_eq!(completion.final_prompt(), final_prompt);
331        } else {
332            panic!("Expected Prompt::Unkwon, got {:?}", prompt);
333        }
334    }
335}