1#[cfg(feature = "xai")]
7use crate::{
8 chat::{ChatMessage, ChatProvider, ChatRole, StructuredOutputFormat},
9 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10 embedding::EmbeddingProvider,
11 error::LLMError,
12 models::ModelsProvider,
13 stt::SpeechToTextProvider,
14 tts::TextToSpeechProvider,
15 LLMProvider,
16};
17use crate::{
18 chat::{ChatResponse, Tool},
19 ToolCall,
20};
21use async_trait::async_trait;
22use futures::stream::Stream;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25
26pub struct XAI {
31 pub api_key: String,
33 pub model: String,
35 pub max_tokens: Option<u32>,
37 pub temperature: Option<f32>,
39 pub system: Option<String>,
41 pub timeout_seconds: Option<u64>,
43 pub stream: Option<bool>,
45 pub top_p: Option<f32>,
47 pub top_k: Option<u32>,
49 pub embedding_encoding_format: Option<String>,
51 pub embedding_dimensions: Option<u32>,
53 pub json_schema: Option<StructuredOutputFormat>,
55 pub xai_search_mode: Option<String>,
57 pub xai_search_source_type: Option<String>,
59 pub xai_search_excluded_websites: Option<Vec<String>>,
61 pub xai_search_max_results: Option<u32>,
63 pub xai_search_from_date: Option<String>,
65 pub xai_search_to_date: Option<String>,
67 client: Client,
69}
70
71#[derive(Debug, Clone, serde::Serialize)]
73pub struct XaiSearchSource {
74 #[serde(rename = "type")]
76 pub source_type: String,
77 pub excluded_websites: Option<Vec<String>>,
79}
80
81#[derive(Debug, Clone, Default, serde::Serialize)]
83pub struct XaiSearchParameters {
84 pub mode: Option<String>,
86 pub sources: Option<Vec<XaiSearchSource>>,
88 pub max_search_results: Option<u32>,
90 pub from_date: Option<String>,
92 pub to_date: Option<String>,
94}
95
96#[derive(Serialize)]
98struct XAIChatMessage<'a> {
99 role: &'a str,
101 content: &'a str,
103}
104
105#[derive(Serialize)]
107struct XAIChatRequest<'a> {
108 model: &'a str,
110 messages: Vec<XAIChatMessage<'a>>,
112 #[serde(skip_serializing_if = "Option::is_none")]
114 max_tokens: Option<u32>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 temperature: Option<f32>,
118 stream: bool,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 top_p: Option<f32>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 top_k: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 response_format: Option<XAIResponseFormat>,
128 #[serde(skip_serializing_if = "Option::is_none")]
130 search_parameters: Option<&'a XaiSearchParameters>,
131}
132
133#[derive(Deserialize, Debug)]
135struct XAIChatResponse {
136 choices: Vec<XAIChatChoice>,
138}
139
140impl std::fmt::Display for XAIChatResponse {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 write!(f, "{}", self.text().unwrap_or_default())
143 }
144}
145
146impl ChatResponse for XAIChatResponse {
147 fn text(&self) -> Option<String> {
148 self.choices.first().map(|c| c.message.content.clone())
149 }
150
151 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
152 None
153 }
154}
155
156#[derive(Deserialize, Debug)]
158struct XAIChatChoice {
159 message: XAIChatMsg,
161}
162
163#[derive(Deserialize, Debug)]
165struct XAIChatMsg {
166 content: String,
168}
169
170#[derive(Debug, Serialize)]
171struct XAIEmbeddingRequest<'a> {
172 model: &'a str,
173 input: Vec<String>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 encoding_format: Option<&'a str>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 dimensions: Option<u32>,
178}
179
180#[derive(Deserialize)]
181struct XAIEmbeddingData {
182 embedding: Vec<f32>,
183}
184
185#[derive(Deserialize, Debug)]
187struct XAIStreamResponse {
188 choices: Vec<XAIStreamChoice>,
190}
191
192#[derive(Deserialize, Debug)]
194struct XAIStreamChoice {
195 delta: XAIStreamDelta,
197}
198
199#[derive(Deserialize, Debug)]
201struct XAIStreamDelta {
202 content: Option<String>,
204}
205
206#[derive(Deserialize)]
207struct XAIEmbeddingResponse {
208 data: Vec<XAIEmbeddingData>,
209}
210
211#[derive(Deserialize, Debug, Serialize)]
212enum XAIResponseType {
213 #[serde(rename = "text")]
214 Text,
215 #[serde(rename = "json_schema")]
216 JsonSchema,
217 #[serde(rename = "json_object")]
218 JsonObject,
219}
220
221#[derive(Deserialize, Debug, Serialize)]
226struct XAIResponseFormat {
227 #[serde(rename = "type")]
228 response_type: XAIResponseType,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 json_schema: Option<StructuredOutputFormat>,
231}
232
233impl XAI {
234 #[allow(clippy::too_many_arguments)]
254 pub fn new(
255 api_key: impl Into<String>,
256 model: Option<String>,
257 max_tokens: Option<u32>,
258 temperature: Option<f32>,
259 timeout_seconds: Option<u64>,
260 system: Option<String>,
261 stream: Option<bool>,
262 top_p: Option<f32>,
263 top_k: Option<u32>,
264 embedding_encoding_format: Option<String>,
265 embedding_dimensions: Option<u32>,
266 json_schema: Option<StructuredOutputFormat>,
267 xai_search_mode: Option<String>,
268 xai_search_source_type: Option<String>,
269 xai_search_excluded_websites: Option<Vec<String>>,
270 xai_search_max_results: Option<u32>,
271 xai_search_from_date: Option<String>,
272 xai_search_to_date: Option<String>,
273 ) -> Self {
274 let mut builder = Client::builder();
275 if let Some(sec) = timeout_seconds {
276 builder = builder.timeout(std::time::Duration::from_secs(sec));
277 }
278 Self {
279 api_key: api_key.into(),
280 model: model.unwrap_or("grok-2-latest".to_string()),
281 max_tokens,
282 temperature,
283 system,
284 timeout_seconds,
285 stream,
286 top_p,
287 top_k,
288 embedding_encoding_format,
289 embedding_dimensions,
290 json_schema,
291 xai_search_mode,
292 xai_search_source_type,
293 xai_search_excluded_websites,
294 xai_search_max_results,
295 xai_search_from_date,
296 xai_search_to_date,
297 client: builder.build().expect("Failed to build reqwest Client"),
298 }
299 }
300}
301
302#[async_trait]
303impl ChatProvider for XAI {
304 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
314 if self.api_key.is_empty() {
315 return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
316 }
317
318 let mut xai_msgs: Vec<XAIChatMessage> = messages
319 .iter()
320 .map(|m| XAIChatMessage {
321 role: match m.role {
322 ChatRole::User => "user",
323 ChatRole::Assistant => "assistant",
324 },
325 content: &m.content,
326 })
327 .collect();
328
329 if let Some(system) = &self.system {
330 xai_msgs.insert(
331 0,
332 XAIChatMessage {
333 role: "system",
334 content: system,
335 },
336 );
337 }
338
339 let response_format: Option<XAIResponseFormat> =
343 self.json_schema.as_ref().map(|s| XAIResponseFormat {
344 response_type: XAIResponseType::JsonSchema,
345 json_schema: Some(s.clone()),
346 });
347
348 let search_parameters = XaiSearchParameters {
349 mode: self.xai_search_mode.clone(),
350 sources: Some(vec![XaiSearchSource {
351 source_type: self.xai_search_source_type.clone().unwrap_or("web".to_string()),
352 excluded_websites: self.xai_search_excluded_websites.clone(),
353 }]),
354 max_search_results: self.xai_search_max_results.clone(),
355 from_date: self.xai_search_from_date.clone(),
356 to_date: self.xai_search_to_date.clone(),
357 };
358
359 let body = XAIChatRequest {
360 model: &self.model,
361 messages: xai_msgs,
362 max_tokens: self.max_tokens,
363 temperature: self.temperature,
364 stream: self.stream.unwrap_or(false),
365 top_p: self.top_p,
366 top_k: self.top_k,
367 response_format,
368 search_parameters: Some(&search_parameters),
369 };
370
371 if log::log_enabled!(log::Level::Trace) {
372 if let Ok(json) = serde_json::to_string(&body) {
373 log::trace!("XAI request payload: {}", json);
374 }
375 }
376
377 let mut request = self
378 .client
379 .post("https://api.x.ai/v1/chat/completions")
380 .bearer_auth(&self.api_key)
381 .json(&body);
382
383 if let Some(timeout) = self.timeout_seconds {
384 request = request.timeout(std::time::Duration::from_secs(timeout));
385 }
386
387 let resp = request.send().await?;
388
389 log::debug!("XAI HTTP status: {}", resp.status());
390
391 let resp = resp.error_for_status()?;
392
393 let json_resp: XAIChatResponse = resp.json().await?;
394 Ok(Box::new(json_resp))
395 }
396
397 async fn chat_with_tools(
408 &self,
409 messages: &[ChatMessage],
410 _tools: Option<&[Tool]>,
411 ) -> Result<Box<dyn ChatResponse>, LLMError> {
412 self.chat(messages).await
414 }
415
416 async fn chat_stream(
426 &self,
427 messages: &[ChatMessage],
428 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
429 if self.api_key.is_empty() {
430 return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
431 }
432
433 let mut xai_msgs: Vec<XAIChatMessage> = messages
434 .iter()
435 .map(|m| XAIChatMessage {
436 role: match m.role {
437 ChatRole::User => "user",
438 ChatRole::Assistant => "assistant",
439 },
440 content: &m.content,
441 })
442 .collect();
443
444 if let Some(system) = &self.system {
445 xai_msgs.insert(
446 0,
447 XAIChatMessage {
448 role: "system",
449 content: system,
450 },
451 );
452 }
453
454 let body = XAIChatRequest {
455 model: &self.model,
456 messages: xai_msgs,
457 max_tokens: self.max_tokens,
458 temperature: self.temperature,
459 stream: true,
460 top_p: self.top_p,
461 top_k: self.top_k,
462 response_format: None,
463 search_parameters: None,
464 };
465
466 let mut request = self
467 .client
468 .post("https://api.x.ai/v1/chat/completions")
469 .bearer_auth(&self.api_key)
470 .json(&body);
471
472 if let Some(timeout) = self.timeout_seconds {
473 request = request.timeout(std::time::Duration::from_secs(timeout));
474 }
475
476 let response = request.send().await?;
477
478 if !response.status().is_success() {
479 let status = response.status();
480 let error_text = response.text().await?;
481 return Err(LLMError::ResponseFormatError {
482 message: format!("X.AI API returned error status: {}", status),
483 raw_response: error_text,
484 });
485 }
486
487 Ok(crate::chat::create_sse_stream(response, parse_xai_sse_chunk))
488 }
489}
490
491#[async_trait]
492impl CompletionProvider for XAI {
493 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
505 Ok(CompletionResponse {
506 text: "X.AI completion not implemented.".into(),
507 })
508 }
509}
510
511#[async_trait]
512impl EmbeddingProvider for XAI {
513 async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
514 if self.api_key.is_empty() {
515 return Err(LLMError::AuthError("Missing X.AI API key".into()));
516 }
517
518 let emb_format = self
519 .embedding_encoding_format
520 .clone()
521 .unwrap_or_else(|| "float".to_string());
522
523 let body = XAIEmbeddingRequest {
524 model: &self.model,
525 input: text,
526 encoding_format: Some(&emb_format),
527 dimensions: self.embedding_dimensions,
528 };
529
530 let resp = self
531 .client
532 .post("https://api.x.ai/v1/embeddings")
533 .bearer_auth(&self.api_key)
534 .json(&body)
535 .send()
536 .await?
537 .error_for_status()?;
538
539 let json_resp: XAIEmbeddingResponse = resp.json().await?;
540
541 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
542 Ok(embeddings)
543 }
544}
545
546#[async_trait]
547impl SpeechToTextProvider for XAI {
548 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
549 Err(LLMError::ProviderError(
550 "XAI does not implement speech to text endpoint yet.".into(),
551 ))
552 }
553}
554
555#[async_trait]
556impl TextToSpeechProvider for XAI {}
557
558#[async_trait]
559impl ModelsProvider for XAI {}
560
561impl LLMProvider for XAI {}
562
563fn parse_xai_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
575 for line in chunk.lines() {
576 let line = line.trim();
577
578 if line.starts_with("data: ") {
579 let data = &line[6..];
580
581 if data == "[DONE]" {
582 return Ok(None);
583 }
584
585 match serde_json::from_str::<XAIStreamResponse>(data) {
586 Ok(response) => {
587 if let Some(choice) = response.choices.first() {
588 if let Some(content) = &choice.delta.content {
589 return Ok(Some(content.clone()));
590 }
591 }
592 return Ok(None);
593 }
594 Err(_) => continue,
595 }
596 }
597 }
598
599 Ok(None)
600}