1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::error::ProviderError;
7use crate::llm::{
8 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
9 ModelUsage,
10};
11
12const DEFAULT_API_BASE_URL: &str = "https://api.x.ai/v1";
13const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
14
15#[derive(Debug, Clone)]
16pub struct GrokModelConfig {
18 pub api_key: String,
20 pub model: String,
22 pub api_base_url: Option<String>,
24 pub temperature: Option<f32>,
26 pub top_p: Option<f32>,
28 pub max_tokens: Option<u32>,
30}
31
32impl GrokModelConfig {
33 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
35 Self {
36 api_key: api_key.into(),
37 model: model.into(),
38 api_base_url: None,
39 temperature: None,
40 top_p: None,
41 max_tokens: Some(4096),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct GrokModel {
49 client: Client,
50 config: GrokModelConfig,
51}
52
53impl GrokModel {
54 pub fn new(config: GrokModelConfig) -> Result<Self, ProviderError> {
56 let client = Client::builder()
57 .build()
58 .map_err(|err| ProviderError::Request(err.to_string()))?;
59
60 Ok(Self { client, config })
61 }
62
63 pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
65 let api_key = std::env::var("XAI_API_KEY")
66 .or_else(|_| std::env::var("GROK_API_KEY"))
67 .map_err(|_| {
68 ProviderError::Request("XAI_API_KEY (or GROK_API_KEY) is not set".to_string())
69 })?;
70
71 Self::new(GrokModelConfig::new(api_key, model))
72 }
73
74 fn endpoint(&self) -> String {
75 let base = self
76 .config
77 .api_base_url
78 .as_deref()
79 .unwrap_or(DEFAULT_API_BASE_URL)
80 .trim_end_matches('/');
81 format!("{base}/chat/completions")
82 }
83}
84
85#[async_trait]
86impl ChatModel for GrokModel {
87 async fn invoke(
88 &self,
89 messages: &[ModelMessage],
90 tools: &[ModelToolDefinition],
91 tool_choice: ModelToolChoice,
92 ) -> Result<ModelCompletion, ProviderError> {
93 let request = build_request(messages, tools, tool_choice, &self.config);
94
95 let response = self
96 .client
97 .post(self.endpoint())
98 .header("authorization", format!("Bearer {}", self.config.api_key))
99 .header("content-type", "application/json")
100 .json(&request)
101 .send()
102 .await
103 .map_err(|err| ProviderError::Request(err.to_string()))?;
104
105 if !response.status().is_success() {
106 return Err(ProviderError::Request(extract_api_error(response).await));
107 }
108
109 let payload = response
110 .json::<GrokChatCompletionResponse>()
111 .await
112 .map_err(|err| ProviderError::Response(err.to_string()))?;
113
114 normalize_response(payload)
115 }
116}
117
118#[derive(Debug, Serialize)]
119struct GrokChatCompletionRequest {
120 model: String,
121 messages: Vec<GrokRequestMessage>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 tools: Option<Vec<GrokToolDefinition>>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 tool_choice: Option<GrokToolChoicePayload>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 temperature: Option<f32>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 top_p: Option<f32>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 max_tokens: Option<u32>,
132}
133
134#[derive(Debug, Serialize)]
135#[serde(tag = "role", rename_all = "lowercase")]
136enum GrokRequestMessage {
137 System {
138 content: String,
139 },
140 User {
141 content: String,
142 },
143 Assistant {
144 #[serde(skip_serializing_if = "Option::is_none")]
145 content: Option<String>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 tool_calls: Option<Vec<GrokToolCall>>,
148 },
149 Tool {
150 tool_call_id: String,
151 content: String,
152 },
153}
154
155#[derive(Debug, Serialize)]
156struct GrokToolDefinition {
157 #[serde(rename = "type")]
158 type_: String,
159 function: GrokToolFunctionDefinition,
160}
161
162#[derive(Debug, Serialize)]
163struct GrokToolFunctionDefinition {
164 name: String,
165 description: String,
166 parameters: Value,
167}
168
169#[derive(Debug, Serialize)]
170#[serde(untagged)]
171enum GrokToolChoicePayload {
172 Mode(String),
173 Specific {
174 #[serde(rename = "type")]
175 type_: String,
176 function: GrokToolChoiceFunction,
177 },
178}
179
180#[derive(Debug, Serialize)]
181struct GrokToolChoiceFunction {
182 name: String,
183}
184
185#[derive(Debug, Serialize, Deserialize, Clone)]
186struct GrokToolCall {
187 id: String,
188 #[serde(rename = "type")]
189 type_: String,
190 function: GrokToolCallFunction,
191}
192
193#[derive(Debug, Serialize, Deserialize, Clone)]
194struct GrokToolCallFunction {
195 name: String,
196 arguments: String,
197}
198
199#[derive(Debug, Deserialize)]
200struct GrokChatCompletionResponse {
201 #[serde(default)]
202 choices: Vec<GrokChoice>,
203 usage: Option<GrokUsage>,
204}
205
206#[derive(Debug, Deserialize)]
207struct GrokChoice {
208 message: Option<GrokAssistantMessage>,
209}
210
211#[derive(Debug, Deserialize)]
212struct GrokAssistantMessage {
213 content: Option<String>,
214 #[serde(default)]
215 tool_calls: Vec<GrokToolCall>,
216 #[serde(default)]
217 reasoning_content: Option<String>,
218}
219
220#[derive(Debug, Deserialize)]
221struct GrokUsage {
222 prompt_tokens: Option<u32>,
223 completion_tokens: Option<u32>,
224 reasoning_tokens: Option<u32>,
225 completion_tokens_details: Option<GrokCompletionTokenDetails>,
226}
227
228#[derive(Debug, Deserialize)]
229struct GrokCompletionTokenDetails {
230 reasoning_tokens: Option<u32>,
231}
232
233#[derive(Debug, Deserialize)]
234struct GrokErrorEnvelope {
235 error: GrokApiError,
236}
237
238#[derive(Debug, Deserialize)]
239struct GrokApiError {
240 message: Option<String>,
241 #[serde(rename = "type")]
242 type_: Option<String>,
243 code: Option<Value>,
244}
245
246fn build_request(
247 messages: &[ModelMessage],
248 tools: &[ModelToolDefinition],
249 tool_choice: ModelToolChoice,
250 config: &GrokModelConfig,
251) -> GrokChatCompletionRequest {
252 let request_messages = ensure_non_empty_messages(to_grok_messages(messages));
253
254 let tools_payload = if tools.is_empty() {
255 None
256 } else {
257 Some(
258 tools
259 .iter()
260 .map(|tool| GrokToolDefinition {
261 type_: "function".to_string(),
262 function: GrokToolFunctionDefinition {
263 name: tool.name.clone(),
264 description: tool.description.clone(),
265 parameters: tool.parameters.clone(),
266 },
267 })
268 .collect::<Vec<_>>(),
269 )
270 };
271
272 let tool_choice_payload = if tools.is_empty() {
273 None
274 } else {
275 Some(match tool_choice {
276 ModelToolChoice::Auto => GrokToolChoicePayload::Mode("auto".to_string()),
277 ModelToolChoice::Required => GrokToolChoicePayload::Mode("required".to_string()),
278 ModelToolChoice::None => GrokToolChoicePayload::Mode("none".to_string()),
279 ModelToolChoice::Tool(name) => GrokToolChoicePayload::Specific {
280 type_: "function".to_string(),
281 function: GrokToolChoiceFunction { name },
282 },
283 })
284 };
285
286 GrokChatCompletionRequest {
287 model: config.model.clone(),
288 messages: request_messages,
289 tools: tools_payload,
290 tool_choice: tool_choice_payload,
291 temperature: config.temperature,
292 top_p: config.top_p,
293 max_tokens: config.max_tokens,
294 }
295}
296
297fn to_grok_messages(messages: &[ModelMessage]) -> Vec<GrokRequestMessage> {
298 let mut request_messages = Vec::new();
299
300 for message in messages {
301 match message {
302 ModelMessage::System(content) => {
303 if content.is_empty() {
304 continue;
305 }
306 request_messages.push(GrokRequestMessage::System {
307 content: content.clone(),
308 });
309 }
310 ModelMessage::User(content) => {
311 if content.is_empty() {
312 continue;
313 }
314 request_messages.push(GrokRequestMessage::User {
315 content: content.clone(),
316 });
317 }
318 ModelMessage::Assistant {
319 content,
320 tool_calls,
321 } => {
322 let serialized_tool_calls = tool_calls
323 .iter()
324 .map(|tool_call| GrokToolCall {
325 id: tool_call.id.clone(),
326 type_: "function".to_string(),
327 function: GrokToolCallFunction {
328 name: tool_call.name.clone(),
329 arguments: tool_call.arguments.to_string(),
330 },
331 })
332 .collect::<Vec<_>>();
333
334 let assistant_content = content.as_ref().filter(|text| !text.is_empty()).cloned();
335 if assistant_content.is_none() && serialized_tool_calls.is_empty() {
336 continue;
337 }
338
339 request_messages.push(GrokRequestMessage::Assistant {
340 content: assistant_content,
341 tool_calls: if serialized_tool_calls.is_empty() {
342 None
343 } else {
344 Some(serialized_tool_calls)
345 },
346 });
347 }
348 ModelMessage::ToolResult {
349 tool_call_id,
350 tool_name: _,
351 content,
352 is_error,
353 } => {
354 let rendered = if *is_error {
355 format!("Error: {content}")
356 } else {
357 content.clone()
358 };
359
360 request_messages.push(GrokRequestMessage::Tool {
361 tool_call_id: tool_call_id.clone(),
362 content: rendered,
363 });
364 }
365 }
366 }
367
368 request_messages
369}
370
371fn ensure_non_empty_messages(mut messages: Vec<GrokRequestMessage>) -> Vec<GrokRequestMessage> {
372 let mut normalized = Vec::with_capacity(messages.len().saturating_add(1));
373 let mut pending_tool_call_ids = Vec::<String>::new();
374
375 for message in messages.drain(..) {
376 match message {
377 GrokRequestMessage::System { content } => {
378 pending_tool_call_ids.clear();
379 normalized.push(GrokRequestMessage::System { content });
380 }
381 GrokRequestMessage::User { content } => {
382 pending_tool_call_ids.clear();
383 normalized.push(GrokRequestMessage::User { content });
384 }
385 GrokRequestMessage::Assistant {
386 content,
387 tool_calls,
388 } => {
389 pending_tool_call_ids.clear();
390 if let Some(calls) = &tool_calls {
391 pending_tool_call_ids.extend(calls.iter().map(|call| call.id.clone()));
392 }
393 normalized.push(GrokRequestMessage::Assistant {
394 content,
395 tool_calls,
396 });
397 }
398 GrokRequestMessage::Tool {
399 tool_call_id,
400 content,
401 } => {
402 if let Some(position) = pending_tool_call_ids
403 .iter()
404 .position(|id| id == &tool_call_id)
405 {
406 pending_tool_call_ids.remove(position);
407 normalized.push(GrokRequestMessage::Tool {
408 tool_call_id,
409 content,
410 });
411 }
412 }
413 }
414 }
415
416 if normalized.is_empty() {
417 normalized.push(GrokRequestMessage::User {
418 content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
419 });
420 return normalized;
421 }
422
423 let starts_with_valid_role = matches!(
424 normalized.first(),
425 Some(GrokRequestMessage::System { .. } | GrokRequestMessage::User { .. })
426 );
427 if !starts_with_valid_role {
428 normalized.insert(
429 0,
430 GrokRequestMessage::User {
431 content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
432 },
433 );
434 }
435
436 normalized
437}
438
439fn normalize_response(
440 response: GrokChatCompletionResponse,
441) -> Result<ModelCompletion, ProviderError> {
442 let choice = response
443 .choices
444 .into_iter()
445 .next()
446 .ok_or_else(|| ProviderError::Response("grok response missing choices".to_string()))?;
447
448 let message = choice.message.ok_or_else(|| {
449 ProviderError::Response("grok response missing choice message".to_string())
450 })?;
451
452 let mut tool_calls = Vec::new();
453 for tool_call in message.tool_calls {
454 let arguments = if tool_call.function.arguments.trim().is_empty() {
455 json!({})
456 } else {
457 serde_json::from_str::<Value>(&tool_call.function.arguments).map_err(|err| {
458 ProviderError::Response(format!(
459 "grok tool call arguments for '{}' are not valid JSON: {err}",
460 tool_call.function.name
461 ))
462 })?
463 };
464
465 tool_calls.push(ModelToolCall {
466 id: tool_call.id,
467 name: tool_call.function.name,
468 arguments,
469 });
470 }
471
472 let usage = response.usage.map(|usage| ModelUsage {
473 input_tokens: usage.prompt_tokens.unwrap_or(0),
476 output_tokens: usage.completion_tokens.unwrap_or(0).saturating_add(
477 usage.reasoning_tokens.unwrap_or_else(|| {
478 usage
479 .completion_tokens_details
480 .and_then(|details| details.reasoning_tokens)
481 .unwrap_or(0)
482 }),
483 ),
484 });
485
486 Ok(ModelCompletion {
487 text: message.content.filter(|text| !text.is_empty()),
488 thinking: message.reasoning_content.filter(|text| !text.is_empty()),
489 tool_calls,
490 usage,
491 })
492}
493
494async fn extract_api_error(response: reqwest::Response) -> String {
495 let status = response.status();
496 let body = response.text().await.unwrap_or_default();
497
498 if let Ok(parsed) = serde_json::from_str::<GrokErrorEnvelope>(&body) {
499 let code = parsed
500 .error
501 .code
502 .map(|value| match value {
503 Value::String(value) => value,
504 other => other.to_string(),
505 })
506 .unwrap_or_else(|| status.as_u16().to_string());
507 let error_type = parsed
508 .error
509 .type_
510 .unwrap_or_else(|| status.to_string().to_uppercase());
511 let message = parsed
512 .error
513 .message
514 .unwrap_or_else(|| "unknown xai api error".to_string());
515
516 return format!("xai api error {code} {error_type}: {message}");
517 }
518
519 if body.is_empty() {
520 format!("xai api request failed ({status})")
521 } else {
522 format!("xai api request failed ({status}): {body}")
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use serde_json::json;
529
530 use super::*;
531
532 fn tool_definition() -> ModelToolDefinition {
533 ModelToolDefinition {
534 name: "lookup".to_string(),
535 description: "Look up something".to_string(),
536 parameters: json!({
537 "type": "object",
538 "properties": {
539 "query": {"type": "string"}
540 },
541 "required": ["query"],
542 "additionalProperties": false
543 }),
544 }
545 }
546
547 #[test]
548 fn build_request_serializes_messages_tools_and_tool_choice() {
549 let messages = vec![
550 ModelMessage::System("You are helpful".to_string()),
551 ModelMessage::User("Find docs".to_string()),
552 ModelMessage::Assistant {
553 content: Some("Calling tool".to_string()),
554 tool_calls: vec![ModelToolCall {
555 id: "call_1".to_string(),
556 name: "lookup".to_string(),
557 arguments: json!({"query": "rust"}),
558 }],
559 },
560 ModelMessage::ToolResult {
561 tool_call_id: "call_1".to_string(),
562 tool_name: "lookup".to_string(),
563 content: "{\"result\":\"ok\"}".to_string(),
564 is_error: false,
565 },
566 ];
567
568 let mut config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
569 config.temperature = Some(0.2);
570 config.max_tokens = Some(512);
571
572 let request = build_request(
573 &messages,
574 &[tool_definition()],
575 ModelToolChoice::Tool("lookup".to_string()),
576 &config,
577 );
578 let value = serde_json::to_value(request).expect("serializes");
579
580 assert_eq!(value["messages"][0]["role"], "system");
581 assert_eq!(value["messages"][0]["content"], "You are helpful");
582 assert_eq!(value["messages"][2]["role"], "assistant");
583 assert_eq!(
584 value["messages"][2]["tool_calls"][0]["function"]["name"],
585 "lookup"
586 );
587 assert_eq!(
588 value["messages"][2]["tool_calls"][0]["function"]["arguments"],
589 "{\"query\":\"rust\"}"
590 );
591 assert_eq!(value["messages"][3]["role"], "tool");
592 assert_eq!(value["messages"][3]["tool_call_id"], "call_1");
593 assert_eq!(value["tools"][0]["function"]["name"], "lookup");
594 assert_eq!(value["tool_choice"]["type"], "function");
595 assert_eq!(value["tool_choice"]["function"]["name"], "lookup");
596 assert!((value["temperature"].as_f64().unwrap_or_default() - 0.2).abs() < 1e-6);
597 assert_eq!(value["max_tokens"], 512);
598 }
599
600 #[test]
601 fn build_request_adds_fallback_content_for_empty_user_message() {
602 let messages = vec![ModelMessage::User(String::new())];
603 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
604
605 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
606 let value = serde_json::to_value(request).expect("serializes");
607
608 assert_eq!(
609 value["messages"].as_array().map(|values| values.len()),
610 Some(1)
611 );
612 assert_eq!(value["messages"][0]["role"], "user");
613 assert_eq!(value["messages"][0]["content"], " ");
614 assert!(value.get("tools").is_none());
615 assert!(value.get("tool_choice").is_none());
616 }
617
618 #[test]
619 fn build_request_inserts_fallback_and_drops_orphan_tool_messages() {
620 let messages = vec![ModelMessage::ToolResult {
621 tool_call_id: "call_1".to_string(),
622 tool_name: "lookup".to_string(),
623 content: "result".to_string(),
624 is_error: false,
625 }];
626 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
627
628 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
629 let value = serde_json::to_value(request).expect("serializes");
630
631 assert_eq!(
632 value["messages"].as_array().map(|values| values.len()),
633 Some(1)
634 );
635 assert_eq!(value["messages"][0]["role"], "user");
636 assert_eq!(value["messages"][0]["content"], " ");
637 }
638
639 #[test]
640 fn build_request_inserts_fallback_when_first_message_is_assistant() {
641 let messages = vec![
642 ModelMessage::User(String::new()),
643 ModelMessage::Assistant {
644 content: Some("Calling tool".to_string()),
645 tool_calls: vec![ModelToolCall {
646 id: "call_1".to_string(),
647 name: "lookup".to_string(),
648 arguments: json!({"query": "rust"}),
649 }],
650 },
651 ModelMessage::ToolResult {
652 tool_call_id: "call_1".to_string(),
653 tool_name: "lookup".to_string(),
654 content: "{\"result\":\"ok\"}".to_string(),
655 is_error: false,
656 },
657 ];
658 let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
659
660 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
661 let value = serde_json::to_value(request).expect("serializes");
662
663 assert_eq!(value["messages"][0]["role"], "user");
664 assert_eq!(value["messages"][0]["content"], " ");
665 assert_eq!(value["messages"][1]["role"], "assistant");
666 assert_eq!(value["messages"][2]["role"], "tool");
667 assert_eq!(value["messages"][2]["tool_call_id"], "call_1");
668 }
669
670 #[test]
671 fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
672 let response = GrokChatCompletionResponse {
673 choices: vec![GrokChoice {
674 message: Some(GrokAssistantMessage {
675 content: Some("answer".to_string()),
676 tool_calls: vec![GrokToolCall {
677 id: "call_x".to_string(),
678 type_: "function".to_string(),
679 function: GrokToolCallFunction {
680 name: "lookup".to_string(),
681 arguments: "{\"q\":\"rust\"}".to_string(),
682 },
683 }],
684 reasoning_content: Some("reasoning".to_string()),
685 }),
686 }],
687 usage: Some(GrokUsage {
688 prompt_tokens: Some(11),
689 completion_tokens: Some(7),
690 reasoning_tokens: None,
691 completion_tokens_details: Some(GrokCompletionTokenDetails {
692 reasoning_tokens: Some(3),
693 }),
694 }),
695 };
696
697 let completion = normalize_response(response).expect("response normalizes");
698
699 assert_eq!(completion.text.as_deref(), Some("answer"));
700 assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
701 assert_eq!(completion.tool_calls.len(), 1);
702 assert_eq!(completion.tool_calls[0].name, "lookup");
703 assert_eq!(completion.tool_calls[0].id, "call_x");
704 assert_eq!(
705 completion.usage,
706 Some(ModelUsage {
707 input_tokens: 11,
708 output_tokens: 10,
709 })
710 );
711 }
712
713 #[test]
714 fn normalize_response_prefers_top_level_reasoning_tokens() {
715 let response = GrokChatCompletionResponse {
716 choices: vec![GrokChoice {
717 message: Some(GrokAssistantMessage {
718 content: Some("answer".to_string()),
719 tool_calls: Vec::new(),
720 reasoning_content: None,
721 }),
722 }],
723 usage: Some(GrokUsage {
724 prompt_tokens: Some(11),
725 completion_tokens: Some(7),
726 reasoning_tokens: Some(4),
727 completion_tokens_details: Some(GrokCompletionTokenDetails {
728 reasoning_tokens: Some(3),
729 }),
730 }),
731 };
732
733 let completion = normalize_response(response).expect("response normalizes");
734
735 assert_eq!(
736 completion.usage,
737 Some(ModelUsage {
738 input_tokens: 11,
739 output_tokens: 11,
740 })
741 );
742 }
743
744 #[test]
745 fn normalize_response_requires_choices() {
746 let err = normalize_response(GrokChatCompletionResponse {
747 choices: Vec::new(),
748 usage: None,
749 })
750 .expect_err("should fail");
751
752 match err {
753 ProviderError::Response(message) => {
754 assert!(message.contains("missing choices"));
755 }
756 other => panic!("unexpected error: {other}"),
757 }
758 }
759
760 #[test]
761 fn normalize_response_fails_on_invalid_tool_arguments() {
762 let err = normalize_response(GrokChatCompletionResponse {
763 choices: vec![GrokChoice {
764 message: Some(GrokAssistantMessage {
765 content: None,
766 tool_calls: vec![GrokToolCall {
767 id: "call_x".to_string(),
768 type_: "function".to_string(),
769 function: GrokToolCallFunction {
770 name: "lookup".to_string(),
771 arguments: "{not json}".to_string(),
772 },
773 }],
774 reasoning_content: None,
775 }),
776 }],
777 usage: None,
778 })
779 .expect_err("should fail");
780
781 match err {
782 ProviderError::Response(message) => {
783 assert!(message.contains("not valid JSON"));
784 }
785 other => panic!("unexpected error: {other}"),
786 }
787 }
788}