dspy_rs/adapter/
chat.rs

1use anyhow::Result;
2use serde_json::{Value, json};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use super::Adapter;
7use crate::serde_utils::get_iter_from_value;
8use crate::{Cache, CallResult, Chat, Example, LM, Message, MetaSignature, Prediction};
9
10#[derive(Default, Clone)]
11pub struct ChatAdapter;
12
13fn get_type_hint(field: &Value) -> String {
14    let schema = &field["schema"];
15    let type_str = field["type"].as_str().unwrap_or("String");
16
17    // Check if schema exists and is not empty (either as string or object)
18    let has_schema = if let Some(s) = schema.as_str() {
19        !s.is_empty()
20    } else {
21        schema.is_object()
22    };
23
24    if !has_schema && type_str == "String" {
25        String::new()
26    } else {
27        format!(" (must be formatted as valid Rust {type_str})")
28    }
29}
30
31impl ChatAdapter {
32    fn get_field_attribute_list(
33        &self,
34        field_iter: impl Iterator<Item = (String, Value)>,
35    ) -> String {
36        let mut field_attributes = String::new();
37        for (i, (field_name, field)) in field_iter.enumerate() {
38            let data_type = field["type"].as_str().unwrap_or("String");
39            let desc = field["desc"].as_str().unwrap_or("");
40
41            field_attributes.push_str(format!("{}. `{field_name}` ({data_type})", i + 1).as_str());
42            if !desc.is_empty() {
43                field_attributes.push_str(format!(": {desc}").as_str());
44            }
45            field_attributes.push('\n');
46        }
47        field_attributes
48    }
49
50    fn get_field_structure(&self, field_iter: impl Iterator<Item = (String, Value)>) -> String {
51        let mut field_structure = String::new();
52        for (field_name, field) in field_iter {
53            let schema = &field["schema"];
54            let data_type = field["type"].as_str().unwrap_or("String");
55
56            // Handle schema as either string or JSON object
57            let schema_prompt = if let Some(s) = schema.as_str() {
58                if s.is_empty() && data_type == "String" {
59                    "".to_string()
60                } else if !s.is_empty() {
61                    format!("\t# note: the value you produce must adhere to the JSON schema: {s}")
62                } else {
63                    format!("\t# note: the value you produce must be a single {data_type} value")
64                }
65            } else if schema.is_object() || schema.is_array() {
66                // Convert JSON object/array to string for display
67                let schema_str = schema.to_string();
68                format!(
69                    "\t# note: the value you produce must adhere to the JSON schema: {schema_str}"
70                )
71            } else if data_type == "String" {
72                "".to_string()
73            } else {
74                format!("\t# note: the value you produce must be a single {data_type} value")
75            };
76
77            field_structure.push_str(
78                format!("[[ ## {field_name} ## ]]\n{field_name}{schema_prompt}\n\n").as_str(),
79            );
80        }
81        field_structure
82    }
83
84    fn format_system_message(&self, signature: &dyn MetaSignature) -> String {
85        let field_description = self.format_field_description(signature);
86        let field_structure = self.format_field_structure(signature);
87        let task_description = self.format_task_description(signature);
88
89        format!("{field_description}\n{field_structure}\n{task_description}")
90    }
91
92    fn format_field_description(&self, signature: &dyn MetaSignature) -> String {
93        let input_field_description =
94            self.get_field_attribute_list(get_iter_from_value(&signature.input_fields()));
95        let output_field_description =
96            self.get_field_attribute_list(get_iter_from_value(&signature.output_fields()));
97
98        format!(
99            "Your input fields are:\n{input_field_description}\nYour output fields are:\n{output_field_description}"
100        )
101    }
102
103    fn format_field_structure(&self, signature: &dyn MetaSignature) -> String {
104        let input_field_structure =
105            self.get_field_structure(get_iter_from_value(&signature.input_fields()));
106        let output_field_structure =
107            self.get_field_structure(get_iter_from_value(&signature.output_fields()));
108
109        format!(
110            "All interactions will be structured in the following way, with the appropriate values filled in.\n\n{input_field_structure}{output_field_structure}[[ ## completed ## ]]\n"
111        )
112    }
113
114    fn format_task_description(&self, signature: &dyn MetaSignature) -> String {
115        let instruction = if signature.instruction().is_empty() {
116            format!(
117                "Given the fields {}, produce the fields {}.",
118                signature
119                    .input_fields()
120                    .as_object()
121                    .unwrap()
122                    .keys()
123                    .map(|k| format!("`{k}`"))
124                    .collect::<Vec<String>>()
125                    .join(", "),
126                signature
127                    .output_fields()
128                    .as_object()
129                    .unwrap()
130                    .keys()
131                    .map(|k| format!("`{k}`"))
132                    .collect::<Vec<String>>()
133                    .join(", ")
134            )
135        } else {
136            signature.instruction().clone()
137        };
138
139        format!("In adhering to this structure, your objective is:\n\t{instruction}")
140    }
141
142    fn format_user_message(&self, signature: &dyn MetaSignature, inputs: &Example) -> String {
143        let mut input_str = String::new();
144        for (field_name, _) in get_iter_from_value(&signature.input_fields()) {
145            let field_value = inputs.get(field_name.as_str(), None);
146            // Extract the actual string value if it's a JSON string, otherwise use as is
147            let field_value_str = if let Some(s) = field_value.as_str() {
148                s.to_string()
149            } else {
150                field_value.to_string()
151            };
152
153            input_str
154                .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
155        }
156
157        let first_output_field = signature
158            .output_fields()
159            .as_object()
160            .unwrap()
161            .keys()
162            .next()
163            .unwrap()
164            .clone();
165        let first_output_field_value = signature
166            .output_fields()
167            .as_object()
168            .unwrap()
169            .get(&first_output_field)
170            .unwrap()
171            .clone();
172
173        let type_hint = get_type_hint(&first_output_field_value);
174
175        let mut user_message = format!(
176            "Respond with the corresponding output fields, starting with the field `{first_output_field}`{type_hint},"
177        );
178        for (field_name, field) in get_iter_from_value(&signature.output_fields()).skip(1) {
179            user_message
180                .push_str(format!(" then `{field_name}`{},", get_type_hint(&field)).as_str());
181        }
182        user_message.push_str(" and then ending with the marker for `completed`.");
183
184        format!("{input_str}{user_message}")
185    }
186
187    fn format_assistant_message(&self, signature: &dyn MetaSignature, outputs: &Example) -> String {
188        let mut assistant_message = String::new();
189        for (field_name, _) in get_iter_from_value(&signature.output_fields()) {
190            let field_value = outputs.get(field_name.as_str(), None);
191            // Extract the actual string value if it's a JSON string, otherwise use as is
192            let field_value_str = if let Some(s) = field_value.as_str() {
193                s.to_string()
194            } else {
195                field_value.to_string()
196            };
197
198            assistant_message
199                .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
200        }
201        assistant_message.push_str("[[ ## completed ## ]]\n");
202        assistant_message
203    }
204
205    fn format_demos(&self, signature: &dyn MetaSignature, demos: &Vec<Example>) -> Chat {
206        let mut chat = Chat::new(vec![]);
207
208        for demo in demos {
209            let user_message = self.format_user_message(signature, demo);
210            let assistant_message = self.format_assistant_message(signature, demo);
211            chat.push("user", &user_message);
212            chat.push("assistant", &assistant_message);
213        }
214
215        chat
216    }
217}
218
219#[async_trait::async_trait]
220impl Adapter for ChatAdapter {
221    fn format(&self, signature: &dyn MetaSignature, inputs: Example) -> Chat {
222        let system_message = self.format_system_message(signature);
223        let user_message = self.format_user_message(signature, &inputs);
224
225        let demos = signature.demos();
226        let demos = self.format_demos(signature, &demos);
227
228        let mut chat = Chat::new(vec![]);
229        chat.push("system", &system_message);
230        chat.push_all(&demos);
231        chat.push("user", &user_message);
232
233        chat
234    }
235
236    fn parse_response(
237        &self,
238        signature: &dyn MetaSignature,
239        response: Message,
240    ) -> HashMap<String, Value> {
241        let mut output = HashMap::new();
242
243        let response_content = response.content();
244
245        for (field_name, field) in get_iter_from_value(&signature.output_fields()) {
246            let field_value = response_content
247                .split(format!("[[ ## {field_name} ## ]]\n").as_str())
248                .nth(1);
249
250            if field_value.is_none() {
251                continue; // Skip field if not found in response
252            }
253            let field_value = field_value.unwrap();
254
255            let extracted_field = field_value.split("[[ ## ").nth(0).unwrap().trim();
256            let data_type = field["type"].as_str().unwrap();
257            let schema = &field["schema"];
258
259            // Check if schema exists (as string or object)
260            let has_schema = if let Some(s) = schema.as_str() {
261                !s.is_empty()
262            } else {
263                schema.is_object() || schema.is_array()
264            };
265
266            if !has_schema && data_type == "String" {
267                output.insert(field_name.clone(), json!(extracted_field));
268            } else {
269                output.insert(
270                    field_name.clone(),
271                    serde_json::from_str(extracted_field).unwrap(),
272                );
273            }
274        }
275
276        output
277    }
278
279    async fn call(
280        &self,
281        lm: Arc<LM>,
282        signature: &dyn MetaSignature,
283        inputs: Example,
284    ) -> Result<Prediction> {
285        // Check cache first (release lock immediately after checking)
286        if lm.config.cache
287            && let Some(cache) = lm.cache_handler.as_ref()
288        {
289            let cache_key = inputs.clone();
290            if let Some(cached) = cache.lock().await.get(cache_key).await? {
291                return Ok(cached);
292            }
293        }
294
295        let messages = self.format(signature, inputs.clone());
296        let response = lm.call(messages).await?;
297        let prompt_str = response.chat.to_json().to_string();
298
299        let output = self.parse_response(signature, response.output);
300
301        let prediction = Prediction {
302            data: output,
303            lm_usage: response.usage,
304        };
305
306        // Store in cache if enabled
307        if lm.config.cache
308            && let Some(cache) = lm.cache_handler.as_ref()
309        {
310            let (tx, rx) = tokio::sync::mpsc::channel(1);
311            let cache_clone = cache.clone();
312            let inputs_clone = inputs.clone();
313
314            // Spawn the cache insert operation to avoid deadlock
315            tokio::spawn(async move {
316                let _ = cache_clone.lock().await.insert(inputs_clone, rx).await;
317            });
318
319            // Send the result to the cache
320            tx.send(CallResult {
321                prompt: prompt_str,
322                prediction: prediction.clone(),
323            })
324            .await
325            .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
326        }
327
328        Ok(prediction)
329    }
330}