openai_chat/
chat_completion.rs

1//! mainly use model text-gpt-3.5-turbo and gpt-3.5-turbo-0301
2//! api: POST https://api.openai.com/v1/chat/completions
3use openai_api_client::{ClientError, Usage};
4use reqwest::{header, ClientBuilder};
5use serde::{Deserialize, Serialize};
6use std::{collections::HashMap, time::Duration};
7
8use crate::error;
9static URL: &str = "https://api.openai.com/v1/chat/completions";
10static TEXT_GPT35_TURBO: &str = "gpt-3.5-turbo";
11#[derive(Deserialize, Serialize)]
12pub struct ChatCompletionsParams {
13    pub model: String,
14    /// \[
15    ///    
16    ///  {"role": "system", "content": "You are a helpful assistant."},
17    ///
18    ///     {"role": "user", "content": "Who won the world series in 2020?"},
19    ///
20    ///     {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
21    ///   
22    ///   {"role": "user", "content": "Where was it played?"}
23    ///
24    /// \]
25    ///
26    /// or simply \[ {"role": "user", "content": "Who won the world series in 2020?"}, \]
27    pub messages: Vec<HashMap<String, String>>,
28    pub temperature: u32,
29    /// By default, the number of tokens the model can return will be (4096 - prompt tokens).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub max_tokens: Option<u32>,
32    pub top_p: f32,
33    pub frequency_penalty: f32,
34    pub presence_penalty: f32,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub stop: Option<Vec<String>>,
37    pub n: u32,
38    pub stream: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub logit_bias: Option<HashMap<String, i32>>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub user: Option<String>,
43}
44
45impl ChatCompletionsParams {
46    pub fn new(messages: Vec<HashMap<String, String>>) -> Self {
47        Self {
48            messages,
49            ..Default::default()
50        }
51    }
52}
53
54impl Default for ChatCompletionsParams {
55    fn default() -> Self {
56        ChatCompletionsParams {
57            model: TEXT_GPT35_TURBO.to_string(),
58            messages: Vec::new(),
59            temperature: 1,
60            top_p: 1.0,
61            n: 1,
62            stream: false,
63            stop: None,
64            max_tokens: None,
65            frequency_penalty: 0.0,
66            presence_penalty: 0.0,
67            logit_bias: None,
68            user: None,
69        }
70    }
71}
72
73#[derive(Deserialize, Serialize, Debug)]
74pub struct ChatCompletionsResponse {
75    pub id: String,
76    pub object: String,
77    pub created: u32,
78    pub model: String,
79    pub choices: Vec<ChatCompletionsChoice>,
80    pub usage: Usage,
81}
82#[derive(Deserialize, Serialize, Debug)]
83pub struct ChatCompletionsChoice {
84    pub index: u32,
85    pub message: HashMap<String, String>,
86    pub finish_reason: String,
87}
88async fn request(body: String, api_key: &str) -> std::result::Result<Vec<u8>, error::Errpr> {
89    let mut header = header::HeaderMap::new();
90    header.insert("Content-Type", "application/json".parse().unwrap());
91    header.insert(
92        "Authorization",
93        format!("Bearer {api_key}").parse().unwrap(),
94    );
95    let client = ClientBuilder::new().default_headers(header).build()?;
96    let response = client
97        .post(URL)
98        .timeout(Duration::from_secs(60))
99        .body(body)
100        .send()
101        .await
102        .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?
103        .bytes()
104        .await
105        .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?;
106
107    Ok(response.to_vec())
108}
109
110/// Use model `text-gpt-3.5-turbo`  to generate a chat completion.
111///
112/// this is a relatively easy wrapper for the api.
113/// For the Prompt message,Assuming you are playing the role `user`,with no other roles invloved.
114///
115/// Note:Only return content,omiting role field
116pub async fn chat_completions(
117    prompt: &str,
118    api_key: &str,
119) -> std::result::Result<String, error::Errpr> {
120    let mut msg = HashMap::new();
121    msg.insert("role".to_string(), "user".to_string());
122    msg.insert("content".to_string(), prompt.to_string());
123    let req = ChatCompletionsParams::new(vec![msg]);
124    let body = serde_json::to_string(&req)?;
125    let response = request(body, api_key).await?;
126    let mut completions_response: ChatCompletionsResponse = serde_json::from_slice(&response)?;
127    Ok(completions_response.choices[0]
128        .message
129        .remove("content")
130        .unwrap_or("".to_string()))
131}
132/// Has the ability to support context chat by passing prompt in the form of an array of hashmap
133/// that contiain role and content
134///
135/// Use model `text-gpt-3.5-turbo`  to generate a chat completion.
136pub async fn chat_completions_full(
137    prompt: Vec<HashMap<String, String>>,
138    api_key: &str,
139) -> std::result::Result<String, error::Errpr> {
140    let req = ChatCompletionsParams::new(prompt);
141    let body = serde_json::to_string(&req)?;
142    let response = request(body, api_key).await?;
143    let mut completions_response: ChatCompletionsResponse = serde_json::from_slice(&response)?;
144    Ok(completions_response.choices[0]
145        .message
146        .remove("content")
147        .unwrap_or("".to_string()))
148}