llm_api_rs/providers/
xai.rs

1// XAI API provider
2// https://docs.x.ai/docs/guides/chat
3// https://console.x.ai/
4
5use crate::core::client::APIClient;
6use crate::core::{ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage};
7use crate::error::LlmApiError;
8use async_trait::async_trait;
9use reqwest::header;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize)]
13struct XaiChatRequest {
14    messages: Vec<XaiMessage>,
15    model: String,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    temperature: Option<f32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    max_tokens: Option<u32>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct XaiMessage {
24    role: String,
25    content: Vec<XaiContent>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29struct XaiContent {
30    #[serde(rename = "type")]
31    content_type: String,
32    text: String,
33}
34
35#[derive(Debug, Deserialize)]
36struct XaiChatResponse {
37    id: String,
38    model: String,
39    choices: Vec<XaiChoice>,
40}
41
42#[derive(Debug, Deserialize)]
43struct XaiChoice {
44    message: XaiMessageResponse,
45    finish_reason: String,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49struct XaiMessageResponse {
50    role: String,
51    content: String,
52}
53
54pub struct XAI {
55    domain: String,
56    api_key: String,
57    client: APIClient,
58}
59
60impl XAI {
61    pub fn new(api_key: String) -> Self {
62        Self {
63            domain: "https://api.x.ai".to_string(),
64            api_key,
65            client: APIClient::new(),
66        }
67    }
68
69    fn convert_messages(messages: Vec<ChatMessage>) -> Vec<XaiMessage> {
70        messages
71            .into_iter()
72            .map(|msg| XaiMessage {
73                role: msg.role,
74                content: vec![XaiContent {
75                    content_type: "text".to_string(),
76                    text: msg.content,
77                }],
78            })
79            .collect()
80    }
81}
82
83#[async_trait]
84impl crate::providers::LlmProvider for XAI {
85    async fn chat_completion(
86        &self,
87        request: ChatCompletionRequest,
88    ) -> Result<ChatCompletionResponse, LlmApiError> {
89        let url = format!("{}/v1/chat/completions", self.domain);
90
91        let req = XaiChatRequest {
92            messages: Self::convert_messages(request.messages),
93            model: request.model,
94            temperature: request.temperature,
95            max_tokens: request.max_tokens,
96        };
97        let headers = vec![(header::AUTHORIZATION, format!("Bearer {}", self.api_key))];
98        let res: XaiChatResponse = self.client.send_request(url, headers, &req).await?;
99        Ok(ChatCompletionResponse {
100            id: res.id,
101            choices: res
102                .choices
103                .into_iter()
104                .map(|choice| ChatChoice {
105                    message: ChatMessage {
106                        role: choice.message.role,
107                        content: choice.message.content,
108                    },
109                    finish_reason: choice.finish_reason,
110                })
111                .collect(),
112            model: res.model,
113            usage: None,
114        })
115    }
116}