1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5 JsonObject,
6 Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11 pub quality: Option<String>, pub size: Option<String>, pub output_format: Option<String>, }
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18 pub model: String,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub temperature: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_p: Option<f32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub n: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub max_tokens: Option<u32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub presence_penalty: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub frequency_penalty: Option<f32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub user: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub response_format: Option<ResponseFormat>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub image_generation: Option<ImageGeneration>,
42 #[serde(skip_serializing_if = "Option::is_none", rename = "server_tools")]
46 pub grok_tools: Option<Vec<GrokTool>>,
47}
48
49impl ChatArguments {
50 pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
51 ChatArguments {
52 model: model.as_ref().to_owned(),
53 messages,
54 temperature: None,
55 top_p: None,
56 n: None,
57 stream: None,
58 stop: None,
59 max_tokens: None,
60 presence_penalty: None,
61 frequency_penalty: None,
62 user: None,
63 response_format: None,
64 image_generation: None,
65 grok_tools: None,
66 }
67 }
68
69 pub fn with_grok_tools(mut self, tools: Vec<GrokTool>) -> Self {
72 self.grok_tools = Some(tools);
73 self
74 }
75}
76
77#[derive(Deserialize, Debug, Clone)]
78pub struct ChatCompletion {
79 #[serde(default)]
80 pub id: Option<String>,
81 pub created: u32,
82 #[serde(default)]
83 pub model: Option<String>,
84 #[serde(default)]
85 pub object: Option<String>,
86 pub choices: Vec<Choice>,
87 pub usage: Usage,
88}
89
90impl std::fmt::Display for ChatCompletion {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", &self.choices[0].message.content)?;
93 Ok(())
94 }
95}
96
97pub mod stream {
98 use bytes::Bytes;
99 use futures_util::Stream;
100 use serde::Deserialize;
101 use std::pin::Pin;
102 use std::str;
103 use std::task::Poll;
104
105 #[derive(Deserialize, Debug, Clone)]
106 pub struct ChatCompletionChunk {
107 pub id: String,
108 pub created: u32,
109 pub model: String,
110 pub choices: Vec<Choice>,
111 pub system_fingerprint: Option<String>,
112 }
113
114 impl std::fmt::Display for ChatCompletionChunk {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(
117 f,
118 "{}",
119 self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
120 )?;
121 Ok(())
122 }
123 }
124
125 #[derive(Deserialize, Debug, Clone)]
126 pub struct Choice {
127 pub delta: ChoiceDelta,
128 pub index: u32,
129 pub finish_reason: Option<String>,
130 }
131
132 #[derive(Deserialize, Debug, Clone)]
133 pub struct ChoiceDelta {
134 pub content: Option<String>,
135 }
136
137 pub struct ChatCompletionChunkStream {
138 byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
139 buf: String,
140 }
141
142 impl ChatCompletionChunkStream {
143 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
144 Self {
145 byte_stream: stream,
146 buf: String::new(),
147 }
148 }
149
150 fn deserialize_buf(
151 self: Pin<&mut Self>,
152 cx: &mut std::task::Context<'_>,
153 ) -> Option<anyhow::Result<ChatCompletionChunk>> {
154 let bufclone = self.buf.clone();
155 let mut chunks = bufclone.split("\n\n").peekable();
156 let first = chunks.next();
157 let second = chunks.peek();
158
159 match first {
160 Some(first) => match first.strip_prefix("data: ") {
161 Some(chunk) => {
162 if !chunk.ends_with("}") {
163 None
164 } else {
165 if let Some(second) = second {
166 if second.ends_with("}") {
167 cx.waker().wake_by_ref();
168 }
169 }
170 self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
171 Some(
172 serde_json::from_str::<ChatCompletionChunk>(chunk)
173 .map_err(|e| anyhow::anyhow!(e)),
174 )
175 }
176 }
177 None => None,
178 },
179 None => None,
180 }
181 }
182 }
183
184 impl Stream for ChatCompletionChunkStream {
185 type Item = anyhow::Result<ChatCompletionChunk>;
186
187 fn poll_next(
188 mut self: Pin<&mut Self>,
189 cx: &mut std::task::Context<'_>,
190 ) -> Poll<Option<Self::Item>> {
191 if let Some(chunk) = self.as_mut().deserialize_buf(cx) {
192 return Poll::Ready(Some(chunk));
193 }
194
195 match self.byte_stream.as_mut().poll_next(cx) {
196 Poll::Ready(bytes_option) => match bytes_option {
197 Some(bytes_result) => match bytes_result {
198 Ok(bytes) => {
199 let data = str::from_utf8(&bytes)?.to_owned();
200 self.buf = self.buf.clone() + &data;
201 match self.deserialize_buf(cx) {
202 Some(chunk) => Poll::Ready(Some(chunk)),
203 None => {
204 cx.waker().wake_by_ref();
205 Poll::Pending
206 }
207 }
208 }
209 Err(e) => Poll::Ready(Some(Err(e.into()))),
210 },
211 None => Poll::Ready(None),
212 },
213 Poll::Pending => Poll::Pending,
214 }
215 }
216 }
217}
218
219#[derive(Deserialize, Debug, Clone)]
220pub struct Usage {
221 pub prompt_tokens: u32,
222 pub completion_tokens: u32,
223 pub total_tokens: u32,
224}
225
226#[derive(Deserialize, Debug, Clone)]
227pub struct Choice {
228 #[serde(default)]
229 pub index: Option<u32>,
230 pub message: Message,
231 pub finish_reason: String,
232}
233
234#[derive(Serialize, Deserialize, Debug, Clone)]
235pub struct Message {
236 pub role: String,
237 pub content: String,
238}
239
240pub enum Role {
241 System,
242 Assistant,
243 User,
244}
245
246#[derive(Serialize, Debug, Clone)]
275pub struct GrokTool {
276 #[serde(rename = "type")]
278 pub tool_type: GrokToolType,
279 #[serde(skip_serializing_if = "Option::is_none")]
281 pub allowed_domains: Option<Vec<String>>,
282 #[serde(skip_serializing_if = "Option::is_none")]
284 pub from_date: Option<String>,
285 #[serde(skip_serializing_if = "Option::is_none")]
287 pub to_date: Option<String>,
288 #[serde(skip_serializing_if = "Option::is_none")]
290 pub collection_ids: Option<Vec<String>>,
291 #[serde(skip_serializing_if = "Option::is_none")]
293 pub server_url: Option<String>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
298#[serde(rename_all = "snake_case")]
299pub enum GrokToolType {
300 WebSearch,
302 XSearch,
304 CodeExecution,
306 CollectionsSearch,
308 Mcp,
310}
311
312impl GrokTool {
313 pub fn web_search() -> Self {
316 Self {
317 tool_type: GrokToolType::WebSearch,
318 allowed_domains: None,
319 from_date: None,
320 to_date: None,
321 collection_ids: None,
322 server_url: None,
323 }
324 }
325
326 pub fn x_search() -> Self {
329 Self {
330 tool_type: GrokToolType::XSearch,
331 allowed_domains: None,
332 from_date: None,
333 to_date: None,
334 collection_ids: None,
335 server_url: None,
336 }
337 }
338
339 pub fn code_execution() -> Self {
342 Self {
343 tool_type: GrokToolType::CodeExecution,
344 allowed_domains: None,
345 from_date: None,
346 to_date: None,
347 collection_ids: None,
348 server_url: None,
349 }
350 }
351
352 pub fn collections_search(collection_ids: Vec<String>) -> Self {
355 Self {
356 tool_type: GrokToolType::CollectionsSearch,
357 allowed_domains: None,
358 from_date: None,
359 to_date: None,
360 collection_ids: Some(collection_ids),
361 server_url: None,
362 }
363 }
364
365 pub fn mcp(server_url: String) -> Self {
368 Self {
369 tool_type: GrokToolType::Mcp,
370 allowed_domains: None,
371 from_date: None,
372 to_date: None,
373 collection_ids: None,
374 server_url: Some(server_url),
375 }
376 }
377
378 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
381 self.allowed_domains = Some(domains);
382 self
383 }
384
385 pub fn with_date_range(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
388 self.from_date = Some(from.into());
389 self.to_date = Some(to.into());
390 self
391 }
392}
393
394#[derive(Serialize, Debug, Clone)]
420pub struct ResponsesArguments {
421 pub model: String,
422 pub input: Vec<ResponsesMessage>,
423 #[serde(skip_serializing_if = "Option::is_none")]
424 pub tools: Option<Vec<GrokTool>>,
425 #[serde(skip_serializing_if = "Option::is_none")]
426 pub temperature: Option<f32>,
427 #[serde(skip_serializing_if = "Option::is_none")]
428 pub max_output_tokens: Option<u32>,
429}
430
431impl ResponsesArguments {
432 pub fn new(model: impl AsRef<str>, input: Vec<ResponsesMessage>) -> Self {
434 Self {
435 model: model.as_ref().to_owned(),
436 input,
437 tools: None,
438 temperature: None,
439 max_output_tokens: None,
440 }
441 }
442
443 pub fn with_tools(mut self, tools: Vec<GrokTool>) -> Self {
445 self.tools = Some(tools);
446 self
447 }
448
449 pub fn with_temperature(mut self, temperature: f32) -> Self {
451 self.temperature = Some(temperature);
452 self
453 }
454
455 pub fn with_max_output_tokens(mut self, max_tokens: u32) -> Self {
457 self.max_output_tokens = Some(max_tokens);
458 self
459 }
460}
461
462#[derive(Serialize, Deserialize, Debug, Clone)]
464pub struct ResponsesMessage {
465 pub role: String,
466 pub content: String,
467}
468
469#[derive(Deserialize, Debug, Clone)]
474pub struct ResponsesCompletion {
475 #[serde(default)]
476 pub id: Option<String>,
477 pub output: Vec<ResponsesOutputItem>,
479 #[serde(default)]
481 pub citations: Vec<String>,
482 pub usage: ResponsesUsage,
484}
485
486impl ResponsesCompletion {
487 pub fn get_text_content(&self) -> String {
489 self.output
490 .iter()
491 .filter_map(|item| {
492 if item.item_type == "message" {
493 item.content.as_ref().map(|contents| {
494 contents
495 .iter()
496 .filter_map(|c| {
497 if c.content_type == "output_text" {
498 c.text.clone()
499 } else {
500 None
501 }
502 })
503 .collect::<Vec<_>>()
504 .join("")
505 })
506 } else {
507 None
508 }
509 })
510 .collect::<Vec<_>>()
511 .join("")
512 }
513}
514
515impl std::fmt::Display for ResponsesCompletion {
516 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517 write!(f, "{}", self.get_text_content())
518 }
519}
520
521#[derive(Deserialize, Debug, Clone)]
523pub struct ResponsesOutputItem {
524 #[serde(rename = "type")]
525 pub item_type: String,
526 #[serde(default)]
527 pub role: Option<String>,
528 #[serde(default)]
529 pub content: Option<Vec<ResponsesContent>>,
530}
531
532#[derive(Deserialize, Debug, Clone)]
534pub struct ResponsesContent {
535 #[serde(rename = "type")]
536 pub content_type: String,
537 #[serde(default)]
538 pub text: Option<String>,
539}
540
541#[derive(Deserialize, Debug, Clone)]
543pub struct ResponsesUsage {
544 #[serde(default)]
545 pub input_tokens: u32,
546 #[serde(default)]
547 pub output_tokens: u32,
548 #[serde(default)]
549 pub total_tokens: u32,
550}