llm/backends/
xai.rs

1//! X.AI API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with X.AI's models through their API.
4//! It implements chat and completion capabilities using the X.AI API endpoints.
5
6#[cfg(feature = "xai")]
7use crate::{
8    chat::{ChatMessage, ChatProvider, ChatRole, StructuredOutputFormat},
9    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10    embedding::EmbeddingProvider,
11    error::LLMError,
12    models::ModelsProvider,
13    stt::SpeechToTextProvider,
14    tts::TextToSpeechProvider,
15    LLMProvider,
16};
17use crate::{
18    chat::{ChatResponse, Tool},
19    ToolCall,
20};
21use async_trait::async_trait;
22use futures::stream::Stream;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25
26/// Client for interacting with X.AI's API.
27///
28/// This struct provides methods for making chat and completion requests to X.AI's language models.
29/// It handles authentication, request configuration, and response parsing.
30pub struct XAI {
31    /// API key for authentication with X.AI services
32    pub api_key: String,
33    /// Model identifier to use for requests (e.g. "grok-2-latest")
34    pub model: String,
35    /// Maximum number of tokens to generate in responses
36    pub max_tokens: Option<u32>,
37    /// Temperature parameter for controlling response randomness (0.0 to 1.0)
38    pub temperature: Option<f32>,
39    /// Optional system prompt to provide context
40    pub system: Option<String>,
41    /// Request timeout duration in seconds
42    pub timeout_seconds: Option<u64>,
43    /// Whether to enable streaming responses
44    pub stream: Option<bool>,
45    /// Top-p sampling parameter for controlling response diversity
46    pub top_p: Option<f32>,
47    /// Top-k sampling parameter for controlling response diversity
48    pub top_k: Option<u32>,
49    /// Embedding encoding format
50    pub embedding_encoding_format: Option<String>,
51    /// Embedding dimensions
52    pub embedding_dimensions: Option<u32>,
53    /// JSON schema for structured output
54    pub json_schema: Option<StructuredOutputFormat>,
55    /// XAI search parameters
56    pub xai_search_mode: Option<String>,
57    /// XAI search sources
58    pub xai_search_source_type: Option<String>,
59    /// XAI search excluded websites
60    pub xai_search_excluded_websites: Option<Vec<String>>,
61    /// XAI search max results
62    pub xai_search_max_results: Option<u32>,
63    /// XAI search from date
64    pub xai_search_from_date: Option<String>,
65    /// XAI search to date
66    pub xai_search_to_date: Option<String>,
67    /// HTTP client for making API requests
68    client: Client,
69}
70
71/// Search source configuration for search parameters
72#[derive(Debug, Clone, serde::Serialize)]
73pub struct XaiSearchSource {
74    /// Type of source: "web" or "news"
75    #[serde(rename = "type")]
76    pub source_type: String,
77    /// List of websites to exclude from this source
78    pub excluded_websites: Option<Vec<String>>,
79}
80
81/// Search parameters for LLM providers that support search functionality
82#[derive(Debug, Clone, Default, serde::Serialize)]
83pub struct XaiSearchParameters {
84    /// Search mode (e.g., "auto")
85    pub mode: Option<String>,
86    /// List of search sources with exclusions
87    pub sources: Option<Vec<XaiSearchSource>>,
88    /// Maximum number of search results to return
89    pub max_search_results: Option<u32>,
90    /// Start date for search results (format: "YYYY-MM-DD")
91    pub from_date: Option<String>,
92    /// End date for search results (format: "YYYY-MM-DD")
93    pub to_date: Option<String>,
94}
95
96/// Individual message in an X.AI chat conversation.
97#[derive(Serialize)]
98struct XAIChatMessage<'a> {
99    /// Role of the message sender (user, assistant, or system)
100    role: &'a str,
101    /// Content of the message
102    content: &'a str,
103}
104
105/// Request payload for X.AI's chat API endpoint.
106#[derive(Serialize)]
107struct XAIChatRequest<'a> {
108    /// Model identifier to use
109    model: &'a str,
110    /// Array of conversation messages
111    messages: Vec<XAIChatMessage<'a>>,
112    /// Maximum tokens to generate
113    #[serde(skip_serializing_if = "Option::is_none")]
114    max_tokens: Option<u32>,
115    /// Temperature parameter
116    #[serde(skip_serializing_if = "Option::is_none")]
117    temperature: Option<f32>,
118    /// Whether to stream the response
119    stream: bool,
120    /// Top-p sampling parameter
121    #[serde(skip_serializing_if = "Option::is_none")]
122    top_p: Option<f32>,
123    /// Top-k sampling parameter
124    #[serde(skip_serializing_if = "Option::is_none")]
125    top_k: Option<u32>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    response_format: Option<XAIResponseFormat>,
128    /// Search parameters for search functionality
129    #[serde(skip_serializing_if = "Option::is_none")]
130    search_parameters: Option<&'a XaiSearchParameters>,
131}
132
133/// Response from X.AI's chat API endpoint.
134#[derive(Deserialize, Debug)]
135struct XAIChatResponse {
136    /// Array of generated responses
137    choices: Vec<XAIChatChoice>,
138}
139
140impl std::fmt::Display for XAIChatResponse {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        write!(f, "{}", self.text().unwrap_or_default())
143    }
144}
145
146impl ChatResponse for XAIChatResponse {
147    fn text(&self) -> Option<String> {
148        self.choices.first().map(|c| c.message.content.clone())
149    }
150
151    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
152        None
153    }
154}
155
156/// Individual response choice from the chat API.
157#[derive(Deserialize, Debug)]
158struct XAIChatChoice {
159    /// Message content and metadata
160    message: XAIChatMsg,
161}
162
163/// Message content from a chat response.
164#[derive(Deserialize, Debug)]
165struct XAIChatMsg {
166    /// Generated text content
167    content: String,
168}
169
170#[derive(Debug, Serialize)]
171struct XAIEmbeddingRequest<'a> {
172    model: &'a str,
173    input: Vec<String>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    encoding_format: Option<&'a str>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    dimensions: Option<u32>,
178}
179
180#[derive(Deserialize)]
181struct XAIEmbeddingData {
182    embedding: Vec<f32>,
183}
184
185/// Response from X.AI's streaming chat API endpoint.
186#[derive(Deserialize, Debug)]
187struct XAIStreamResponse {
188    /// Array of generated responses
189    choices: Vec<XAIStreamChoice>,
190}
191
192/// Individual response choice from the streaming chat API.
193#[derive(Deserialize, Debug)]
194struct XAIStreamChoice {
195    /// Delta content
196    delta: XAIStreamDelta,
197}
198
199/// Delta content from a streaming chat response.
200#[derive(Deserialize, Debug)]
201struct XAIStreamDelta {
202    /// Generated text content
203    content: Option<String>,
204}
205
206#[derive(Deserialize)]
207struct XAIEmbeddingResponse {
208    data: Vec<XAIEmbeddingData>,
209}
210
211#[derive(Deserialize, Debug, Serialize)]
212enum XAIResponseType {
213    #[serde(rename = "text")]
214    Text,
215    #[serde(rename = "json_schema")]
216    JsonSchema,
217    #[serde(rename = "json_object")]
218    JsonObject,
219}
220
221/// An object specifying the format that the model must output.
222/// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
223/// Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it.
224/// The structured outputs feature is only supported for the `grok-2-latest` model.
225#[derive(Deserialize, Debug, Serialize)]
226struct XAIResponseFormat {
227    #[serde(rename = "type")]
228    response_type: XAIResponseType,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    json_schema: Option<StructuredOutputFormat>,
231}
232
233impl XAI {
234    /// Creates a new X.AI client with the specified configuration.
235    ///
236    /// # Arguments
237    ///
238    /// * `api_key` - Authentication key for X.AI API access
239    /// * `model` - Model identifier (defaults to "grok-2-latest" if None)
240    /// * `max_tokens` - Maximum number of tokens to generate in responses
241    /// * `temperature` - Sampling temperature for controlling randomness
242    /// * `timeout_seconds` - Request timeout duration in seconds
243    /// * `system` - System prompt for providing context
244    /// * `stream` - Whether to enable streaming responses
245    /// * `top_p` - Top-p sampling parameter
246    /// * `top_k` - Top-k sampling parameter
247    /// * `json_schema` - JSON schema for structured output
248    /// * `search_parameters` - Search parameters for search functionality
249    ///
250    /// # Returns
251    ///
252    /// A configured X.AI client instance ready to make API requests.
253    #[allow(clippy::too_many_arguments)]
254    pub fn new(
255        api_key: impl Into<String>,
256        model: Option<String>,
257        max_tokens: Option<u32>,
258        temperature: Option<f32>,
259        timeout_seconds: Option<u64>,
260        system: Option<String>,
261        stream: Option<bool>,
262        top_p: Option<f32>,
263        top_k: Option<u32>,
264        embedding_encoding_format: Option<String>,
265        embedding_dimensions: Option<u32>,
266        json_schema: Option<StructuredOutputFormat>,
267        xai_search_mode: Option<String>,
268        xai_search_source_type: Option<String>,
269        xai_search_excluded_websites: Option<Vec<String>>,
270        xai_search_max_results: Option<u32>,
271        xai_search_from_date: Option<String>,
272        xai_search_to_date: Option<String>,
273    ) -> Self {
274        let mut builder = Client::builder();
275        if let Some(sec) = timeout_seconds {
276            builder = builder.timeout(std::time::Duration::from_secs(sec));
277        }
278        Self {
279            api_key: api_key.into(),
280            model: model.unwrap_or("grok-2-latest".to_string()),
281            max_tokens,
282            temperature,
283            system,
284            timeout_seconds,
285            stream,
286            top_p,
287            top_k,
288            embedding_encoding_format,
289            embedding_dimensions,
290            json_schema,
291            xai_search_mode,
292            xai_search_source_type,
293            xai_search_excluded_websites,
294            xai_search_max_results,
295            xai_search_from_date,
296            xai_search_to_date,
297            client: builder.build().expect("Failed to build reqwest Client"),
298        }
299    }
300}
301
302#[async_trait]
303impl ChatProvider for XAI {
304    /// Sends a chat request to the X.AI API and returns the response.
305    ///
306    /// # Arguments
307    ///
308    /// * `messages` - Array of chat messages representing the conversation
309    ///
310    /// # Returns
311    ///
312    /// The generated response text, or an error if the request fails.
313    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
314        if self.api_key.is_empty() {
315            return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
316        }
317
318        let mut xai_msgs: Vec<XAIChatMessage> = messages
319            .iter()
320            .map(|m| XAIChatMessage {
321                role: match m.role {
322                    ChatRole::User => "user",
323                    ChatRole::Assistant => "assistant",
324                },
325                content: &m.content,
326            })
327            .collect();
328
329        if let Some(system) = &self.system {
330            xai_msgs.insert(
331                0,
332                XAIChatMessage {
333                    role: "system",
334                    content: system,
335                },
336            );
337        }
338
339        // OpenAI's structured output has some [odd requirements](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat&lang=curl#supported-schemas).
340        // There's currently no check for these, so we'll leave it up to the user to provide a valid schema.
341        // Unknown if XAI requires these too, but since it copies everything else from OpenAI, it's likely.
342        let response_format: Option<XAIResponseFormat> =
343            self.json_schema.as_ref().map(|s| XAIResponseFormat {
344                response_type: XAIResponseType::JsonSchema,
345                json_schema: Some(s.clone()),
346            });
347
348        let search_parameters = XaiSearchParameters {
349            mode: self.xai_search_mode.clone(),
350            sources: Some(vec![XaiSearchSource {
351                source_type: self.xai_search_source_type.clone().unwrap_or("web".to_string()),
352                excluded_websites: self.xai_search_excluded_websites.clone(),
353            }]),
354            max_search_results: self.xai_search_max_results.clone(),
355            from_date: self.xai_search_from_date.clone(),
356            to_date: self.xai_search_to_date.clone(),
357        };
358        
359        let body = XAIChatRequest {
360            model: &self.model,
361            messages: xai_msgs,
362            max_tokens: self.max_tokens,
363            temperature: self.temperature,
364            stream: self.stream.unwrap_or(false),
365            top_p: self.top_p,
366            top_k: self.top_k,
367            response_format,
368            search_parameters: Some(&search_parameters),
369        };
370
371        if log::log_enabled!(log::Level::Trace) {
372            if let Ok(json) = serde_json::to_string(&body) {
373                log::trace!("XAI request payload: {}", json);
374            }
375        }
376
377        let mut request = self
378            .client
379            .post("https://api.x.ai/v1/chat/completions")
380            .bearer_auth(&self.api_key)
381            .json(&body);
382
383        if let Some(timeout) = self.timeout_seconds {
384            request = request.timeout(std::time::Duration::from_secs(timeout));
385        }
386
387        let resp = request.send().await?;
388
389        log::debug!("XAI HTTP status: {}", resp.status());
390
391        let resp = resp.error_for_status()?;
392
393        let json_resp: XAIChatResponse = resp.json().await?;
394        Ok(Box::new(json_resp))
395    }
396
397    /// Sends a chat request to X.AI's API with tools.
398    ///
399    /// # Arguments
400    ///
401    /// * `messages` - The conversation history as a slice of chat messages
402    /// * `tools` - Optional slice of tools to use in the chat
403    ///
404    /// # Returns
405    ///
406    /// The provider's response text or an error
407    async fn chat_with_tools(
408        &self,
409        messages: &[ChatMessage],
410        _tools: Option<&[Tool]>,
411    ) -> Result<Box<dyn ChatResponse>, LLMError> {
412        // XAI doesn't support tools yet, fall back to regular chat
413        self.chat(messages).await
414    }
415
416    /// Sends a streaming chat request to X.AI's API.
417    ///
418    /// # Arguments
419    ///
420    /// * `messages` - Slice of chat messages representing the conversation
421    ///
422    /// # Returns
423    ///
424    /// A stream of text tokens or an error
425    async fn chat_stream(
426        &self,
427        messages: &[ChatMessage],
428    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
429        if self.api_key.is_empty() {
430            return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
431        }
432
433        let mut xai_msgs: Vec<XAIChatMessage> = messages
434            .iter()
435            .map(|m| XAIChatMessage {
436                role: match m.role {
437                    ChatRole::User => "user",
438                    ChatRole::Assistant => "assistant",
439                },
440                content: &m.content,
441            })
442            .collect();
443
444        if let Some(system) = &self.system {
445            xai_msgs.insert(
446                0,
447                XAIChatMessage {
448                    role: "system",
449                    content: system,
450                },
451            );
452        }
453
454        let body = XAIChatRequest {
455            model: &self.model,
456            messages: xai_msgs,
457            max_tokens: self.max_tokens,
458            temperature: self.temperature,
459            stream: true,
460            top_p: self.top_p,
461            top_k: self.top_k,
462            response_format: None,
463            search_parameters: None,
464        };
465
466        let mut request = self
467            .client
468            .post("https://api.x.ai/v1/chat/completions")
469            .bearer_auth(&self.api_key)
470            .json(&body);
471
472        if let Some(timeout) = self.timeout_seconds {
473            request = request.timeout(std::time::Duration::from_secs(timeout));
474        }
475
476        let response = request.send().await?;
477
478        if !response.status().is_success() {
479            let status = response.status();
480            let error_text = response.text().await?;
481            return Err(LLMError::ResponseFormatError {
482                message: format!("X.AI API returned error status: {}", status),
483                raw_response: error_text,
484            });
485        }
486
487        Ok(crate::chat::create_sse_stream(response, parse_xai_sse_chunk))
488    }
489}
490
491#[async_trait]
492impl CompletionProvider for XAI {
493    /// Sends a completion request to X.AI's API.
494    ///
495    /// This functionality is currently not implemented.
496    ///
497    /// # Arguments
498    ///
499    /// * `_req` - The completion request parameters
500    ///
501    /// # Returns
502    ///
503    /// A placeholder response indicating the functionality is not implemented.
504    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
505        Ok(CompletionResponse {
506            text: "X.AI completion not implemented.".into(),
507        })
508    }
509}
510
511#[async_trait]
512impl EmbeddingProvider for XAI {
513    async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
514        if self.api_key.is_empty() {
515            return Err(LLMError::AuthError("Missing X.AI API key".into()));
516        }
517
518        let emb_format = self
519            .embedding_encoding_format
520            .clone()
521            .unwrap_or_else(|| "float".to_string());
522
523        let body = XAIEmbeddingRequest {
524            model: &self.model,
525            input: text,
526            encoding_format: Some(&emb_format),
527            dimensions: self.embedding_dimensions,
528        };
529
530        let resp = self
531            .client
532            .post("https://api.x.ai/v1/embeddings")
533            .bearer_auth(&self.api_key)
534            .json(&body)
535            .send()
536            .await?
537            .error_for_status()?;
538
539        let json_resp: XAIEmbeddingResponse = resp.json().await?;
540
541        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
542        Ok(embeddings)
543    }
544}
545
546#[async_trait]
547impl SpeechToTextProvider for XAI {
548    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
549        Err(LLMError::ProviderError(
550            "XAI does not implement speech to text endpoint yet.".into(),
551        ))
552    }
553}
554
555#[async_trait]
556impl TextToSpeechProvider for XAI {}
557
558#[async_trait]
559impl ModelsProvider for XAI {}
560
561impl LLMProvider for XAI {}
562
563/// Parses a Server-Sent Events (SSE) chunk from X.AI's streaming API.
564///
565/// # Arguments
566///
567/// * `chunk` - The raw SSE chunk text
568///
569/// # Returns
570///
571/// * `Ok(Some(String))` - Content token if found
572/// * `Ok(None)` - If chunk should be skipped (e.g., ping, done signal)
573/// * `Err(LLMError)` - If parsing fails
574fn parse_xai_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
575    for line in chunk.lines() {
576        let line = line.trim();
577        
578        if line.starts_with("data: ") {
579            let data = &line[6..];
580            
581            if data == "[DONE]" {
582                return Ok(None);
583            }
584            
585            match serde_json::from_str::<XAIStreamResponse>(data) {
586                Ok(response) => {
587                    if let Some(choice) = response.choices.first() {
588                        if let Some(content) = &choice.delta.content {
589                            return Ok(Some(content.clone()));
590                        }
591                    }
592                    return Ok(None);
593                }
594                Err(_) => continue,
595            }
596        }
597    }
598    
599    Ok(None)
600}