use crate::chat_compl::SearchParameters;
use crate::error::XaiError;
use crate::error::check_for_model_error;
use crate::traits::{ClientConfig, ResponsesFetcher};
use reqwest::Method;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponseInput {
Text(String),
Messages(Vec<serde_json::Value>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generate_summary: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateResponseRequest {
pub model: String,
pub input: ResponseInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_turns: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ReasoningConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_parameters: Option<SearchParameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_management: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseInputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseOutputTokensDetails {
pub reasoning_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSideToolUsageDetails {
pub code_interpreter_calls: u32,
pub document_search_calls: u32,
pub file_search_calls: u32,
pub mcp_calls: u32,
pub web_search_calls: u32,
pub x_search_calls: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
pub input_tokens_details: ResponseInputTokensDetails,
pub output_tokens_details: ResponseOutputTokensDetails,
pub num_sources_used: u32,
pub num_server_side_tools_used: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_in_usd_ticks: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_in_nano_usd: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_side_tool_usage_details: Option<ServerSideToolUsageDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseObject {
pub id: String,
pub object: String,
pub model: String,
pub status: String,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<u64>,
pub output: Vec<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponseUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
pub store: bool,
pub parallel_tool_calls: bool,
pub tool_choice: serde_json::Value,
pub tools: Vec<serde_json::Value>,
pub text: serde_json::Value,
pub truncation: String,
pub service_tier: String,
pub top_logprobs: u32,
pub frequency_penalty: f64,
pub presence_penalty: f64,
pub metadata: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ReasoningConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub incomplete_details: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<serde_json::Value>,
pub background: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tool_calls: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteResponseObject {
pub id: String,
pub object: String,
pub deleted: bool,
}
#[derive(Debug, Clone)]
pub struct ResponsesRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
client: T,
request: CreateResponseRequest,
}
impl<T> ResponsesRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
pub fn new(client: T, model: impl Into<String>, input: ResponseInput) -> Self {
Self {
client,
request: CreateResponseRequest {
model: model.into(),
input,
instructions: None,
previous_response_id: None,
store: None,
stream: None,
temperature: None,
top_p: None,
max_output_tokens: None,
max_turns: None,
reasoning: None,
search_parameters: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
text: None,
logprobs: None,
top_logprobs: None,
user: None,
context_management: None,
include: None,
service_tier: None,
truncation: None,
background: None,
prompt_cache_key: None,
metadata: None,
},
}
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.request.instructions = Some(instructions.into());
self
}
pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
self.request.previous_response_id = Some(id.into());
self
}
pub fn store(mut self, store: bool) -> Self {
self.request.store = Some(store);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.request.stream = Some(stream);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.request.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.request.top_p = Some(top_p);
self
}
pub fn max_output_tokens(mut self, max_output_tokens: u32) -> Self {
self.request.max_output_tokens = Some(max_output_tokens);
self
}
pub fn max_turns(mut self, max_turns: u32) -> Self {
self.request.max_turns = Some(max_turns);
self
}
pub fn reasoning(mut self, reasoning: ReasoningConfig) -> Self {
self.request.reasoning = Some(reasoning);
self
}
pub fn search_parameters(mut self, search_parameters: SearchParameters) -> Self {
self.request.search_parameters = Some(search_parameters);
self
}
pub fn tools(mut self, tools: Vec<serde_json::Value>) -> Self {
self.request.tools = Some(tools);
self
}
pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
self.request.tool_choice = Some(tool_choice);
self
}
pub fn parallel_tool_calls(mut self, parallel_tool_calls: bool) -> Self {
self.request.parallel_tool_calls = Some(parallel_tool_calls);
self
}
pub fn logprobs(mut self, logprobs: bool) -> Self {
self.request.logprobs = Some(logprobs);
self
}
pub fn top_logprobs(mut self, top_logprobs: u32) -> Self {
self.request.top_logprobs = Some(top_logprobs);
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.request.user = Some(user.into());
self
}
pub fn include(mut self, include: Vec<String>) -> Self {
self.request.include = Some(include);
self
}
pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
self.request.prompt_cache_key = Some(key.into());
self
}
pub fn build(self) -> Result<CreateResponseRequest, XaiError> {
if self.request.model.trim().is_empty() {
return Err(XaiError::Validation("model cannot be empty".to_string()));
}
Ok(self.request)
}
}
impl<T> ResponsesFetcher for ResponsesRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
async fn create_response(
&self,
request: CreateResponseRequest,
) -> Result<ResponseObject, XaiError> {
let response = self
.client
.request(Method::POST, "responses")?
.json(&request)
.send()
.await?;
if response.status().is_success() {
let obj = response.json::<ResponseObject>().await?;
Ok(obj)
} else {
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}
Err(XaiError::Http(error_body))
}
}
async fn get_response(&self, response_id: &str) -> Result<ResponseObject, XaiError> {
let url = format!("responses/{}", response_id);
let response = self.client.request(Method::GET, &url)?.send().await?;
if response.status().is_success() {
let obj = response.json::<ResponseObject>().await?;
Ok(obj)
} else {
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}
Err(XaiError::Http(error_body))
}
}
async fn delete_response(&self, response_id: &str) -> Result<DeleteResponseObject, XaiError> {
let url = format!("responses/{}", response_id);
let response = self.client.request(Method::DELETE, &url)?.send().await?;
if response.status().is_success() {
let obj = response.json::<DeleteResponseObject>().await?;
Ok(obj)
} else {
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}
Err(XaiError::Http(error_body))
}
}
}