1use anyhow::{Context, Result};
2use uuid::Uuid;
3
4use crate::{
5 chat_context::ChatContext, chat_response::ChatResponse,
6 function_specification::FunctionSpecification, message::Message,
7};
8
9const DEFAULT_MODEL: &str = "gpt-3.5-turbo-0613";
10const URL: &str = "https://api.openai.com/v1/chat/completions";
11
12pub struct ChatGPTBuilder {
14 model: Option<String>,
15 openai_api_token: Option<String>,
16 session_id: Option<String>,
17 chat_context: Option<ChatContext>,
18}
19
20impl ChatGPTBuilder {
21 pub fn new() -> Self {
22 ChatGPTBuilder {
23 model: None,
24 openai_api_token: None,
25 session_id: None,
26 chat_context: None,
27 }
28 }
29
30 pub fn model(mut self, model: String) -> Self {
31 self.model = Some(model);
32 self
33 }
34
35 pub fn openai_api_token(mut self, openai_api_token: String) -> Self {
36 self.openai_api_token = Some(openai_api_token);
37 self
38 }
39
40 pub fn session_id(mut self, session_id: String) -> Self {
41 self.session_id = Some(session_id);
42 self
43 }
44
45 pub fn chat_context(mut self, chat_context: ChatContext) -> Self {
46 self.chat_context = Some(chat_context);
47 self
48 }
49
50 pub fn build(self) -> Result<ChatGPT> {
51 let client = reqwest::Client::new();
52 let model = if let Some(m) = self.model {
53 m
54 } else {
55 DEFAULT_MODEL.to_string()
56 };
57 let openai_api_token = self
58 .openai_api_token
59 .context("OpenAI API token is missing")?;
60 let session_id = if let Some(s) = self.session_id {
61 s
62 } else {
63 Uuid::new_v4().to_string()
64 };
65 let chat_context = if let Some(c) = self.chat_context {
66 c
67 } else {
68 let mut c = ChatContext::new(model.clone());
69 c.model = model.clone();
70 c
71 };
72
73 Ok(ChatGPT {
74 client,
75 model,
76 openai_api_token,
77 session_id,
78 chat_context,
79 })
80 }
81}
82
83pub struct ChatGPT {
85 client: reqwest::Client,
86 pub model: String,
87 openai_api_token: String,
88 pub session_id: String,
89 pub chat_context: ChatContext,
90}
91
92impl ChatGPT {
93 pub fn new(
119 client: reqwest::Client,
120 model: String,
121 openai_api_token: String,
122 session_id: String,
123 chat_context: ChatContext,
124 ) -> Result<ChatGPT> {
125 Ok(ChatGPT {
126 client,
127 model,
128 openai_api_token,
129 session_id,
130 chat_context,
131 })
132 }
133
134 pub async fn completion(&mut self) -> Result<ChatResponse> {
145 let response = self
146 .client
147 .post(URL)
148 .bearer_auth(&self.openai_api_token)
149 .header("Content-Type", "application/json")
150 .body(self.chat_context.to_string())
152 .send()
153 .await
154 .context(format!("Failed to receive the response from {}", URL))?
155 .text()
156 .await
157 .context("Failed to retrieve the content of the response")?;
158
159 let answer = parse_removing_newlines(response)?;
160 Ok(answer)
161 }
162
163 pub async fn completion_managed(&mut self, content: String) -> Result<ChatResponse> {
180 self.completion_with_user_content_updating_context(content)
181 .await
182 }
183
184 pub async fn completion_with_message(&mut self, message: Message) -> Result<ChatResponse> {
201 self.push_message(message);
202 self.completion().await
203 }
204
205 pub async fn completion_with_user_content(&mut self, content: String) -> Result<ChatResponse> {
220 let message = Message::new_user_message(content);
221 self.completion_with_message(message).await
222 }
223
224 pub async fn completion_with_user_content_updating_context(
239 &mut self,
240 content: String,
241 ) -> Result<ChatResponse> {
242 let message = Message::new_user_message(content);
243 self.completion_with_message_updating_context(message).await
244 }
245
246 pub async fn completion_with_message_updating_context(
267 &mut self,
268 message: Message,
269 ) -> Result<ChatResponse> {
270 self.push_message(message);
271 let response = self.completion().await?;
272 if let Some(choice) = response.choices.last() {
273 self.push_message(choice.message.clone());
274 };
275 Ok(response)
276 }
277
278 pub fn push_message(&mut self, message: Message) {
285 self.chat_context.push_message(message);
286 }
287
288 pub fn set_messages(&mut self, messages: Vec<Message>) {
296 self.chat_context.set_messages(messages);
297 }
298
299 pub fn push_function(&mut self, function: FunctionSpecification) {
306 self.chat_context.push_function(function);
307 }
308
309 pub fn set_functions(&mut self, functions: Vec<FunctionSpecification>) {
317 self.chat_context.set_functions(functions);
318 }
319
320 pub fn last_content(&self) -> Option<String> {
322 self.chat_context.last_content()
323 }
324
325 pub fn last_function(&self) -> Option<(String, String)> {
327 self.chat_context.last_function_call()
328 }
329}
330
331fn parse_removing_newlines(response: String) -> Result<ChatResponse> {
332 let r = response.replace("\n", "");
333 let response: ChatResponse = serde_json::from_str(&r).context(format!(
334 "Could not parse the response. The object to parse: \n{}",
335 r
336 ))?;
337 Ok(response)
338}
339
340#[cfg(test)]
341mod tests {
342 use std::collections::HashMap;
343
344 use crate::{function_specification::Parameters, message::FunctionCall};
345
346 use super::*;
347
348 #[test]
349 fn test_chat_gpt_new() {
350 let chat_gpt = ChatGPTBuilder::new()
351 .openai_api_token("123".to_string())
352 .build()
353 .expect("Failed to create ChatGPT");
354 assert_eq!(chat_gpt.session_id.len(), 36);
355 assert_eq!(chat_gpt.chat_context.model, DEFAULT_MODEL);
356 assert_eq!(chat_gpt.model, DEFAULT_MODEL);
357 }
358
359 #[test]
360 fn test_chat_gpt_new_with_everything() {
361 let chat_gpt = ChatGPTBuilder::new()
362 .session_id("session_id".to_string())
363 .model("model".to_string())
364 .openai_api_token("1234".to_string())
365 .build()
366 .expect("Failed to create ChatGPT");
367 assert_eq!(chat_gpt.session_id, "session_id");
368 assert_eq!(chat_gpt.openai_api_token, "1234");
369 assert_eq!(chat_gpt.chat_context.model, "model");
370 }
371
372 #[test]
373 fn test_chat_gpt_push_message() {
374 let mut chat_gpt = ChatGPTBuilder::new()
375 .openai_api_token("key".to_string())
376 .build()
377 .expect("Failed to create ChatGPT");
378 let message = Message::new_user_message("content".to_string());
379 chat_gpt.push_message(message);
380 assert_eq!(chat_gpt.chat_context.messages.len(), 1);
381 }
382
383 #[test]
384 fn test_chat_gpt_set_message() {
385 let mut chat_gpt = ChatGPTBuilder::new()
386 .openai_api_token("key".to_string())
387 .build()
388 .expect("Failed to create ChatGPT");
389 let message = Message::new_user_message("content".to_string());
390 chat_gpt.set_messages(vec![message]);
391 assert_eq!(chat_gpt.chat_context.messages.len(), 1);
392 }
393
394 #[test]
395 fn test_chat_gpt_push_function() {
396 let mut chat_gpt = ChatGPTBuilder::new()
397 .openai_api_token("key".to_string())
398 .build()
399 .expect("Failed to create ChatGPT");
400 let function = FunctionSpecification::new("function".to_string(), None, None);
401 chat_gpt.push_function(function);
402 assert_eq!(chat_gpt.chat_context.functions.len(), 1);
403 }
404
405 #[test]
406 fn test_chat_gpt_set_function() {
407 let mut chat_gpt = ChatGPTBuilder::new()
408 .openai_api_token("key".to_string())
409 .build()
410 .expect("Failed to create ChatGPT");
411 let function = FunctionSpecification::new(
412 "function".to_string(),
413 Some("Test function".to_string()),
414 Some(Parameters {
415 type_: "string".to_string(),
416 properties: HashMap::new(),
417 required: vec![],
418 }),
419 );
420 chat_gpt.set_functions(vec![function]);
421 assert_eq!(chat_gpt.chat_context.functions.len(), 1);
422
423 let function = chat_gpt
424 .chat_context
425 .functions
426 .get(0)
427 .expect("Failed to get the function");
428 assert_eq!(function.name, "function");
429 assert_eq!(
430 function
431 .description
432 .as_ref()
433 .expect("Failed to get the description"),
434 "Test function"
435 );
436 assert_eq!(
437 function
438 .parameters
439 .as_ref()
440 .expect("Failed to get the parameters")
441 .type_,
442 "string"
443 );
444 }
445
446 #[test]
447 fn test_parse_removing_newlines() {
448 use crate::message::FunctionCall;
449
450 let r = r#"{
451 "id": "chatcmpl-7Ut7jsNlTUO9k9L5kBF0uDAyG19pK",
452 "object": "chat.completion",
453 "created": 1687596091,
454 "model": "gpt-3.5-turbo-0613",
455 "choices": [
456 {
457 "index": 0,
458 "message": {
459 "role": "assistant",
460 "content": null,
461 "function_call": {
462 "name": "get_current_weather",
463 "arguments": "{\n \"location\": \"Madrid, Spain\"\n}"
464 }
465 },
466 "finish_reason": "function_call"
467 }
468 ],
469 "usage": {
470 "prompt_tokens": 90,
471 "completion_tokens": 19,
472 "total_tokens": 109
473 }
474}"#
475 .to_string();
476 let response = parse_removing_newlines(r).expect("Failed to parse");
477 let message = response
478 .choices
479 .first()
480 .expect("There is no choice")
481 .message
482 .clone();
483
484 assert_eq!(message.role, "assistant");
485 assert_eq!(message.content, None);
486 assert_eq!(message.name, None);
487 assert_eq!(
488 message.function_call,
489 Some(FunctionCall {
490 name: "get_current_weather".to_string(),
491 arguments: "{\n \"location\": \"Madrid, Spain\"\n}".to_string(),
492 })
493 );
494 }
495
496 #[test]
497 fn test_fix_context_when_function_replied_with_content() {
498 use crate::message::FunctionCall;
499
500 let r = r#"{"id":"chatcmpl-7VneSVRn9qJ1crw3m0V0kmnCq8Pnn","object":"chat.completion","created":1687813384,"choices":[{"index":0,"message":{"role":"assistant","function_call":{"name":"completion_managed","arguments":"{
501 \"content\": \"Hi, model!\"
502}"}},"finish_reason":"function_call"}],"usage":{"prompt_tokens":61,"completion_tokens":18,"total_tokens":79}}"#.to_string();
503 let response = parse_removing_newlines(r).expect("Failed to parse");
504 let message = response
505 .choices
506 .last()
507 .expect("There is no choice")
508 .message
509 .clone();
510
511 assert_eq!(message.role, "assistant");
512 assert_eq!(message.content, None);
513 assert_eq!(message.name, None);
514 assert_eq!(
515 message.function_call,
516 Some(FunctionCall {
517 name: "completion_managed".to_string(),
518 arguments: "{ \"content\": \"Hi, model!\"}".to_string(),
519 })
520 );
521 }
522
523 #[test]
524 fn test_last_content() {
525 let mut chat_gpt = ChatGPTBuilder::new()
526 .openai_api_token("key".to_string())
527 .build()
528 .expect("Failed to create ChatGPT");
529 let message = Message::new_user_message("content".to_string());
530 chat_gpt.push_message(message);
531 let message = Message::new_user_message("content2".to_string());
532 chat_gpt.push_message(message);
533 let message = Message::new_user_message("content3".to_string());
534 chat_gpt.push_message(message);
535 assert_eq!(chat_gpt.last_content(), Some("content3".to_string()));
536 }
537
538 #[test]
539 fn test_last_content_empty() {
540 let chat_gpt = ChatGPTBuilder::new()
541 .openai_api_token("key".to_string())
542 .build()
543 .expect("Failed to create ChatGPT");
544 assert_eq!(chat_gpt.last_content(), None);
545 }
546
547 #[test]
548 fn test_last_function() {
549 let mut chat_gpt = ChatGPTBuilder::new()
550 .openai_api_token("key".to_string())
551 .build()
552 .expect("Failed to create ChatGPT");
553 let mut msg = Message::new("function".to_string());
554 msg.set_function_call(FunctionCall {
555 name: "function".to_string(),
556 arguments: "1".to_string(),
557 });
558 chat_gpt.push_message(msg);
559 let mut msg = Message::new("function2".to_string());
560 msg.set_function_call(FunctionCall {
561 name: "function2".to_string(),
562 arguments: "2".to_string(),
563 });
564 chat_gpt.push_message(msg);
565 let mut msg = Message::new("function3".to_string());
566 msg.set_function_call(FunctionCall {
567 name: "function3".to_string(),
568 arguments: "3".to_string(),
569 });
570 chat_gpt.push_message(msg);
571 assert_eq!(
572 chat_gpt.last_function(),
573 Some(("function3".to_string(), "3".to_string()))
574 );
575 }
576}