rig/providers/xai/
completion.rs1use crate::{
7 completion::{self, CompletionError},
8 json_utils,
9 providers::openai::Message,
10};
11
12use serde_json::json;
13use xai_api_types::{CompletionResponse, ToolDefinition};
14
15use super::client::{xai_api_types::ApiResponse, Client};
16
17pub const GROK_BETA: &str = "grok-beta";
19
20#[derive(Clone)]
25pub struct CompletionModel {
26 client: Client,
27 pub model: String,
28}
29
30impl CompletionModel {
31 pub fn new(client: Client, model: &str) -> Self {
32 Self {
33 client,
34 model: model.to_string(),
35 }
36 }
37}
38
39impl completion::CompletionModel for CompletionModel {
40 type Response = CompletionResponse;
41
42 #[cfg_attr(feature = "worker", worker::send)]
43 async fn completion(
44 &self,
45 completion_request: completion::CompletionRequest,
46 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
47 let mut full_history: Vec<Message> = match &completion_request.preamble {
49 Some(preamble) => {
50 if preamble.is_empty() {
51 vec![]
52 } else {
53 vec![Message::system(preamble)]
54 }
55 }
56 None => vec![],
57 };
58
59 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
61
62 let chat_history: Vec<Message> = completion_request
64 .chat_history
65 .into_iter()
66 .map(|message| message.try_into())
67 .collect::<Result<Vec<Vec<Message>>, _>>()?
68 .into_iter()
69 .flatten()
70 .collect();
71
72 full_history.extend(chat_history);
74 full_history.extend(prompt);
75
76 let mut request = if completion_request.tools.is_empty() {
77 json!({
78 "model": self.model,
79 "messages": full_history,
80 "temperature": completion_request.temperature,
81 })
82 } else {
83 json!({
84 "model": self.model,
85 "messages": full_history,
86 "temperature": completion_request.temperature,
87 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
88 "tool_choice": "auto",
89 })
90 };
91
92 request = if let Some(params) = completion_request.additional_params {
93 json_utils::merge(request, params)
94 } else {
95 request
96 };
97
98 let response = self
99 .client
100 .post("/v1/chat/completions")
101 .json(&request)
102 .send()
103 .await?;
104
105 if response.status().is_success() {
106 match response.json::<ApiResponse<CompletionResponse>>().await? {
107 ApiResponse::Ok(completion) => completion.try_into(),
108 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
109 }
110 } else {
111 Err(CompletionError::ProviderError(response.text().await?))
112 }
113 }
114}
115
116pub mod xai_api_types {
117 use serde::{Deserialize, Serialize};
118
119 use crate::completion::{self, CompletionError};
120 use crate::providers::openai::{AssistantContent, Message};
121 use crate::OneOrMany;
122
123 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
124 type Error = CompletionError;
125
126 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
127 let choice = response.choices.first().ok_or_else(|| {
128 CompletionError::ResponseError("Response contained no choices".to_owned())
129 })?;
130 let content = match &choice.message {
131 Message::Assistant {
132 content,
133 tool_calls,
134 ..
135 } => {
136 let mut content = content
137 .iter()
138 .map(|c| match c {
139 AssistantContent::Text { text } => {
140 completion::AssistantContent::text(text)
141 }
142 AssistantContent::Refusal { refusal } => {
143 completion::AssistantContent::text(refusal)
144 }
145 })
146 .collect::<Vec<_>>();
147
148 content.extend(
149 tool_calls
150 .iter()
151 .map(|call| {
152 completion::AssistantContent::tool_call(
153 &call.function.name,
154 &call.function.name,
155 call.function.arguments.clone(),
156 )
157 })
158 .collect::<Vec<_>>(),
159 );
160 Ok(content)
161 }
162 _ => Err(CompletionError::ResponseError(
163 "Response did not contain a valid message or tool call".into(),
164 )),
165 }?;
166
167 let choice = OneOrMany::many(content).map_err(|_| {
168 CompletionError::ResponseError(
169 "Response contained no message or tool call (empty)".to_owned(),
170 )
171 })?;
172
173 Ok(completion::CompletionResponse {
174 choice,
175 raw_response: response,
176 })
177 }
178 }
179
180 impl From<completion::ToolDefinition> for ToolDefinition {
181 fn from(tool: completion::ToolDefinition) -> Self {
182 Self {
183 r#type: "function".into(),
184 function: tool,
185 }
186 }
187 }
188
189 #[derive(Clone, Debug, Deserialize, Serialize)]
190 pub struct ToolDefinition {
191 pub r#type: String,
192 pub function: completion::ToolDefinition,
193 }
194
195 #[derive(Debug, Deserialize)]
196 pub struct Function {
197 pub name: String,
198 pub arguments: String,
199 }
200
201 #[derive(Debug, Deserialize)]
202 pub struct CompletionResponse {
203 pub id: String,
204 pub model: String,
205 pub choices: Vec<Choice>,
206 pub created: i64,
207 pub object: String,
208 pub system_fingerprint: String,
209 pub usage: Usage,
210 }
211
212 #[derive(Debug, Deserialize)]
213 pub struct Choice {
214 pub finish_reason: String,
215 pub index: i32,
216 pub message: Message,
217 }
218
219 #[derive(Debug, Deserialize)]
220 pub struct Usage {
221 pub completion_tokens: i32,
222 pub prompt_tokens: i32,
223 pub total_tokens: i32,
224 }
225}