llm/backends/
xai.rs

1//! X.AI API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with X.AI's models through their API.
4//! It implements chat and completion capabilities using the X.AI API endpoints.
5
6#[cfg(feature = "xai")]
7use crate::{
8    chat::{ChatMessage, ChatProvider, ChatResponse, ChatRole, StructuredOutputFormat, Tool, Usage},
9    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10    embedding::EmbeddingProvider,
11    error::LLMError,
12    models::ModelsProvider,
13    stt::SpeechToTextProvider,
14    tts::TextToSpeechProvider,
15    LLMProvider,
16};
17use crate::ToolCall;
18use async_trait::async_trait;
19use futures::stream::Stream;
20use reqwest::Client;
21use serde::{Deserialize, Serialize};
22
23/// Client for interacting with X.AI's API.
24///
25/// This struct provides methods for making chat and completion requests to X.AI's language models.
26/// It handles authentication, request configuration, and response parsing.
27pub struct XAI {
28    /// API key for authentication with X.AI services
29    pub api_key: String,
30    /// Model identifier to use for requests (e.g. "grok-2-latest")
31    pub model: String,
32    /// Maximum number of tokens to generate in responses
33    pub max_tokens: Option<u32>,
34    /// Temperature parameter for controlling response randomness (0.0 to 1.0)
35    pub temperature: Option<f32>,
36    /// Optional system prompt to provide context
37    pub system: Option<String>,
38    /// Request timeout duration in seconds
39    pub timeout_seconds: Option<u64>,
40    /// Top-p sampling parameter for controlling response diversity
41    pub top_p: Option<f32>,
42    /// Top-k sampling parameter for controlling response diversity
43    pub top_k: Option<u32>,
44    /// Embedding encoding format
45    pub embedding_encoding_format: Option<String>,
46    /// Embedding dimensions
47    pub embedding_dimensions: Option<u32>,
48    /// JSON schema for structured output
49    pub json_schema: Option<StructuredOutputFormat>,
50    /// XAI search parameters
51    pub xai_search_mode: Option<String>,
52    /// XAI search sources
53    pub xai_search_source_type: Option<String>,
54    /// XAI search excluded websites
55    pub xai_search_excluded_websites: Option<Vec<String>>,
56    /// XAI search max results
57    pub xai_search_max_results: Option<u32>,
58    /// XAI search from date
59    pub xai_search_from_date: Option<String>,
60    /// XAI search to date
61    pub xai_search_to_date: Option<String>,
62    /// HTTP client for making API requests
63    client: Client,
64}
65
66/// Search source configuration for search parameters
67#[derive(Debug, Clone, serde::Serialize)]
68pub struct XaiSearchSource {
69    /// Type of source: "web" or "news"
70    #[serde(rename = "type")]
71    pub source_type: String,
72    /// List of websites to exclude from this source
73    pub excluded_websites: Option<Vec<String>>,
74}
75
76/// Search parameters for LLM providers that support search functionality
77#[derive(Debug, Clone, Default, serde::Serialize)]
78pub struct XaiSearchParameters {
79    /// Search mode (e.g., "auto")
80    pub mode: Option<String>,
81    /// List of search sources with exclusions
82    pub sources: Option<Vec<XaiSearchSource>>,
83    /// Maximum number of search results to return
84    pub max_search_results: Option<u32>,
85    /// Start date for search results (format: "YYYY-MM-DD")
86    pub from_date: Option<String>,
87    /// End date for search results (format: "YYYY-MM-DD")
88    pub to_date: Option<String>,
89}
90
91/// Individual message in an X.AI chat conversation.
92#[derive(Serialize)]
93struct XAIChatMessage<'a> {
94    /// Role of the message sender (user, assistant, or system)
95    role: &'a str,
96    /// Content of the message
97    content: &'a str,
98}
99
100/// Request payload for X.AI's chat API endpoint.
101#[derive(Serialize)]
102struct XAIChatRequest<'a> {
103    /// Model identifier to use
104    model: &'a str,
105    /// Array of conversation messages
106    messages: Vec<XAIChatMessage<'a>>,
107    /// Maximum tokens to generate
108    #[serde(skip_serializing_if = "Option::is_none")]
109    max_tokens: Option<u32>,
110    /// Temperature parameter
111    #[serde(skip_serializing_if = "Option::is_none")]
112    temperature: Option<f32>,
113    /// Whether to stream the response
114    stream: bool,
115    /// Top-p sampling parameter
116    #[serde(skip_serializing_if = "Option::is_none")]
117    top_p: Option<f32>,
118    /// Top-k sampling parameter
119    #[serde(skip_serializing_if = "Option::is_none")]
120    top_k: Option<u32>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    response_format: Option<XAIResponseFormat>,
123    /// Search parameters for search functionality
124    #[serde(skip_serializing_if = "Option::is_none")]
125    search_parameters: Option<&'a XaiSearchParameters>,
126}
127
128/// Response from X.AI's chat API endpoint.
129#[derive(Deserialize, Debug)]
130struct XAIChatResponse {
131    /// Array of generated responses
132    choices: Vec<XAIChatChoice>,
133    /// Usage metadata for the request
134    usage: Option<Usage>,
135}
136
137impl std::fmt::Display for XAIChatResponse {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        write!(f, "{}", self.text().unwrap_or_default())
140    }
141}
142
143impl ChatResponse for XAIChatResponse {
144    fn text(&self) -> Option<String> {
145        self.choices.first().map(|c| c.message.content.clone())
146    }
147
148    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
149        None
150    }
151
152    fn usage(&self) -> Option<Usage> {
153        self.usage.clone()
154    }
155}
156
157/// Individual response choice from the chat API.
158#[derive(Deserialize, Debug)]
159struct XAIChatChoice {
160    /// Message content and metadata
161    message: XAIChatMsg,
162}
163
164/// Message content from a chat response.
165#[derive(Deserialize, Debug)]
166struct XAIChatMsg {
167    /// Generated text content
168    content: String,
169}
170
171#[derive(Debug, Serialize)]
172struct XAIEmbeddingRequest<'a> {
173    model: &'a str,
174    input: Vec<String>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    encoding_format: Option<&'a str>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    dimensions: Option<u32>,
179}
180
181#[derive(Deserialize)]
182struct XAIEmbeddingData {
183    embedding: Vec<f32>,
184}
185
186/// Response from X.AI's streaming chat API endpoint.
187#[derive(Deserialize, Debug)]
188struct XAIStreamResponse {
189    /// Array of generated responses
190    choices: Vec<XAIStreamChoice>,
191}
192
193/// Individual response choice from the streaming chat API.
194#[derive(Deserialize, Debug)]
195struct XAIStreamChoice {
196    /// Delta content
197    delta: XAIStreamDelta,
198}
199
200/// Delta content from a streaming chat response.
201#[derive(Deserialize, Debug)]
202struct XAIStreamDelta {
203    /// Generated text content
204    content: Option<String>,
205}
206
207#[derive(Deserialize)]
208struct XAIEmbeddingResponse {
209    data: Vec<XAIEmbeddingData>,
210}
211
212#[derive(Deserialize, Debug, Serialize)]
213enum XAIResponseType {
214    #[serde(rename = "text")]
215    Text,
216    #[serde(rename = "json_schema")]
217    JsonSchema,
218    #[serde(rename = "json_object")]
219    JsonObject,
220}
221
222/// An object specifying the format that the model must output.
223/// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
224/// Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it.
225/// The structured outputs feature is only supported for the `grok-2-latest` model.
226#[derive(Deserialize, Debug, Serialize)]
227struct XAIResponseFormat {
228    #[serde(rename = "type")]
229    response_type: XAIResponseType,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    json_schema: Option<StructuredOutputFormat>,
232}
233
234impl XAI {
235    /// Creates a new X.AI client with the specified configuration.
236    ///
237    /// # Arguments
238    ///
239    /// * `api_key` - Authentication key for X.AI API access
240    /// * `model` - Model identifier (defaults to "grok-2-latest" if None)
241    /// * `max_tokens` - Maximum number of tokens to generate in responses
242    /// * `temperature` - Sampling temperature for controlling randomness
243    /// * `timeout_seconds` - Request timeout duration in seconds
244    /// * `system` - System prompt for providing context
245    /// * `stream` - Whether to enable streaming responses
246    /// * `top_p` - Top-p sampling parameter
247    /// * `top_k` - Top-k sampling parameter
248    /// * `json_schema` - JSON schema for structured output
249    /// * `search_parameters` - Search parameters for search functionality
250    ///
251    /// # Returns
252    ///
253    /// A configured X.AI client instance ready to make API requests.
254    #[allow(clippy::too_many_arguments)]
255    pub fn new(
256        api_key: impl Into<String>,
257        model: Option<String>,
258        max_tokens: Option<u32>,
259        temperature: Option<f32>,
260        timeout_seconds: Option<u64>,
261        system: Option<String>,
262        top_p: Option<f32>,
263        top_k: Option<u32>,
264        embedding_encoding_format: Option<String>,
265        embedding_dimensions: Option<u32>,
266        json_schema: Option<StructuredOutputFormat>,
267        xai_search_mode: Option<String>,
268        xai_search_source_type: Option<String>,
269        xai_search_excluded_websites: Option<Vec<String>>,
270        xai_search_max_results: Option<u32>,
271        xai_search_from_date: Option<String>,
272        xai_search_to_date: Option<String>,
273    ) -> Self {
274        let mut builder = Client::builder();
275        if let Some(sec) = timeout_seconds {
276            builder = builder.timeout(std::time::Duration::from_secs(sec));
277        }
278        Self {
279            api_key: api_key.into(),
280            model: model.unwrap_or("grok-2-latest".to_string()),
281            max_tokens,
282            temperature,
283            system,
284            timeout_seconds,
285            top_p,
286            top_k,
287            embedding_encoding_format,
288            embedding_dimensions,
289            json_schema,
290            xai_search_mode,
291            xai_search_source_type,
292            xai_search_excluded_websites,
293            xai_search_max_results,
294            xai_search_from_date,
295            xai_search_to_date,
296            client: builder.build().expect("Failed to build reqwest Client"),
297        }
298    }
299}
300
301#[async_trait]
302impl ChatProvider for XAI {
303    /// Sends a chat request to the X.AI API and returns the response.
304    ///
305    /// # Arguments
306    ///
307    /// * `messages` - Array of chat messages representing the conversation
308    ///
309    /// # Returns
310    ///
311    /// The generated response text, or an error if the request fails.
312    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
313        if self.api_key.is_empty() {
314            return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
315        }
316
317        let mut xai_msgs: Vec<XAIChatMessage> = messages
318            .iter()
319            .map(|m| XAIChatMessage {
320                role: match m.role {
321                    ChatRole::User => "user",
322                    ChatRole::Assistant => "assistant",
323                },
324                content: &m.content,
325            })
326            .collect();
327
328        if let Some(system) = &self.system {
329            xai_msgs.insert(
330                0,
331                XAIChatMessage {
332                    role: "system",
333                    content: system,
334                },
335            );
336        }
337
338        // OpenAI's structured output has some [odd requirements](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat&lang=curl#supported-schemas).
339        // There's currently no check for these, so we'll leave it up to the user to provide a valid schema.
340        // Unknown if XAI requires these too, but since it copies everything else from OpenAI, it's likely.
341        let response_format: Option<XAIResponseFormat> =
342            self.json_schema.as_ref().map(|s| XAIResponseFormat {
343                response_type: XAIResponseType::JsonSchema,
344                json_schema: Some(s.clone()),
345            });
346
347        let search_parameters = XaiSearchParameters {
348            mode: self.xai_search_mode.clone(),
349            sources: Some(vec![XaiSearchSource {
350                source_type: self
351                    .xai_search_source_type
352                    .clone()
353                    .unwrap_or("web".to_string()),
354                excluded_websites: self.xai_search_excluded_websites.clone(),
355            }]),
356            max_search_results: self.xai_search_max_results,
357            from_date: self.xai_search_from_date.clone(),
358            to_date: self.xai_search_to_date.clone(),
359        };
360
361        let body = XAIChatRequest {
362            model: &self.model,
363            messages: xai_msgs,
364            max_tokens: self.max_tokens,
365            temperature: self.temperature,
366            stream: false,
367            top_p: self.top_p,
368            top_k: self.top_k,
369            response_format,
370            search_parameters: Some(&search_parameters),
371        };
372
373        if log::log_enabled!(log::Level::Trace) {
374            if let Ok(json) = serde_json::to_string(&body) {
375                log::trace!("XAI request payload: {}", json);
376            }
377        }
378
379        let mut request = self
380            .client
381            .post("https://api.x.ai/v1/chat/completions")
382            .bearer_auth(&self.api_key)
383            .json(&body);
384
385        if let Some(timeout) = self.timeout_seconds {
386            request = request.timeout(std::time::Duration::from_secs(timeout));
387        }
388
389        let resp = request.send().await?;
390
391        log::debug!("XAI HTTP status: {}", resp.status());
392
393        let resp = resp.error_for_status()?;
394
395        let json_resp: XAIChatResponse = resp.json().await?;
396        Ok(Box::new(json_resp))
397    }
398
399    /// Sends a chat request to X.AI's API with tools.
400    ///
401    /// # Arguments
402    ///
403    /// * `messages` - The conversation history as a slice of chat messages
404    /// * `tools` - Optional slice of tools to use in the chat
405    ///
406    /// # Returns
407    ///
408    /// The provider's response text or an error
409    async fn chat_with_tools(
410        &self,
411        messages: &[ChatMessage],
412        _tools: Option<&[Tool]>,
413    ) -> Result<Box<dyn ChatResponse>, LLMError> {
414        // XAI doesn't support tools yet, fall back to regular chat
415        self.chat(messages).await
416    }
417
418    /// Sends a streaming chat request to X.AI's API.
419    ///
420    /// # Arguments
421    ///
422    /// * `messages` - Slice of chat messages representing the conversation
423    ///
424    /// # Returns
425    ///
426    /// A stream of text tokens or an error
427    async fn chat_stream(
428        &self,
429        messages: &[ChatMessage],
430    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
431    {
432        if self.api_key.is_empty() {
433            return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
434        }
435
436        let mut xai_msgs: Vec<XAIChatMessage> = messages
437            .iter()
438            .map(|m| XAIChatMessage {
439                role: match m.role {
440                    ChatRole::User => "user",
441                    ChatRole::Assistant => "assistant",
442                },
443                content: &m.content,
444            })
445            .collect();
446
447        if let Some(system) = &self.system {
448            xai_msgs.insert(
449                0,
450                XAIChatMessage {
451                    role: "system",
452                    content: system,
453                },
454            );
455        }
456
457        let body = XAIChatRequest {
458            model: &self.model,
459            messages: xai_msgs,
460            max_tokens: self.max_tokens,
461            temperature: self.temperature,
462            stream: true,
463            top_p: self.top_p,
464            top_k: self.top_k,
465            response_format: None,
466            search_parameters: None,
467        };
468
469        let mut request = self
470            .client
471            .post("https://api.x.ai/v1/chat/completions")
472            .bearer_auth(&self.api_key)
473            .json(&body);
474
475        if let Some(timeout) = self.timeout_seconds {
476            request = request.timeout(std::time::Duration::from_secs(timeout));
477        }
478
479        let response = request.send().await?;
480
481        if !response.status().is_success() {
482            let status = response.status();
483            let error_text = response.text().await?;
484            return Err(LLMError::ResponseFormatError {
485                message: format!("X.AI API returned error status: {status}"),
486                raw_response: error_text,
487            });
488        }
489
490        Ok(crate::chat::create_sse_stream(
491            response,
492            parse_xai_sse_chunk,
493        ))
494    }
495}
496
497#[async_trait]
498impl CompletionProvider for XAI {
499    /// Sends a completion request to X.AI's API.
500    ///
501    /// This functionality is currently not implemented.
502    ///
503    /// # Arguments
504    ///
505    /// * `_req` - The completion request parameters
506    ///
507    /// # Returns
508    ///
509    /// A placeholder response indicating the functionality is not implemented.
510    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
511        Ok(CompletionResponse {
512            text: "X.AI completion not implemented.".into(),
513        })
514    }
515}
516
517#[async_trait]
518impl EmbeddingProvider for XAI {
519    async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
520        if self.api_key.is_empty() {
521            return Err(LLMError::AuthError("Missing X.AI API key".into()));
522        }
523
524        let emb_format = self
525            .embedding_encoding_format
526            .clone()
527            .unwrap_or_else(|| "float".to_string());
528
529        let body = XAIEmbeddingRequest {
530            model: &self.model,
531            input: text,
532            encoding_format: Some(&emb_format),
533            dimensions: self.embedding_dimensions,
534        };
535
536        let resp = self
537            .client
538            .post("https://api.x.ai/v1/embeddings")
539            .bearer_auth(&self.api_key)
540            .json(&body)
541            .send()
542            .await?
543            .error_for_status()?;
544
545        let json_resp: XAIEmbeddingResponse = resp.json().await?;
546
547        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
548        Ok(embeddings)
549    }
550}
551
552#[async_trait]
553impl SpeechToTextProvider for XAI {
554    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
555        Err(LLMError::ProviderError(
556            "XAI does not implement speech to text endpoint yet.".into(),
557        ))
558    }
559}
560
561#[async_trait]
562impl TextToSpeechProvider for XAI {}
563
564#[async_trait]
565impl ModelsProvider for XAI {}
566
567impl LLMProvider for XAI {}
568
569/// Parses a Server-Sent Events (SSE) chunk from X.AI's streaming API.
570///
571/// # Arguments
572///
573/// * `chunk` - The raw SSE chunk text
574///
575/// # Returns
576///
577/// * `Ok(Some(String))` - Content token if found
578/// * `Ok(None)` - If chunk should be skipped (e.g., ping, done signal)
579/// * `Err(LLMError)` - If parsing fails
580fn parse_xai_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
581    for line in chunk.lines() {
582        let line = line.trim();
583
584        if let Some(data) = line.strip_prefix("data: ") {
585            if data == "[DONE]" {
586                return Ok(None);
587            }
588
589            match serde_json::from_str::<XAIStreamResponse>(data) {
590                Ok(response) => {
591                    if let Some(choice) = response.choices.first() {
592                        if let Some(content) = &choice.delta.content {
593                            return Ok(Some(content.clone()));
594                        }
595                    }
596                    return Ok(None);
597                }
598                Err(_) => continue,
599            }
600        }
601    }
602
603    Ok(None)
604}