bep/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9};
10
11use serde_json::json;
12use xai_api_types::{CompletionResponse, ToolDefinition};
13
14use super::client::{xai_api_types::ApiResponse, Client};
15
16/// `grok-beta` completion model
17pub const GROK_BETA: &str = "grok-beta";
18
19// =================================================================
20// Bep Implementation Types
21// =================================================================
22
23#[derive(Clone)]
24pub struct CompletionModel {
25    client: Client,
26    pub model: String,
27}
28
29impl CompletionModel {
30    pub fn new(client: Client, model: &str) -> Self {
31        Self {
32            client,
33            model: model.to_string(),
34        }
35    }
36}
37
38impl completion::CompletionModel for CompletionModel {
39    type Response = CompletionResponse;
40
41    async fn completion(
42        &self,
43        mut completion_request: completion::CompletionRequest,
44    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
45        let mut messages = if let Some(preamble) = &completion_request.preamble {
46            vec![completion::Message {
47                role: "system".into(),
48                content: preamble.clone(),
49            }]
50        } else {
51            vec![]
52        };
53        messages.append(&mut completion_request.chat_history);
54
55        let prompt_with_context = completion_request.prompt_with_context();
56
57        messages.push(completion::Message {
58            role: "user".into(),
59            content: prompt_with_context,
60        });
61
62        let mut request = if completion_request.tools.is_empty() {
63            json!({
64                "model": self.model,
65                "messages": messages,
66                "temperature": completion_request.temperature,
67            })
68        } else {
69            json!({
70                "model": self.model,
71                "messages": messages,
72                "temperature": completion_request.temperature,
73                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
74                "tool_choice": "auto",
75            })
76        };
77
78        request = if let Some(params) = completion_request.additional_params {
79            json_utils::merge(request, params)
80        } else {
81            request
82        };
83
84        let response = self
85            .client
86            .post("/v1/chat/completions")
87            .json(&request)
88            .send()
89            .await?;
90
91        if response.status().is_success() {
92            match response.json::<ApiResponse<CompletionResponse>>().await? {
93                ApiResponse::Ok(completion) => completion.try_into(),
94                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
95            }
96        } else {
97            Err(CompletionError::ProviderError(response.text().await?))
98        }
99    }
100}
101
102pub mod xai_api_types {
103    use serde::{Deserialize, Serialize};
104
105    use crate::completion::{self, CompletionError};
106
107    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
108        type Error = CompletionError;
109
110        fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
111            match value.choices.as_slice() {
112                [Choice {
113                    message:
114                        Message {
115                            content: Some(content),
116                            ..
117                        },
118                    ..
119                }, ..] => Ok(completion::CompletionResponse {
120                    choice: completion::ModelChoice::Message(content.to_string()),
121                    raw_response: value,
122                }),
123                [Choice {
124                    message:
125                        Message {
126                            tool_calls: Some(calls),
127                            ..
128                        },
129                    ..
130                }, ..] => {
131                    let call = calls.first().ok_or(CompletionError::ResponseError(
132                        "Tool selection is empty".into(),
133                    ))?;
134
135                    Ok(completion::CompletionResponse {
136                        choice: completion::ModelChoice::ToolCall(
137                            call.function.name.clone(),
138                            serde_json::from_str(&call.function.arguments)?,
139                        ),
140                        raw_response: value,
141                    })
142                }
143                _ => Err(CompletionError::ResponseError(
144                    "Response did not contain a message or tool call".into(),
145                )),
146            }
147        }
148    }
149
150    impl From<completion::ToolDefinition> for ToolDefinition {
151        fn from(tool: completion::ToolDefinition) -> Self {
152            Self {
153                r#type: "function".into(),
154                function: tool,
155            }
156        }
157    }
158
159    #[derive(Debug, Deserialize)]
160    pub struct ToolCall {
161        pub id: String,
162        pub r#type: String,
163        pub function: Function,
164    }
165
166    #[derive(Clone, Debug, Deserialize, Serialize)]
167    pub struct ToolDefinition {
168        pub r#type: String,
169        pub function: completion::ToolDefinition,
170    }
171
172    #[derive(Debug, Deserialize)]
173    pub struct Function {
174        pub name: String,
175        pub arguments: String,
176    }
177
178    #[derive(Debug, Deserialize)]
179    pub struct CompletionResponse {
180        pub id: String,
181        pub model: String,
182        pub choices: Vec<Choice>,
183        pub created: i64,
184        pub object: String,
185        pub system_fingerprint: String,
186        pub usage: Usage,
187    }
188
189    #[derive(Debug, Deserialize)]
190    pub struct Choice {
191        pub finish_reason: String,
192        pub index: i32,
193        pub message: Message,
194    }
195
196    #[derive(Debug, Deserialize)]
197    pub struct Message {
198        pub role: String,
199        pub content: Option<String>,
200        pub tool_calls: Option<Vec<ToolCall>>,
201    }
202
203    #[derive(Debug, Deserialize)]
204    pub struct Usage {
205        pub completion_tokens: i32,
206        pub prompt_tokens: i32,
207        pub total_tokens: i32,
208    }
209}