1use std::time::Duration;
4
5use crate::Error;
6use chrono::serde::ts_seconds;
7use chrono::{DateTime, Utc};
8use derive_builder::Builder;
9use reqwest::StatusCode;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13pub struct Client {
15 client: reqwest::Client,
17 api_key: Option<String>,
19 base_url: String,
21}
22
23#[derive(Builder, Default)]
26pub struct ChatCompletionRequest {
27 model: String,
28 messages: Vec<Message>,
29 temperature: f32,
30 timeout: Duration,
31}
32
33#[derive(Builder, Default, Debug, Serialize, Deserialize)]
36pub struct ChatCompletion {
37 #[serde(with = "ts_seconds")]
38 pub created: DateTime<Utc>,
39 pub choices: Vec<Choice>,
40 pub model: String,
41 pub usage: Usage,
42}
43
44#[derive(Debug, Deserialize)]
46pub struct APIError {
47 pub message: String,
48 #[serde(rename = "type")]
49 pub error_type: String,
50 pub param: Option<String>,
51 pub code: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
56struct APIErrorContainer {
57 error: APIError,
58}
59
60#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
62pub struct Message {
63 pub role: Role,
64 pub content: String,
65}
66
67#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
68#[serde(rename_all = "lowercase")]
69pub enum Role {
70 System,
71 Assistant,
72 User,
73}
74
75#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
76pub struct Usage {
77 pub prompt_tokens: u32,
78 pub completion_tokens: u32,
79 pub total_tokens: u32,
80}
81
82#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
83#[serde(rename_all = "snake_case")]
84pub enum FinishReason {
85 Stop,
86 Length,
87 FunctionCall,
88 ContentFilter,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
92pub struct Choice {
93 pub message: Message,
94 pub finish_reason: FinishReason,
95}
96
97impl Client {
98 pub fn new(api_key: Option<String>, base_url: String) -> Result<Self, Error> {
99 let client = reqwest::Client::builder()
100 .build()
101 .map_err(Error::FailedToFetch)?;
102 Ok(Self {
103 client,
104 api_key,
105 base_url,
106 })
107 }
108
109 pub async fn chat_complete(
110 &self,
111 request: &ChatCompletionRequest,
112 ) -> Result<ChatCompletion, Error> {
113 let api_key = &self.api_key.as_ref().ok_or(Error::NoAPIKey)?;
114 let model = &request.model;
115
116 let resp = self
117 .client
118 .post(self.chat_endpoint())
119 .bearer_auth(api_key)
120 .timeout(request.timeout)
121 .header("Content-Type", "application/json")
122 .json(&json!({
123 "model": model,
124 "messages": request.messages,
125 "temperature": request.temperature,
126 }))
127 .send()
128 .await
129 .map_err(Error::FailedToFetch)?;
130
131 match resp.status() {
132 StatusCode::OK => {
133 let res: ChatCompletion = resp.json().await.map_err(Error::FailedToFetch)?;
134 Ok(res)
135 }
136 _ => {
137 let error = resp
138 .json::<APIErrorContainer>()
139 .await
140 .map_err(Error::FailedToFetch)?
141 .error;
142 Err(Error::OpenAIError { error })
143 }
144 }
145 }
146
147 fn chat_endpoint(&self) -> String {
148 format!("{}{}", self.base_url, "/v1/chat/completions")
149 }
150}
151
152impl Message {
153 pub fn system(content: &str) -> Message {
154 Message {
155 role: Role::System,
156 content: content.to_string(),
157 }
158 }
159 pub fn user(content: &str) -> Message {
160 Message {
161 role: Role::User,
162 content: content.to_string(),
163 }
164 }
165 pub fn assistant(content: &str) -> Message {
166 Message {
167 role: Role::Assistant,
168 content: content.to_string(),
169 }
170 }
171}
172
173impl ChatCompletionRequest {
174 pub fn builder() -> ChatCompletionRequestBuilder {
175 ChatCompletionRequestBuilder::default()
176 }
177}
178
179impl ChatCompletion {
180 pub fn builder() -> ChatCompletionBuilder {
181 ChatCompletionBuilder::default()
182 }
183}
184
185#[cfg(test)]
186mod test {
187
188 use super::*;
189 use anyhow::Result;
190 use chrono::TimeZone;
191
192 #[test]
193 fn parse_chat_completion_response() -> Result<()> {
194 let data = r#"{
195 "created": 1688413145,
196 "model": "gpt-3.5-turbo-0613",
197 "choices": [{
198 "index": 0,
199 "message": {
200 "role": "assistant",
201 "content": "Hello! How can I assist you today?"
202 },
203 "finish_reason": "stop"
204 }],
205 "usage": {
206 "prompt_tokens": 8,
207 "completion_tokens": 9,
208 "total_tokens": 17
209 }
210 }
211 "#;
212
213 let resp = serde_json::from_str::<ChatCompletion>(data)?;
214
215 assert_eq!(resp.created, Utc.timestamp_opt(1688413145, 0).unwrap());
216 assert_eq!(
217 resp.choices,
218 vec![Choice {
219 message: Message {
220 role: Role::Assistant,
221 content: "Hello! How can I assist you today?".to_string()
222 },
223 finish_reason: FinishReason::Stop
224 }]
225 );
226 assert_eq!(resp.model, "gpt-3.5-turbo-0613");
227 assert_eq!(
228 resp.usage,
229 Usage {
230 prompt_tokens: 8,
231 completion_tokens: 9,
232 total_tokens: 17,
233 }
234 );
235
236 Ok(())
237 }
238
239 #[test]
240 fn parse_chat_completion_error() -> Result<()> {
241 let data = r#"{
242 "error": {
243 "message": "An error message",
244 "type": "invalid_request_error",
245 "param": null,
246 "code": null
247 }
248 }
249 "#;
250
251 let resp = serde_json::from_str::<APIErrorContainer>(data)?.error;
252
253 assert_eq!(resp.message, "An error message");
254 assert_eq!(resp.error_type, "invalid_request_error");
255 assert_eq!(resp.param, None);
256 assert_eq!(resp.code, None);
257
258 Ok(())
259 }
260}