openai_rust2/
chat.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5    JsonObject,
6    Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11    pub quality: Option<String>,       // e.g., "standard", "hd"
12    pub size: Option<String>,          // e.g., "1024x1024"
13    pub output_format: Option<String>, // e.g., "base64", "url"
14}
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18    pub model: String,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub temperature: Option<f32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub top_p: Option<f32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub n: Option<u32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub max_tokens: Option<u32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub presence_penalty: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub frequency_penalty: Option<f32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub user: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub response_format: Option<ResponseFormat>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub image_generation: Option<ImageGeneration>,
42    /// xAI Agent Tools API - server-side tools for agentic capabilities.
43    /// Includes: web_search, x_search, code_execution, collections_search, mcp.
44    /// See: https://docs.x.ai/docs/guides/tools/overview
45    #[serde(skip_serializing_if = "Option::is_none", rename = "server_tools")]
46    pub grok_tools: Option<Vec<GrokTool>>,
47}
48
49impl ChatArguments {
50    pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
51        ChatArguments {
52            model: model.as_ref().to_owned(),
53            messages,
54            temperature: None,
55            top_p: None,
56            n: None,
57            stream: None,
58            stop: None,
59            max_tokens: None,
60            presence_penalty: None,
61            frequency_penalty: None,
62            user: None,
63            response_format: None,
64            image_generation: None,
65            grok_tools: None,
66        }
67    }
68
69    /// Add xAI server-side tools for agentic capabilities.
70    /// Recommended model: `grok-4-1-fast` for best tool-calling performance.
71    pub fn with_grok_tools(mut self, tools: Vec<GrokTool>) -> Self {
72        self.grok_tools = Some(tools);
73        self
74    }
75}
76
77#[derive(Deserialize, Debug, Clone)]
78pub struct ChatCompletion {
79    #[serde(default)]
80    pub id: Option<String>,
81    pub created: u32,
82    #[serde(default)]
83    pub model: Option<String>,
84    #[serde(default)]
85    pub object: Option<String>,
86    pub choices: Vec<Choice>,
87    pub usage: Usage,
88}
89
90impl std::fmt::Display for ChatCompletion {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(f, "{}", &self.choices[0].message.content)?;
93        Ok(())
94    }
95}
96
97pub mod stream {
98    use bytes::Bytes;
99    use futures_util::Stream;
100    use serde::Deserialize;
101    use std::pin::Pin;
102    use std::str;
103    use std::task::Poll;
104
105    #[derive(Deserialize, Debug, Clone)]
106    pub struct ChatCompletionChunk {
107        pub id: String,
108        pub created: u32,
109        pub model: String,
110        pub choices: Vec<Choice>,
111        pub system_fingerprint: Option<String>,
112    }
113
114    impl std::fmt::Display for ChatCompletionChunk {
115        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116            write!(
117                f,
118                "{}",
119                self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
120            )?;
121            Ok(())
122        }
123    }
124
125    #[derive(Deserialize, Debug, Clone)]
126    pub struct Choice {
127        pub delta: ChoiceDelta,
128        pub index: u32,
129        pub finish_reason: Option<String>,
130    }
131
132    #[derive(Deserialize, Debug, Clone)]
133    pub struct ChoiceDelta {
134        pub content: Option<String>,
135    }
136
137    pub struct ChatCompletionChunkStream {
138        byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
139        buf: String,
140    }
141
142    impl ChatCompletionChunkStream {
143        pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
144            Self {
145                byte_stream: stream,
146                buf: String::new(),
147            }
148        }
149
150        fn deserialize_buf(
151            self: Pin<&mut Self>,
152            cx: &mut std::task::Context<'_>,
153        ) -> Option<anyhow::Result<ChatCompletionChunk>> {
154            let bufclone = self.buf.clone();
155            let mut chunks = bufclone.split("\n\n").peekable();
156            let first = chunks.next();
157            let second = chunks.peek();
158
159            match first {
160                Some(first) => match first.strip_prefix("data: ") {
161                    Some(chunk) => {
162                        if !chunk.ends_with("}") {
163                            None
164                        } else {
165                            if let Some(second) = second {
166                                if second.ends_with("}") {
167                                    cx.waker().wake_by_ref();
168                                }
169                            }
170                            self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
171                            Some(
172                                serde_json::from_str::<ChatCompletionChunk>(chunk)
173                                    .map_err(|e| anyhow::anyhow!(e)),
174                            )
175                        }
176                    }
177                    None => None,
178                },
179                None => None,
180            }
181        }
182    }
183
184    impl Stream for ChatCompletionChunkStream {
185        type Item = anyhow::Result<ChatCompletionChunk>;
186
187        fn poll_next(
188            mut self: Pin<&mut Self>,
189            cx: &mut std::task::Context<'_>,
190        ) -> Poll<Option<Self::Item>> {
191            if let Some(chunk) = self.as_mut().deserialize_buf(cx) {
192                return Poll::Ready(Some(chunk));
193            }
194
195            match self.byte_stream.as_mut().poll_next(cx) {
196                Poll::Ready(bytes_option) => match bytes_option {
197                    Some(bytes_result) => match bytes_result {
198                        Ok(bytes) => {
199                            let data = str::from_utf8(&bytes)?.to_owned();
200                            self.buf = self.buf.clone() + &data;
201                            match self.deserialize_buf(cx) {
202                                Some(chunk) => Poll::Ready(Some(chunk)),
203                                None => {
204                                    cx.waker().wake_by_ref();
205                                    Poll::Pending
206                                }
207                            }
208                        }
209                        Err(e) => Poll::Ready(Some(Err(e.into()))),
210                    },
211                    None => Poll::Ready(None),
212                },
213                Poll::Pending => Poll::Pending,
214            }
215        }
216    }
217}
218
219#[derive(Deserialize, Debug, Clone)]
220pub struct Usage {
221    pub prompt_tokens: u32,
222    pub completion_tokens: u32,
223    pub total_tokens: u32,
224}
225
226#[derive(Deserialize, Debug, Clone)]
227pub struct Choice {
228    #[serde(default)]
229    pub index: Option<u32>,
230    pub message: Message,
231    pub finish_reason: String,
232}
233
234#[derive(Serialize, Deserialize, Debug, Clone)]
235pub struct Message {
236    pub role: String,
237    pub content: String,
238}
239
240pub enum Role {
241    System,
242    Assistant,
243    User,
244}
245
246// =============================================================================
247// xAI Agent Tools API
248// See: https://docs.x.ai/docs/guides/tools/overview
249// =============================================================================
250
251/// Represents a server-side tool available in xAI's Agent Tools API.
252///
253/// xAI provides agentic server-side tool calling where the model autonomously
254/// explores, searches, and executes code. The server handles the entire
255/// reasoning and tool-execution loop.
256///
257/// # Supported Models
258/// - `grok-4-1-fast` (recommended for agentic tool calling)
259/// - `grok-4-1-fast-non-reasoning`
260/// - `grok-4`, `grok-4-fast`, `grok-4-fast-non-reasoning`
261///
262/// # Example
263/// ```rust,no_run
264/// use openai_rust2::chat::GrokTool;
265///
266/// let tools = vec![
267///     GrokTool::web_search(),
268///     GrokTool::x_search(),
269///     GrokTool::code_execution(),
270///     GrokTool::collections_search(vec!["collection-id-1".into()]),
271///     GrokTool::mcp("https://my-mcp-server.com".into()),
272/// ];
273/// ```
274#[derive(Serialize, Debug, Clone)]
275pub struct GrokTool {
276    /// The type of tool: "web_search", "x_search", "code_execution", "collections_search", "mcp"
277    #[serde(rename = "type")]
278    pub tool_type: GrokToolType,
279    /// Restrict web search to specific domains (max 5). Only applies to web_search.
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub allowed_domains: Option<Vec<String>>,
282    /// Inclusive start date for search results (ISO8601: YYYY-MM-DD). Applies to web_search and x_search.
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub from_date: Option<String>,
285    /// Inclusive end date for search results (ISO8601: YYYY-MM-DD). Applies to web_search and x_search.
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub to_date: Option<String>,
288    /// Collection IDs to search. Required for collections_search tool.
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub collection_ids: Option<Vec<String>>,
291    /// MCP server URL. Required for mcp tool.
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub server_url: Option<String>,
294}
295
296/// The type of xAI server-side tool.
297#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
298#[serde(rename_all = "snake_case")]
299pub enum GrokToolType {
300    /// Real-time web search and page browsing
301    WebSearch,
302    /// Search X (Twitter) posts, users, and threads
303    XSearch,
304    /// Execute Python code for calculations and data analysis
305    CodeExecution,
306    /// Search uploaded document collections (knowledge bases)
307    CollectionsSearch,
308    /// Connect to external MCP servers for custom tools
309    Mcp,
310}
311
312impl GrokTool {
313    /// Create a web_search tool with default settings.
314    /// Allows the agent to search the web and browse pages.
315    pub fn web_search() -> Self {
316        Self {
317            tool_type: GrokToolType::WebSearch,
318            allowed_domains: None,
319            from_date: None,
320            to_date: None,
321            collection_ids: None,
322            server_url: None,
323        }
324    }
325
326    /// Create an x_search tool with default settings.
327    /// Allows the agent to search X posts, users, and threads.
328    pub fn x_search() -> Self {
329        Self {
330            tool_type: GrokToolType::XSearch,
331            allowed_domains: None,
332            from_date: None,
333            to_date: None,
334            collection_ids: None,
335            server_url: None,
336        }
337    }
338
339    /// Create a code_execution tool.
340    /// Allows the agent to execute Python code for calculations and data analysis.
341    pub fn code_execution() -> Self {
342        Self {
343            tool_type: GrokToolType::CodeExecution,
344            allowed_domains: None,
345            from_date: None,
346            to_date: None,
347            collection_ids: None,
348            server_url: None,
349        }
350    }
351
352    /// Create a collections_search tool with the specified collection IDs.
353    /// Allows the agent to search through uploaded knowledge bases.
354    pub fn collections_search(collection_ids: Vec<String>) -> Self {
355        Self {
356            tool_type: GrokToolType::CollectionsSearch,
357            allowed_domains: None,
358            from_date: None,
359            to_date: None,
360            collection_ids: Some(collection_ids),
361            server_url: None,
362        }
363    }
364
365    /// Create an MCP tool connecting to an external MCP server.
366    /// Allows the agent to access custom tools from the specified server.
367    pub fn mcp(server_url: String) -> Self {
368        Self {
369            tool_type: GrokToolType::Mcp,
370            allowed_domains: None,
371            from_date: None,
372            to_date: None,
373            collection_ids: None,
374            server_url: Some(server_url),
375        }
376    }
377
378    /// Restrict web search to specific domains (max 5).
379    /// Only applies to web_search tool.
380    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
381        self.allowed_domains = Some(domains);
382        self
383    }
384
385    /// Set the date range for search results (ISO8601: YYYY-MM-DD).
386    /// Applies to web_search and x_search tools.
387    pub fn with_date_range(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
388        self.from_date = Some(from.into());
389        self.to_date = Some(to.into());
390        self
391    }
392}
393
394// =============================================================================
395// xAI Responses API
396// See: https://docs.x.ai/docs/guides/tools/search-tools
397// The Responses API is a separate endpoint (/v1/responses) for agentic tool calling.
398// =============================================================================
399
400/// Request arguments for xAI's Responses API endpoint (/v1/responses).
401///
402/// This API provides agentic tool calling where the model autonomously
403/// explores, searches, and executes code. Unlike the Chat Completions API,
404/// the Responses API uses `input` instead of `messages` and `tools` instead
405/// of `server_tools`.
406///
407/// # Example
408/// ```rust,no_run
409/// use openai_rust2::chat::{ResponsesArguments, ResponsesMessage, GrokTool};
410///
411/// let args = ResponsesArguments::new(
412///     "grok-4-1-fast-reasoning",
413///     vec![ResponsesMessage {
414///         role: "user".to_string(),
415///         content: "What is the current price of Bitcoin?".to_string(),
416///     }],
417/// ).with_tools(vec![GrokTool::web_search()]);
418/// ```
419#[derive(Serialize, Debug, Clone)]
420pub struct ResponsesArguments {
421    pub model: String,
422    pub input: Vec<ResponsesMessage>,
423    #[serde(skip_serializing_if = "Option::is_none")]
424    pub tools: Option<Vec<GrokTool>>,
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub temperature: Option<f32>,
427    #[serde(skip_serializing_if = "Option::is_none")]
428    pub max_output_tokens: Option<u32>,
429}
430
431impl ResponsesArguments {
432    /// Create new ResponsesArguments for the xAI Responses API.
433    pub fn new(model: impl AsRef<str>, input: Vec<ResponsesMessage>) -> Self {
434        Self {
435            model: model.as_ref().to_owned(),
436            input,
437            tools: None,
438            temperature: None,
439            max_output_tokens: None,
440        }
441    }
442
443    /// Add tools for agentic capabilities.
444    pub fn with_tools(mut self, tools: Vec<GrokTool>) -> Self {
445        self.tools = Some(tools);
446        self
447    }
448
449    /// Set the temperature for response generation.
450    pub fn with_temperature(mut self, temperature: f32) -> Self {
451        self.temperature = Some(temperature);
452        self
453    }
454
455    /// Set the maximum output tokens.
456    pub fn with_max_output_tokens(mut self, max_tokens: u32) -> Self {
457        self.max_output_tokens = Some(max_tokens);
458        self
459    }
460}
461
462/// Message format for the Responses API input array.
463#[derive(Serialize, Deserialize, Debug, Clone)]
464pub struct ResponsesMessage {
465    pub role: String,
466    pub content: String,
467}
468
469/// Response from xAI's Responses API.
470///
471/// The Responses API returns a different format from Chat Completions,
472/// including citations for sources used during agentic search.
473#[derive(Deserialize, Debug, Clone)]
474pub struct ResponsesCompletion {
475    #[serde(default)]
476    pub id: Option<String>,
477    /// The output content items from the model
478    pub output: Vec<ResponsesOutputItem>,
479    /// Citations for sources used during search (URLs)
480    #[serde(default)]
481    pub citations: Vec<String>,
482    /// Token usage statistics
483    pub usage: ResponsesUsage,
484}
485
486impl ResponsesCompletion {
487    /// Extract the text content from the response output.
488    pub fn get_text_content(&self) -> String {
489        self.output
490            .iter()
491            .filter_map(|item| {
492                if item.item_type == "message" {
493                    item.content.as_ref().map(|contents| {
494                        contents
495                            .iter()
496                            .filter_map(|c| {
497                                if c.content_type == "output_text" {
498                                    c.text.clone()
499                                } else {
500                                    None
501                                }
502                            })
503                            .collect::<Vec<_>>()
504                            .join("")
505                    })
506                } else {
507                    None
508                }
509            })
510            .collect::<Vec<_>>()
511            .join("")
512    }
513}
514
515impl std::fmt::Display for ResponsesCompletion {
516    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517        write!(f, "{}", self.get_text_content())
518    }
519}
520
521/// An output item in the Responses API response.
522#[derive(Deserialize, Debug, Clone)]
523pub struct ResponsesOutputItem {
524    #[serde(rename = "type")]
525    pub item_type: String,
526    #[serde(default)]
527    pub role: Option<String>,
528    #[serde(default)]
529    pub content: Option<Vec<ResponsesContent>>,
530}
531
532/// Content within a Responses API output item.
533#[derive(Deserialize, Debug, Clone)]
534pub struct ResponsesContent {
535    #[serde(rename = "type")]
536    pub content_type: String,
537    #[serde(default)]
538    pub text: Option<String>,
539}
540
541/// Token usage for Responses API.
542#[derive(Deserialize, Debug, Clone)]
543pub struct ResponsesUsage {
544    #[serde(default)]
545    pub input_tokens: u32,
546    #[serde(default)]
547    pub output_tokens: u32,
548    #[serde(default)]
549    pub total_tokens: u32,
550}