1use anyhow::Result;
2use serde_json::{Value, json};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use super::Adapter;
7use crate::serde_utils::get_iter_from_value;
8use crate::{Cache, CallResult, Chat, Example, LM, Message, MetaSignature, Prediction};
9
10#[derive(Default, Clone)]
11pub struct ChatAdapter;
12
13fn get_type_hint(field: &Value) -> String {
14 let schema = &field["schema"];
15 let type_str = field["type"].as_str().unwrap_or("String");
16
17 let has_schema = if let Some(s) = schema.as_str() {
19 !s.is_empty()
20 } else {
21 schema.is_object()
22 };
23
24 if !has_schema && type_str == "String" {
25 String::new()
26 } else {
27 format!(" (must be formatted as valid Rust {type_str})")
28 }
29}
30
31impl ChatAdapter {
32 fn get_field_attribute_list(
33 &self,
34 field_iter: impl Iterator<Item = (String, Value)>,
35 ) -> String {
36 let mut field_attributes = String::new();
37 for (i, (field_name, field)) in field_iter.enumerate() {
38 let data_type = field["type"].as_str().unwrap_or("String");
39 let desc = field["desc"].as_str().unwrap_or("");
40
41 field_attributes.push_str(format!("{}. `{field_name}` ({data_type})", i + 1).as_str());
42 if !desc.is_empty() {
43 field_attributes.push_str(format!(": {desc}").as_str());
44 }
45 field_attributes.push('\n');
46 }
47 field_attributes
48 }
49
50 fn get_field_structure(&self, field_iter: impl Iterator<Item = (String, Value)>) -> String {
51 let mut field_structure = String::new();
52 for (field_name, field) in field_iter {
53 let schema = &field["schema"];
54 let data_type = field["type"].as_str().unwrap_or("String");
55
56 let schema_prompt = if let Some(s) = schema.as_str() {
58 if s.is_empty() && data_type == "String" {
59 "".to_string()
60 } else if !s.is_empty() {
61 format!("\t# note: the value you produce must adhere to the JSON schema: {s}")
62 } else {
63 format!("\t# note: the value you produce must be a single {data_type} value")
64 }
65 } else if schema.is_object() || schema.is_array() {
66 let schema_str = schema.to_string();
68 format!(
69 "\t# note: the value you produce must adhere to the JSON schema: {schema_str}"
70 )
71 } else if data_type == "String" {
72 "".to_string()
73 } else {
74 format!("\t# note: the value you produce must be a single {data_type} value")
75 };
76
77 field_structure.push_str(
78 format!("[[ ## {field_name} ## ]]\n{field_name}{schema_prompt}\n\n").as_str(),
79 );
80 }
81 field_structure
82 }
83
84 fn format_system_message(&self, signature: &dyn MetaSignature) -> String {
85 let field_description = self.format_field_description(signature);
86 let field_structure = self.format_field_structure(signature);
87 let task_description = self.format_task_description(signature);
88
89 format!("{field_description}\n{field_structure}\n{task_description}")
90 }
91
92 fn format_field_description(&self, signature: &dyn MetaSignature) -> String {
93 let input_field_description =
94 self.get_field_attribute_list(get_iter_from_value(&signature.input_fields()));
95 let output_field_description =
96 self.get_field_attribute_list(get_iter_from_value(&signature.output_fields()));
97
98 format!(
99 "Your input fields are:\n{input_field_description}\nYour output fields are:\n{output_field_description}"
100 )
101 }
102
103 fn format_field_structure(&self, signature: &dyn MetaSignature) -> String {
104 let input_field_structure =
105 self.get_field_structure(get_iter_from_value(&signature.input_fields()));
106 let output_field_structure =
107 self.get_field_structure(get_iter_from_value(&signature.output_fields()));
108
109 format!(
110 "All interactions will be structured in the following way, with the appropriate values filled in.\n\n{input_field_structure}{output_field_structure}[[ ## completed ## ]]\n"
111 )
112 }
113
114 fn format_task_description(&self, signature: &dyn MetaSignature) -> String {
115 let instruction = if signature.instruction().is_empty() {
116 format!(
117 "Given the fields {}, produce the fields {}.",
118 signature
119 .input_fields()
120 .as_object()
121 .unwrap()
122 .keys()
123 .map(|k| format!("`{k}`"))
124 .collect::<Vec<String>>()
125 .join(", "),
126 signature
127 .output_fields()
128 .as_object()
129 .unwrap()
130 .keys()
131 .map(|k| format!("`{k}`"))
132 .collect::<Vec<String>>()
133 .join(", ")
134 )
135 } else {
136 signature.instruction().clone()
137 };
138
139 format!("In adhering to this structure, your objective is:\n\t{instruction}")
140 }
141
142 fn format_user_message(&self, signature: &dyn MetaSignature, inputs: &Example) -> String {
143 let mut input_str = String::new();
144 for (field_name, _) in get_iter_from_value(&signature.input_fields()) {
145 let field_value = inputs.get(field_name.as_str(), None);
146 let field_value_str = if let Some(s) = field_value.as_str() {
148 s.to_string()
149 } else {
150 field_value.to_string()
151 };
152
153 input_str
154 .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
155 }
156
157 let first_output_field = signature
158 .output_fields()
159 .as_object()
160 .unwrap()
161 .keys()
162 .next()
163 .unwrap()
164 .clone();
165 let first_output_field_value = signature
166 .output_fields()
167 .as_object()
168 .unwrap()
169 .get(&first_output_field)
170 .unwrap()
171 .clone();
172
173 let type_hint = get_type_hint(&first_output_field_value);
174
175 let mut user_message = format!(
176 "Respond with the corresponding output fields, starting with the field `{first_output_field}`{type_hint},"
177 );
178 for (field_name, field) in get_iter_from_value(&signature.output_fields()).skip(1) {
179 user_message
180 .push_str(format!(" then `{field_name}`{},", get_type_hint(&field)).as_str());
181 }
182 user_message.push_str(" and then ending with the marker for `completed`.");
183
184 format!("{input_str}{user_message}")
185 }
186
187 fn format_assistant_message(&self, signature: &dyn MetaSignature, outputs: &Example) -> String {
188 let mut assistant_message = String::new();
189 for (field_name, _) in get_iter_from_value(&signature.output_fields()) {
190 let field_value = outputs.get(field_name.as_str(), None);
191 let field_value_str = if let Some(s) = field_value.as_str() {
193 s.to_string()
194 } else {
195 field_value.to_string()
196 };
197
198 assistant_message
199 .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
200 }
201 assistant_message.push_str("[[ ## completed ## ]]\n");
202 assistant_message
203 }
204
205 fn format_demos(&self, signature: &dyn MetaSignature, demos: &Vec<Example>) -> Chat {
206 let mut chat = Chat::new(vec![]);
207
208 for demo in demos {
209 let user_message = self.format_user_message(signature, demo);
210 let assistant_message = self.format_assistant_message(signature, demo);
211 chat.push("user", &user_message);
212 chat.push("assistant", &assistant_message);
213 }
214
215 chat
216 }
217}
218
219#[async_trait::async_trait]
220impl Adapter for ChatAdapter {
221 fn format(&self, signature: &dyn MetaSignature, inputs: Example) -> Chat {
222 let system_message = self.format_system_message(signature);
223 let user_message = self.format_user_message(signature, &inputs);
224
225 let demos = signature.demos();
226 let demos = self.format_demos(signature, &demos);
227
228 let mut chat = Chat::new(vec![]);
229 chat.push("system", &system_message);
230 chat.push_all(&demos);
231 chat.push("user", &user_message);
232
233 chat
234 }
235
236 fn parse_response(
237 &self,
238 signature: &dyn MetaSignature,
239 response: Message,
240 ) -> HashMap<String, Value> {
241 let mut output = HashMap::new();
242
243 let response_content = response.content();
244
245 for (field_name, field) in get_iter_from_value(&signature.output_fields()) {
246 let field_value = response_content
247 .split(format!("[[ ## {field_name} ## ]]\n").as_str())
248 .nth(1);
249
250 if field_value.is_none() {
251 continue; }
253 let field_value = field_value.unwrap();
254
255 let extracted_field = field_value.split("[[ ## ").nth(0).unwrap().trim();
256 let data_type = field["type"].as_str().unwrap();
257 let schema = &field["schema"];
258
259 let has_schema = if let Some(s) = schema.as_str() {
261 !s.is_empty()
262 } else {
263 schema.is_object() || schema.is_array()
264 };
265
266 if !has_schema && data_type == "String" {
267 output.insert(field_name.clone(), json!(extracted_field));
268 } else {
269 output.insert(
270 field_name.clone(),
271 serde_json::from_str(extracted_field).unwrap(),
272 );
273 }
274 }
275
276 output
277 }
278
279 async fn call(
280 &self,
281 lm: Arc<LM>,
282 signature: &dyn MetaSignature,
283 inputs: Example,
284 ) -> Result<Prediction> {
285 if lm.config.cache
287 && let Some(cache) = lm.cache_handler.as_ref()
288 {
289 let cache_key = inputs.clone();
290 if let Some(cached) = cache.lock().await.get(cache_key).await? {
291 return Ok(cached);
292 }
293 }
294
295 let messages = self.format(signature, inputs.clone());
296 let response = lm.call(messages).await?;
297 let prompt_str = response.chat.to_json().to_string();
298
299 let output = self.parse_response(signature, response.output);
300
301 let prediction = Prediction {
302 data: output,
303 lm_usage: response.usage,
304 };
305
306 if lm.config.cache
308 && let Some(cache) = lm.cache_handler.as_ref()
309 {
310 let (tx, rx) = tokio::sync::mpsc::channel(1);
311 let cache_clone = cache.clone();
312 let inputs_clone = inputs.clone();
313
314 tokio::spawn(async move {
316 let _ = cache_clone.lock().await.insert(inputs_clone, rx).await;
317 });
318
319 tx.send(CallResult {
321 prompt: prompt_str,
322 prediction: prediction.clone(),
323 })
324 .await
325 .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
326 }
327
328 Ok(prediction)
329 }
330}