openai_safe/
assistant.rs

1use std::time::Duration;
2
3use crate::domain::{
4    OpenAIAssistantResp, OpenAIMessageListResp, OpenAIMessageResp, OpenAIRunResp, OpenAIThreadResp,
5};
6use crate::enums::{OpenAIAssistantRole, OpenAIRunStatus};
7use crate::utils::sanitize_json_response;
8use crate::{constants::OPENAI_ASSISTANT_INSTRUCTIONS, models::OpenAIModels};
9use anyhow::{anyhow, Result};
10use log::error;
11use log::info;
12use reqwest::{header, Client};
13use schemars::{schema_for, JsonSchema};
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use serde_json::{json, Value};
17use tokio::time;
18use tokio::time::timeout;
19
20/// [OpenAI Docs](https://platform.openai.com/docs/assistants/overview)
21///
22/// The Assistants API allows you to build AI assistants within your own applications.
23/// An Assistant has instructions and can leverage models, tools, and knowledge to respond to user queries.
24/// The Assistants API currently supports three types of tools: Code Interpreter, Retrieval, and Function calling.
25/// In the future, we plan to release more OpenAI-built tools, and allow you to provide
26/// your own tools on our platform.
27#[derive(Deserialize, Serialize, Debug, Clone)]
28pub struct OpenAIAssistant {
29    id: Option<String>,
30    thread_id: Option<String>,
31    run_id: Option<String>,
32    model: OpenAIModels,
33    instructions: String,
34    debug: bool,
35    api_key: String,
36}
37
38impl OpenAIAssistant {
39    //Constructor
40    pub async fn new(model: OpenAIModels, open_ai_key: &str, debug: bool) -> Result<Self> {
41        let mut new_assistant = OpenAIAssistant {
42            id: None,
43            thread_id: None,
44            run_id: None,
45            model,
46            instructions: OPENAI_ASSISTANT_INSTRUCTIONS.to_string(),
47            debug,
48            api_key: open_ai_key.to_string(),
49        };
50
51        //Call OpenAI API to get an ID for the assistant
52        new_assistant.create_assistant().await?;
53
54        //Add first message thus initializing the thread
55        new_assistant
56            .add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
57            .await?;
58
59        Ok(new_assistant)
60    }
61
62    /*
63     * This function creates an Assistant and updates the ID of the OpenAIAssistant struct
64     */
65    async fn create_assistant(&mut self) -> Result<()> {
66        //Get the API url
67        let assistant_url = "https://api.openai.com/v1/assistants";
68
69        let code_interpreter = json!({
70            "type": "retrieval",
71        });
72        let assistant_body = json!({
73            "instructions": self.instructions.clone(),
74            "model": self.model.as_str(),
75            "tools": vec![code_interpreter],
76        });
77
78        //Make the API call
79        let client = Client::new();
80
81        let response = client
82            .post(assistant_url)
83            .header(header::CONTENT_TYPE, "application/json")
84            .header("OpenAI-Beta", "assistants=v1")
85            .bearer_auth(&self.api_key)
86            .json(&assistant_body)
87            .send()
88            .await?;
89
90        let response_status = response.status();
91        let response_text = response.text().await?;
92
93        if self.debug {
94            info!(
95                "[debug] OpenAI Assistant API response: [{}] {:#?}",
96                &response_status, &response_text
97            );
98        }
99
100        //Deserialize the string response into the Assistant object
101        let response_deser: OpenAIAssistantResp =
102            serde_json::from_str(&response_text).map_err(|error| {
103                error!(
104                    "[OpenAIAssistant] Assistant API response serialization error: {}",
105                    &error
106                );
107                anyhow!("Error: {}", error)
108            })?;
109
110        //Add correct ID to self
111        self.id = Some(response_deser.id);
112
113        Ok(())
114    }
115
116    /*
117     * This function performs all the orchestration needed to submit a prompt and get and answer
118     */
119    pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
120        mut self,
121        message: &str,
122        file_ids: &[String],
123    ) -> Result<T> {
124        //Step 1: Instruct the Assistant to answer with the right Json format
125        //Output schema is extracted from the type parameter
126        let schema = schema_for!(T);
127        let schema_json: Value = serde_json::to_value(&schema)?;
128        let schema_string = serde_json::to_string(&schema_json).unwrap_or_default();
129
130        //We instruct Assistant to answer with that schema
131        let schema_message = format!(
132            "Response should include only the data portion of a Json formatted as per the following schema: {}. 
133            The response should only include well-formatted data, and not the schema itself.
134            Do not include any other words or characters, including the word 'json'. Only respond with the data. 
135            You need to validate the Json before returning.",
136            schema_string
137        );
138        self.add_message(&schema_message, &Vec::new()).await?;
139
140        //Step 2: Add user message and files to thread
141        self.add_message(message, file_ids).await?;
142
143        //Step 3: Kick off processing (aka Run)
144        self.start_run().await?;
145
146        //Step 4: Check in on the status of the run
147        let operation_timeout = Duration::from_secs(600); // Timeout for the whole operation
148        let poll_interval = Duration::from_secs(10);
149
150        let _result = timeout(operation_timeout, async {
151            let mut interval = time::interval(poll_interval);
152            loop {
153                interval.tick().await; // Wait for the next interval tick
154                match self.get_run_status().await {
155                    Ok(resp) => match resp.status {
156                        //Completed successfully. Time to get results.
157                        OpenAIRunStatus::Completed => {
158                            break Ok(());
159                        }
160                        //TODO: We will need better handling of requires_action
161                        OpenAIRunStatus::RequiresAction
162                        | OpenAIRunStatus::Cancelling
163                        | OpenAIRunStatus::Cancelled
164                        | OpenAIRunStatus::Failed
165                        | OpenAIRunStatus::Expired => {
166                            return Err(anyhow!("Failed to validate status of the run"));
167                        }
168                        _ => continue, // Keep polling if in_progress or queued
169                    },
170                    Err(e) => return Err(e), // Break on error
171                }
172            }
173        })
174        .await?;
175
176        //Step 5: Get all messages posted on the thread. This should now include response from the Assistant
177        let messages = self.get_message_thread().await?;
178
179        messages
180            .into_iter()
181            .filter(|message| message.role == OpenAIAssistantRole::Assistant)
182            .find_map(|message| {
183                message.content.into_iter().find_map(|content| {
184                    content.text.and_then(|text| {
185                        let sanitized_text = sanitize_json_response(&text.value);
186                        serde_json::from_str::<T>(&sanitized_text).ok()
187                    })
188                })
189            })
190            .ok_or(anyhow!("No valid response form OpenAI Assistant found."))
191    }
192
193    ///
194    /// This method can be used to provide data that will be used as context for the prompt.
195    /// Using this function you can provide multiple sets of context data by calling it multiple times. New values will be as messages to the thread
196    /// It accepts any struct that implements the Serialize trait.
197    ///
198    pub async fn set_context<T: Serialize>(mut self, dataset_name: &str, data: &T) -> Result<Self> {
199        let serialized_data = if let Ok(json) = serde_json::to_string(&data) {
200            json
201        } else {
202            return Err(anyhow!("Unable serialize provided input data."));
203        };
204        let message = format!("'{dataset_name}'= {serialized_data}");
205        let file_ids = Vec::new();
206        self.add_message(&message, &file_ids).await?;
207        Ok(self)
208    }
209
210    /*
211     * This function creates a Thread and updates the thread_id of the OpenAIAssistant struct
212     */
213    async fn add_message(&mut self, message: &str, file_ids: &[String]) -> Result<()> {
214        //Prepare the body that is to be send to OpenAI APIs
215        let message = match file_ids.is_empty() {
216            false => json!({
217                "role": "user",
218                "content": message.to_string(),
219                "file_ids": file_ids.to_vec(),
220            }),
221            true => json!({
222                "role": "user",
223                "content": message.to_string(),
224            }),
225        };
226
227        //If there is no thread_id we need to create one
228        match self.thread_id {
229            None => {
230                let body = json!({
231                    "messages": vec![message],
232                });
233
234                self.create_thread(&body).await
235            }
236            Some(_) => self.add_message_thread(&message).await,
237        }
238    }
239
240    /*
241     * This function creates a Thread and updates the thread_id of the OpenAIAssistant struct
242     */
243    async fn create_thread(&mut self, body: &serde_json::Value) -> Result<()> {
244        let thread_url = "https://api.openai.com/v1/threads";
245
246        //Make the API call
247        let client = Client::new();
248
249        let response = client
250            .post(thread_url)
251            .header(header::CONTENT_TYPE, "application/json")
252            .header("OpenAI-Beta", "assistants=v1")
253            .bearer_auth(&self.api_key)
254            .json(&body)
255            .send()
256            .await?;
257
258        let response_status = response.status();
259        let response_text = response.text().await?;
260
261        if self.debug {
262            info!(
263                "[debug] OpenAI Threads API response: [{}] {:#?}",
264                &response_status, &response_text
265            );
266        }
267
268        //Deserialize the string response into the Thread object
269        let response_deser: OpenAIThreadResp =
270            serde_json::from_str(&response_text).map_err(|error| {
271                error!(
272                    "[OpenAIAssistant] Thread API response serialization error: {}",
273                    &error
274                );
275                anyhow!("Error: {}", error)
276            })?;
277
278        //Add thread_id to self
279        self.thread_id = Some(response_deser.id);
280
281        Ok(())
282    }
283
284    /*
285     * This function adds a message to an existing thread
286     */
287    async fn add_message_thread(&self, body: &serde_json::Value) -> Result<()> {
288        if self.thread_id.is_none() {
289            return Err(anyhow!("No active thread detected."));
290        }
291
292        let message_url = format!(
293            "https://api.openai.com/v1/threads/{}/messages",
294            self.thread_id.clone().unwrap_or_default()
295        );
296
297        //Make the API call
298        let client = Client::new();
299
300        let response = client
301            .post(message_url)
302            .header(header::CONTENT_TYPE, "application/json")
303            .header("OpenAI-Beta", "assistants=v1")
304            .bearer_auth(&self.api_key)
305            .json(&body)
306            .send()
307            .await?;
308
309        let response_status = response.status();
310        let response_text = response.text().await?;
311
312        if self.debug {
313            info!(
314                "[debug] OpenAI Messages API response: [{}] {:#?}",
315                &response_status, &response_text
316            );
317        }
318
319        //Deserialize the string response into the Message object to confirm if there were any errors
320        let _response_deser: OpenAIMessageResp =
321            serde_json::from_str(&response_text).map_err(|error| {
322                error!(
323                    "[OpenAIAssistant] Messages API response serialization error: {}",
324                    &error
325                );
326                anyhow!("Error: {}", error)
327            })?;
328
329        Ok(())
330    }
331
332    /*
333     * This function gets all message posted to an existing thread
334     */
335    async fn get_message_thread(&self) -> Result<Vec<OpenAIMessageResp>> {
336        if self.thread_id.is_none() {
337            return Err(anyhow!("No active thread detected."));
338        }
339
340        let message_url = format!(
341            "https://api.openai.com/v1/threads/{}/messages",
342            self.thread_id.clone().unwrap_or_default()
343        );
344
345        //Make the API call
346        let client = Client::new();
347
348        let response = client
349            .get(message_url)
350            .header(header::CONTENT_TYPE, "application/json")
351            .header("OpenAI-Beta", "assistants=v1")
352            .bearer_auth(&self.api_key)
353            .send()
354            .await?;
355
356        let response_status = response.status();
357        let response_text = response.text().await?;
358
359        if self.debug {
360            info!(
361                "[debug] OpenAI Messages API response: [{}] {:#?}",
362                &response_status, &response_text
363            );
364        }
365
366        //Deserialize the string response into a vector of OpenAIMessageResp objects
367        let response_deser: OpenAIMessageListResp =
368            serde_json::from_str(&response_text).map_err(|error| {
369                error!(
370                    "[OpenAIAssistant] Messages API response serialization error: {}",
371                    &error
372                );
373                anyhow!("Error: {}", error)
374            })?;
375
376        Ok(response_deser.data)
377    }
378
379    /*
380     * This function starts an assistant run
381     */
382    async fn start_run(&mut self) -> Result<()> {
383        let assistant_id = if let Some(id) = self.id.clone() {
384            id
385        } else {
386            return Err(anyhow!("No active assistant detected."));
387        };
388
389        let thread_id = if let Some(id) = self.thread_id.clone() {
390            id
391        } else {
392            return Err(anyhow!("No active thread detected."));
393        };
394
395        let run_url = format!("https://api.openai.com/v1/threads/{}/runs", thread_id);
396
397        let body = json!({
398            "assistant_id": assistant_id,
399        });
400
401        //Make the API call
402        let client = Client::new();
403
404        let response = client
405            .post(run_url)
406            .header(header::CONTENT_TYPE, "application/json")
407            .header("OpenAI-Beta", "assistants=v1")
408            .bearer_auth(&self.api_key)
409            .json(&body)
410            .send()
411            .await?;
412
413        let response_status = response.status();
414        let response_text = response.text().await?;
415
416        if self.debug {
417            info!(
418                "[debug] OpenAI Messages API response: [{}] {:#?}",
419                &response_status, &response_text
420            );
421        }
422
423        //Deserialize the string response into the Message object to confirm if there were any errors
424        let response_deser: OpenAIRunResp =
425            serde_json::from_str(&response_text).map_err(|error| {
426                error!(
427                    "[OpenAIAssistant] Run API response serialization error: {}",
428                    &error
429                );
430                anyhow!("Error: {}", error)
431            })?;
432
433        //Update run_id
434        self.run_id = Some(response_deser.id);
435
436        Ok(())
437    }
438
439    /*
440     * This function checks the status of an assistant run
441     */
442    async fn get_run_status(&self) -> Result<OpenAIRunResp> {
443        let thread_id = if let Some(id) = self.thread_id.clone() {
444            id
445        } else {
446            return Err(anyhow!("No active thread detected."));
447        };
448
449        let run_id = if let Some(id) = self.run_id.clone() {
450            id
451        } else {
452            return Err(anyhow!("No active run detected."));
453        };
454
455        let run_url = format!("https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}");
456
457        //Make the API call
458        let client = Client::new();
459
460        let response = client
461            .get(run_url)
462            .header(header::CONTENT_TYPE, "application/json")
463            .header("OpenAI-Beta", "assistants=v1")
464            .bearer_auth(&self.api_key)
465            .send()
466            .await?;
467
468        let response_status = response.status();
469        let response_text = response.text().await?;
470
471        if self.debug {
472            info!(
473                "[debug] OpenAI Run status API response: [{}] {:#?}",
474                &response_status, &response_text
475            );
476        }
477
478        //Deserialize the string response into the Message object to confirm if there were any errors
479        let response_deser: OpenAIRunResp =
480            serde_json::from_str(&response_text).map_err(|error| {
481                error!(
482                    "[OpenAIAssistant] Run API response serialization error: {}",
483                    &error
484                );
485                anyhow!("Error: {}", error)
486            })?;
487
488        Ok(response_deser)
489    }
490}