openai_gpt_client/
client.rs

1use log::{debug, info, warn};
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use reqwest::{Client, Error, Response};
4use serde::{Deserialize, Serialize};
5use serde_json::to_value;
6
7use crate::chat::{ChatMessage, ChatRequest, ChatResponse};
8use crate::model_variants::ModelId;
9use crate::text_completion::{TextCompletionRequest, TextCompletionResponse};
10use crate::GptError;
11
12#[derive(Debug, Clone, Copy)]
13pub enum ClientProfile {
14    Chat,
15    Code,
16}
17
18impl ClientProfile {
19    pub fn get_temperature(&self) -> f32 {
20        match self {
21            ClientProfile::Chat => 0.4,
22            ClientProfile::Code => 0.7,
23        }
24    }
25
26    pub fn get_top_p(&self) -> f32 {
27        match self {
28            ClientProfile::Chat => 0.9,
29            ClientProfile::Code => 0.7,
30        }
31    }
32
33    pub fn get_frequency_penalty(&self) -> f32 {
34        match self {
35            ClientProfile::Chat => 0.0,
36            ClientProfile::Code => 0.2,
37        }
38    }
39
40    pub fn get_presence_penalty(&self) -> f32 {
41        match self {
42            ClientProfile::Chat => 0.6,
43            ClientProfile::Code => 0.0,
44        }
45    }
46
47    pub fn get_stop(&self) -> Option<Stop> {
48        None
49    }
50}
51
52#[derive(Debug, Serialize, Deserialize)]
53#[serde(untagged)]
54pub enum Stop {
55    Single(String),
56    Multiple(Vec<String>),
57}
58
59pub struct OpenAiClient {
60    client: Client,
61    profile: ClientProfile,
62}
63
64impl OpenAiClient {
65    pub fn new(api_key: &str, profile: ClientProfile) -> OpenAiClient {
66        let headers = Self::build_headers(api_key);
67
68        let client = Client::builder().default_headers(headers).build().unwrap();
69
70        OpenAiClient { client, profile }
71    }
72
73    fn build_headers(api_key: &str) -> HeaderMap {
74        let mut headers = HeaderMap::new();
75        headers.insert(
76            AUTHORIZATION,
77            HeaderValue::from_str(&format!("Bearer {api_key}")).unwrap(),
78        );
79        headers
80    }
81
82    #[allow(dead_code)]
83    async fn get_request(&self, endpoint: &str) -> Result<Response, Error> {
84        let url = format!("https://api.openai.com/v1/{endpoint}");
85
86        self.client.get(&url).send().await
87    }
88
89    async fn post_request(
90        &self,
91        endpoint: &str,
92        body: serde_json::Value,
93    ) -> Result<Response, Error> {
94        let url = format!("https://api.openai.com/v1/{endpoint}");
95
96        // Send the request
97        let response = self
98            .client
99            .post(&url)
100            .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
101            .json(&body)
102            .send()
103            .await?;
104
105        Ok(response)
106    }
107
108    pub async fn completion(
109        &self,
110        model: ModelId,
111        prompt: &str,
112        max_tokens: u16,
113    ) -> Result<String, Error> {
114        let request = TextCompletionRequest {
115            model,
116            prompt: prompt.to_owned(),
117            max_tokens: Some(i32::from(max_tokens)),
118            stop: self.profile.get_stop(),
119            temperature: Some(self.profile.get_temperature() as f64),
120            top_p: Some(self.profile.get_top_p() as f64),
121            frequency_penalty: Some(self.profile.get_frequency_penalty() as f64),
122            presence_penalty: Some(self.profile.get_presence_penalty() as f64),
123            ..Default::default()
124        };
125        debug!("Request: {:?}", request);
126
127        let response = self
128            .post_request("completions", to_value(&request).unwrap())
129            .await?;
130        debug!("Response: {:?}", response);
131
132        let body = response.text().await?;
133        debug!("Body: {:?}", body);
134
135        // Deserialize the response body as a TextCompletionResponse object
136        let completion_response: TextCompletionResponse = serde_json::from_str(&body).unwrap();
137
138        // Return the text generated by the API
139        Ok(completion_response.choices.unwrap()[0].text.clone())
140    }
141
142    pub async fn chat(
143        &self,
144        model: ModelId,
145        max_tokens: u16,
146        messages: Vec<ChatMessage>,
147    ) -> Result<ChatMessage, GptError> {
148        info!("Messages:\n {:#?}", messages);
149        let request = ChatRequest {
150            model,
151            messages,
152            max_tokens: Some(i32::from(max_tokens)),
153            stop: self.profile.get_stop(),
154            temperature: Some(self.profile.get_temperature() as f64),
155            top_p: Some(self.profile.get_top_p() as f64),
156            frequency_penalty: Some(self.profile.get_frequency_penalty() as f64),
157            presence_penalty: Some(self.profile.get_presence_penalty() as f64),
158            ..Default::default()
159        };
160        debug!("Request: {:?}", request);
161
162        let response = self
163            .post_request("chat/completions", to_value(&request).unwrap())
164            .await?;
165        debug!("Response: {:?}", response);
166
167        let body = response.text().await?;
168        debug!("Body: {:?}", body);
169
170        // Deserialize the response body as a TextCompletionResponse object
171        let chat_response: ChatResponse = serde_json::from_str(&body)?;
172
173        if let Some(usage) = chat_response.usage {
174            warn!(
175                "Completion: {:?}, total ${}",
176                usage,
177                usage.total_tokens as f32 * 0.002 / 1000.0
178            );
179        }
180
181        // Return the text generated by the API
182        Ok(chat_response.choices[0].message.clone())
183    }
184}