openai_rust2/
chat.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5    JsonObject,
6    Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11    pub quality: Option<String>,       // e.g., "standard", "hd"
12    pub size: Option<String>,          // e.g., "1024x1024"
13    pub output_format: Option<String>, // e.g., "base64", "url"
14}
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18    pub model: String,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub temperature: Option<f32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub top_p: Option<f32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub n: Option<u32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub max_tokens: Option<u32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub presence_penalty: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub frequency_penalty: Option<f32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub user: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub response_format: Option<ResponseFormat>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub image_generation: Option<ImageGeneration>,
42    /// xAI Agent Tools API - server-side tools for agentic capabilities.
43    /// Includes: web_search, x_search, code_execution, collections_search, mcp.
44    /// See: https://docs.x.ai/docs/guides/tools/overview
45    #[serde(skip_serializing_if = "Option::is_none", rename = "tools")]
46    pub grok_tools: Option<Vec<GrokTool>>,
47}
48
49impl ChatArguments {
50    pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
51        ChatArguments {
52            model: model.as_ref().to_owned(),
53            messages,
54            temperature: None,
55            top_p: None,
56            n: None,
57            stream: None,
58            stop: None,
59            max_tokens: None,
60            presence_penalty: None,
61            frequency_penalty: None,
62            user: None,
63            response_format: None,
64            image_generation: None,
65            grok_tools: None,
66        }
67    }
68
69    /// Add xAI server-side tools for agentic capabilities.
70    /// Recommended model: `grok-4-1-fast` for best tool-calling performance.
71    pub fn with_grok_tools(mut self, tools: Vec<GrokTool>) -> Self {
72        self.grok_tools = Some(tools);
73        self
74    }
75}
76
77#[derive(Deserialize, Debug, Clone)]
78pub struct ChatCompletion {
79    #[serde(default)]
80    pub id: Option<String>,
81    pub created: u32,
82    #[serde(default)]
83    pub model: Option<String>,
84    #[serde(default)]
85    pub object: Option<String>,
86    pub choices: Vec<Choice>,
87    pub usage: Usage,
88}
89
90impl std::fmt::Display for ChatCompletion {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(f, "{}", &self.choices[0].message.content)?;
93        Ok(())
94    }
95}
96
97pub mod stream {
98    use bytes::Bytes;
99    use futures_util::Stream;
100    use serde::Deserialize;
101    use std::pin::Pin;
102    use std::str;
103    use std::task::Poll;
104
105    #[derive(Deserialize, Debug, Clone)]
106    pub struct ChatCompletionChunk {
107        pub id: String,
108        pub created: u32,
109        pub model: String,
110        pub choices: Vec<Choice>,
111        pub system_fingerprint: Option<String>,
112    }
113
114    impl std::fmt::Display for ChatCompletionChunk {
115        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116            write!(
117                f,
118                "{}",
119                self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
120            )?;
121            Ok(())
122        }
123    }
124
125    #[derive(Deserialize, Debug, Clone)]
126    pub struct Choice {
127        pub delta: ChoiceDelta,
128        pub index: u32,
129        pub finish_reason: Option<String>,
130    }
131
132    #[derive(Deserialize, Debug, Clone)]
133    pub struct ChoiceDelta {
134        pub content: Option<String>,
135    }
136
137    pub struct ChatCompletionChunkStream {
138        byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
139        buf: String,
140    }
141
142    impl ChatCompletionChunkStream {
143        pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
144            Self {
145                byte_stream: stream,
146                buf: String::new(),
147            }
148        }
149
150        fn deserialize_buf(
151            self: Pin<&mut Self>,
152            cx: &mut std::task::Context<'_>,
153        ) -> Option<anyhow::Result<ChatCompletionChunk>> {
154            let bufclone = self.buf.clone();
155            let mut chunks = bufclone.split("\n\n").peekable();
156            let first = chunks.next();
157            let second = chunks.peek();
158
159            match first {
160                Some(first) => match first.strip_prefix("data: ") {
161                    Some(chunk) => {
162                        if !chunk.ends_with("}") {
163                            None
164                        } else {
165                            if let Some(second) = second {
166                                if second.ends_with("}") {
167                                    cx.waker().wake_by_ref();
168                                }
169                            }
170                            self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
171                            Some(
172                                serde_json::from_str::<ChatCompletionChunk>(chunk)
173                                    .map_err(|e| anyhow::anyhow!(e)),
174                            )
175                        }
176                    }
177                    None => None,
178                },
179                None => None,
180            }
181        }
182    }
183
184    impl Stream for ChatCompletionChunkStream {
185        type Item = anyhow::Result<ChatCompletionChunk>;
186
187        fn poll_next(
188            mut self: Pin<&mut Self>,
189            cx: &mut std::task::Context<'_>,
190        ) -> Poll<Option<Self::Item>> {
191            if let Some(chunk) = self.as_mut().deserialize_buf(cx) {
192                return Poll::Ready(Some(chunk));
193            }
194
195            match self.byte_stream.as_mut().poll_next(cx) {
196                Poll::Ready(bytes_option) => match bytes_option {
197                    Some(bytes_result) => match bytes_result {
198                        Ok(bytes) => {
199                            let data = str::from_utf8(&bytes)?.to_owned();
200                            self.buf = self.buf.clone() + &data;
201                            match self.deserialize_buf(cx) {
202                                Some(chunk) => Poll::Ready(Some(chunk)),
203                                None => {
204                                    cx.waker().wake_by_ref();
205                                    Poll::Pending
206                                }
207                            }
208                        }
209                        Err(e) => Poll::Ready(Some(Err(e.into()))),
210                    },
211                    None => Poll::Ready(None),
212                },
213                Poll::Pending => Poll::Pending,
214            }
215        }
216    }
217}
218
219#[derive(Deserialize, Debug, Clone)]
220pub struct Usage {
221    pub prompt_tokens: u32,
222    pub completion_tokens: u32,
223    pub total_tokens: u32,
224}
225
226#[derive(Deserialize, Debug, Clone)]
227pub struct Choice {
228    #[serde(default)]
229    pub index: Option<u32>,
230    pub message: Message,
231    pub finish_reason: String,
232}
233
234#[derive(Serialize, Deserialize, Debug, Clone)]
235pub struct Message {
236    pub role: String,
237    pub content: String,
238}
239
240pub enum Role {
241    System,
242    Assistant,
243    User,
244}
245
246// =============================================================================
247// xAI Agent Tools API
248// See: https://docs.x.ai/docs/guides/tools/overview
249// =============================================================================
250
251/// Represents a server-side tool available in xAI's Agent Tools API.
252///
253/// xAI provides agentic server-side tool calling where the model autonomously
254/// explores, searches, and executes code. The server handles the entire
255/// reasoning and tool-execution loop.
256///
257/// # Supported Models
258/// - `grok-4-1-fast` (recommended for agentic tool calling)
259/// - `grok-4-1-fast-non-reasoning`
260/// - `grok-4`, `grok-4-fast`, `grok-4-fast-non-reasoning`
261///
262/// # Example
263/// ```rust,no_run
264/// use openai_rust2::chat::GrokTool;
265///
266/// let tools = vec![
267///     GrokTool::web_search(),
268///     GrokTool::x_search(),
269///     GrokTool::code_execution(),
270///     GrokTool::collections_search(vec!["collection-id-1".into()]),
271///     GrokTool::mcp("https://my-mcp-server.com".into()),
272/// ];
273/// ```
274#[derive(Serialize, Debug, Clone)]
275pub struct GrokTool {
276    /// The type of tool: "web_search", "x_search", "code_execution", "collections_search", "mcp"
277    #[serde(rename = "type")]
278    pub tool_type: GrokToolType,
279    /// Restrict web search to specific domains (max 5). Only applies to web_search.
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub allowed_domains: Option<Vec<String>>,
282    /// Inclusive start date for search results (ISO8601: YYYY-MM-DD). Applies to web_search and x_search.
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub from_date: Option<String>,
285    /// Inclusive end date for search results (ISO8601: YYYY-MM-DD). Applies to web_search and x_search.
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub to_date: Option<String>,
288    /// Collection IDs to search. Required for collections_search tool.
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub collection_ids: Option<Vec<String>>,
291    /// MCP server URL. Required for mcp tool.
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub server_url: Option<String>,
294}
295
296/// The type of xAI server-side tool.
297#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
298#[serde(rename_all = "snake_case")]
299pub enum GrokToolType {
300    /// Real-time web search and page browsing
301    WebSearch,
302    /// Search X (Twitter) posts, users, and threads
303    XSearch,
304    /// Execute Python code for calculations and data analysis
305    CodeExecution,
306    /// Search uploaded document collections (knowledge bases)
307    CollectionsSearch,
308    /// Connect to external MCP servers for custom tools
309    Mcp,
310}
311
312impl GrokTool {
313    /// Create a web_search tool with default settings.
314    /// Allows the agent to search the web and browse pages.
315    pub fn web_search() -> Self {
316        Self {
317            tool_type: GrokToolType::WebSearch,
318            allowed_domains: None,
319            from_date: None,
320            to_date: None,
321            collection_ids: None,
322            server_url: None,
323        }
324    }
325
326    /// Create an x_search tool with default settings.
327    /// Allows the agent to search X posts, users, and threads.
328    pub fn x_search() -> Self {
329        Self {
330            tool_type: GrokToolType::XSearch,
331            allowed_domains: None,
332            from_date: None,
333            to_date: None,
334            collection_ids: None,
335            server_url: None,
336        }
337    }
338
339    /// Create a code_execution tool.
340    /// Allows the agent to execute Python code for calculations and data analysis.
341    pub fn code_execution() -> Self {
342        Self {
343            tool_type: GrokToolType::CodeExecution,
344            allowed_domains: None,
345            from_date: None,
346            to_date: None,
347            collection_ids: None,
348            server_url: None,
349        }
350    }
351
352    /// Create a collections_search tool with the specified collection IDs.
353    /// Allows the agent to search through uploaded knowledge bases.
354    pub fn collections_search(collection_ids: Vec<String>) -> Self {
355        Self {
356            tool_type: GrokToolType::CollectionsSearch,
357            allowed_domains: None,
358            from_date: None,
359            to_date: None,
360            collection_ids: Some(collection_ids),
361            server_url: None,
362        }
363    }
364
365    /// Create an MCP tool connecting to an external MCP server.
366    /// Allows the agent to access custom tools from the specified server.
367    pub fn mcp(server_url: String) -> Self {
368        Self {
369            tool_type: GrokToolType::Mcp,
370            allowed_domains: None,
371            from_date: None,
372            to_date: None,
373            collection_ids: None,
374            server_url: Some(server_url),
375        }
376    }
377
378    /// Restrict web search to specific domains (max 5).
379    /// Only applies to web_search tool.
380    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
381        self.allowed_domains = Some(domains);
382        self
383    }
384
385    /// Set the date range for search results (ISO8601: YYYY-MM-DD).
386    /// Applies to web_search and x_search tools.
387    pub fn with_date_range(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
388        self.from_date = Some(from.into());
389        self.to_date = Some(to.into());
390        self
391    }
392}