cogni/
openai.rs

1//! Interactions with OpenAI APIs
2
3use std::time::Duration;
4
5use crate::Error;
6use chrono::serde::ts_seconds;
7use chrono::{DateTime, Utc};
8use derive_builder::Builder;
9use reqwest::StatusCode;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13/// Convienience Client for OpenAI Chat Completions API
14pub struct Client {
15    /// Inner client
16    client: reqwest::Client,
17    /// Default API Key
18    api_key: Option<String>,
19    /// Base URL for API Endpoint
20    base_url: String,
21}
22
23/// Requests for chat_completion
24/// Reference: <https://platform.openai.com/docs/api-reference/chat>
25#[derive(Builder, Default)]
26pub struct ChatCompletionRequest {
27    model: String,
28    messages: Vec<Message>,
29    temperature: f32,
30    timeout: Duration,
31}
32
33/// Responses from chat_completion
34/// Reference: <https://platform.openai.com/docs/api-reference/chat>
35#[derive(Builder, Default, Debug, Serialize, Deserialize)]
36pub struct ChatCompletion {
37    #[serde(with = "ts_seconds")]
38    pub created: DateTime<Utc>,
39    pub choices: Vec<Choice>,
40    pub model: String,
41    pub usage: Usage,
42}
43
44/// API Errors from OpenAI
45#[derive(Debug, Deserialize)]
46pub struct APIError {
47    pub message: String,
48    #[serde(rename = "type")]
49    pub error_type: String,
50    pub param: Option<String>,
51    pub code: Option<String>,
52}
53
54/// Wraps `APIError` for deserializing OpenAI Response
55#[derive(Debug, Deserialize)]
56struct APIErrorContainer {
57    error: APIError,
58}
59
60/// Messages in chat completion request and response
61#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
62pub struct Message {
63    pub role: Role,
64    pub content: String,
65}
66
67#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
68#[serde(rename_all = "lowercase")]
69pub enum Role {
70    System,
71    Assistant,
72    User,
73}
74
75#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
76pub struct Usage {
77    pub prompt_tokens: u32,
78    pub completion_tokens: u32,
79    pub total_tokens: u32,
80}
81
82#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
83#[serde(rename_all = "snake_case")]
84pub enum FinishReason {
85    Stop,
86    Length,
87    FunctionCall,
88    ContentFilter,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
92pub struct Choice {
93    pub message: Message,
94    pub finish_reason: FinishReason,
95}
96
97impl Client {
98    pub fn new(api_key: Option<String>, base_url: String) -> Result<Self, Error> {
99        let client = reqwest::Client::builder()
100            .build()
101            .map_err(Error::FailedToFetch)?;
102        Ok(Self {
103            client,
104            api_key,
105            base_url,
106        })
107    }
108
109    pub async fn chat_complete(
110        &self,
111        request: &ChatCompletionRequest,
112    ) -> Result<ChatCompletion, Error> {
113        let api_key = &self.api_key.as_ref().ok_or(Error::NoAPIKey)?;
114        let model = &request.model;
115
116        let resp = self
117            .client
118            .post(self.chat_endpoint())
119            .bearer_auth(api_key)
120            .timeout(request.timeout)
121            .header("Content-Type", "application/json")
122            .json(&json!({
123                "model": model,
124                "messages": request.messages,
125                "temperature": request.temperature,
126            }))
127            .send()
128            .await
129            .map_err(Error::FailedToFetch)?;
130
131        match resp.status() {
132            StatusCode::OK => {
133                let res: ChatCompletion = resp.json().await.map_err(Error::FailedToFetch)?;
134                Ok(res)
135            }
136            _ => {
137                let error = resp
138                    .json::<APIErrorContainer>()
139                    .await
140                    .map_err(Error::FailedToFetch)?
141                    .error;
142                Err(Error::OpenAIError { error })
143            }
144        }
145    }
146
147    fn chat_endpoint(&self) -> String {
148        format!("{}{}", self.base_url, "/v1/chat/completions")
149    }
150}
151
152impl Message {
153    pub fn system(content: &str) -> Message {
154        Message {
155            role: Role::System,
156            content: content.to_string(),
157        }
158    }
159    pub fn user(content: &str) -> Message {
160        Message {
161            role: Role::User,
162            content: content.to_string(),
163        }
164    }
165    pub fn assistant(content: &str) -> Message {
166        Message {
167            role: Role::Assistant,
168            content: content.to_string(),
169        }
170    }
171}
172
173impl ChatCompletionRequest {
174    pub fn builder() -> ChatCompletionRequestBuilder {
175        ChatCompletionRequestBuilder::default()
176    }
177}
178
179impl ChatCompletion {
180    pub fn builder() -> ChatCompletionBuilder {
181        ChatCompletionBuilder::default()
182    }
183}
184
185#[cfg(test)]
186mod test {
187
188    use super::*;
189    use anyhow::Result;
190    use chrono::TimeZone;
191
192    #[test]
193    fn parse_chat_completion_response() -> Result<()> {
194        let data = r#"{
195             "created": 1688413145,
196             "model": "gpt-3.5-turbo-0613",
197             "choices": [{
198                 "index": 0,
199                 "message": {
200                     "role": "assistant",
201                     "content": "Hello! How can I assist you today?"
202                 },
203                 "finish_reason": "stop"
204             }],
205             "usage": {
206                 "prompt_tokens": 8,
207                 "completion_tokens": 9,
208                 "total_tokens": 17
209             }
210        }
211        "#;
212
213        let resp = serde_json::from_str::<ChatCompletion>(data)?;
214
215        assert_eq!(resp.created, Utc.timestamp_opt(1688413145, 0).unwrap());
216        assert_eq!(
217            resp.choices,
218            vec![Choice {
219                message: Message {
220                    role: Role::Assistant,
221                    content: "Hello! How can I assist you today?".to_string()
222                },
223                finish_reason: FinishReason::Stop
224            }]
225        );
226        assert_eq!(resp.model, "gpt-3.5-turbo-0613");
227        assert_eq!(
228            resp.usage,
229            Usage {
230                prompt_tokens: 8,
231                completion_tokens: 9,
232                total_tokens: 17,
233            }
234        );
235
236        Ok(())
237    }
238
239    #[test]
240    fn parse_chat_completion_error() -> Result<()> {
241        let data = r#"{
242            "error": {
243                "message": "An error message",
244                "type": "invalid_request_error",
245                "param": null,
246                "code": null
247            }
248        }
249        "#;
250
251        let resp = serde_json::from_str::<APIErrorContainer>(data)?.error;
252
253        assert_eq!(resp.message, "An error message");
254        assert_eq!(resp.error_type, "invalid_request_error");
255        assert_eq!(resp.param, None);
256        assert_eq!(resp.code, None);
257
258        Ok(())
259    }
260}