use serde::{Deserialize, Serialize};
use anyhow::Result;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}
impl ChatMessage {
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ChatMessage,
}
#[derive(Debug, Clone)]
pub struct OpenAIClient {
api_key: String,
client: reqwest::Client,
}
impl OpenAIClient {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
client: reqwest::Client::new(),
}
}
fn create_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers
}
pub async fn chat_completion(
&self,
model: impl Into<String>,
messages: Vec<ChatMessage>,
) -> Result<String> {
let request = ChatCompletionRequest {
model: model.into(),
messages,
};
let response: ChatCompletionResponse = self
.client
.post("https://api.openai.com/v1/chat/completions")
.headers(self.create_headers())
.json(&request)
.send()
.await?
.json()
.await?;
Ok(response
.choices
.first()
.map(|choice| choice.message.content.clone())
.unwrap_or_default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_chat_message() {
let message = ChatMessage::new(Role::User, "Hello");
assert!(matches!(message.role, Role::User));
assert_eq!(message.content, "Hello");
}
}