1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5 JsonObject,
6 Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11 pub quality: Option<String>, pub size: Option<String>, pub output_format: Option<String>, }
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18 pub model: String,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub temperature: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_p: Option<f32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub n: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub max_tokens: Option<u32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub presence_penalty: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub frequency_penalty: Option<f32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub user: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub response_format: Option<ResponseFormat>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub image_generation: Option<ImageGeneration>,
42 #[serde(skip_serializing_if = "Option::is_none", rename = "server_tools")]
46 pub grok_tools: Option<Vec<GrokTool>>,
47 #[serde(skip_serializing_if = "Option::is_none")]
52 pub tools: Option<Vec<OpenAITool>>,
53}
54
55impl ChatArguments {
56 pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
57 ChatArguments {
58 model: model.as_ref().to_owned(),
59 messages,
60 temperature: None,
61 top_p: None,
62 n: None,
63 stream: None,
64 stop: None,
65 max_tokens: None,
66 presence_penalty: None,
67 frequency_penalty: None,
68 user: None,
69 response_format: None,
70 image_generation: None,
71 grok_tools: None,
72 tools: None,
73 }
74 }
75
76 pub fn with_grok_tools(mut self, tools: Vec<GrokTool>) -> Self {
79 self.grok_tools = Some(tools);
80 self
81 }
82
83 pub fn with_openai_tools(mut self, tools: Vec<OpenAITool>) -> Self {
87 self.tools = Some(tools);
88 self
89 }
90}
91
92#[derive(Deserialize, Debug, Clone)]
93pub struct ChatCompletion {
94 #[serde(default)]
95 pub id: Option<String>,
96 pub created: u32,
97 #[serde(default)]
98 pub model: Option<String>,
99 #[serde(default)]
100 pub object: Option<String>,
101 pub choices: Vec<Choice>,
102 pub usage: Usage,
103}
104
105impl std::fmt::Display for ChatCompletion {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "{}", &self.choices[0].message.content)?;
108 Ok(())
109 }
110}
111
112pub mod stream {
113 use bytes::Bytes;
114 use futures_util::Stream;
115 use serde::Deserialize;
116 use std::pin::Pin;
117 use std::str;
118 use std::task::Poll;
119
120 #[derive(Deserialize, Debug, Clone)]
121 pub struct ChatCompletionChunk {
122 pub id: String,
123 pub created: u32,
124 pub model: String,
125 pub choices: Vec<Choice>,
126 pub system_fingerprint: Option<String>,
127 }
128
129 impl std::fmt::Display for ChatCompletionChunk {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 write!(
132 f,
133 "{}",
134 self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
135 )?;
136 Ok(())
137 }
138 }
139
140 #[derive(Deserialize, Debug, Clone)]
141 pub struct Choice {
142 pub delta: ChoiceDelta,
143 pub index: u32,
144 pub finish_reason: Option<String>,
145 }
146
147 #[derive(Deserialize, Debug, Clone)]
148 pub struct ChoiceDelta {
149 pub content: Option<String>,
150 }
151
152 pub struct ChatCompletionChunkStream {
153 byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
154 buf: String,
155 }
156
157 impl ChatCompletionChunkStream {
158 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
159 Self {
160 byte_stream: stream,
161 buf: String::new(),
162 }
163 }
164
165 fn deserialize_buf(
166 self: Pin<&mut Self>,
167 cx: &mut std::task::Context<'_>,
168 ) -> Option<anyhow::Result<ChatCompletionChunk>> {
169 let bufclone = self.buf.clone();
170 let mut chunks = bufclone.split("\n\n").peekable();
171 let first = chunks.next();
172 let second = chunks.peek();
173
174 match first {
175 Some(first) => match first.strip_prefix("data: ") {
176 Some(chunk) => {
177 if !chunk.ends_with("}") {
178 None
179 } else {
180 if let Some(second) = second {
181 if second.ends_with("}") {
182 cx.waker().wake_by_ref();
183 }
184 }
185 self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
186 Some(
187 serde_json::from_str::<ChatCompletionChunk>(chunk)
188 .map_err(|e| anyhow::anyhow!(e)),
189 )
190 }
191 }
192 None => None,
193 },
194 None => None,
195 }
196 }
197 }
198
199 impl Stream for ChatCompletionChunkStream {
200 type Item = anyhow::Result<ChatCompletionChunk>;
201
202 fn poll_next(
203 mut self: Pin<&mut Self>,
204 cx: &mut std::task::Context<'_>,
205 ) -> Poll<Option<Self::Item>> {
206 if let Some(chunk) = self.as_mut().deserialize_buf(cx) {
207 return Poll::Ready(Some(chunk));
208 }
209
210 match self.byte_stream.as_mut().poll_next(cx) {
211 Poll::Ready(bytes_option) => match bytes_option {
212 Some(bytes_result) => match bytes_result {
213 Ok(bytes) => {
214 let data = str::from_utf8(&bytes)?.to_owned();
215 self.buf = self.buf.clone() + &data;
216 match self.deserialize_buf(cx) {
217 Some(chunk) => Poll::Ready(Some(chunk)),
218 None => {
219 cx.waker().wake_by_ref();
220 Poll::Pending
221 }
222 }
223 }
224 Err(e) => Poll::Ready(Some(Err(e.into()))),
225 },
226 None => Poll::Ready(None),
227 },
228 Poll::Pending => Poll::Pending,
229 }
230 }
231 }
232}
233
234#[derive(Deserialize, Debug, Clone)]
235pub struct Usage {
236 pub prompt_tokens: u32,
237 pub completion_tokens: u32,
238 pub total_tokens: u32,
239}
240
241#[derive(Deserialize, Debug, Clone)]
242pub struct Choice {
243 #[serde(default)]
244 pub index: Option<u32>,
245 pub message: Message,
246 pub finish_reason: String,
247}
248
249#[derive(Serialize, Deserialize, Debug, Clone)]
250pub struct Message {
251 pub role: String,
252 pub content: String,
253}
254
255pub enum Role {
256 System,
257 Assistant,
258 User,
259}
260
261#[derive(Serialize, Debug, Clone)]
290pub struct GrokTool {
291 #[serde(rename = "type")]
293 pub tool_type: GrokToolType,
294 #[serde(skip_serializing_if = "Option::is_none")]
296 pub allowed_domains: Option<Vec<String>>,
297 #[serde(skip_serializing_if = "Option::is_none")]
299 pub from_date: Option<String>,
300 #[serde(skip_serializing_if = "Option::is_none")]
302 pub to_date: Option<String>,
303 #[serde(skip_serializing_if = "Option::is_none")]
305 pub collection_ids: Option<Vec<String>>,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 pub server_url: Option<String>,
309}
310
311#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
313#[serde(rename_all = "snake_case")]
314pub enum GrokToolType {
315 WebSearch,
317 XSearch,
319 CodeExecution,
321 CollectionsSearch,
323 Mcp,
325}
326
327impl GrokTool {
328 pub fn web_search() -> Self {
331 Self {
332 tool_type: GrokToolType::WebSearch,
333 allowed_domains: None,
334 from_date: None,
335 to_date: None,
336 collection_ids: None,
337 server_url: None,
338 }
339 }
340
341 pub fn x_search() -> Self {
344 Self {
345 tool_type: GrokToolType::XSearch,
346 allowed_domains: None,
347 from_date: None,
348 to_date: None,
349 collection_ids: None,
350 server_url: None,
351 }
352 }
353
354 pub fn code_execution() -> Self {
357 Self {
358 tool_type: GrokToolType::CodeExecution,
359 allowed_domains: None,
360 from_date: None,
361 to_date: None,
362 collection_ids: None,
363 server_url: None,
364 }
365 }
366
367 pub fn collections_search(collection_ids: Vec<String>) -> Self {
370 Self {
371 tool_type: GrokToolType::CollectionsSearch,
372 allowed_domains: None,
373 from_date: None,
374 to_date: None,
375 collection_ids: Some(collection_ids),
376 server_url: None,
377 }
378 }
379
380 pub fn mcp(server_url: String) -> Self {
383 Self {
384 tool_type: GrokToolType::Mcp,
385 allowed_domains: None,
386 from_date: None,
387 to_date: None,
388 collection_ids: None,
389 server_url: Some(server_url),
390 }
391 }
392
393 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
396 self.allowed_domains = Some(domains);
397 self
398 }
399
400 pub fn with_date_range(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
403 self.from_date = Some(from.into());
404 self.to_date = Some(to.into());
405 self
406 }
407}
408
409#[derive(Serialize, Debug, Clone)]
435pub struct ResponsesArguments {
436 pub model: String,
437 pub input: Vec<ResponsesMessage>,
438 #[serde(skip_serializing_if = "Option::is_none")]
439 pub tools: Option<Vec<GrokTool>>,
440 #[serde(skip_serializing_if = "Option::is_none")]
441 pub temperature: Option<f32>,
442 #[serde(skip_serializing_if = "Option::is_none")]
443 pub max_output_tokens: Option<u32>,
444}
445
446impl ResponsesArguments {
447 pub fn new(model: impl AsRef<str>, input: Vec<ResponsesMessage>) -> Self {
449 Self {
450 model: model.as_ref().to_owned(),
451 input,
452 tools: None,
453 temperature: None,
454 max_output_tokens: None,
455 }
456 }
457
458 pub fn with_tools(mut self, tools: Vec<GrokTool>) -> Self {
460 self.tools = Some(tools);
461 self
462 }
463
464 pub fn with_temperature(mut self, temperature: f32) -> Self {
466 self.temperature = Some(temperature);
467 self
468 }
469
470 pub fn with_max_output_tokens(mut self, max_tokens: u32) -> Self {
472 self.max_output_tokens = Some(max_tokens);
473 self
474 }
475}
476
477#[derive(Serialize, Deserialize, Debug, Clone)]
479pub struct ResponsesMessage {
480 pub role: String,
481 pub content: String,
482}
483
484#[derive(Deserialize, Debug, Clone)]
489pub struct ResponsesCompletion {
490 #[serde(default)]
491 pub id: Option<String>,
492 pub output: Vec<ResponsesOutputItem>,
494 #[serde(default)]
496 pub citations: Vec<String>,
497 pub usage: ResponsesUsage,
499}
500
501impl ResponsesCompletion {
502 pub fn get_text_content(&self) -> String {
504 self.output
505 .iter()
506 .filter_map(|item| {
507 if item.item_type == "message" {
508 item.content.as_ref().map(|contents| {
509 contents
510 .iter()
511 .filter_map(|c| {
512 if c.content_type == "output_text" {
513 c.text.clone()
514 } else {
515 None
516 }
517 })
518 .collect::<Vec<_>>()
519 .join("")
520 })
521 } else {
522 None
523 }
524 })
525 .collect::<Vec<_>>()
526 .join("")
527 }
528}
529
530impl std::fmt::Display for ResponsesCompletion {
531 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532 write!(f, "{}", self.get_text_content())
533 }
534}
535
536#[derive(Deserialize, Debug, Clone)]
538pub struct ResponsesOutputItem {
539 #[serde(rename = "type")]
540 pub item_type: String,
541 #[serde(default)]
542 pub role: Option<String>,
543 #[serde(default)]
544 pub content: Option<Vec<ResponsesContent>>,
545}
546
547#[derive(Deserialize, Debug, Clone)]
549pub struct ResponsesContent {
550 #[serde(rename = "type")]
551 pub content_type: String,
552 #[serde(default)]
553 pub text: Option<String>,
554}
555
556#[derive(Deserialize, Debug, Clone)]
558pub struct ResponsesUsage {
559 #[serde(default)]
560 pub input_tokens: u32,
561 #[serde(default)]
562 pub output_tokens: u32,
563 #[serde(default)]
564 pub total_tokens: u32,
565}
566
567#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
574#[serde(rename_all = "snake_case")]
575pub enum OpenAIToolType {
576 WebSearch,
578 FileSearch,
580 CodeInterpreter,
582}
583
584#[derive(Serialize, Deserialize, Debug, Clone)]
586pub struct UserLocation {
587 #[serde(skip_serializing_if = "Option::is_none")]
589 pub country: Option<String>,
590 #[serde(skip_serializing_if = "Option::is_none")]
592 pub city: Option<String>,
593 #[serde(skip_serializing_if = "Option::is_none")]
595 pub region: Option<String>,
596 #[serde(skip_serializing_if = "Option::is_none")]
598 pub timezone: Option<String>,
599}
600
601#[derive(Serialize, Debug, Clone)]
628pub struct OpenAITool {
629 #[serde(rename = "type")]
631 pub tool_type: OpenAIToolType,
632 #[serde(skip_serializing_if = "Option::is_none")]
636 pub search_context_size: Option<String>,
637 #[serde(skip_serializing_if = "Option::is_none")]
640 pub user_location: Option<UserLocation>,
641 #[serde(skip_serializing_if = "Option::is_none")]
644 pub max_num_results: Option<u32>,
645}
646
647impl OpenAITool {
648 pub fn web_search() -> Self {
651 Self {
652 tool_type: OpenAIToolType::WebSearch,
653 search_context_size: None,
654 user_location: None,
655 max_num_results: None,
656 }
657 }
658
659 pub fn file_search() -> Self {
662 Self {
663 tool_type: OpenAIToolType::FileSearch,
664 search_context_size: None,
665 user_location: None,
666 max_num_results: None,
667 }
668 }
669
670 pub fn code_interpreter() -> Self {
673 Self {
674 tool_type: OpenAIToolType::CodeInterpreter,
675 search_context_size: None,
676 user_location: None,
677 max_num_results: None,
678 }
679 }
680
681 pub fn with_search_context_size(mut self, size: impl Into<String>) -> Self {
684 self.search_context_size = Some(size.into());
685 self
686 }
687
688 pub fn with_user_location(mut self, location: UserLocation) -> Self {
690 self.user_location = Some(location);
691 self
692 }
693
694 pub fn with_max_num_results(mut self, max_results: u32) -> Self {
696 self.max_num_results = Some(max_results);
697 self
698 }
699}
700
701#[derive(Serialize, Debug, Clone)]
721pub struct OpenAIResponsesArguments {
722 pub model: String,
723 pub input: Vec<ResponsesMessage>,
724 #[serde(skip_serializing_if = "Option::is_none")]
725 pub tools: Option<Vec<OpenAITool>>,
726 #[serde(skip_serializing_if = "Option::is_none")]
727 pub temperature: Option<f32>,
728 #[serde(skip_serializing_if = "Option::is_none")]
729 pub max_output_tokens: Option<u32>,
730}
731
732impl OpenAIResponsesArguments {
733 pub fn new(model: impl AsRef<str>, input: Vec<ResponsesMessage>) -> Self {
735 Self {
736 model: model.as_ref().to_owned(),
737 input,
738 tools: None,
739 temperature: None,
740 max_output_tokens: None,
741 }
742 }
743
744 pub fn with_tools(mut self, tools: Vec<OpenAITool>) -> Self {
746 self.tools = Some(tools);
747 self
748 }
749
750 pub fn with_temperature(mut self, temperature: f32) -> Self {
752 self.temperature = Some(temperature);
753 self
754 }
755
756 pub fn with_max_output_tokens(mut self, max_tokens: u32) -> Self {
758 self.max_output_tokens = Some(max_tokens);
759 self
760 }
761}