1use anyhow::Result;
2use rig::tool::ToolDyn;
3use serde_json::{Value, json};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use super::Adapter;
8use crate::serde_utils::get_iter_from_value;
9use crate::{Cache, CallResult, Chat, Example, LM, Message, MetaSignature, Prediction};
10
11#[derive(Default, Clone)]
12pub struct ChatAdapter;
13
14fn get_type_hint(field: &Value) -> String {
15 let schema = &field["schema"];
16 let type_str = field["type"].as_str().unwrap_or("String");
17
18 let has_schema = if let Some(s) = schema.as_str() {
20 !s.is_empty()
21 } else {
22 schema.is_object()
23 };
24
25 if !has_schema && type_str == "String" {
26 String::new()
27 } else {
28 format!(" (must be formatted as valid Rust {type_str})")
29 }
30}
31
32impl ChatAdapter {
33 fn get_field_attribute_list(
34 &self,
35 field_iter: impl Iterator<Item = (String, Value)>,
36 ) -> String {
37 let mut field_attributes = String::new();
38 for (i, (field_name, field)) in field_iter.enumerate() {
39 let data_type = field["type"].as_str().unwrap_or("String");
40 let desc = field["desc"].as_str().unwrap_or("");
41
42 field_attributes.push_str(format!("{}. `{field_name}` ({data_type})", i + 1).as_str());
43 if !desc.is_empty() {
44 field_attributes.push_str(format!(": {desc}").as_str());
45 }
46 field_attributes.push('\n');
47 }
48 field_attributes
49 }
50
51 fn get_field_structure(&self, field_iter: impl Iterator<Item = (String, Value)>) -> String {
52 let mut field_structure = String::new();
53 for (field_name, field) in field_iter {
54 let schema = &field["schema"];
55 let data_type = field["type"].as_str().unwrap_or("String");
56
57 let schema_prompt = if let Some(s) = schema.as_str() {
59 if s.is_empty() && data_type == "String" {
60 "".to_string()
61 } else if !s.is_empty() {
62 format!("\t# note: the value you produce must adhere to the JSON schema: {s}")
63 } else {
64 format!("\t# note: the value you produce must be a single {data_type} value")
65 }
66 } else if schema.is_object() || schema.is_array() {
67 let schema_str = schema.to_string();
69 format!(
70 "\t# note: the value you produce must adhere to the JSON schema: {schema_str}"
71 )
72 } else if data_type == "String" {
73 "".to_string()
74 } else {
75 format!("\t# note: the value you produce must be a single {data_type} value")
76 };
77
78 field_structure.push_str(
79 format!("[[ ## {field_name} ## ]]\n{field_name}{schema_prompt}\n\n").as_str(),
80 );
81 }
82 field_structure
83 }
84
85 fn format_system_message(&self, signature: &dyn MetaSignature) -> String {
86 let field_description = self.format_field_description(signature);
87 let field_structure = self.format_field_structure(signature);
88 let task_description = self.format_task_description(signature);
89
90 format!("{field_description}\n{field_structure}\n{task_description}")
91 }
92
93 fn format_field_description(&self, signature: &dyn MetaSignature) -> String {
94 let input_field_description =
95 self.get_field_attribute_list(get_iter_from_value(&signature.input_fields()));
96 let output_field_description =
97 self.get_field_attribute_list(get_iter_from_value(&signature.output_fields()));
98
99 format!(
100 "Your input fields are:\n{input_field_description}\nYour output fields are:\n{output_field_description}"
101 )
102 }
103
104 fn format_field_structure(&self, signature: &dyn MetaSignature) -> String {
105 let input_field_structure =
106 self.get_field_structure(get_iter_from_value(&signature.input_fields()));
107 let output_field_structure =
108 self.get_field_structure(get_iter_from_value(&signature.output_fields()));
109
110 format!(
111 "All interactions will be structured in the following way, with the appropriate values filled in.\n\n{input_field_structure}{output_field_structure}[[ ## completed ## ]]\n"
112 )
113 }
114
115 fn format_task_description(&self, signature: &dyn MetaSignature) -> String {
116 let instruction = if signature.instruction().is_empty() {
117 format!(
118 "Given the fields {}, produce the fields {}.",
119 signature
120 .input_fields()
121 .as_object()
122 .unwrap()
123 .keys()
124 .map(|k| format!("`{k}`"))
125 .collect::<Vec<String>>()
126 .join(", "),
127 signature
128 .output_fields()
129 .as_object()
130 .unwrap()
131 .keys()
132 .map(|k| format!("`{k}`"))
133 .collect::<Vec<String>>()
134 .join(", ")
135 )
136 } else {
137 signature.instruction().clone()
138 };
139
140 format!("In adhering to this structure, your objective is:\n\t{instruction}")
141 }
142
143 fn format_user_message(&self, signature: &dyn MetaSignature, inputs: &Example) -> String {
144 let mut input_str = String::new();
145 for (field_name, _) in get_iter_from_value(&signature.input_fields()) {
146 let field_value = inputs.get(field_name.as_str(), None);
147 let field_value_str = if let Some(s) = field_value.as_str() {
149 s.to_string()
150 } else {
151 field_value.to_string()
152 };
153
154 input_str
155 .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
156 }
157
158 let first_output_field = signature
159 .output_fields()
160 .as_object()
161 .unwrap()
162 .keys()
163 .next()
164 .unwrap()
165 .clone();
166 let first_output_field_value = signature
167 .output_fields()
168 .as_object()
169 .unwrap()
170 .get(&first_output_field)
171 .unwrap()
172 .clone();
173
174 let type_hint = get_type_hint(&first_output_field_value);
175
176 let mut user_message = format!(
177 "Respond with the corresponding output fields, starting with the field `{first_output_field}`{type_hint},"
178 );
179 for (field_name, field) in get_iter_from_value(&signature.output_fields()).skip(1) {
180 user_message
181 .push_str(format!(" then `{field_name}`{},", get_type_hint(&field)).as_str());
182 }
183 user_message.push_str(" and then ending with the marker for `completed`.");
184
185 format!("{input_str}{user_message}")
186 }
187
188 fn format_assistant_message(&self, signature: &dyn MetaSignature, outputs: &Example) -> String {
189 let mut assistant_message = String::new();
190 for (field_name, _) in get_iter_from_value(&signature.output_fields()) {
191 let field_value = outputs.get(field_name.as_str(), None);
192 let field_value_str = if let Some(s) = field_value.as_str() {
194 s.to_string()
195 } else {
196 field_value.to_string()
197 };
198
199 assistant_message
200 .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
201 }
202 assistant_message.push_str("[[ ## completed ## ]]\n");
203 assistant_message
204 }
205
206 fn format_demos(&self, signature: &dyn MetaSignature, demos: &Vec<Example>) -> Chat {
207 let mut chat = Chat::new(vec![]);
208
209 for demo in demos {
210 let user_message = self.format_user_message(signature, demo);
211 let assistant_message = self.format_assistant_message(signature, demo);
212 chat.push("user", &user_message);
213 chat.push("assistant", &assistant_message);
214 }
215
216 chat
217 }
218}
219
220#[async_trait::async_trait]
221impl Adapter for ChatAdapter {
222 fn format(&self, signature: &dyn MetaSignature, inputs: Example) -> Chat {
223 let system_message = self.format_system_message(signature);
224 let user_message = self.format_user_message(signature, &inputs);
225
226 let demos = signature.demos();
227 let demos = self.format_demos(signature, &demos);
228
229 let mut chat = Chat::new(vec![]);
230 chat.push("system", &system_message);
231 chat.push_all(&demos);
232 chat.push("user", &user_message);
233
234 chat
235 }
236
237 fn parse_response(
238 &self,
239 signature: &dyn MetaSignature,
240 response: Message,
241 ) -> HashMap<String, Value> {
242 let mut output = HashMap::new();
243
244 let response_content = response.content();
245
246 for (field_name, field) in get_iter_from_value(&signature.output_fields()) {
247 let field_value = response_content
248 .split(format!("[[ ## {field_name} ## ]]\n").as_str())
249 .nth(1);
250
251 if field_value.is_none() {
252 continue; }
254 let field_value = field_value.unwrap();
255
256 let extracted_field = field_value.split("[[ ## ").nth(0).unwrap().trim();
257 let data_type = field["type"].as_str().unwrap();
258 let schema = &field["schema"];
259
260 let has_schema = if let Some(s) = schema.as_str() {
262 !s.is_empty()
263 } else {
264 schema.is_object() || schema.is_array()
265 };
266
267 if !has_schema && data_type == "String" {
268 output.insert(field_name.clone(), json!(extracted_field));
269 } else {
270 output.insert(
271 field_name.clone(),
272 serde_json::from_str(extracted_field).unwrap(),
273 );
274 }
275 }
276
277 output
278 }
279
280 async fn call(
281 &self,
282 lm: Arc<LM>,
283 signature: &dyn MetaSignature,
284 inputs: Example,
285 tools: Vec<Arc<dyn ToolDyn>>,
286 ) -> Result<Prediction> {
287 if lm.cache
289 && let Some(cache) = lm.cache_handler.as_ref()
290 {
291 let cache_key = inputs.clone();
292 if let Some(cached) = cache.lock().await.get(cache_key).await? {
293 return Ok(cached);
294 }
295 }
296
297 let messages = self.format(signature, inputs.clone());
298 let response = lm.call(messages, tools).await?;
299 let prompt_str = response.chat.to_json().to_string();
300
301 let mut output = self.parse_response(signature, response.output);
302 if !response.tool_calls.is_empty() {
303 output.insert(
304 "tool_calls".to_string(),
305 response
306 .tool_calls
307 .into_iter()
308 .map(|call| json!(call))
309 .collect::<Value>(),
310 );
311 output.insert(
312 "tool_executions".to_string(),
313 response
314 .tool_executions
315 .into_iter()
316 .map(|execution| json!(execution))
317 .collect::<Value>(),
318 );
319 }
320
321 let prediction = Prediction {
322 data: output,
323 lm_usage: response.usage,
324 };
325
326 if lm.cache
328 && let Some(cache) = lm.cache_handler.as_ref()
329 {
330 let (tx, rx) = tokio::sync::mpsc::channel(1);
331 let cache_clone = cache.clone();
332 let inputs_clone = inputs.clone();
333
334 tokio::spawn(async move {
336 let _ = cache_clone.lock().await.insert(inputs_clone, rx).await;
337 });
338
339 tx.send(CallResult {
341 prompt: prompt_str,
342 prediction: prediction.clone(),
343 })
344 .await
345 .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
346 }
347
348 Ok(prediction)
349 }
350}