bep/providers/xai/
completion.rs1use crate::{
7 completion::{self, CompletionError},
8 json_utils,
9};
10
11use serde_json::json;
12use xai_api_types::{CompletionResponse, ToolDefinition};
13
14use super::client::{xai_api_types::ApiResponse, Client};
15
16pub const GROK_BETA: &str = "grok-beta";
18
19#[derive(Clone)]
24pub struct CompletionModel {
25 client: Client,
26 pub model: String,
27}
28
29impl CompletionModel {
30 pub fn new(client: Client, model: &str) -> Self {
31 Self {
32 client,
33 model: model.to_string(),
34 }
35 }
36}
37
38impl completion::CompletionModel for CompletionModel {
39 type Response = CompletionResponse;
40
41 async fn completion(
42 &self,
43 mut completion_request: completion::CompletionRequest,
44 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
45 let mut messages = if let Some(preamble) = &completion_request.preamble {
46 vec![completion::Message {
47 role: "system".into(),
48 content: preamble.clone(),
49 }]
50 } else {
51 vec![]
52 };
53 messages.append(&mut completion_request.chat_history);
54
55 let prompt_with_context = completion_request.prompt_with_context();
56
57 messages.push(completion::Message {
58 role: "user".into(),
59 content: prompt_with_context,
60 });
61
62 let mut request = if completion_request.tools.is_empty() {
63 json!({
64 "model": self.model,
65 "messages": messages,
66 "temperature": completion_request.temperature,
67 })
68 } else {
69 json!({
70 "model": self.model,
71 "messages": messages,
72 "temperature": completion_request.temperature,
73 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
74 "tool_choice": "auto",
75 })
76 };
77
78 request = if let Some(params) = completion_request.additional_params {
79 json_utils::merge(request, params)
80 } else {
81 request
82 };
83
84 let response = self
85 .client
86 .post("/v1/chat/completions")
87 .json(&request)
88 .send()
89 .await?;
90
91 if response.status().is_success() {
92 match response.json::<ApiResponse<CompletionResponse>>().await? {
93 ApiResponse::Ok(completion) => completion.try_into(),
94 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
95 }
96 } else {
97 Err(CompletionError::ProviderError(response.text().await?))
98 }
99 }
100}
101
102pub mod xai_api_types {
103 use serde::{Deserialize, Serialize};
104
105 use crate::completion::{self, CompletionError};
106
107 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
108 type Error = CompletionError;
109
110 fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
111 match value.choices.as_slice() {
112 [Choice {
113 message:
114 Message {
115 content: Some(content),
116 ..
117 },
118 ..
119 }, ..] => Ok(completion::CompletionResponse {
120 choice: completion::ModelChoice::Message(content.to_string()),
121 raw_response: value,
122 }),
123 [Choice {
124 message:
125 Message {
126 tool_calls: Some(calls),
127 ..
128 },
129 ..
130 }, ..] => {
131 let call = calls.first().ok_or(CompletionError::ResponseError(
132 "Tool selection is empty".into(),
133 ))?;
134
135 Ok(completion::CompletionResponse {
136 choice: completion::ModelChoice::ToolCall(
137 call.function.name.clone(),
138 serde_json::from_str(&call.function.arguments)?,
139 ),
140 raw_response: value,
141 })
142 }
143 _ => Err(CompletionError::ResponseError(
144 "Response did not contain a message or tool call".into(),
145 )),
146 }
147 }
148 }
149
150 impl From<completion::ToolDefinition> for ToolDefinition {
151 fn from(tool: completion::ToolDefinition) -> Self {
152 Self {
153 r#type: "function".into(),
154 function: tool,
155 }
156 }
157 }
158
159 #[derive(Debug, Deserialize)]
160 pub struct ToolCall {
161 pub id: String,
162 pub r#type: String,
163 pub function: Function,
164 }
165
166 #[derive(Clone, Debug, Deserialize, Serialize)]
167 pub struct ToolDefinition {
168 pub r#type: String,
169 pub function: completion::ToolDefinition,
170 }
171
172 #[derive(Debug, Deserialize)]
173 pub struct Function {
174 pub name: String,
175 pub arguments: String,
176 }
177
178 #[derive(Debug, Deserialize)]
179 pub struct CompletionResponse {
180 pub id: String,
181 pub model: String,
182 pub choices: Vec<Choice>,
183 pub created: i64,
184 pub object: String,
185 pub system_fingerprint: String,
186 pub usage: Usage,
187 }
188
189 #[derive(Debug, Deserialize)]
190 pub struct Choice {
191 pub finish_reason: String,
192 pub index: i32,
193 pub message: Message,
194 }
195
196 #[derive(Debug, Deserialize)]
197 pub struct Message {
198 pub role: String,
199 pub content: Option<String>,
200 pub tool_calls: Option<Vec<ToolCall>>,
201 }
202
203 #[derive(Debug, Deserialize)]
204 pub struct Usage {
205 pub completion_tokens: i32,
206 pub prompt_tokens: i32,
207 pub total_tokens: i32,
208 }
209}