use crate::error::XaiError;
use crate::error::check_for_model_error;
use crate::traits::ChatCompletionsFetcher;
use crate::traits::ClientConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl fmt::Display for ContentPart {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ContentPart::Text { text } => write!(f, "{}", text),
ContentPart::ImageUrl { image_url } => write!(f, "[Image: {}]", image_url.url),
}
}
}
impl fmt::Display for MessageContent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MessageContent::Text(text) => write!(f, "{}", text),
MessageContent::Parts(parts) => {
for (i, part) in parts.iter().enumerate() {
if i > 0 {
write!(f, " ")?;
}
write!(f, "{}", part)?;
}
Ok(())
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<serde_json::Value>>,
}
impl Message {
pub fn text(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: MessageContent::Text(content.into()),
reasoning_content: None,
refusal: None,
tool_calls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub from_date: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub to_date: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_citations: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_search_results: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sources: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<u32, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deferred: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_parameters: Option<SearchParameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub web_search_options: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Logprobs {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Logprobs>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTokensDetails {
pub text_tokens: u32,
pub audio_tokens: u32,
pub image_tokens: u32,
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: u32,
pub audio_tokens: u32,
pub accepted_prediction_tokens: u32,
pub rejected_prediction_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_in_usd_ticks: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_sources_used: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citations: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_files: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ChatCompletionsRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
client: T,
request: ChatCompletionRequest,
}
impl<T> ChatCompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
pub fn new(client: T, model: String, messages: Vec<Message>) -> Self {
Self {
client,
request: ChatCompletionRequest {
model,
messages,
temperature: None,
max_tokens: None,
max_completion_tokens: None,
frequency_penalty: None,
presence_penalty: None,
n: None,
stop: None,
stream: false,
logprobs: None,
top_p: None,
top_logprobs: None,
seed: None,
user: None,
logit_bias: None,
deferred: None,
parallel_tool_calls: None,
reasoning_effort: None,
response_format: None,
search_parameters: None,
stream_options: None,
tool_choice: None,
tools: None,
web_search_options: None,
},
}
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.request.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.request.max_tokens = Some(max_tokens);
self
}
pub fn max_completion_tokens(mut self, max_completion_tokens: u32) -> Self {
self.request.max_completion_tokens = Some(max_completion_tokens);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.request.frequency_penalty = Some(frequency_penalty);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.request.presence_penalty = Some(presence_penalty);
self
}
pub fn n(mut self, n: u32) -> Self {
self.request.n = Some(n);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.request.stop = Some(stop);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.request.stream = stream;
self
}
pub fn logprobs(mut self, logprobs: bool) -> Self {
self.request.logprobs = Some(logprobs);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.request.top_p = Some(top_p);
self
}
pub fn top_logprobs(mut self, top_logprobs: u32) -> Self {
self.request.top_logprobs = Some(top_logprobs);
self
}
pub fn seed(mut self, seed: u32) -> Self {
self.request.seed = Some(seed);
self
}
pub fn user(mut self, user: String) -> Self {
self.request.user = Some(user);
self
}
pub fn logit_bias(mut self, logit_bias: HashMap<u32, f32>) -> Self {
self.request.logit_bias = Some(logit_bias);
self
}
pub fn deferred(mut self, deferred: bool) -> Self {
self.request.deferred = Some(deferred);
self
}
pub fn parallel_tool_calls(mut self, parallel_tool_calls: bool) -> Self {
self.request.parallel_tool_calls = Some(parallel_tool_calls);
self
}
pub fn reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.request.reasoning_effort = Some(effort.into());
self
}
pub fn response_format(mut self, response_format: serde_json::Value) -> Self {
self.request.response_format = Some(response_format);
self
}
pub fn search_parameters(mut self, search_parameters: SearchParameters) -> Self {
self.request.search_parameters = Some(search_parameters);
self
}
pub fn stream_options(mut self, stream_options: StreamOptions) -> Self {
self.request.stream_options = Some(stream_options);
self
}
pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
self.request.tool_choice = Some(tool_choice);
self
}
pub fn tools(mut self, tools: Vec<serde_json::Value>) -> Self {
self.request.tools = Some(tools);
self
}
pub fn web_search_options(mut self, web_search_options: serde_json::Value) -> Self {
self.request.web_search_options = Some(web_search_options);
self
}
pub fn build(self) -> Result<ChatCompletionRequest, XaiError> {
Ok(self.request)
}
}
impl<T> ChatCompletionsFetcher for ChatCompletionsRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
async fn create_chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, XaiError> {
let response = self
.client
.request(reqwest::Method::POST, "chat/completions")?
.json(&request)
.send()
.await?;
if response.status().is_success() {
let chat_completion = response.json::<ChatCompletionResponse>().await?;
Ok(chat_completion)
} else {
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}
Err(XaiError::Http(error_body))
}
}
}