1use std::time::Duration;
2
3use crate::domain::{
4 OpenAIAssistantResp, OpenAIMessageListResp, OpenAIMessageResp, OpenAIRunResp, OpenAIThreadResp,
5};
6use crate::enums::{OpenAIAssistantRole, OpenAIRunStatus};
7use crate::utils::sanitize_json_response;
8use crate::{constants::OPENAI_ASSISTANT_INSTRUCTIONS, models::OpenAIModels};
9use anyhow::{anyhow, Result};
10use log::error;
11use log::info;
12use reqwest::{header, Client};
13use schemars::{schema_for, JsonSchema};
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use serde_json::{json, Value};
17use tokio::time;
18use tokio::time::timeout;
19
20#[derive(Deserialize, Serialize, Debug, Clone)]
28pub struct OpenAIAssistant {
29 id: Option<String>,
30 thread_id: Option<String>,
31 run_id: Option<String>,
32 model: OpenAIModels,
33 instructions: String,
34 debug: bool,
35 api_key: String,
36}
37
38impl OpenAIAssistant {
39 pub async fn new(model: OpenAIModels, open_ai_key: &str, debug: bool) -> Result<Self> {
41 let mut new_assistant = OpenAIAssistant {
42 id: None,
43 thread_id: None,
44 run_id: None,
45 model,
46 instructions: OPENAI_ASSISTANT_INSTRUCTIONS.to_string(),
47 debug,
48 api_key: open_ai_key.to_string(),
49 };
50
51 new_assistant.create_assistant().await?;
53
54 new_assistant
56 .add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
57 .await?;
58
59 Ok(new_assistant)
60 }
61
62 async fn create_assistant(&mut self) -> Result<()> {
66 let assistant_url = "https://api.openai.com/v1/assistants";
68
69 let code_interpreter = json!({
70 "type": "retrieval",
71 });
72 let assistant_body = json!({
73 "instructions": self.instructions.clone(),
74 "model": self.model.as_str(),
75 "tools": vec![code_interpreter],
76 });
77
78 let client = Client::new();
80
81 let response = client
82 .post(assistant_url)
83 .header(header::CONTENT_TYPE, "application/json")
84 .header("OpenAI-Beta", "assistants=v1")
85 .bearer_auth(&self.api_key)
86 .json(&assistant_body)
87 .send()
88 .await?;
89
90 let response_status = response.status();
91 let response_text = response.text().await?;
92
93 if self.debug {
94 info!(
95 "[debug] OpenAI Assistant API response: [{}] {:#?}",
96 &response_status, &response_text
97 );
98 }
99
100 let response_deser: OpenAIAssistantResp =
102 serde_json::from_str(&response_text).map_err(|error| {
103 error!(
104 "[OpenAIAssistant] Assistant API response serialization error: {}",
105 &error
106 );
107 anyhow!("Error: {}", error)
108 })?;
109
110 self.id = Some(response_deser.id);
112
113 Ok(())
114 }
115
116 pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
120 mut self,
121 message: &str,
122 file_ids: &[String],
123 ) -> Result<T> {
124 let schema = schema_for!(T);
127 let schema_json: Value = serde_json::to_value(&schema)?;
128 let schema_string = serde_json::to_string(&schema_json).unwrap_or_default();
129
130 let schema_message = format!(
132 "Response should include only the data portion of a Json formatted as per the following schema: {}.
133 The response should only include well-formatted data, and not the schema itself.
134 Do not include any other words or characters, including the word 'json'. Only respond with the data.
135 You need to validate the Json before returning.",
136 schema_string
137 );
138 self.add_message(&schema_message, &Vec::new()).await?;
139
140 self.add_message(message, file_ids).await?;
142
143 self.start_run().await?;
145
146 let operation_timeout = Duration::from_secs(600); let poll_interval = Duration::from_secs(10);
149
150 let _result = timeout(operation_timeout, async {
151 let mut interval = time::interval(poll_interval);
152 loop {
153 interval.tick().await; match self.get_run_status().await {
155 Ok(resp) => match resp.status {
156 OpenAIRunStatus::Completed => {
158 break Ok(());
159 }
160 OpenAIRunStatus::RequiresAction
162 | OpenAIRunStatus::Cancelling
163 | OpenAIRunStatus::Cancelled
164 | OpenAIRunStatus::Failed
165 | OpenAIRunStatus::Expired => {
166 return Err(anyhow!("Failed to validate status of the run"));
167 }
168 _ => continue, },
170 Err(e) => return Err(e), }
172 }
173 })
174 .await?;
175
176 let messages = self.get_message_thread().await?;
178
179 messages
180 .into_iter()
181 .filter(|message| message.role == OpenAIAssistantRole::Assistant)
182 .find_map(|message| {
183 message.content.into_iter().find_map(|content| {
184 content.text.and_then(|text| {
185 let sanitized_text = sanitize_json_response(&text.value);
186 serde_json::from_str::<T>(&sanitized_text).ok()
187 })
188 })
189 })
190 .ok_or(anyhow!("No valid response form OpenAI Assistant found."))
191 }
192
193 pub async fn set_context<T: Serialize>(mut self, dataset_name: &str, data: &T) -> Result<Self> {
199 let serialized_data = if let Ok(json) = serde_json::to_string(&data) {
200 json
201 } else {
202 return Err(anyhow!("Unable serialize provided input data."));
203 };
204 let message = format!("'{dataset_name}'= {serialized_data}");
205 let file_ids = Vec::new();
206 self.add_message(&message, &file_ids).await?;
207 Ok(self)
208 }
209
210 async fn add_message(&mut self, message: &str, file_ids: &[String]) -> Result<()> {
214 let message = match file_ids.is_empty() {
216 false => json!({
217 "role": "user",
218 "content": message.to_string(),
219 "file_ids": file_ids.to_vec(),
220 }),
221 true => json!({
222 "role": "user",
223 "content": message.to_string(),
224 }),
225 };
226
227 match self.thread_id {
229 None => {
230 let body = json!({
231 "messages": vec![message],
232 });
233
234 self.create_thread(&body).await
235 }
236 Some(_) => self.add_message_thread(&message).await,
237 }
238 }
239
240 async fn create_thread(&mut self, body: &serde_json::Value) -> Result<()> {
244 let thread_url = "https://api.openai.com/v1/threads";
245
246 let client = Client::new();
248
249 let response = client
250 .post(thread_url)
251 .header(header::CONTENT_TYPE, "application/json")
252 .header("OpenAI-Beta", "assistants=v1")
253 .bearer_auth(&self.api_key)
254 .json(&body)
255 .send()
256 .await?;
257
258 let response_status = response.status();
259 let response_text = response.text().await?;
260
261 if self.debug {
262 info!(
263 "[debug] OpenAI Threads API response: [{}] {:#?}",
264 &response_status, &response_text
265 );
266 }
267
268 let response_deser: OpenAIThreadResp =
270 serde_json::from_str(&response_text).map_err(|error| {
271 error!(
272 "[OpenAIAssistant] Thread API response serialization error: {}",
273 &error
274 );
275 anyhow!("Error: {}", error)
276 })?;
277
278 self.thread_id = Some(response_deser.id);
280
281 Ok(())
282 }
283
284 async fn add_message_thread(&self, body: &serde_json::Value) -> Result<()> {
288 if self.thread_id.is_none() {
289 return Err(anyhow!("No active thread detected."));
290 }
291
292 let message_url = format!(
293 "https://api.openai.com/v1/threads/{}/messages",
294 self.thread_id.clone().unwrap_or_default()
295 );
296
297 let client = Client::new();
299
300 let response = client
301 .post(message_url)
302 .header(header::CONTENT_TYPE, "application/json")
303 .header("OpenAI-Beta", "assistants=v1")
304 .bearer_auth(&self.api_key)
305 .json(&body)
306 .send()
307 .await?;
308
309 let response_status = response.status();
310 let response_text = response.text().await?;
311
312 if self.debug {
313 info!(
314 "[debug] OpenAI Messages API response: [{}] {:#?}",
315 &response_status, &response_text
316 );
317 }
318
319 let _response_deser: OpenAIMessageResp =
321 serde_json::from_str(&response_text).map_err(|error| {
322 error!(
323 "[OpenAIAssistant] Messages API response serialization error: {}",
324 &error
325 );
326 anyhow!("Error: {}", error)
327 })?;
328
329 Ok(())
330 }
331
332 async fn get_message_thread(&self) -> Result<Vec<OpenAIMessageResp>> {
336 if self.thread_id.is_none() {
337 return Err(anyhow!("No active thread detected."));
338 }
339
340 let message_url = format!(
341 "https://api.openai.com/v1/threads/{}/messages",
342 self.thread_id.clone().unwrap_or_default()
343 );
344
345 let client = Client::new();
347
348 let response = client
349 .get(message_url)
350 .header(header::CONTENT_TYPE, "application/json")
351 .header("OpenAI-Beta", "assistants=v1")
352 .bearer_auth(&self.api_key)
353 .send()
354 .await?;
355
356 let response_status = response.status();
357 let response_text = response.text().await?;
358
359 if self.debug {
360 info!(
361 "[debug] OpenAI Messages API response: [{}] {:#?}",
362 &response_status, &response_text
363 );
364 }
365
366 let response_deser: OpenAIMessageListResp =
368 serde_json::from_str(&response_text).map_err(|error| {
369 error!(
370 "[OpenAIAssistant] Messages API response serialization error: {}",
371 &error
372 );
373 anyhow!("Error: {}", error)
374 })?;
375
376 Ok(response_deser.data)
377 }
378
379 async fn start_run(&mut self) -> Result<()> {
383 let assistant_id = if let Some(id) = self.id.clone() {
384 id
385 } else {
386 return Err(anyhow!("No active assistant detected."));
387 };
388
389 let thread_id = if let Some(id) = self.thread_id.clone() {
390 id
391 } else {
392 return Err(anyhow!("No active thread detected."));
393 };
394
395 let run_url = format!("https://api.openai.com/v1/threads/{}/runs", thread_id);
396
397 let body = json!({
398 "assistant_id": assistant_id,
399 });
400
401 let client = Client::new();
403
404 let response = client
405 .post(run_url)
406 .header(header::CONTENT_TYPE, "application/json")
407 .header("OpenAI-Beta", "assistants=v1")
408 .bearer_auth(&self.api_key)
409 .json(&body)
410 .send()
411 .await?;
412
413 let response_status = response.status();
414 let response_text = response.text().await?;
415
416 if self.debug {
417 info!(
418 "[debug] OpenAI Messages API response: [{}] {:#?}",
419 &response_status, &response_text
420 );
421 }
422
423 let response_deser: OpenAIRunResp =
425 serde_json::from_str(&response_text).map_err(|error| {
426 error!(
427 "[OpenAIAssistant] Run API response serialization error: {}",
428 &error
429 );
430 anyhow!("Error: {}", error)
431 })?;
432
433 self.run_id = Some(response_deser.id);
435
436 Ok(())
437 }
438
439 async fn get_run_status(&self) -> Result<OpenAIRunResp> {
443 let thread_id = if let Some(id) = self.thread_id.clone() {
444 id
445 } else {
446 return Err(anyhow!("No active thread detected."));
447 };
448
449 let run_id = if let Some(id) = self.run_id.clone() {
450 id
451 } else {
452 return Err(anyhow!("No active run detected."));
453 };
454
455 let run_url = format!("https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}");
456
457 let client = Client::new();
459
460 let response = client
461 .get(run_url)
462 .header(header::CONTENT_TYPE, "application/json")
463 .header("OpenAI-Beta", "assistants=v1")
464 .bearer_auth(&self.api_key)
465 .send()
466 .await?;
467
468 let response_status = response.status();
469 let response_text = response.text().await?;
470
471 if self.debug {
472 info!(
473 "[debug] OpenAI Run status API response: [{}] {:#?}",
474 &response_status, &response_text
475 );
476 }
477
478 let response_deser: OpenAIRunResp =
480 serde_json::from_str(&response_text).map_err(|error| {
481 error!(
482 "[OpenAIAssistant] Run API response serialization error: {}",
483 &error
484 );
485 anyhow!("Error: {}", error)
486 })?;
487
488 Ok(response_deser)
489 }
490}