1use super::{
4 openai_get, openai_get_with_query, openai_post, ApiResponseOrError, Credentials,
5 RequestPagination, Usage,
6};
7use crate::openai_request_stream;
8use derive_builder::Builder;
9use futures_util::StreamExt;
10use reqwest::Method;
11use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17pub type ChatCompletion = ChatCompletionGeneric<ChatCompletionChoice>;
19
20pub type ChatCompletionDelta = ChatCompletionGeneric<ChatCompletionChoiceDelta>;
22
23#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
24pub struct ChatCompletionGeneric<C> {
25 #[serde(default)]
26 pub id: String,
27 #[serde(default)]
28 pub object: String,
29 #[serde(default)]
30 pub created: u64,
31 #[serde(default)]
32 pub model: String,
33 #[serde(default = "default_empty_vec")]
34 pub choices: Vec<C>,
35 pub usage: Option<Usage>,
36}
37
38#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
39pub struct ChatCompletionChoice {
40 pub index: u64,
41 pub finish_reason: String,
42 pub message: ChatCompletionMessage,
43}
44
45#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
46pub struct ChatCompletionChoiceDelta {
47 pub index: u64,
48 pub finish_reason: Option<String>,
49 pub delta: ChatCompletionMessageDelta,
50}
51
52fn is_none_or_empty_vec<T>(opt: &Option<Vec<T>>) -> bool {
53 opt.as_ref().map(|v| v.is_empty()).unwrap_or(true)
54}
55
56#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)]
57pub struct ChatCompletionMessage {
58 pub role: ChatCompletionMessageRole,
60 pub content: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub name: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
72 pub function_call: Option<ChatCompletionFunctionCall>,
73 #[serde(skip_serializing_if = "Option::is_none")]
76 pub tool_call_id: Option<String>,
77 #[serde(skip_serializing_if = "is_none_or_empty_vec")]
81 pub tool_calls: Option<Vec<ToolCall>>,
82}
83
84#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
86pub struct ChatCompletionMessageDelta {
87 pub role: Option<ChatCompletionMessageRole>,
89 pub content: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub name: Option<String>,
94 #[serde(skip_serializing_if = "Option::is_none")]
98 pub function_call: Option<ChatCompletionFunctionCallDelta>,
99 #[serde(skip_serializing_if = "Option::is_none")]
102 pub tool_call_id: Option<String>,
103 #[serde(skip_serializing_if = "is_none_or_empty_vec")]
107 pub tool_calls: Option<Vec<ToolCall>>,
108}
109
110#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
111pub struct ToolCall {
112 pub id: String,
114 pub r#type: String,
116 pub function: ToolCallFunction,
118}
119
120#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
121pub struct ToolCallFunction {
122 pub name: String,
124 pub arguments: String,
130}
131
132#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
133pub struct ChatCompletionFunctionDefinition {
134 pub name: String,
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub description: Option<String>,
139 #[serde(skip_serializing_if = "Option::is_none")]
143 pub parameters: Option<Value>,
144}
145
146#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
147pub struct ChatCompletionFunctionCall {
148 pub name: String,
150 pub arguments: String,
153}
154
155#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
157pub struct ChatCompletionFunctionCallDelta {
158 pub name: Option<String>,
160 pub arguments: Option<String>,
163}
164
165#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)]
166#[serde(rename_all = "lowercase")]
167pub enum ChatCompletionMessageRole {
168 System,
169 User,
170 Assistant,
171 Function,
172 Tool,
173 Developer,
174}
175
176#[derive(Serialize, Builder, Debug, Clone)]
177#[builder(derive(Clone, Debug, PartialEq))]
178#[builder(pattern = "owned")]
179#[builder(name = "ChatCompletionBuilder")]
180#[builder(setter(strip_option, into))]
181pub struct ChatCompletionRequest {
182 model: String,
185 messages: Vec<ChatCompletionMessage>,
187 #[builder(default)]
191 #[serde(skip_serializing_if = "Option::is_none")]
192 temperature: Option<f32>,
193 #[builder(default)]
197 #[serde(skip_serializing_if = "Option::is_none")]
198 top_p: Option<f32>,
199 #[builder(default)]
201 #[serde(skip_serializing_if = "Option::is_none")]
202 n: Option<u8>,
203 #[builder(default)]
204 #[serde(skip_serializing_if = "Option::is_none")]
205 stream: Option<bool>,
206 #[builder(default)]
208 #[serde(skip_serializing_if = "Vec::is_empty")]
209 stop: Vec<String>,
210 #[builder(default)]
212 #[serde(skip_serializing_if = "Option::is_none")]
213 seed: Option<u64>,
214 #[builder(default)]
216 #[serde(skip_serializing_if = "Option::is_none")]
217 max_tokens: Option<u64>,
218 #[builder(default)]
221 #[serde(skip_serializing_if = "Option::is_none")]
222 max_completion_tokens: Option<u64>,
223 #[builder(default)]
227 #[serde(skip_serializing_if = "Option::is_none")]
228 presence_penalty: Option<f32>,
229 #[builder(default)]
233 #[serde(skip_serializing_if = "Option::is_none")]
234 frequency_penalty: Option<f32>,
235 #[builder(default)]
239 #[serde(skip_serializing_if = "Option::is_none")]
240 logit_bias: Option<HashMap<String, f32>>,
241 #[builder(default)]
243 #[serde(skip_serializing_if = "String::is_empty")]
244 user: String,
245 #[builder(default)]
252 #[serde(skip_serializing_if = "Vec::is_empty")]
253 functions: Vec<ChatCompletionFunctionDefinition>,
254 #[builder(default)]
264 #[serde(skip_serializing_if = "Option::is_none")]
265 function_call: Option<Value>,
266 #[builder(default)]
270 #[serde(skip_serializing_if = "Option::is_none")]
271 response_format: Option<ChatCompletionResponseFormat>,
272 #[serde(skip_serializing)]
274 #[builder(default)]
275 credentials: Option<Credentials>,
276 #[builder(default)]
279 #[serde(skip_serializing_if = "Option::is_none")]
280 venice_parameters: Option<VeniceParameters>,
281 #[serde(skip_serializing_if = "Option::is_none")]
283 #[builder(default)]
284 pub store: Option<bool>,
285}
286
287#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
288pub struct VeniceParameters {
289 pub include_venice_system_prompt: bool,
290}
291
292#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
293pub struct ChatCompletionResponseFormat {
294 #[serde(rename = "type")]
296 typ: String,
297}
298
299impl ChatCompletionResponseFormat {
300 pub fn json_object() -> Self {
301 ChatCompletionResponseFormat {
302 typ: "json_object".to_string(),
303 }
304 }
305
306 pub fn text() -> Self {
307 ChatCompletionResponseFormat {
308 typ: "text".to_string(),
309 }
310 }
311}
312
313impl<C> ChatCompletionGeneric<C> {
314 pub fn builder(
315 model: &str,
316 messages: impl Into<Vec<ChatCompletionMessage>>,
317 ) -> ChatCompletionBuilder {
318 ChatCompletionBuilder::create_empty()
319 .model(model)
320 .messages(messages)
321 }
322}
323
324#[derive(Serialize, Builder, Debug, Clone, Default)]
325#[builder(derive(Clone, Debug, PartialEq))]
326#[builder(pattern = "owned")]
327#[builder(name = "ChatCompletionMessagesRequestBuilder")]
328#[builder(setter(strip_option, into))]
329pub struct ChatCompletionMessagesRequest {
330 #[serde(skip_serializing)]
331 pub completion_id: String,
332
333 #[builder(default)]
334 #[serde(skip_serializing)]
335 pub credentials: Option<Credentials>,
336
337 #[builder(default)]
338 #[serde(flatten)]
339 pub pagination: RequestPagination,
340}
341
342#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
344pub struct ChatCompletionMessages {
345 pub data: Vec<ChatCompletionMessage>,
346 pub object: String,
347 pub first_id: Option<String>,
348 pub last_id: Option<String>,
349 pub has_more: bool,
350}
351
352impl ChatCompletion {
353 pub async fn create(request: ChatCompletionRequest) -> ApiResponseOrError<Self> {
354 let credentials_opt = request.credentials.clone();
355 openai_post("chat/completions", &request, credentials_opt).await
356 }
357
358 pub async fn get(id: &str, credentials: Credentials) -> ApiResponseOrError<Self> {
360 let route = format!("chat/completions/{}", id);
361 openai_get(route.as_str(), Some(credentials)).await
362 }
363}
364
365impl ChatCompletionDelta {
366 pub async fn create(
367 request: ChatCompletionRequest,
368 ) -> Result<Receiver<Self>, CannotCloneRequestError> {
369 let credentials_opt = request.credentials.clone();
370 let stream = openai_request_stream(
371 Method::POST,
372 "chat/completions",
373 |r| r.json(&request),
374 credentials_opt,
375 )
376 .await?;
377 let (tx, rx) = channel::<Self>(32);
378 tokio::spawn(forward_deserialized_chat_response_stream(stream, tx));
379 Ok(rx)
380 }
381
382 pub fn merge(
384 &mut self,
385 other: ChatCompletionDelta,
386 ) -> Result<(), ChatCompletionDeltaMergeError> {
387 if other.id.ne(&self.id) {
388 return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds);
389 }
390 for other_choice in other.choices.iter() {
391 for choice in self.choices.iter_mut() {
392 if choice.index != other_choice.index {
393 continue;
394 }
395 choice.merge(other_choice)?;
396 }
397 }
398 Ok(())
399 }
400}
401
402impl ChatCompletionChoiceDelta {
403 pub fn merge(
404 &mut self,
405 other: &ChatCompletionChoiceDelta,
406 ) -> Result<(), ChatCompletionDeltaMergeError> {
407 if self.index != other.index {
408 return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices);
409 }
410 if self.delta.role.is_none() {
411 if let Some(other_role) = other.delta.role {
412 self.delta.role = Some(other_role);
414 }
415 }
416 if self.delta.name.is_none() {
417 if let Some(other_name) = &other.delta.name {
418 self.delta.name = Some(other_name.clone());
420 }
421 }
422 match self.delta.content.as_mut() {
424 Some(content) => {
425 match &other.delta.content {
426 Some(other_content) => {
427 content.push_str(other_content)
429 }
430 None => {}
431 }
432 }
433 None => {
434 match &other.delta.content {
435 Some(other_content) => {
436 self.delta.content = Some(other_content.clone());
438 }
439 None => {}
440 }
441 }
442 };
443
444 match self.delta.function_call.as_mut() {
448 Some(function_call) => {
449 match &other.delta.function_call {
450 Some(other_function_call) => {
451 match (&mut function_call.arguments, &other_function_call.arguments) {
453 (Some(function_call), Some(other_function_call)) => {
454 function_call.push_str(&other_function_call);
455 }
456 (None, Some(other_function_call)) => {
457 function_call.arguments = Some(other_function_call.clone());
458 }
459 _ => {}
460 }
461 }
462 None => {}
463 }
464 }
465 None => {
466 match &other.delta.function_call {
467 Some(other_function_call) => {
468 self.delta.function_call = Some(other_function_call.clone());
470 }
471 None => {}
472 }
473 }
474 };
475 Ok(())
476 }
477}
478
479impl From<ChatCompletionDelta> for ChatCompletion {
480 fn from(delta: ChatCompletionDelta) -> Self {
481 ChatCompletion {
482 id: delta.id,
483 object: delta.object,
484 created: delta.created,
485 model: delta.model,
486 usage: delta.usage,
487 choices: delta
488 .choices
489 .iter()
490 .map(|choice| ChatCompletionChoice {
491 index: choice.index,
492 finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason),
493 message: ChatCompletionMessage {
494 role: choice
495 .delta
496 .role
497 .unwrap_or_else(|| ChatCompletionMessageRole::System),
498 content: choice.delta.content.clone(),
499 name: choice.delta.name.clone(),
500 function_call: choice.delta.function_call.clone().map(|f| f.into()),
501 tool_call_id: None,
502 tool_calls: Some(Vec::new()),
503 },
504 })
505 .collect(),
506 }
507 }
508}
509
510impl From<ChatCompletionFunctionCallDelta> for ChatCompletionFunctionCall {
511 fn from(delta: ChatCompletionFunctionCallDelta) -> Self {
512 ChatCompletionFunctionCall {
513 name: delta.name.unwrap_or("".to_string()),
514 arguments: delta.arguments.unwrap_or_default(),
515 }
516 }
517}
518
519impl ChatCompletionMessages {
520 pub fn builder(completion_id: String) -> ChatCompletionMessagesRequestBuilder {
522 ChatCompletionMessagesRequestBuilder::create_empty()
523 .completion_id(completion_id.to_string())
524 }
525
526 pub async fn fetch(
528 request: ChatCompletionMessagesRequest,
529 ) -> ApiResponseOrError<ChatCompletionMessages> {
530 let route = format!("chat/completions/{}/messages", request.completion_id);
531 let credentials = request.credentials.clone();
532 openai_get_with_query(route.as_str(), &request, credentials).await
533 }
534}
535
536#[derive(Debug)]
537pub enum ChatCompletionDeltaMergeError {
538 DifferentCompletionIds,
539 DifferentCompletionChoiceIndices,
540 FunctionCallArgumentTypeMismatch,
541}
542
543impl std::fmt::Display for ChatCompletionDeltaMergeError {
544 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545 match self {
546 ChatCompletionDeltaMergeError::DifferentCompletionIds => {
547 f.write_str("Different completion IDs")
548 }
549 ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices => {
550 f.write_str("Different completion choice indices")
551 }
552 ChatCompletionDeltaMergeError::FunctionCallArgumentTypeMismatch => {
553 f.write_str("Function call argument type mismatch")
554 }
555 }
556 }
557}
558
559impl std::error::Error for ChatCompletionDeltaMergeError {}
560
561async fn forward_deserialized_chat_response_stream(
562 mut stream: EventSource,
563 tx: Sender<ChatCompletionDelta>,
564) -> anyhow::Result<()> {
565 while let Some(event) = stream.next().await {
566 let event = event?;
567 match event {
568 Event::Message(event) => {
569 let completion = serde_json::from_str::<ChatCompletionDelta>(&event.data)?;
570 tx.send(completion).await?;
571 }
572 _ => {}
573 }
574 }
575 Ok(())
576}
577
578impl ChatCompletionBuilder {
579 pub async fn create(self) -> ApiResponseOrError<ChatCompletion> {
580 ChatCompletion::create(self.build().unwrap()).await
581 }
582
583 pub async fn create_stream(
584 mut self,
585 ) -> Result<Receiver<ChatCompletionDelta>, CannotCloneRequestError> {
586 self.stream = Some(Some(true));
587 ChatCompletionDelta::create(self.build().unwrap()).await
588 }
589}
590
591impl ChatCompletionMessagesRequestBuilder {
592 pub async fn fetch(self) -> ApiResponseOrError<ChatCompletionMessages> {
594 ChatCompletionMessages::fetch(self.build().unwrap()).await
595 }
596}
597
598fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
599 match string {
600 Some(value) => value.clone(),
601 None => "".to_string(),
602 }
603}
604
605impl Default for ChatCompletionMessageRole {
606 fn default() -> Self {
607 Self::User
608 }
609}
610
611fn default_empty_vec<C>() -> Vec<C> {
612 Vec::new()
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use dotenvy::dotenv;
619 use std::time::Duration;
620 use tokio::time::sleep;
621
622 #[tokio::test]
623 async fn chat() {
624 dotenv().ok();
625 let credentials = Credentials::from_env();
626
627 let chat_completion = ChatCompletion::builder(
628 "gpt-3.5-turbo",
629 [ChatCompletionMessage {
630 role: ChatCompletionMessageRole::User,
631 content: Some("Hello!".to_string()),
632 name: None,
633 function_call: None,
634 tool_call_id: None,
635 tool_calls: Some(Vec::new()),
636 }],
637 )
638 .temperature(0.0)
639 .response_format(ChatCompletionResponseFormat::text())
640 .credentials(credentials)
641 .create()
642 .await
643 .unwrap();
644
645 assert_eq!(
646 chat_completion
647 .choices
648 .first()
649 .unwrap()
650 .message
651 .content
652 .as_ref()
653 .unwrap(),
654 "Hello! How can I assist you today?"
655 );
656 }
657
658 #[tokio::test]
661 async fn chat_seed() {
662 dotenv().ok();
663 let credentials = Credentials::from_env();
664
665 let chat_completion = ChatCompletion::builder(
666 "gpt-3.5-turbo",
667 [ChatCompletionMessage {
668 role: ChatCompletionMessageRole::User,
669 content: Some(
670 "What type of seed does Mr. England sow in the song? Reply with 1 word."
671 .to_string(),
672 ),
673 name: None,
674 function_call: None,
675 tool_call_id: None,
676 tool_calls: Some(Vec::new()),
677 }],
678 )
679 .temperature(0.0)
681 .seed(1337u64)
682 .credentials(credentials)
683 .create()
684 .await
685 .unwrap();
686
687 assert_eq!(
688 chat_completion
689 .choices
690 .first()
691 .unwrap()
692 .message
693 .content
694 .as_ref()
695 .unwrap(),
696 "Love"
697 );
698 }
699
700 #[tokio::test]
701 async fn chat_stream() {
702 dotenv().ok();
703 let credentials = Credentials::from_env();
704
705 let chat_stream = ChatCompletion::builder(
706 "gpt-3.5-turbo",
707 [ChatCompletionMessage {
708 role: ChatCompletionMessageRole::User,
709 content: Some("Hello!".to_string()),
710 name: None,
711 function_call: None,
712 tool_call_id: None,
713 tool_calls: Some(Vec::new()),
714 }],
715 )
716 .temperature(0.0)
717 .credentials(credentials)
718 .create_stream()
719 .await
720 .unwrap();
721
722 let chat_completion = stream_to_completion(chat_stream).await;
723
724 assert_eq!(
725 chat_completion
726 .choices
727 .first()
728 .unwrap()
729 .message
730 .content
731 .as_ref()
732 .unwrap(),
733 "Hello! How can I assist you today?"
734 );
735 }
736
737 #[tokio::test]
738 async fn chat_function() {
739 dotenv().ok();
740 let credentials = Credentials::from_env();
741
742 let chat_stream = ChatCompletion::builder(
743 "gpt-4o",
744 [
745 ChatCompletionMessage {
746 role: ChatCompletionMessageRole::User,
747 content: Some("What is the weather in Boston?".to_string()),
748 name: None,
749 function_call: None,
750 tool_call_id: None,
751 tool_calls: Some(Vec::new()),
752 }
753 ]
754 ).functions([ChatCompletionFunctionDefinition {
755 description: Some("Get the current weather in a given location.".to_string()),
756 name: "get_current_weather".to_string(),
757 parameters: Some(serde_json::json!({
758 "type": "object",
759 "properties": {
760 "location": {
761 "type": "string",
762 "description": "The city and state to get the weather for. (eg: San Francisco, CA)"
763 }
764 },
765 "required": ["location"]
766 })),
767 }])
768 .temperature(0.2)
769 .credentials(credentials)
770 .create_stream()
771 .await
772 .unwrap();
773
774 let chat_completion = stream_to_completion(chat_stream).await;
775
776 assert_eq!(
777 chat_completion
778 .choices
779 .first()
780 .unwrap()
781 .message
782 .function_call
783 .as_ref()
784 .unwrap()
785 .name,
786 "get_current_weather".to_string(),
787 );
788
789 assert_eq!(
790 serde_json::from_str::<Value>(
791 &chat_completion
792 .choices
793 .first()
794 .unwrap()
795 .message
796 .function_call
797 .as_ref()
798 .unwrap()
799 .arguments
800 )
801 .unwrap(),
802 serde_json::json!({
803 "location": "Boston, MA"
804 }),
805 );
806 }
807
808 #[tokio::test]
809 async fn chat_response_format_json() {
810 dotenv().ok();
811 let credentials = Credentials::from_env();
812 let chat_completion = ChatCompletion::builder(
813 "gpt-3.5-turbo",
814 [ChatCompletionMessage {
815 role: ChatCompletionMessageRole::User,
816 content: Some("Write an example JSON for a JWT header using RS256".to_string()),
817 name: None,
818 function_call: None,
819 tool_call_id: None,
820 tool_calls: Some(Vec::new()),
821 }],
822 )
823 .temperature(0.0)
824 .seed(1337u64)
825 .response_format(ChatCompletionResponseFormat::json_object())
826 .credentials(credentials)
827 .create()
828 .await
829 .unwrap();
830 let response_string = chat_completion
831 .choices
832 .first()
833 .unwrap()
834 .message
835 .content
836 .as_ref()
837 .unwrap();
838 #[derive(Deserialize, Eq, PartialEq, Debug)]
839 struct Response {
840 alg: String,
841 typ: String,
842 }
843 let response = serde_json::from_str::<Response>(response_string).unwrap();
844 assert_eq!(
845 response,
846 Response {
847 alg: "RS256".to_owned(),
848 typ: "JWT".to_owned()
849 }
850 );
851 }
852
853 #[test]
854 fn builder_clone_and_eq() {
855 let builder_a = ChatCompletion::builder("gpt-4", [])
856 .temperature(0.0)
857 .seed(65u64);
858 let builder_b = builder_a.clone();
859 let builder_c = builder_b.clone().temperature(1.0);
860 let builder_d = ChatCompletionBuilder::default();
861 assert_eq!(builder_a, builder_b);
862 assert_ne!(builder_a, builder_c);
863 assert_ne!(builder_b, builder_c);
864 assert_ne!(builder_a, builder_d);
865 assert_ne!(builder_c, builder_d);
866 }
867
868 async fn stream_to_completion(
869 mut chat_stream: Receiver<ChatCompletionDelta>,
870 ) -> ChatCompletion {
871 let mut merged: Option<ChatCompletionDelta> = None;
872 while let Some(delta) = chat_stream.recv().await {
873 match merged.as_mut() {
874 Some(c) => {
875 c.merge(delta).unwrap();
876 }
877 None => merged = Some(delta),
878 };
879 }
880 merged.unwrap().into()
881 }
882
883 #[tokio::test]
884 async fn chat_tool_response_completion() {
885 dotenv().ok();
886 let credentials = Credentials::from_env();
887
888 let chat_completion = ChatCompletion::builder(
889 "gpt-4o-mini",
890 [
891 ChatCompletionMessage {
892 role: ChatCompletionMessageRole::User,
893 content: Some(
894 "What's 0.9102847*28456? \
895 reply in plain text, \
896 round the number to to 2 decimals \
897 and reply with the result number only, \
898 with no full stop at the end"
899 .to_string(),
900 ),
901 name: None,
902 function_call: None,
903 tool_call_id: None,
904 tool_calls: Some(Vec::new()),
905 },
906 ChatCompletionMessage {
907 role: ChatCompletionMessageRole::Assistant,
908 content: Some("Let me calculate that for you.".to_string()),
909 name: None,
910 function_call: None,
911 tool_call_id: None,
912 tool_calls: Some(vec![ToolCall {
913 id: "the_tool_call".to_string(),
914 r#type: "function".to_string(),
915 function: ToolCallFunction {
916 name: "mul".to_string(),
917 arguments: "not_required_to_be_valid_here".to_string(),
918 },
919 }]),
920 },
921 ChatCompletionMessage {
922 role: ChatCompletionMessageRole::Tool,
923 content: Some("the result is 25903.061423199997".to_string()),
924 name: None,
925 function_call: None,
926 tool_call_id: Some("the_tool_call".to_owned()),
927 tool_calls: Some(Vec::new()),
928 },
929 ],
930 )
931 .temperature(0.0)
933 .seed(1337u64)
934 .credentials(credentials)
935 .create()
936 .await
937 .unwrap();
938
939 assert_eq!(
940 chat_completion
941 .choices
942 .first()
943 .unwrap()
944 .message
945 .content
946 .as_ref()
947 .unwrap(),
948 "25903.06"
949 );
950 }
951
952 #[tokio::test]
953 async fn get_completion() {
954 dotenv().ok();
955 let credentials = Credentials::from_env();
956
957 let chat_completion = ChatCompletion::builder(
958 "gpt-3.5-turbo",
959 [ChatCompletionMessage {
960 role: ChatCompletionMessageRole::User,
961 content: Some("Hello!".to_string()),
962 ..Default::default()
963 }],
964 )
965 .credentials(credentials.clone())
966 .store(true)
967 .create()
968 .await
969 .unwrap();
970
971 sleep(Duration::from_secs(7)).await;
973
974 let retrieved_completion = ChatCompletion::get(&chat_completion.id, credentials.clone())
975 .await
976 .unwrap();
977
978 assert_eq!(retrieved_completion, chat_completion);
979 }
980
981 #[tokio::test]
982 async fn get_completion_non_existent() {
983 dotenv().ok();
984 let credentials = Credentials::from_env();
985
986 match ChatCompletion::get("non_existent_id", credentials.clone()).await {
987 Ok(_) => panic!("Expected error"),
988 Err(e) => assert_eq!(e.code, Some("not_found".to_string())),
989 }
990 }
991
992 #[tokio::test]
993 async fn get_completion_messages() {
994 dotenv().ok();
995 let credentials = Credentials::from_env();
996
997 let user_message = ChatCompletionMessage {
998 role: ChatCompletionMessageRole::User,
999 content: Some("Tell me a short joke".to_string()),
1000 ..Default::default()
1001 };
1002
1003 let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()])
1004 .credentials(credentials.clone())
1005 .store(true)
1006 .create()
1007 .await
1008 .unwrap();
1009
1010 sleep(Duration::from_secs(7)).await;
1012
1013 let retrieved_messages = ChatCompletionMessages::builder(chat_completion.id)
1014 .credentials(credentials.clone())
1015 .fetch()
1016 .await
1017 .unwrap();
1018
1019 assert_eq!(retrieved_messages.data, vec![user_message]);
1020 assert_eq!(retrieved_messages.has_more, false);
1021 }
1022
1023 #[tokio::test]
1024 async fn get_completion_messages_with_pagination() {
1025 dotenv().ok();
1026 let credentials = Credentials::from_env();
1027
1028 let user_message = ChatCompletionMessage {
1029 role: ChatCompletionMessageRole::User,
1030 content: Some("Tell me a short joke".to_string()),
1031 ..Default::default()
1032 };
1033
1034 let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()])
1035 .credentials(credentials.clone())
1036 .store(true)
1037 .create()
1038 .await
1039 .unwrap();
1040
1041 dbg!(&chat_completion);
1042
1043 sleep(Duration::from_secs(7)).await;
1045
1046 let retrieved_messages1 = ChatCompletionMessages::builder(chat_completion.id.clone())
1048 .credentials(credentials.clone())
1049 .pagination(RequestPagination {
1050 limit: Some(1),
1051 ..Default::default()
1052 })
1053 .fetch()
1054 .await
1055 .unwrap();
1056
1057 assert_eq!(retrieved_messages1.data, vec![user_message]);
1058 assert_eq!(retrieved_messages1.has_more, false);
1059 assert!(retrieved_messages1.first_id.is_some());
1060 assert!(retrieved_messages1.last_id.is_some());
1061
1062 let retrieved_messages2 = ChatCompletionMessages::builder(chat_completion.id.clone())
1064 .credentials(credentials.clone())
1065 .pagination(RequestPagination {
1066 limit: Some(1),
1067 after: Some(retrieved_messages1.first_id.unwrap()),
1068 ..Default::default()
1069 })
1070 .fetch()
1071 .await
1072 .unwrap();
1073
1074 assert_eq!(retrieved_messages2.data, vec![]);
1075 assert_eq!(retrieved_messages2.has_more, false);
1076 assert!(retrieved_messages2.first_id.is_none());
1077 assert!(retrieved_messages2.last_id.is_none());
1078 }
1079}