rig/providers/xai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    http_client,
7};
8
9// ================================================================
10// xAI Client
11// ================================================================
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct XAiExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct XAiExtBuilder;
17
18type XAiApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<XAiExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<XAiExtBuilder, XAiApiKey, H>;
22
23const XAI_BASE_URL: &str = "https://api.x.ai";
24
25impl Provider for XAiExt {
26    type Builder = XAiExtBuilder;
27
28    const VERIFY_PATH: &'static str = "/v1/api-key";
29
30    fn build<H>(
31        _: &client::ClientBuilder<Self::Builder, XAiApiKey, H>,
32    ) -> http_client::Result<Self> {
33        Ok(Self)
34    }
35}
36
37impl<H> Capabilities<H> for XAiExt {
38    type Completion = Capable<super::completion::CompletionModel<H>>;
39
40    type Embeddings = Nothing;
41    type Transcription = Nothing;
42    #[cfg(feature = "image")]
43    type ImageGeneration = Nothing;
44    #[cfg(feature = "audio")]
45    type AudioGeneration = Nothing;
46}
47
48impl DebugExt for XAiExt {}
49
50impl ProviderBuilder for XAiExtBuilder {
51    type Output = XAiExt;
52    type ApiKey = XAiApiKey;
53
54    const BASE_URL: &'static str = XAI_BASE_URL;
55}
56
57impl ProviderClient for Client {
58    type Input = String;
59
60    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
61    /// Panics if the environment variable is not set.
62    fn from_env() -> Self {
63        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
64        Self::new(&api_key).unwrap()
65    }
66
67    fn from_val(input: Self::Input) -> Self {
68        Self::new(&input).unwrap()
69    }
70}
71
72pub mod xai_api_types {
73    use serde::Deserialize;
74
75    impl ApiErrorResponse {
76        pub fn message(&self) -> String {
77            format!("Code `{}`: {}", self.code, self.error)
78        }
79    }
80
81    #[derive(Debug, Deserialize)]
82    pub struct ApiErrorResponse {
83        pub error: String,
84        pub code: String,
85    }
86
87    #[derive(Debug, Deserialize)]
88    #[serde(untagged)]
89    pub enum ApiResponse<T> {
90        Ok(T),
91        Error(ApiErrorResponse),
92    }
93}