dspy_rs/adapter/
chat.rs

1use anyhow::Result;
2use rig::tool::ToolDyn;
3use serde_json::{Value, json};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use super::Adapter;
8use crate::serde_utils::get_iter_from_value;
9use crate::{Cache, CallResult, Chat, Example, LM, Message, MetaSignature, Prediction};
10
11#[derive(Default, Clone)]
12pub struct ChatAdapter;
13
14fn get_type_hint(field: &Value) -> String {
15    let schema = &field["schema"];
16    let type_str = field["type"].as_str().unwrap_or("String");
17
18    // Check if schema exists and is not empty (either as string or object)
19    let has_schema = if let Some(s) = schema.as_str() {
20        !s.is_empty()
21    } else {
22        schema.is_object()
23    };
24
25    if !has_schema && type_str == "String" {
26        String::new()
27    } else {
28        format!(" (must be formatted as valid Rust {type_str})")
29    }
30}
31
32impl ChatAdapter {
33    fn get_field_attribute_list(
34        &self,
35        field_iter: impl Iterator<Item = (String, Value)>,
36    ) -> String {
37        let mut field_attributes = String::new();
38        for (i, (field_name, field)) in field_iter.enumerate() {
39            let data_type = field["type"].as_str().unwrap_or("String");
40            let desc = field["desc"].as_str().unwrap_or("");
41
42            field_attributes.push_str(format!("{}. `{field_name}` ({data_type})", i + 1).as_str());
43            if !desc.is_empty() {
44                field_attributes.push_str(format!(": {desc}").as_str());
45            }
46            field_attributes.push('\n');
47        }
48        field_attributes
49    }
50
51    fn get_field_structure(&self, field_iter: impl Iterator<Item = (String, Value)>) -> String {
52        let mut field_structure = String::new();
53        for (field_name, field) in field_iter {
54            let schema = &field["schema"];
55            let data_type = field["type"].as_str().unwrap_or("String");
56
57            // Handle schema as either string or JSON object
58            let schema_prompt = if let Some(s) = schema.as_str() {
59                if s.is_empty() && data_type == "String" {
60                    "".to_string()
61                } else if !s.is_empty() {
62                    format!("\t# note: the value you produce must adhere to the JSON schema: {s}")
63                } else {
64                    format!("\t# note: the value you produce must be a single {data_type} value")
65                }
66            } else if schema.is_object() || schema.is_array() {
67                // Convert JSON object/array to string for display
68                let schema_str = schema.to_string();
69                format!(
70                    "\t# note: the value you produce must adhere to the JSON schema: {schema_str}"
71                )
72            } else if data_type == "String" {
73                "".to_string()
74            } else {
75                format!("\t# note: the value you produce must be a single {data_type} value")
76            };
77
78            field_structure.push_str(
79                format!("[[ ## {field_name} ## ]]\n{field_name}{schema_prompt}\n\n").as_str(),
80            );
81        }
82        field_structure
83    }
84
85    fn format_system_message(&self, signature: &dyn MetaSignature) -> String {
86        let field_description = self.format_field_description(signature);
87        let field_structure = self.format_field_structure(signature);
88        let task_description = self.format_task_description(signature);
89
90        format!("{field_description}\n{field_structure}\n{task_description}")
91    }
92
93    fn format_field_description(&self, signature: &dyn MetaSignature) -> String {
94        let input_field_description =
95            self.get_field_attribute_list(get_iter_from_value(&signature.input_fields()));
96        let output_field_description =
97            self.get_field_attribute_list(get_iter_from_value(&signature.output_fields()));
98
99        format!(
100            "Your input fields are:\n{input_field_description}\nYour output fields are:\n{output_field_description}"
101        )
102    }
103
104    fn format_field_structure(&self, signature: &dyn MetaSignature) -> String {
105        let input_field_structure =
106            self.get_field_structure(get_iter_from_value(&signature.input_fields()));
107        let output_field_structure =
108            self.get_field_structure(get_iter_from_value(&signature.output_fields()));
109
110        format!(
111            "All interactions will be structured in the following way, with the appropriate values filled in.\n\n{input_field_structure}{output_field_structure}[[ ## completed ## ]]\n"
112        )
113    }
114
115    fn format_task_description(&self, signature: &dyn MetaSignature) -> String {
116        let instruction = if signature.instruction().is_empty() {
117            format!(
118                "Given the fields {}, produce the fields {}.",
119                signature
120                    .input_fields()
121                    .as_object()
122                    .unwrap()
123                    .keys()
124                    .map(|k| format!("`{k}`"))
125                    .collect::<Vec<String>>()
126                    .join(", "),
127                signature
128                    .output_fields()
129                    .as_object()
130                    .unwrap()
131                    .keys()
132                    .map(|k| format!("`{k}`"))
133                    .collect::<Vec<String>>()
134                    .join(", ")
135            )
136        } else {
137            signature.instruction().clone()
138        };
139
140        format!("In adhering to this structure, your objective is:\n\t{instruction}")
141    }
142
143    fn format_user_message(&self, signature: &dyn MetaSignature, inputs: &Example) -> String {
144        let mut input_str = String::new();
145        for (field_name, _) in get_iter_from_value(&signature.input_fields()) {
146            let field_value = inputs.get(field_name.as_str(), None);
147            // Extract the actual string value if it's a JSON string, otherwise use as is
148            let field_value_str = if let Some(s) = field_value.as_str() {
149                s.to_string()
150            } else {
151                field_value.to_string()
152            };
153
154            input_str
155                .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
156        }
157
158        let first_output_field = signature
159            .output_fields()
160            .as_object()
161            .unwrap()
162            .keys()
163            .next()
164            .unwrap()
165            .clone();
166        let first_output_field_value = signature
167            .output_fields()
168            .as_object()
169            .unwrap()
170            .get(&first_output_field)
171            .unwrap()
172            .clone();
173
174        let type_hint = get_type_hint(&first_output_field_value);
175
176        let mut user_message = format!(
177            "Respond with the corresponding output fields, starting with the field `{first_output_field}`{type_hint},"
178        );
179        for (field_name, field) in get_iter_from_value(&signature.output_fields()).skip(1) {
180            user_message
181                .push_str(format!(" then `{field_name}`{},", get_type_hint(&field)).as_str());
182        }
183        user_message.push_str(" and then ending with the marker for `completed`.");
184
185        format!("{input_str}{user_message}")
186    }
187
188    fn format_assistant_message(&self, signature: &dyn MetaSignature, outputs: &Example) -> String {
189        let mut assistant_message = String::new();
190        for (field_name, _) in get_iter_from_value(&signature.output_fields()) {
191            let field_value = outputs.get(field_name.as_str(), None);
192            // Extract the actual string value if it's a JSON string, otherwise use as is
193            let field_value_str = if let Some(s) = field_value.as_str() {
194                s.to_string()
195            } else {
196                field_value.to_string()
197            };
198
199            assistant_message
200                .push_str(format!("[[ ## {field_name} ## ]]\n{field_value_str}\n\n",).as_str());
201        }
202        assistant_message.push_str("[[ ## completed ## ]]\n");
203        assistant_message
204    }
205
206    fn format_demos(&self, signature: &dyn MetaSignature, demos: &Vec<Example>) -> Chat {
207        let mut chat = Chat::new(vec![]);
208
209        for demo in demos {
210            let user_message = self.format_user_message(signature, demo);
211            let assistant_message = self.format_assistant_message(signature, demo);
212            chat.push("user", &user_message);
213            chat.push("assistant", &assistant_message);
214        }
215
216        chat
217    }
218}
219
220#[async_trait::async_trait]
221impl Adapter for ChatAdapter {
222    fn format(&self, signature: &dyn MetaSignature, inputs: Example) -> Chat {
223        let system_message = self.format_system_message(signature);
224        let user_message = self.format_user_message(signature, &inputs);
225
226        let demos = signature.demos();
227        let demos = self.format_demos(signature, &demos);
228
229        let mut chat = Chat::new(vec![]);
230        chat.push("system", &system_message);
231        chat.push_all(&demos);
232        chat.push("user", &user_message);
233
234        chat
235    }
236
237    fn parse_response(
238        &self,
239        signature: &dyn MetaSignature,
240        response: Message,
241    ) -> HashMap<String, Value> {
242        let mut output = HashMap::new();
243
244        let response_content = response.content();
245
246        for (field_name, field) in get_iter_from_value(&signature.output_fields()) {
247            let field_value = response_content
248                .split(format!("[[ ## {field_name} ## ]]\n").as_str())
249                .nth(1);
250
251            if field_value.is_none() {
252                continue; // Skip field if not found in response
253            }
254            let field_value = field_value.unwrap();
255
256            let extracted_field = field_value.split("[[ ## ").nth(0).unwrap().trim();
257            let data_type = field["type"].as_str().unwrap();
258            let schema = &field["schema"];
259
260            // Check if schema exists (as string or object)
261            let has_schema = if let Some(s) = schema.as_str() {
262                !s.is_empty()
263            } else {
264                schema.is_object() || schema.is_array()
265            };
266
267            if !has_schema && data_type == "String" {
268                output.insert(field_name.clone(), json!(extracted_field));
269            } else {
270                output.insert(
271                    field_name.clone(),
272                    serde_json::from_str(extracted_field).unwrap(),
273                );
274            }
275        }
276
277        output
278    }
279
280    async fn call(
281        &self,
282        lm: Arc<LM>,
283        signature: &dyn MetaSignature,
284        inputs: Example,
285        tools: Vec<Arc<dyn ToolDyn>>,
286    ) -> Result<Prediction> {
287        // Check cache first (release lock immediately after checking)
288        if lm.cache
289            && let Some(cache) = lm.cache_handler.as_ref()
290        {
291            let cache_key = inputs.clone();
292            if let Some(cached) = cache.lock().await.get(cache_key).await? {
293                return Ok(cached);
294            }
295        }
296
297        let messages = self.format(signature, inputs.clone());
298        let response = lm.call(messages, tools).await?;
299        let prompt_str = response.chat.to_json().to_string();
300
301        let mut output = self.parse_response(signature, response.output);
302        if !response.tool_calls.is_empty() {
303            output.insert(
304                "tool_calls".to_string(),
305                response
306                    .tool_calls
307                    .into_iter()
308                    .map(|call| json!(call))
309                    .collect::<Value>(),
310            );
311            output.insert(
312                "tool_executions".to_string(),
313                response
314                    .tool_executions
315                    .into_iter()
316                    .map(|execution| json!(execution))
317                    .collect::<Value>(),
318            );
319        }
320
321        let prediction = Prediction {
322            data: output,
323            lm_usage: response.usage,
324        };
325
326        // Store in cache if enabled
327        if lm.cache
328            && let Some(cache) = lm.cache_handler.as_ref()
329        {
330            let (tx, rx) = tokio::sync::mpsc::channel(1);
331            let cache_clone = cache.clone();
332            let inputs_clone = inputs.clone();
333
334            // Spawn the cache insert operation to avoid deadlock
335            tokio::spawn(async move {
336                let _ = cache_clone.lock().await.insert(inputs_clone, rx).await;
337            });
338
339            // Send the result to the cache
340            tx.send(CallResult {
341                prompt: prompt_str,
342                prediction: prediction.clone(),
343            })
344            .await
345            .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
346        }
347
348        Ok(prediction)
349    }
350}