use anyhow::{Context, Result};
use futures_core::Stream;
use reqwest::header;
use serde::{Deserialize, Serialize};
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[derive(Clone)]
pub struct ChatGptClient {
client: reqwest::Client,
model: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_owned(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_owned(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_owned(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: "assistant".to_owned(),
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_owned(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize, Clone)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Serialize, Clone)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinition>>,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: Delta,
}
#[derive(Debug, Deserialize)]
struct Delta {
content: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("OpenAI API error (HTTP {status}): {body}")]
Api { status: u16, body: String },
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl ChatGptClient {
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let mut headers = header::HeaderMap::new();
let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
.context("invalid API key characters")?;
auth.set_sensitive(true);
headers.insert(header::AUTHORIZATION, auth);
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.context("failed to build HTTP client")?;
Ok(Self {
client,
model: model.to_owned(),
})
}
pub async fn chat(&self, messages: Vec<Message>) -> Result<String, LlmError> {
let request = ChatRequest {
model: self.model.clone(),
messages,
temperature: None,
stream: false,
tools: None,
};
let response = self.client.post(API_URL).json(&request).send().await?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(LlmError::Api {
status: status.as_u16(),
body,
});
}
let parsed: ChatResponse = response.json().await?;
Ok(parsed
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.unwrap_or_default())
}
pub async fn chat_with_tools(
&self,
messages: Vec<Message>,
tools: Option<&[ToolDefinition]>,
) -> Result<(Message, Option<String>), LlmError> {
let request = ChatRequest {
model: self.model.clone(),
messages,
temperature: None,
stream: false,
tools: tools.map(|t| t.to_vec()),
};
let response = self.client.post(API_URL).json(&request).send().await?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(LlmError::Api {
status: status.as_u16(),
body,
});
}
let parsed: ChatResponse = response.json().await?;
let choice = parsed
.choices
.into_iter()
.next()
.ok_or_else(|| LlmError::Other(anyhow::anyhow!("no choices in response")))?;
Ok((choice.message, choice.finish_reason))
}
pub fn chat_stream(
&self,
messages: Vec<Message>,
) -> impl Stream<Item = Result<String, LlmError>> + Send {
let client = self.client.clone();
let model = self.model.clone();
async_stream::try_stream! {
let request = ChatRequest {
model,
messages,
temperature: None,
stream: true,
tools: None,
};
let mut response = client.post(API_URL).json(&request).send().await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let mut body = String::new();
while let Some(chunk) = response.chunk().await? {
body.push_str(&String::from_utf8_lossy(&chunk));
}
Err(LlmError::Api { status, body })?;
}
let mut buffer = String::new();
while let Some(chunk) = response.chunk().await? {
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find("\n\n") {
let event = buffer[..pos].to_owned();
buffer = buffer[pos + 2..].to_owned();
for line in event.lines() {
let data = match line.strip_prefix("data: ") {
Some(d) => d.trim(),
None => continue,
};
if data == "[DONE]" {
return;
}
if let Ok(parsed) = serde_json::from_str::<StreamChunk>(data) {
for choice in parsed.choices {
if let Some(content) = choice.delta.content {
yield content;
}
}
}
}
}
}
}
}
}