Skip to main content

agent_sdk_rs/llm/
grok.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::error::ProviderError;
7use crate::llm::{
8    ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
9    ModelUsage,
10};
11
12const DEFAULT_API_BASE_URL: &str = "https://api.x.ai/v1";
13const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
14
15#[derive(Debug, Clone)]
16pub struct GrokModelConfig {
17    pub api_key: String,
18    pub model: String,
19    pub api_base_url: Option<String>,
20    pub temperature: Option<f32>,
21    pub top_p: Option<f32>,
22    pub max_tokens: Option<u32>,
23}
24
25impl GrokModelConfig {
26    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
27        Self {
28            api_key: api_key.into(),
29            model: model.into(),
30            api_base_url: None,
31            temperature: None,
32            top_p: None,
33            max_tokens: Some(4096),
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct GrokModel {
40    client: Client,
41    config: GrokModelConfig,
42}
43
44impl GrokModel {
45    pub fn new(config: GrokModelConfig) -> Result<Self, ProviderError> {
46        let client = Client::builder()
47            .build()
48            .map_err(|err| ProviderError::Request(err.to_string()))?;
49
50        Ok(Self { client, config })
51    }
52
53    pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
54        let api_key = std::env::var("XAI_API_KEY")
55            .or_else(|_| std::env::var("GROK_API_KEY"))
56            .map_err(|_| {
57                ProviderError::Request("XAI_API_KEY (or GROK_API_KEY) is not set".to_string())
58            })?;
59
60        Self::new(GrokModelConfig::new(api_key, model))
61    }
62
63    fn endpoint(&self) -> String {
64        let base = self
65            .config
66            .api_base_url
67            .as_deref()
68            .unwrap_or(DEFAULT_API_BASE_URL)
69            .trim_end_matches('/');
70        format!("{base}/chat/completions")
71    }
72}
73
74#[async_trait]
75impl ChatModel for GrokModel {
76    async fn invoke(
77        &self,
78        messages: &[ModelMessage],
79        tools: &[ModelToolDefinition],
80        tool_choice: ModelToolChoice,
81    ) -> Result<ModelCompletion, ProviderError> {
82        let request = build_request(messages, tools, tool_choice, &self.config);
83
84        let response = self
85            .client
86            .post(self.endpoint())
87            .header("authorization", format!("Bearer {}", self.config.api_key))
88            .header("content-type", "application/json")
89            .json(&request)
90            .send()
91            .await
92            .map_err(|err| ProviderError::Request(err.to_string()))?;
93
94        if !response.status().is_success() {
95            return Err(ProviderError::Request(extract_api_error(response).await));
96        }
97
98        let payload = response
99            .json::<GrokChatCompletionResponse>()
100            .await
101            .map_err(|err| ProviderError::Response(err.to_string()))?;
102
103        normalize_response(payload)
104    }
105}
106
107#[derive(Debug, Serialize)]
108struct GrokChatCompletionRequest {
109    model: String,
110    messages: Vec<GrokRequestMessage>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    tools: Option<Vec<GrokToolDefinition>>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    tool_choice: Option<GrokToolChoicePayload>,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    temperature: Option<f32>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    top_p: Option<f32>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    max_tokens: Option<u32>,
121}
122
123#[derive(Debug, Serialize)]
124#[serde(tag = "role", rename_all = "lowercase")]
125enum GrokRequestMessage {
126    System {
127        content: String,
128    },
129    User {
130        content: String,
131    },
132    Assistant {
133        #[serde(skip_serializing_if = "Option::is_none")]
134        content: Option<String>,
135        #[serde(skip_serializing_if = "Option::is_none")]
136        tool_calls: Option<Vec<GrokToolCall>>,
137    },
138    Tool {
139        tool_call_id: String,
140        content: String,
141    },
142}
143
144#[derive(Debug, Serialize)]
145struct GrokToolDefinition {
146    #[serde(rename = "type")]
147    type_: String,
148    function: GrokToolFunctionDefinition,
149}
150
151#[derive(Debug, Serialize)]
152struct GrokToolFunctionDefinition {
153    name: String,
154    description: String,
155    parameters: Value,
156}
157
158#[derive(Debug, Serialize)]
159#[serde(untagged)]
160enum GrokToolChoicePayload {
161    Mode(String),
162    Specific {
163        #[serde(rename = "type")]
164        type_: String,
165        function: GrokToolChoiceFunction,
166    },
167}
168
169#[derive(Debug, Serialize)]
170struct GrokToolChoiceFunction {
171    name: String,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175struct GrokToolCall {
176    id: String,
177    #[serde(rename = "type")]
178    type_: String,
179    function: GrokToolCallFunction,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183struct GrokToolCallFunction {
184    name: String,
185    arguments: String,
186}
187
188#[derive(Debug, Deserialize)]
189struct GrokChatCompletionResponse {
190    #[serde(default)]
191    choices: Vec<GrokChoice>,
192    usage: Option<GrokUsage>,
193}
194
195#[derive(Debug, Deserialize)]
196struct GrokChoice {
197    message: Option<GrokAssistantMessage>,
198}
199
200#[derive(Debug, Deserialize)]
201struct GrokAssistantMessage {
202    content: Option<String>,
203    #[serde(default)]
204    tool_calls: Vec<GrokToolCall>,
205    #[serde(default)]
206    reasoning_content: Option<String>,
207}
208
209#[derive(Debug, Deserialize)]
210struct GrokUsage {
211    prompt_tokens: Option<u32>,
212    completion_tokens: Option<u32>,
213    reasoning_tokens: Option<u32>,
214    completion_tokens_details: Option<GrokCompletionTokenDetails>,
215}
216
217#[derive(Debug, Deserialize)]
218struct GrokCompletionTokenDetails {
219    reasoning_tokens: Option<u32>,
220}
221
222#[derive(Debug, Deserialize)]
223struct GrokErrorEnvelope {
224    error: GrokApiError,
225}
226
227#[derive(Debug, Deserialize)]
228struct GrokApiError {
229    message: Option<String>,
230    #[serde(rename = "type")]
231    type_: Option<String>,
232    code: Option<Value>,
233}
234
235fn build_request(
236    messages: &[ModelMessage],
237    tools: &[ModelToolDefinition],
238    tool_choice: ModelToolChoice,
239    config: &GrokModelConfig,
240) -> GrokChatCompletionRequest {
241    let request_messages = ensure_non_empty_messages(to_grok_messages(messages));
242
243    let tools_payload = if tools.is_empty() {
244        None
245    } else {
246        Some(
247            tools
248                .iter()
249                .map(|tool| GrokToolDefinition {
250                    type_: "function".to_string(),
251                    function: GrokToolFunctionDefinition {
252                        name: tool.name.clone(),
253                        description: tool.description.clone(),
254                        parameters: tool.parameters.clone(),
255                    },
256                })
257                .collect::<Vec<_>>(),
258        )
259    };
260
261    let tool_choice_payload = if tools.is_empty() {
262        None
263    } else {
264        Some(match tool_choice {
265            ModelToolChoice::Auto => GrokToolChoicePayload::Mode("auto".to_string()),
266            ModelToolChoice::Required => GrokToolChoicePayload::Mode("required".to_string()),
267            ModelToolChoice::None => GrokToolChoicePayload::Mode("none".to_string()),
268            ModelToolChoice::Tool(name) => GrokToolChoicePayload::Specific {
269                type_: "function".to_string(),
270                function: GrokToolChoiceFunction { name },
271            },
272        })
273    };
274
275    GrokChatCompletionRequest {
276        model: config.model.clone(),
277        messages: request_messages,
278        tools: tools_payload,
279        tool_choice: tool_choice_payload,
280        temperature: config.temperature,
281        top_p: config.top_p,
282        max_tokens: config.max_tokens,
283    }
284}
285
286fn to_grok_messages(messages: &[ModelMessage]) -> Vec<GrokRequestMessage> {
287    let mut request_messages = Vec::new();
288
289    for message in messages {
290        match message {
291            ModelMessage::System(content) => {
292                if content.is_empty() {
293                    continue;
294                }
295                request_messages.push(GrokRequestMessage::System {
296                    content: content.clone(),
297                });
298            }
299            ModelMessage::User(content) => {
300                if content.is_empty() {
301                    continue;
302                }
303                request_messages.push(GrokRequestMessage::User {
304                    content: content.clone(),
305                });
306            }
307            ModelMessage::Assistant {
308                content,
309                tool_calls,
310            } => {
311                let serialized_tool_calls = tool_calls
312                    .iter()
313                    .map(|tool_call| GrokToolCall {
314                        id: tool_call.id.clone(),
315                        type_: "function".to_string(),
316                        function: GrokToolCallFunction {
317                            name: tool_call.name.clone(),
318                            arguments: tool_call.arguments.to_string(),
319                        },
320                    })
321                    .collect::<Vec<_>>();
322
323                let assistant_content = content.as_ref().filter(|text| !text.is_empty()).cloned();
324                if assistant_content.is_none() && serialized_tool_calls.is_empty() {
325                    continue;
326                }
327
328                request_messages.push(GrokRequestMessage::Assistant {
329                    content: assistant_content,
330                    tool_calls: if serialized_tool_calls.is_empty() {
331                        None
332                    } else {
333                        Some(serialized_tool_calls)
334                    },
335                });
336            }
337            ModelMessage::ToolResult {
338                tool_call_id,
339                tool_name: _,
340                content,
341                is_error,
342            } => {
343                let rendered = if *is_error {
344                    format!("Error: {content}")
345                } else {
346                    content.clone()
347                };
348
349                request_messages.push(GrokRequestMessage::Tool {
350                    tool_call_id: tool_call_id.clone(),
351                    content: rendered,
352                });
353            }
354        }
355    }
356
357    request_messages
358}
359
360fn ensure_non_empty_messages(mut messages: Vec<GrokRequestMessage>) -> Vec<GrokRequestMessage> {
361    let mut normalized = Vec::with_capacity(messages.len().saturating_add(1));
362    let mut pending_tool_call_ids = Vec::<String>::new();
363
364    for message in messages.drain(..) {
365        match message {
366            GrokRequestMessage::System { content } => {
367                pending_tool_call_ids.clear();
368                normalized.push(GrokRequestMessage::System { content });
369            }
370            GrokRequestMessage::User { content } => {
371                pending_tool_call_ids.clear();
372                normalized.push(GrokRequestMessage::User { content });
373            }
374            GrokRequestMessage::Assistant {
375                content,
376                tool_calls,
377            } => {
378                pending_tool_call_ids.clear();
379                if let Some(calls) = &tool_calls {
380                    pending_tool_call_ids.extend(calls.iter().map(|call| call.id.clone()));
381                }
382                normalized.push(GrokRequestMessage::Assistant {
383                    content,
384                    tool_calls,
385                });
386            }
387            GrokRequestMessage::Tool {
388                tool_call_id,
389                content,
390            } => {
391                if let Some(position) = pending_tool_call_ids
392                    .iter()
393                    .position(|id| id == &tool_call_id)
394                {
395                    pending_tool_call_ids.remove(position);
396                    normalized.push(GrokRequestMessage::Tool {
397                        tool_call_id,
398                        content,
399                    });
400                }
401            }
402        }
403    }
404
405    if normalized.is_empty() {
406        normalized.push(GrokRequestMessage::User {
407            content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
408        });
409        return normalized;
410    }
411
412    let starts_with_valid_role = matches!(
413        normalized.first(),
414        Some(GrokRequestMessage::System { .. } | GrokRequestMessage::User { .. })
415    );
416    if !starts_with_valid_role {
417        normalized.insert(
418            0,
419            GrokRequestMessage::User {
420                content: EMPTY_USER_CONTENT_FALLBACK.to_string(),
421            },
422        );
423    }
424
425    normalized
426}
427
428fn normalize_response(
429    response: GrokChatCompletionResponse,
430) -> Result<ModelCompletion, ProviderError> {
431    let choice = response
432        .choices
433        .into_iter()
434        .next()
435        .ok_or_else(|| ProviderError::Response("grok response missing choices".to_string()))?;
436
437    let message = choice.message.ok_or_else(|| {
438        ProviderError::Response("grok response missing choice message".to_string())
439    })?;
440
441    let mut tool_calls = Vec::new();
442    for tool_call in message.tool_calls {
443        let arguments = if tool_call.function.arguments.trim().is_empty() {
444            json!({})
445        } else {
446            serde_json::from_str::<Value>(&tool_call.function.arguments).map_err(|err| {
447                ProviderError::Response(format!(
448                    "grok tool call arguments for '{}' are not valid JSON: {err}",
449                    tool_call.function.name
450                ))
451            })?
452        };
453
454        tool_calls.push(ModelToolCall {
455            id: tool_call.id,
456            name: tool_call.function.name,
457            arguments,
458        });
459    }
460
461    let usage = response.usage.map(|usage| ModelUsage {
462        // xAI may return reasoning tokens either top-level or nested in completion details.
463        // Prefer top-level when present to avoid undercounting and avoid double-counting.
464        input_tokens: usage.prompt_tokens.unwrap_or(0),
465        output_tokens: usage.completion_tokens.unwrap_or(0).saturating_add(
466            usage.reasoning_tokens.unwrap_or_else(|| {
467                usage
468                    .completion_tokens_details
469                    .and_then(|details| details.reasoning_tokens)
470                    .unwrap_or(0)
471            }),
472        ),
473    });
474
475    Ok(ModelCompletion {
476        text: message.content.filter(|text| !text.is_empty()),
477        thinking: message.reasoning_content.filter(|text| !text.is_empty()),
478        tool_calls,
479        usage,
480    })
481}
482
483async fn extract_api_error(response: reqwest::Response) -> String {
484    let status = response.status();
485    let body = response.text().await.unwrap_or_default();
486
487    if let Ok(parsed) = serde_json::from_str::<GrokErrorEnvelope>(&body) {
488        let code = parsed
489            .error
490            .code
491            .map(|value| match value {
492                Value::String(value) => value,
493                other => other.to_string(),
494            })
495            .unwrap_or_else(|| status.as_u16().to_string());
496        let error_type = parsed
497            .error
498            .type_
499            .unwrap_or_else(|| status.to_string().to_uppercase());
500        let message = parsed
501            .error
502            .message
503            .unwrap_or_else(|| "unknown xai api error".to_string());
504
505        return format!("xai api error {code} {error_type}: {message}");
506    }
507
508    if body.is_empty() {
509        format!("xai api request failed ({status})")
510    } else {
511        format!("xai api request failed ({status}): {body}")
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use serde_json::json;
518
519    use super::*;
520
521    fn tool_definition() -> ModelToolDefinition {
522        ModelToolDefinition {
523            name: "lookup".to_string(),
524            description: "Look up something".to_string(),
525            parameters: json!({
526                "type": "object",
527                "properties": {
528                    "query": {"type": "string"}
529                },
530                "required": ["query"],
531                "additionalProperties": false
532            }),
533        }
534    }
535
536    #[test]
537    fn build_request_serializes_messages_tools_and_tool_choice() {
538        let messages = vec![
539            ModelMessage::System("You are helpful".to_string()),
540            ModelMessage::User("Find docs".to_string()),
541            ModelMessage::Assistant {
542                content: Some("Calling tool".to_string()),
543                tool_calls: vec![ModelToolCall {
544                    id: "call_1".to_string(),
545                    name: "lookup".to_string(),
546                    arguments: json!({"query": "rust"}),
547                }],
548            },
549            ModelMessage::ToolResult {
550                tool_call_id: "call_1".to_string(),
551                tool_name: "lookup".to_string(),
552                content: "{\"result\":\"ok\"}".to_string(),
553                is_error: false,
554            },
555        ];
556
557        let mut config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
558        config.temperature = Some(0.2);
559        config.max_tokens = Some(512);
560
561        let request = build_request(
562            &messages,
563            &[tool_definition()],
564            ModelToolChoice::Tool("lookup".to_string()),
565            &config,
566        );
567        let value = serde_json::to_value(request).expect("serializes");
568
569        assert_eq!(value["messages"][0]["role"], "system");
570        assert_eq!(value["messages"][0]["content"], "You are helpful");
571        assert_eq!(value["messages"][2]["role"], "assistant");
572        assert_eq!(
573            value["messages"][2]["tool_calls"][0]["function"]["name"],
574            "lookup"
575        );
576        assert_eq!(
577            value["messages"][2]["tool_calls"][0]["function"]["arguments"],
578            "{\"query\":\"rust\"}"
579        );
580        assert_eq!(value["messages"][3]["role"], "tool");
581        assert_eq!(value["messages"][3]["tool_call_id"], "call_1");
582        assert_eq!(value["tools"][0]["function"]["name"], "lookup");
583        assert_eq!(value["tool_choice"]["type"], "function");
584        assert_eq!(value["tool_choice"]["function"]["name"], "lookup");
585        assert!((value["temperature"].as_f64().unwrap_or_default() - 0.2).abs() < 1e-6);
586        assert_eq!(value["max_tokens"], 512);
587    }
588
589    #[test]
590    fn build_request_adds_fallback_content_for_empty_user_message() {
591        let messages = vec![ModelMessage::User(String::new())];
592        let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
593
594        let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
595        let value = serde_json::to_value(request).expect("serializes");
596
597        assert_eq!(
598            value["messages"].as_array().map(|values| values.len()),
599            Some(1)
600        );
601        assert_eq!(value["messages"][0]["role"], "user");
602        assert_eq!(value["messages"][0]["content"], " ");
603        assert!(value.get("tools").is_none());
604        assert!(value.get("tool_choice").is_none());
605    }
606
607    #[test]
608    fn build_request_inserts_fallback_and_drops_orphan_tool_messages() {
609        let messages = vec![ModelMessage::ToolResult {
610            tool_call_id: "call_1".to_string(),
611            tool_name: "lookup".to_string(),
612            content: "result".to_string(),
613            is_error: false,
614        }];
615        let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
616
617        let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
618        let value = serde_json::to_value(request).expect("serializes");
619
620        assert_eq!(
621            value["messages"].as_array().map(|values| values.len()),
622            Some(1)
623        );
624        assert_eq!(value["messages"][0]["role"], "user");
625        assert_eq!(value["messages"][0]["content"], " ");
626    }
627
628    #[test]
629    fn build_request_inserts_fallback_when_first_message_is_assistant() {
630        let messages = vec![
631            ModelMessage::User(String::new()),
632            ModelMessage::Assistant {
633                content: Some("Calling tool".to_string()),
634                tool_calls: vec![ModelToolCall {
635                    id: "call_1".to_string(),
636                    name: "lookup".to_string(),
637                    arguments: json!({"query": "rust"}),
638                }],
639            },
640            ModelMessage::ToolResult {
641                tool_call_id: "call_1".to_string(),
642                tool_name: "lookup".to_string(),
643                content: "{\"result\":\"ok\"}".to_string(),
644                is_error: false,
645            },
646        ];
647        let config = GrokModelConfig::new("key", "grok-4-1-fast-reasoning");
648
649        let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
650        let value = serde_json::to_value(request).expect("serializes");
651
652        assert_eq!(value["messages"][0]["role"], "user");
653        assert_eq!(value["messages"][0]["content"], " ");
654        assert_eq!(value["messages"][1]["role"], "assistant");
655        assert_eq!(value["messages"][2]["role"], "tool");
656        assert_eq!(value["messages"][2]["tool_call_id"], "call_1");
657    }
658
659    #[test]
660    fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
661        let response = GrokChatCompletionResponse {
662            choices: vec![GrokChoice {
663                message: Some(GrokAssistantMessage {
664                    content: Some("answer".to_string()),
665                    tool_calls: vec![GrokToolCall {
666                        id: "call_x".to_string(),
667                        type_: "function".to_string(),
668                        function: GrokToolCallFunction {
669                            name: "lookup".to_string(),
670                            arguments: "{\"q\":\"rust\"}".to_string(),
671                        },
672                    }],
673                    reasoning_content: Some("reasoning".to_string()),
674                }),
675            }],
676            usage: Some(GrokUsage {
677                prompt_tokens: Some(11),
678                completion_tokens: Some(7),
679                reasoning_tokens: None,
680                completion_tokens_details: Some(GrokCompletionTokenDetails {
681                    reasoning_tokens: Some(3),
682                }),
683            }),
684        };
685
686        let completion = normalize_response(response).expect("response normalizes");
687
688        assert_eq!(completion.text.as_deref(), Some("answer"));
689        assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
690        assert_eq!(completion.tool_calls.len(), 1);
691        assert_eq!(completion.tool_calls[0].name, "lookup");
692        assert_eq!(completion.tool_calls[0].id, "call_x");
693        assert_eq!(
694            completion.usage,
695            Some(ModelUsage {
696                input_tokens: 11,
697                output_tokens: 10,
698            })
699        );
700    }
701
702    #[test]
703    fn normalize_response_prefers_top_level_reasoning_tokens() {
704        let response = GrokChatCompletionResponse {
705            choices: vec![GrokChoice {
706                message: Some(GrokAssistantMessage {
707                    content: Some("answer".to_string()),
708                    tool_calls: Vec::new(),
709                    reasoning_content: None,
710                }),
711            }],
712            usage: Some(GrokUsage {
713                prompt_tokens: Some(11),
714                completion_tokens: Some(7),
715                reasoning_tokens: Some(4),
716                completion_tokens_details: Some(GrokCompletionTokenDetails {
717                    reasoning_tokens: Some(3),
718                }),
719            }),
720        };
721
722        let completion = normalize_response(response).expect("response normalizes");
723
724        assert_eq!(
725            completion.usage,
726            Some(ModelUsage {
727                input_tokens: 11,
728                output_tokens: 11,
729            })
730        );
731    }
732
733    #[test]
734    fn normalize_response_requires_choices() {
735        let err = normalize_response(GrokChatCompletionResponse {
736            choices: Vec::new(),
737            usage: None,
738        })
739        .expect_err("should fail");
740
741        match err {
742            ProviderError::Response(message) => {
743                assert!(message.contains("missing choices"));
744            }
745            other => panic!("unexpected error: {other}"),
746        }
747    }
748
749    #[test]
750    fn normalize_response_fails_on_invalid_tool_arguments() {
751        let err = normalize_response(GrokChatCompletionResponse {
752            choices: vec![GrokChoice {
753                message: Some(GrokAssistantMessage {
754                    content: None,
755                    tool_calls: vec![GrokToolCall {
756                        id: "call_x".to_string(),
757                        type_: "function".to_string(),
758                        function: GrokToolCallFunction {
759                            name: "lookup".to_string(),
760                            arguments: "{not json}".to_string(),
761                        },
762                    }],
763                    reasoning_content: None,
764                }),
765            }],
766            usage: None,
767        })
768        .expect_err("should fail");
769
770        match err {
771            ProviderError::Response(message) => {
772                assert!(message.contains("not valid JSON"));
773            }
774            other => panic!("unexpected error: {other}"),
775        }
776    }
777}