1#[cfg(feature = "xai")]
7use crate::{
8 chat::{ChatMessage, ChatProvider, ChatResponse, ChatRole, StructuredOutputFormat, Tool, Usage},
9 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10 embedding::EmbeddingProvider,
11 error::LLMError,
12 models::ModelsProvider,
13 stt::SpeechToTextProvider,
14 tts::TextToSpeechProvider,
15 LLMProvider,
16};
17use crate::ToolCall;
18use async_trait::async_trait;
19use futures::stream::Stream;
20use reqwest::Client;
21use serde::{Deserialize, Serialize};
22
23pub struct XAI {
28 pub api_key: String,
30 pub model: String,
32 pub max_tokens: Option<u32>,
34 pub temperature: Option<f32>,
36 pub system: Option<String>,
38 pub timeout_seconds: Option<u64>,
40 pub top_p: Option<f32>,
42 pub top_k: Option<u32>,
44 pub embedding_encoding_format: Option<String>,
46 pub embedding_dimensions: Option<u32>,
48 pub json_schema: Option<StructuredOutputFormat>,
50 pub xai_search_mode: Option<String>,
52 pub xai_search_source_type: Option<String>,
54 pub xai_search_excluded_websites: Option<Vec<String>>,
56 pub xai_search_max_results: Option<u32>,
58 pub xai_search_from_date: Option<String>,
60 pub xai_search_to_date: Option<String>,
62 client: Client,
64}
65
66#[derive(Debug, Clone, serde::Serialize)]
68pub struct XaiSearchSource {
69 #[serde(rename = "type")]
71 pub source_type: String,
72 pub excluded_websites: Option<Vec<String>>,
74}
75
76#[derive(Debug, Clone, Default, serde::Serialize)]
78pub struct XaiSearchParameters {
79 pub mode: Option<String>,
81 pub sources: Option<Vec<XaiSearchSource>>,
83 pub max_search_results: Option<u32>,
85 pub from_date: Option<String>,
87 pub to_date: Option<String>,
89}
90
91#[derive(Serialize)]
93struct XAIChatMessage<'a> {
94 role: &'a str,
96 content: &'a str,
98}
99
100#[derive(Serialize)]
102struct XAIChatRequest<'a> {
103 model: &'a str,
105 messages: Vec<XAIChatMessage<'a>>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 max_tokens: Option<u32>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 temperature: Option<f32>,
113 stream: bool,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 top_p: Option<f32>,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 top_k: Option<u32>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 response_format: Option<XAIResponseFormat>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 search_parameters: Option<&'a XaiSearchParameters>,
126}
127
128#[derive(Deserialize, Debug)]
130struct XAIChatResponse {
131 choices: Vec<XAIChatChoice>,
133 usage: Option<Usage>,
135}
136
137impl std::fmt::Display for XAIChatResponse {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 write!(f, "{}", self.text().unwrap_or_default())
140 }
141}
142
143impl ChatResponse for XAIChatResponse {
144 fn text(&self) -> Option<String> {
145 self.choices.first().map(|c| c.message.content.clone())
146 }
147
148 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
149 None
150 }
151
152 fn usage(&self) -> Option<Usage> {
153 self.usage.clone()
154 }
155}
156
157#[derive(Deserialize, Debug)]
159struct XAIChatChoice {
160 message: XAIChatMsg,
162}
163
164#[derive(Deserialize, Debug)]
166struct XAIChatMsg {
167 content: String,
169}
170
171#[derive(Debug, Serialize)]
172struct XAIEmbeddingRequest<'a> {
173 model: &'a str,
174 input: Vec<String>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 encoding_format: Option<&'a str>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 dimensions: Option<u32>,
179}
180
181#[derive(Deserialize)]
182struct XAIEmbeddingData {
183 embedding: Vec<f32>,
184}
185
186#[derive(Deserialize, Debug)]
188struct XAIStreamResponse {
189 choices: Vec<XAIStreamChoice>,
191}
192
193#[derive(Deserialize, Debug)]
195struct XAIStreamChoice {
196 delta: XAIStreamDelta,
198}
199
200#[derive(Deserialize, Debug)]
202struct XAIStreamDelta {
203 content: Option<String>,
205}
206
207#[derive(Deserialize)]
208struct XAIEmbeddingResponse {
209 data: Vec<XAIEmbeddingData>,
210}
211
212#[derive(Deserialize, Debug, Serialize)]
213enum XAIResponseType {
214 #[serde(rename = "text")]
215 Text,
216 #[serde(rename = "json_schema")]
217 JsonSchema,
218 #[serde(rename = "json_object")]
219 JsonObject,
220}
221
222#[derive(Deserialize, Debug, Serialize)]
227struct XAIResponseFormat {
228 #[serde(rename = "type")]
229 response_type: XAIResponseType,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 json_schema: Option<StructuredOutputFormat>,
232}
233
234impl XAI {
235 #[allow(clippy::too_many_arguments)]
255 pub fn new(
256 api_key: impl Into<String>,
257 model: Option<String>,
258 max_tokens: Option<u32>,
259 temperature: Option<f32>,
260 timeout_seconds: Option<u64>,
261 system: Option<String>,
262 top_p: Option<f32>,
263 top_k: Option<u32>,
264 embedding_encoding_format: Option<String>,
265 embedding_dimensions: Option<u32>,
266 json_schema: Option<StructuredOutputFormat>,
267 xai_search_mode: Option<String>,
268 xai_search_source_type: Option<String>,
269 xai_search_excluded_websites: Option<Vec<String>>,
270 xai_search_max_results: Option<u32>,
271 xai_search_from_date: Option<String>,
272 xai_search_to_date: Option<String>,
273 ) -> Self {
274 let mut builder = Client::builder();
275 if let Some(sec) = timeout_seconds {
276 builder = builder.timeout(std::time::Duration::from_secs(sec));
277 }
278 Self {
279 api_key: api_key.into(),
280 model: model.unwrap_or("grok-2-latest".to_string()),
281 max_tokens,
282 temperature,
283 system,
284 timeout_seconds,
285 top_p,
286 top_k,
287 embedding_encoding_format,
288 embedding_dimensions,
289 json_schema,
290 xai_search_mode,
291 xai_search_source_type,
292 xai_search_excluded_websites,
293 xai_search_max_results,
294 xai_search_from_date,
295 xai_search_to_date,
296 client: builder.build().expect("Failed to build reqwest Client"),
297 }
298 }
299}
300
301#[async_trait]
302impl ChatProvider for XAI {
303 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
313 if self.api_key.is_empty() {
314 return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
315 }
316
317 let mut xai_msgs: Vec<XAIChatMessage> = messages
318 .iter()
319 .map(|m| XAIChatMessage {
320 role: match m.role {
321 ChatRole::User => "user",
322 ChatRole::Assistant => "assistant",
323 },
324 content: &m.content,
325 })
326 .collect();
327
328 if let Some(system) = &self.system {
329 xai_msgs.insert(
330 0,
331 XAIChatMessage {
332 role: "system",
333 content: system,
334 },
335 );
336 }
337
338 let response_format: Option<XAIResponseFormat> =
342 self.json_schema.as_ref().map(|s| XAIResponseFormat {
343 response_type: XAIResponseType::JsonSchema,
344 json_schema: Some(s.clone()),
345 });
346
347 let search_parameters = XaiSearchParameters {
348 mode: self.xai_search_mode.clone(),
349 sources: Some(vec![XaiSearchSource {
350 source_type: self
351 .xai_search_source_type
352 .clone()
353 .unwrap_or("web".to_string()),
354 excluded_websites: self.xai_search_excluded_websites.clone(),
355 }]),
356 max_search_results: self.xai_search_max_results,
357 from_date: self.xai_search_from_date.clone(),
358 to_date: self.xai_search_to_date.clone(),
359 };
360
361 let body = XAIChatRequest {
362 model: &self.model,
363 messages: xai_msgs,
364 max_tokens: self.max_tokens,
365 temperature: self.temperature,
366 stream: false,
367 top_p: self.top_p,
368 top_k: self.top_k,
369 response_format,
370 search_parameters: Some(&search_parameters),
371 };
372
373 if log::log_enabled!(log::Level::Trace) {
374 if let Ok(json) = serde_json::to_string(&body) {
375 log::trace!("XAI request payload: {}", json);
376 }
377 }
378
379 let mut request = self
380 .client
381 .post("https://api.x.ai/v1/chat/completions")
382 .bearer_auth(&self.api_key)
383 .json(&body);
384
385 if let Some(timeout) = self.timeout_seconds {
386 request = request.timeout(std::time::Duration::from_secs(timeout));
387 }
388
389 let resp = request.send().await?;
390
391 log::debug!("XAI HTTP status: {}", resp.status());
392
393 let resp = resp.error_for_status()?;
394
395 let json_resp: XAIChatResponse = resp.json().await?;
396 Ok(Box::new(json_resp))
397 }
398
399 async fn chat_with_tools(
410 &self,
411 messages: &[ChatMessage],
412 _tools: Option<&[Tool]>,
413 ) -> Result<Box<dyn ChatResponse>, LLMError> {
414 self.chat(messages).await
416 }
417
418 async fn chat_stream(
428 &self,
429 messages: &[ChatMessage],
430 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
431 {
432 if self.api_key.is_empty() {
433 return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
434 }
435
436 let mut xai_msgs: Vec<XAIChatMessage> = messages
437 .iter()
438 .map(|m| XAIChatMessage {
439 role: match m.role {
440 ChatRole::User => "user",
441 ChatRole::Assistant => "assistant",
442 },
443 content: &m.content,
444 })
445 .collect();
446
447 if let Some(system) = &self.system {
448 xai_msgs.insert(
449 0,
450 XAIChatMessage {
451 role: "system",
452 content: system,
453 },
454 );
455 }
456
457 let body = XAIChatRequest {
458 model: &self.model,
459 messages: xai_msgs,
460 max_tokens: self.max_tokens,
461 temperature: self.temperature,
462 stream: true,
463 top_p: self.top_p,
464 top_k: self.top_k,
465 response_format: None,
466 search_parameters: None,
467 };
468
469 let mut request = self
470 .client
471 .post("https://api.x.ai/v1/chat/completions")
472 .bearer_auth(&self.api_key)
473 .json(&body);
474
475 if let Some(timeout) = self.timeout_seconds {
476 request = request.timeout(std::time::Duration::from_secs(timeout));
477 }
478
479 let response = request.send().await?;
480
481 if !response.status().is_success() {
482 let status = response.status();
483 let error_text = response.text().await?;
484 return Err(LLMError::ResponseFormatError {
485 message: format!("X.AI API returned error status: {status}"),
486 raw_response: error_text,
487 });
488 }
489
490 Ok(crate::chat::create_sse_stream(
491 response,
492 parse_xai_sse_chunk,
493 ))
494 }
495}
496
497#[async_trait]
498impl CompletionProvider for XAI {
499 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
511 Ok(CompletionResponse {
512 text: "X.AI completion not implemented.".into(),
513 })
514 }
515}
516
517#[async_trait]
518impl EmbeddingProvider for XAI {
519 async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
520 if self.api_key.is_empty() {
521 return Err(LLMError::AuthError("Missing X.AI API key".into()));
522 }
523
524 let emb_format = self
525 .embedding_encoding_format
526 .clone()
527 .unwrap_or_else(|| "float".to_string());
528
529 let body = XAIEmbeddingRequest {
530 model: &self.model,
531 input: text,
532 encoding_format: Some(&emb_format),
533 dimensions: self.embedding_dimensions,
534 };
535
536 let resp = self
537 .client
538 .post("https://api.x.ai/v1/embeddings")
539 .bearer_auth(&self.api_key)
540 .json(&body)
541 .send()
542 .await?
543 .error_for_status()?;
544
545 let json_resp: XAIEmbeddingResponse = resp.json().await?;
546
547 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
548 Ok(embeddings)
549 }
550}
551
552#[async_trait]
553impl SpeechToTextProvider for XAI {
554 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
555 Err(LLMError::ProviderError(
556 "XAI does not implement speech to text endpoint yet.".into(),
557 ))
558 }
559}
560
561#[async_trait]
562impl TextToSpeechProvider for XAI {}
563
564#[async_trait]
565impl ModelsProvider for XAI {}
566
567impl LLMProvider for XAI {}
568
569fn parse_xai_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
581 for line in chunk.lines() {
582 let line = line.trim();
583
584 if let Some(data) = line.strip_prefix("data: ") {
585 if data == "[DONE]" {
586 return Ok(None);
587 }
588
589 match serde_json::from_str::<XAIStreamResponse>(data) {
590 Ok(response) => {
591 if let Some(choice) = response.choices.first() {
592 if let Some(content) = &choice.delta.content {
593 return Ok(Some(content.clone()));
594 }
595 }
596 return Ok(None);
597 }
598 Err(_) => continue,
599 }
600 }
601 }
602
603 Ok(None)
604}