openai_chat/
chat_completion.rs1use openai_api_client::{ClientError, Usage};
4use reqwest::{header, ClientBuilder};
5use serde::{Deserialize, Serialize};
6use std::{collections::HashMap, time::Duration};
7
8use crate::error;
9static URL: &str = "https://api.openai.com/v1/chat/completions";
10static TEXT_GPT35_TURBO: &str = "gpt-3.5-turbo";
11#[derive(Deserialize, Serialize)]
12pub struct ChatCompletionsParams {
13 pub model: String,
14 pub messages: Vec<HashMap<String, String>>,
28 pub temperature: u32,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub max_tokens: Option<u32>,
32 pub top_p: f32,
33 pub frequency_penalty: f32,
34 pub presence_penalty: f32,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub stop: Option<Vec<String>>,
37 pub n: u32,
38 pub stream: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub logit_bias: Option<HashMap<String, i32>>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub user: Option<String>,
43}
44
45impl ChatCompletionsParams {
46 pub fn new(messages: Vec<HashMap<String, String>>) -> Self {
47 Self {
48 messages,
49 ..Default::default()
50 }
51 }
52}
53
54impl Default for ChatCompletionsParams {
55 fn default() -> Self {
56 ChatCompletionsParams {
57 model: TEXT_GPT35_TURBO.to_string(),
58 messages: Vec::new(),
59 temperature: 1,
60 top_p: 1.0,
61 n: 1,
62 stream: false,
63 stop: None,
64 max_tokens: None,
65 frequency_penalty: 0.0,
66 presence_penalty: 0.0,
67 logit_bias: None,
68 user: None,
69 }
70 }
71}
72
73#[derive(Deserialize, Serialize, Debug)]
74pub struct ChatCompletionsResponse {
75 pub id: String,
76 pub object: String,
77 pub created: u32,
78 pub model: String,
79 pub choices: Vec<ChatCompletionsChoice>,
80 pub usage: Usage,
81}
82#[derive(Deserialize, Serialize, Debug)]
83pub struct ChatCompletionsChoice {
84 pub index: u32,
85 pub message: HashMap<String, String>,
86 pub finish_reason: String,
87}
88async fn request(body: String, api_key: &str) -> std::result::Result<Vec<u8>, error::Errpr> {
89 let mut header = header::HeaderMap::new();
90 header.insert("Content-Type", "application/json".parse().unwrap());
91 header.insert(
92 "Authorization",
93 format!("Bearer {api_key}").parse().unwrap(),
94 );
95 let client = ClientBuilder::new().default_headers(header).build()?;
96 let response = client
97 .post(URL)
98 .timeout(Duration::from_secs(60))
99 .body(body)
100 .send()
101 .await
102 .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?
103 .bytes()
104 .await
105 .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?;
106
107 Ok(response.to_vec())
108}
109
110pub async fn chat_completions(
117 prompt: &str,
118 api_key: &str,
119) -> std::result::Result<String, error::Errpr> {
120 let mut msg = HashMap::new();
121 msg.insert("role".to_string(), "user".to_string());
122 msg.insert("content".to_string(), prompt.to_string());
123 let req = ChatCompletionsParams::new(vec![msg]);
124 let body = serde_json::to_string(&req)?;
125 let response = request(body, api_key).await?;
126 let mut completions_response: ChatCompletionsResponse = serde_json::from_slice(&response)?;
127 Ok(completions_response.choices[0]
128 .message
129 .remove("content")
130 .unwrap_or("".to_string()))
131}
132pub async fn chat_completions_full(
137 prompt: Vec<HashMap<String, String>>,
138 api_key: &str,
139) -> std::result::Result<String, error::Errpr> {
140 let req = ChatCompletionsParams::new(prompt);
141 let body = serde_json::to_string(&req)?;
142 let response = request(body, api_key).await?;
143 let mut completions_response: ChatCompletionsResponse = serde_json::from_slice(&response)?;
144 Ok(completions_response.choices[0]
145 .message
146 .remove("content")
147 .unwrap_or("".to_string()))
148}