llm_api_rs/providers/
xai.rs1use crate::core::client::APIClient;
6use crate::core::{ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage};
7use crate::error::LlmApiError;
8use async_trait::async_trait;
9use reqwest::header;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize)]
13struct XaiChatRequest {
14 messages: Vec<XaiMessage>,
15 model: String,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 max_tokens: Option<u32>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct XaiMessage {
24 role: String,
25 content: Vec<XaiContent>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29struct XaiContent {
30 #[serde(rename = "type")]
31 content_type: String,
32 text: String,
33}
34
35#[derive(Debug, Deserialize)]
36struct XaiChatResponse {
37 id: String,
38 model: String,
39 choices: Vec<XaiChoice>,
40}
41
42#[derive(Debug, Deserialize)]
43struct XaiChoice {
44 message: XaiMessageResponse,
45 finish_reason: String,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49struct XaiMessageResponse {
50 role: String,
51 content: String,
52}
53
54pub struct XAI {
55 domain: String,
56 api_key: String,
57 client: APIClient,
58}
59
60impl XAI {
61 pub fn new(api_key: String) -> Self {
62 Self {
63 domain: "https://api.x.ai".to_string(),
64 api_key,
65 client: APIClient::new(),
66 }
67 }
68
69 fn convert_messages(messages: Vec<ChatMessage>) -> Vec<XaiMessage> {
70 messages
71 .into_iter()
72 .map(|msg| XaiMessage {
73 role: msg.role,
74 content: vec![XaiContent {
75 content_type: "text".to_string(),
76 text: msg.content,
77 }],
78 })
79 .collect()
80 }
81}
82
83#[async_trait]
84impl crate::providers::LlmProvider for XAI {
85 async fn chat_completion(
86 &self,
87 request: ChatCompletionRequest,
88 ) -> Result<ChatCompletionResponse, LlmApiError> {
89 let url = format!("{}/v1/chat/completions", self.domain);
90
91 let req = XaiChatRequest {
92 messages: Self::convert_messages(request.messages),
93 model: request.model,
94 temperature: request.temperature,
95 max_tokens: request.max_tokens,
96 };
97 let headers = vec![(header::AUTHORIZATION, format!("Bearer {}", self.api_key))];
98 let res: XaiChatResponse = self.client.send_request(url, headers, &req).await?;
99 Ok(ChatCompletionResponse {
100 id: res.id,
101 choices: res
102 .choices
103 .into_iter()
104 .map(|choice| ChatChoice {
105 message: ChatMessage {
106 role: choice.message.role,
107 content: choice.message.content,
108 },
109 finish_reason: choice.finish_reason,
110 })
111 .collect(),
112 model: res.model,
113 usage: None,
114 })
115 }
116}