Skip to main content

agent_tools_interface/core/
xai.rs

1/// xAI agentic handler — constructs proper /responses requests for Grok tools.
2///
3/// xAI's API is NOT a simple REST endpoint. It accepts:
4/// ```json
5/// POST /v1/responses
6/// { "model": "grok-3-mini", "tools": [{"type": "web_search"}], "input": "query" }
7/// ```
8/// The model internally calls tools and returns agentic results.
9use reqwest::Client;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::time::Duration;
13
14use crate::core::http::HttpError;
15use crate::core::keyring::Keyring;
16use crate::core::manifest::{Provider, Tool};
17
18const XAI_TIMEOUT_SECS: u64 = 90;
19
20/// Execute an xAI tool via the /responses agentic API.
21pub async fn execute_xai_tool(
22    provider: &Provider,
23    tool: &Tool,
24    args: &HashMap<String, Value>,
25    keyring: &Keyring,
26) -> Result<Value, HttpError> {
27    let key_name = provider
28        .auth_key_name
29        .as_deref()
30        .ok_or_else(|| HttpError::MissingKey("auth_key_name not set for xAI".into()))?;
31    let api_key = keyring
32        .get(key_name)
33        .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
34
35    let query = args
36        .get("query")
37        .and_then(|v| v.as_str())
38        .unwrap_or("latest news");
39
40    let model = args
41        .get("model")
42        .and_then(|v| v.as_str())
43        .unwrap_or("grok-4-fast-non-reasoning");
44
45    // Map tool name to xAI tool types
46    let xai_tools = map_tool_types(&tool.name);
47
48    // Build input prompt — for trending, prefix with "trending" context
49    let input = if tool.name == "xai_trending_search" {
50        format!("What are the trending topics and discussions about: {query}")
51    } else {
52        query.to_string()
53    };
54
55    let request_body = serde_json::json!({
56        "model": model,
57        "tools": xai_tools,
58        "input": input,
59    });
60
61    let url = format!("{}/responses", provider.base_url.trim_end_matches('/'));
62
63    let client = Client::builder()
64        .timeout(Duration::from_secs(XAI_TIMEOUT_SECS))
65        .build()?;
66
67    let response = client
68        .post(&url)
69        .bearer_auth(api_key)
70        .json(&request_body)
71        .send()
72        .await?;
73
74    let status = response.status();
75    if !status.is_success() {
76        let body = response.text().await.unwrap_or_else(|_| "empty".into());
77        return Err(HttpError::ApiError {
78            status: status.as_u16(),
79            body,
80        });
81    }
82
83    let body: Value = response
84        .json()
85        .await
86        .map_err(|e| HttpError::ParseError(format!("failed to parse xAI response: {e}")))?;
87
88    // Extract useful content from the agentic response.
89    // The response has an "output" array with items of different types.
90    // We extract text content and search result URLs.
91    Ok(extract_xai_results(&body))
92}
93
94/// Map ATI tool name to xAI tool type objects.
95fn map_tool_types(tool_name: &str) -> Vec<Value> {
96    match tool_name {
97        "xai_web_search" => vec![serde_json::json!({"type": "web_search"})],
98        "xai_x_search" | "xai_trending_search" => {
99            vec![serde_json::json!({"type": "x_search"})]
100        }
101        "xai_combined_search" => vec![
102            serde_json::json!({"type": "web_search"}),
103            serde_json::json!({"type": "x_search"}),
104        ],
105        _ => vec![serde_json::json!({"type": "web_search"})],
106    }
107}
108
109/// Extract meaningful results from xAI's agentic response format.
110///
111/// The output array contains items like:
112/// - `{"type": "message", "content": [{"type": "output_text", "text": "..."}]}`
113/// - `{"type": "web_search_call", "action": {"query": "...", "sources": [...]}}}`
114fn extract_xai_results(body: &Value) -> Value {
115    let output = match body.get("output").and_then(|o| o.as_array()) {
116        Some(arr) => arr,
117        None => return body.clone(),
118    };
119
120    let mut text_content = Vec::new();
121    let mut search_queries = Vec::new();
122    let mut annotations = Vec::new();
123
124    for item in output {
125        let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("");
126
127        if item_type == "message" {
128            if let Some(content) = item.get("content").and_then(|c| c.as_array()) {
129                for block in content {
130                    let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
131                    if block_type == "output_text" || block_type == "text" {
132                        if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
133                            text_content.push(text.to_string());
134                        }
135                    }
136                    // Collect URL citations from annotations
137                    if let Some(annots) = block.get("annotations").and_then(|a| a.as_array()) {
138                        for ann in annots {
139                            if ann.get("type").and_then(|t| t.as_str()) == Some("url_citation") {
140                                annotations.push(ann.clone());
141                            }
142                        }
143                    }
144                }
145            }
146        } else if item_type.ends_with("_call") {
147            // web_search_call, x_search_call, etc.
148            if let Some(action) = item.get("action") {
149                if let Some(query) = action.get("query").and_then(|q| q.as_str()) {
150                    search_queries.push(query.to_string());
151                }
152            }
153        }
154    }
155
156    serde_json::json!({
157        "text": text_content.join("\n\n"),
158        "citations": annotations,
159        "search_queries": search_queries,
160        "raw_output_count": output.len(),
161    })
162}