1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::error::ProviderError;
7use crate::llm::{
8 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
9 ModelUsage,
10};
11
12const DEFAULT_API_BASE_URL: &str = "https://api.x.ai/v1";
13const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
14
15#[derive(Debug, Clone)]
16pub struct GrokModelConfig {
17 pub api_key: String,
18 pub model: String,
19 pub api_base_url: Option<String>,
20 pub temperature: Option<f32>,
21 pub top_p: Option<f32>,
22 pub max_tokens: Option<u32>,
23}
24
25impl GrokModelConfig {
26 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
27 Self {
28 api_key: api_key.into(),
29 model: model.into(),
30 api_base_url: None,
31 temperature: None,
32 top_p: None,
33 max_tokens: Some(4096),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct GrokModel {
40 client: Client,
41 config: GrokModelConfig,
42}
43
44impl GrokModel {
45 pub fn new(config: GrokModelConfig) -> Result<Self, ProviderError> {
46 let client = Client::builder()
47 .build()
48 .map_err(|err| ProviderError::Request(err.to_string()))?;
49
50 Ok(Self { client, config })
51 }
52
53 pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
54 let api_key = std::env::var("XAI_API_KEY")
55 .or_else(|_| std::env::var("GROK_API_KEY"))
56 .map_err(|_| {
57 ProviderError::Request("XAI_API_KEY (or GROK_API_KEY) is not set".to_string())
58 })?;
59
60 Self::new(GrokModelConfig::new(api_key, model))
61 }
62
63 fn endpoint(&self) -> String {
64 let base = self
65 .config
66 .api_base_url
67 .as_deref()
68 .unwrap_or(DEFAULT_API_BASE_URL)
69 .trim_end_matches('/');
70 format!("{base}/chat/completions")
71 }
72}
73
74#[async_trait]
75impl ChatModel for GrokModel {
76 async fn invoke(
77 &self,
78 messages: &[ModelMessage],
79 tools: &[ModelToolDefinition],
80 tool_choice: ModelToolChoice,
81 ) -> Result<ModelCompletion, ProviderError> {
82 let request = build_request(messages, tools, tool_choice, &self.config);
83
84 let response = self
85 .client
86 .post(self.endpoint())
87 .header("authorization", format!("Bearer {}", self.config.api_key))
88 .header("content-type", "application/json")
89 .json(&request)
90 .send()
91 .await
92 .map_err(|err| ProviderError::Request(err.to_string()))?;
93
94 if !response.status().is_success() {
95 return Err(ProviderError::Request(extract_api_error(response).await));
96 }
97
98 let payload = response
99 .json::<GrokChatCompletionResponse>()
100 .await
101 .map_err(|err| ProviderError::Response(err.to_string()))?;
102
103 normalize_response(payload)
104 }
105}
106
107#[derive(Debug, Serialize)]
108struct GrokChatCompletionRequest {
109 model: String,
110 messages: Vec<GrokRequestMessage>,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 tools: Option<Vec<GrokToolDefinition>>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 tool_choice: Option<GrokToolChoicePayload>,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 temperature: Option<f32>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 top_p: Option<f32>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 max_tokens: Option<u32>,
121}
122
123#[derive(Debug, Serialize)]
124#[serde(tag = "role", rename_all = "lowercase")]
125enum GrokRequestMessage {
126 System {
127 content: String,
128 },
129 User {
130 content: String,
131 },
132 Assistant {
133 #[serde(skip_serializing_if = "Option::is_none")]
134 content: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 tool_calls: Option<Vec<GrokToolCall>>,
137 },
138 Tool {
139 tool_call_id: String,
140 content: String,
141 },
142}
143
144#[derive(Debug, Serialize)]
145struct GrokToolDefinition {
146 #[serde(rename = "type")]
147 type_: String,
148 function: GrokToolFunctionDefinition,
149}
150
151#[derive(Debug, Serialize)]
152struct GrokToolFunctionDefinition {
153 name: String,
154 description: String,
155 parameters: Value,
156}
157
158#[derive(Debug, Serialize)]
159#[serde(untagged)]
160enum GrokToolChoicePayload {
161 Mode(String),
162 Specific {
163 #[serde(rename = "type")]
164 type_: String,
165 function: GrokToolChoiceFunction,
166 },
167}
168
169#[derive(Debug, Serialize)]
170struct GrokToolChoiceFunction {
171 name: String,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175struct GrokToolCall {
176 id: String,
177 #[serde(rename = "type")]
178 type_: String,
179 function: GrokToolCallFunction,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183struct GrokToolCallFunction {
184 name: String,
185 arguments: String,
186}
187
188#[derive(Debug, Deserialize)]
189struct GrokChatCompletionResponse {
190 #[serde(default)]
191 choices: Vec<GrokChoice>,
192 usage: Option<GrokUsage>,
193}
194
195#[derive(Debug, Deserialize)]
196struct GrokChoice {
197 message: Option<GrokAssistantMessage>,
198}
199
200#[derive(Debug, Deserialize)]
201struct GrokAssistantMessage {
202 content: Option<String>,
203 #[serde(default)]
204 tool_calls: Vec<GrokToolCall>,
205 #[serde(default)]
206 reasoning_content: Option<String>,
207}
208
209#[derive(Debug, Deserialize)]
210struct GrokUsage {
211 prompt_tokens: Option<u32>,
212 completion_tokens: Option<u32>,
213 reasoning_tokens: Option<u32>,
214 completion_tokens_details: Option<GrokCompletionTokenDetails>,
215}
216
217#[derive(Debug, Deserialize)]
218struct GrokCompletionTokenDetails {
219 reasoning_tokens: Option<u32>,
220}
221
222#[derive(Debug, Deserialize)]
223struct GrokErrorEnvelope {
224 error: GrokApiError,
225}
226
227#[derive(Debug, Deserialize)]
228struct GrokApiError {
229 message: Option<String>,
230 #[serde(rename = "type")]
231 type_: Option<String>,
232 code: Option<Value>,
233}
234
235fn build_request(
236 messages: &[ModelMessage],
237 tools: &[ModelToolDefinition],
238 tool_choice: ModelToolChoice,
239 config: &GrokModelConfig,
240) -> GrokChatCompletionRequest {
241 let request_messages = ensure_non_empty_messages(to_grok_messages(messages));
242
243 let tools_payload = if tools.is_empty() {
244 None
245 } else {
246 Some(
247 tools
248 .iter()
249 .map(|tool| GrokToolDefinition {
250 type_: "function".to_string(),
251 function: GrokToolFunctionDefinition {
252 name: tool.name.clone(),
253 description: tool.description.clone(),
254 parameters: tool.parameters.clone(),
255 },
256 })
257 .collect::<Vec<_>>(),
258 )
259 };
260
261 let tool_choice_payload = if tools.is_empty() {
262 None
263 } else {
264 Some(match tool_choice {
265 ModelToolChoice::Auto => GrokToolChoicePayload::Mode("auto".to_string()),
266 ModelToolChoice::Required => GrokToolChoicePayload::Mode("required".to_string()),
267 ModelToolChoice::None => GrokToolChoicePayload::Mode("none".to_string()),
268 ModelToolChoice::Tool(name) => GrokToolChoicePayload::Specific {
269 type_: "function".to_string(),
270 function: GrokToolChoiceFunction { name },
271 },
272 })
273 };
274
275 GrokChatCompletionRequest {
276 model: config.model.clone(),
277 messages: request_messages,
278 tools: tools_payload,
279 tool_choice: tool_choice_payload,
280 temperature: config.temperature,
281 top_p: config.top_p,
282 max_tokens: config.max_tokens,
283 }
284}
285
286fn to_grok_messages(messages: &[ModelMessage]) -> Vec<GrokRequestMessage> {
287 let mut request_messages = Vec::new();
288
289 for message in messages {
290 match message {
291 ModelMessage::System(content) => {
292 if content.is_empty() {
293 continue;
294 }
295 request_messages.push(GrokRequestMessage::System {
296 content: content.clone(),
297 });
298 }
299 ModelMessage::User(content) => {
300 if content.is_empty() {
301 continue;
302 }
303 request_messages.push(GrokRequestMessage::User {
304 content: content.clone(),
305 });
306 }
307 ModelMessage::Assistant {
308 content,
309 tool_calls,
310 } => {
311 let serialized_tool_calls = tool_calls
312 .iter()
313 .map(|tool_call| GrokToolCall {
314 id: tool_call.id.clone(),
315 type_: "function".to_string(),
316 function: GrokToolCallFunction {
317 name: tool_call.name.clone(),
318 arguments: tool_call.arguments.to_string(),
319 },
320 })
321 .collect::<Vec<_>>();
322
323 let assistant_content = content.as_ref().filter(|text| !text.is_empty()).cloned();
324 if assistant_content.is_none() && serialized_tool_calls.is_empty() {
325 continue;
326 }
327
328 request_messages.push(GrokRequestMessage::Assistant {
329 content: assistant_content,
330 tool_calls: if serialized_tool_calls.is_empty() {
331 None
332 } else {
333 Some(serialized_tool_calls)
334 },
335 });
336 }
337 ModelMessage::ToolResult {
338 tool_call_id,
339 tool_name: _,
340 content,
341 is_error,
342 } => {
343 let rendered = if *is_error {
344 format!("Error: {content}")
345 } else {
346 content.clone()
347 };
348
349 request_messages.push(GrokRequestMessage::Tool {
350 tool_call_id: tool_call_id.clone(),
351 content: rendered,
352 });
353 }
354 }
355 }
356
357 request_messages
358}
359
360fn ensure_non_empty_messages(mut messages: Vec<GrokRequestMessage>) -> Vec<GrokRequestMessage> {
361 let mut normalized = Vec::with_capacity(messages.len().saturating_add(1));
362 let mut pending_tool_call_ids = Vec::<String>::new();
363
364 for message in messages.drain(..) {
365 match message {
366 GrokRequestMessage::System { content } => {
367 pending_tool_call_ids.clear();
368 normalized.push(GrokRequestMessage::System { content });
369 }
370 GrokRequestMessage::User { content } => {
371 pending_tool_call_ids.clear();
372 normalized.push(GrokRequestMessage::User { content });
373 }
374 GrokRequestMessage::Assistant {
375 content,
376 tool_calls,
377 } => {
378 pending_tool_call_ids.clear();
379 if let Some(calls) = &tool_calls {
380 pending_tool_call_ids.extend(calls.iter().map(|call| call.id.clone()));
381 }
382 normalized.push(GrokRequestMessage::Assistant {
383 content,
384 tool_calls,
385 });
386 }
387 GrokRequestMessage::Tool {
388 tool_call_id,
389 content,
390 } => {
391 if let Some(position) = pending_tool_call_ids
392 .iter()
393 .position(|id| id == &tool_call_id)
394 {
395 pending_tool_call_ids.remove(position);
396 normalized.push(GrokRequestMessage::Tool {
397 tool_call_id,
398 content,
399 });
400 }
401 }
402 }
403 }
404
405 if normalized.is_empty() {
406 normalized.push(GrokRequestMessage::User {
407 content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
408 });
409 return normalized;
410 }
411
412 let starts_with_valid_role = matches!(
413 normalized.first(),
414 Some(GrokRequestMessage::System { .. } | GrokRequestMessage::User { .. })
415 );
416 if !starts_with_valid_role {
417 normalized.insert(
418 0,
419 GrokRequestMessage::User {
420 content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
421 },
422 );
423 }
424
425 normalized
426}
427
428fn normalize_response(
429 response: GrokChatCompletionResponse,
430) -> Result<ModelCompletion, ProviderError> {
431 let choice = response
432 .choices
433 .into_iter()
434 .next()
435 .ok_or_else(|| ProviderError::Response("grok response missing choices".to_string()))?;
436
437 let message = choice.message.ok_or_else(|| {
438 ProviderError::Response("grok response missing choice message".to_string())
439 })?;
440
441 let mut tool_calls = Vec::new();
442 for tool_call in message.tool_calls {
443 let arguments = if tool_call.function.arguments.trim().is_empty() {
444 json!({})
445 } else {
446 serde_json::from_str::<Value>(&tool_call.function.arguments).map_err(|err| {
447 ProviderError::Response(format!(
448 "grok tool call arguments for '{}' are not valid JSON: {err}",
449 tool_call.function.name
450 ))
451 })?
452 };
453
454 tool_calls.push(ModelToolCall {
455 id: tool_call.id,
456 name: tool_call.function.name,
457 arguments,
458 });
459 }
460
461 let usage = response.usage.map(|usage| ModelUsage {
462 input_tokens: usage.prompt_tokens.unwrap_or(0),
465 output_tokens: usage.completion_tokens.unwrap_or(0).saturating_add(
466 usage.reasoning_tokens.unwrap_or_else(|| {
467 usage
468 .completion_tokens_details
469 .and_then(|details| details.reasoning_tokens)
470 .unwrap_or(0)
471 }),
472 ),
473 });
474
475 Ok(ModelCompletion {
476 text: message.content.filter(|text| !text.is_empty()),
477 thinking: message.reasoning_content.filter(|text| !text.is_empty()),
478 tool_calls,
479 usage,
480 })
481}
482
483async fn extract_api_error(response: reqwest::Response) -> String {
484 let status = response.status();
485 let body = response.text().await.unwrap_or_default();
486
487 if let Ok(parsed) = serde_json::from_str::<GrokErrorEnvelope>(&body) {
488 let code = parsed
489 .error
490 .code
491 .map(|value| match value {
492 Value::String(value) => value,
493 other => other.to_string(),
494 })
495 .unwrap_or_else(|| status.as_u16().to_string());
496 let error_type = parsed
497 .error
498 .type_
499 .unwrap_or_else(|| status.to_string().to_uppercase());
500 let message = parsed
501 .error
502 .message
503 .unwrap_or_else(|| "unknown xai api error".to_string());
504
505 return format!("xai api error {code} {error_type}: {message}");
506 }
507
508 if body.is_empty() {
509 format!("xai api request failed ({status})")
510 } else {
511 format!("xai api request failed ({status}): {body}")
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use serde_json::json;
518
519 use super::*;
520
521 fn tool_definition() -> ModelToolDefinition {
522 ModelToolDefinition {
523 name: "lookup".to_string(),
524 description: "Look up something".to_string(),
525 parameters: json!({
526 "type": "object",
527 "properties": {
528 "query": {"type": "string"}
529 },
530 "required": ["query"],
531 "additionalProperties": false
532 }),
533 }
534 }
535
536 #[test]
537 fn build_request_serializes_messages_tools_and_tool_choice() {
538 let messages = vec![
539 ModelMessage::System("You are helpful".to_string()),
540 ModelMessage::User("Find docs".to_string()),
541 ModelMessage::Assistant {
542 content: Some("Calling tool".to_string()),
543 tool_calls: vec![ModelToolCall {
544 id: "call_1".to_string(),
545 name: "lookup".to_string(),
546 arguments: json!({"query": "rust"}),
547 }],
548 },
549 ModelMessage::ToolResult {
550 tool_call_id: "call_1".to_string(),
551 tool_name: "lookup".to_string(),
552 content: "{\"result\":\"ok\"}".to_string(),
553 is_error: false,
554 },
555 ];
556
557 let mut config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
558 config.temperature = Some(0.2);
559 config.max_tokens = Some(512);
560
561 let request = build_request(
562 &messages,
563 &[tool_definition()],
564 ModelToolChoice::Tool("lookup".to_string()),
565 &config,
566 );
567 let value = serde_json::to_value(request).expect("serializes");
568
569 assert_eq!(value["messages"][0]["role"], "system");
570 assert_eq!(value["messages"][0]["content"], "You are helpful");
571 assert_eq!(value["messages"][2]["role"], "assistant");
572 assert_eq!(
573 value["messages"][2]["tool_calls"][0]["function"]["name"],
574 "lookup"
575 );
576 assert_eq!(
577 value["messages"][2]["tool_calls"][0]["function"]["arguments"],
578 "{\"query\":\"rust\"}"
579 );
580 assert_eq!(value["messages"][3]["role"], "tool");
581 assert_eq!(value["messages"][3]["tool_call_id"], "call_1");
582 assert_eq!(value["tools"][0]["function"]["name"], "lookup");
583 assert_eq!(value["tool_choice"]["type"], "function");
584 assert_eq!(value["tool_choice"]["function"]["name"], "lookup");
585 assert!((value["temperature"].as_f64().unwrap_or_default() - 0.2).abs() < 1e-6);
586 assert_eq!(value["max_tokens"], 512);
587 }
588
589 #[test]
590 fn build_request_adds_fallback_content_for_empty_user_message() {
591 let messages = vec![ModelMessage::User(String::new())];
592 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
593
594 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
595 let value = serde_json::to_value(request).expect("serializes");
596
597 assert_eq!(
598 value["messages"].as_array().map(|values| values.len()),
599 Some(1)
600 );
601 assert_eq!(value["messages"][0]["role"], "user");
602 assert_eq!(value["messages"][0]["content"], " ");
603 assert!(value.get("tools").is_none());
604 assert!(value.get("tool_choice").is_none());
605 }
606
607 #[test]
608 fn build_request_inserts_fallback_and_drops_orphan_tool_messages() {
609 let messages = vec![ModelMessage::ToolResult {
610 tool_call_id: "call_1".to_string(),
611 tool_name: "lookup".to_string(),
612 content: "result".to_string(),
613 is_error: false,
614 }];
615 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
616
617 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
618 let value = serde_json::to_value(request).expect("serializes");
619
620 assert_eq!(
621 value["messages"].as_array().map(|values| values.len()),
622 Some(1)
623 );
624 assert_eq!(value["messages"][0]["role"], "user");
625 assert_eq!(value["messages"][0]["content"], " ");
626 }
627
628 #[test]
629 fn build_request_inserts_fallback_when_first_message_is_assistant() {
630 let messages = vec![
631 ModelMessage::User(String::new()),
632 ModelMessage::Assistant {
633 content: Some("Calling tool".to_string()),
634 tool_calls: vec![ModelToolCall {
635 id: "call_1".to_string(),
636 name: "lookup".to_string(),
637 arguments: json!({"query": "rust"}),
638 }],
639 },
640 ModelMessage::ToolResult {
641 tool_call_id: "call_1".to_string(),
642 tool_name: "lookup".to_string(),
643 content: "{\"result\":\"ok\"}".to_string(),
644 is_error: false,
645 },
646 ];
647 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
648
649 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
650 let value = serde_json::to_value(request).expect("serializes");
651
652 assert_eq!(value["messages"][0]["role"], "user");
653 assert_eq!(value["messages"][0]["content"], " ");
654 assert_eq!(value["messages"][1]["role"], "assistant");
655 assert_eq!(value["messages"][2]["role"], "tool");
656 assert_eq!(value["messages"][2]["tool_call_id"], "call_1");
657 }
658
659 #[test]
660 fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
661 let response = GrokChatCompletionResponse {
662 choices: vec![GrokChoice {
663 message: Some(GrokAssistantMessage {
664 content: Some("answer".to_string()),
665 tool_calls: vec![GrokToolCall {
666 id: "call_x".to_string(),
667 type_: "function".to_string(),
668 function: GrokToolCallFunction {
669 name: "lookup".to_string(),
670 arguments: "{\"q\":\"rust\"}".to_string(),
671 },
672 }],
673 reasoning_content: Some("reasoning".to_string()),
674 }),
675 }],
676 usage: Some(GrokUsage {
677 prompt_tokens: Some(11),
678 completion_tokens: Some(7),
679 reasoning_tokens: None,
680 completion_tokens_details: Some(GrokCompletionTokenDetails {
681 reasoning_tokens: Some(3),
682 }),
683 }),
684 };
685
686 let completion = normalize_response(response).expect("response normalizes");
687
688 assert_eq!(completion.text.as_deref(), Some("answer"));
689 assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
690 assert_eq!(completion.tool_calls.len(), 1);
691 assert_eq!(completion.tool_calls[0].name, "lookup");
692 assert_eq!(completion.tool_calls[0].id, "call_x");
693 assert_eq!(
694 completion.usage,
695 Some(ModelUsage {
696 input_tokens: 11,
697 output_tokens: 10,
698 })
699 );
700 }
701
702 #[test]
703 fn normalize_response_prefers_top_level_reasoning_tokens() {
704 let response = GrokChatCompletionResponse {
705 choices: vec![GrokChoice {
706 message: Some(GrokAssistantMessage {
707 content: Some("answer".to_string()),
708 tool_calls: Vec::new(),
709 reasoning_content: None,
710 }),
711 }],
712 usage: Some(GrokUsage {
713 prompt_tokens: Some(11),
714 completion_tokens: Some(7),
715 reasoning_tokens: Some(4),
716 completion_tokens_details: Some(GrokCompletionTokenDetails {
717 reasoning_tokens: Some(3),
718 }),
719 }),
720 };
721
722 let completion = normalize_response(response).expect("response normalizes");
723
724 assert_eq!(
725 completion.usage,
726 Some(ModelUsage {
727 input_tokens: 11,
728 output_tokens: 11,
729 })
730 );
731 }
732
733 #[test]
734 fn normalize_response_requires_choices() {
735 let err = normalize_response(GrokChatCompletionResponse {
736 choices: Vec::new(),
737 usage: None,
738 })
739 .expect_err("should fail");
740
741 match err {
742 ProviderError::Response(message) => {
743 assert!(message.contains("missing choices"));
744 }
745 other => panic!("unexpected error: {other}"),
746 }
747 }
748
749 #[test]
750 fn normalize_response_fails_on_invalid_tool_arguments() {
751 let err = normalize_response(GrokChatCompletionResponse {
752 choices: vec![GrokChoice {
753 message: Some(GrokAssistantMessage {
754 content: None,
755 tool_calls: vec![GrokToolCall {
756 id: "call_x".to_string(),
757 type_: "function".to_string(),
758 function: GrokToolCallFunction {
759 name: "lookup".to_string(),
760 arguments: "{not json}".to_string(),
761 },
762 }],
763 reasoning_content: None,
764 }),
765 }],
766 usage: None,
767 })
768 .expect_err("should fail");
769
770 match err {
771 ProviderError::Response(message) => {
772 assert!(message.contains("not valid JSON"));
773 }
774 other => panic!("unexpected error: {other}"),
775 }
776 }
777}