use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::conversations::response::{Conversation, ConversationItemListResponse, ConversationListResponse, DeleteConversationResponse, InputItem};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
const CONVERSATIONS_PATH: &str = "conversations";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ConversationInclude {
#[serde(rename = "web_search_call.action.sources")]
WebSearchCallSources,
#[serde(rename = "code_interpreter_call.outputs")]
CodeInterpreterCallOutputs,
#[serde(rename = "file_search_call.results")]
FileSearchCallResults,
#[serde(rename = "message.input_image.image_url")]
MessageInputImageUrl,
#[serde(rename = "reasoning.encrypted_content")]
ReasoningEncryptedContent,
}
impl ConversationInclude {
pub fn as_str(&self) -> &'static str {
match self {
ConversationInclude::WebSearchCallSources => "web_search_call.action.sources",
ConversationInclude::CodeInterpreterCallOutputs => "code_interpreter_call.outputs",
ConversationInclude::FileSearchCallResults => "file_search_call.results",
ConversationInclude::MessageInputImageUrl => "message.input_image.image_url",
ConversationInclude::ReasoningEncryptedContent => "reasoning.encrypted_content",
}
}
}
#[derive(Debug, Clone, Serialize)]
struct CreateConversationRequest {
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
items: Option<Vec<InputItem>>,
}
#[derive(Debug, Clone, Serialize)]
struct UpdateConversationRequest {
metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize)]
struct CreateItemsRequest {
items: Vec<InputItem>,
}
pub struct Conversations {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Conversations {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
fn handle_error(status: request::StatusCode, content: &str) -> OpenAIToolError {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(content) {
OpenAIToolError::Error(error_resp.error.message.unwrap_or_default())
} else {
OpenAIToolError::Error(format!("API error ({}): {}", status, content))
}
}
pub async fn create(&self, metadata: Option<HashMap<String, String>>, items: Option<Vec<InputItem>>) -> Result<Conversation> {
let (client, headers) = self.create_client()?;
let request_body = CreateConversationRequest { metadata, items };
let body = serde_json::to_string(&request_body)?;
let url = self.auth.endpoint(CONVERSATIONS_PATH);
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<Conversation>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn retrieve(&self, conversation_id: &str) -> Result<Conversation> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<Conversation>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn update(&self, conversation_id: &str, metadata: HashMap<String, String>) -> Result<Conversation> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id);
let request_body = UpdateConversationRequest { metadata };
let body = serde_json::to_string(&request_body)?;
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<Conversation>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn delete(&self, conversation_id: &str) -> Result<DeleteConversationResponse> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id);
let response = client.delete(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<DeleteConversationResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn create_items(&self, conversation_id: &str, items: Vec<InputItem>) -> Result<ConversationItemListResponse> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}/items", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id);
let request_body = CreateItemsRequest { items };
let body = serde_json::to_string(&request_body)?;
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<ConversationItemListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list_items(
&self,
conversation_id: &str,
limit: Option<u32>,
after: Option<&str>,
order: Option<&str>,
include: Option<Vec<ConversationInclude>>,
) -> Result<ConversationItemListResponse> {
let (client, headers) = self.create_client()?;
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
if let Some(o) = order {
params.push(format!("order={}", o));
}
if let Some(inc) = include {
for i in inc {
params.push(format!("include[]={}", i.as_str()));
}
}
let url = if params.is_empty() {
format!("{}/{}/items", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id)
} else {
format!("{}/{}/items?{}", self.auth.endpoint(CONVERSATIONS_PATH), conversation_id, params.join("&"))
};
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<ConversationItemListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<ConversationListResponse> {
let (client, headers) = self.create_client()?;
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
let url = if params.is_empty() {
self.auth.endpoint(CONVERSATIONS_PATH)
} else {
format!("{}?{}", self.auth.endpoint(CONVERSATIONS_PATH), params.join("&"))
};
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
return Err(Self::handle_error(status, &content));
}
serde_json::from_str::<ConversationListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}