1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5 JsonObject,
6 Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11 pub quality: Option<String>, pub size: Option<String>, pub output_format: Option<String>, }
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18 pub model: String,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub temperature: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_p: Option<f32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub n: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub max_tokens: Option<u32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub presence_penalty: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub frequency_penalty: Option<f32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub user: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub response_format: Option<ResponseFormat>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub image_generation: Option<ImageGeneration>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub search_parameters: Option<SearchParameters>, }
45
46impl ChatArguments {
47 pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
48 ChatArguments {
49 model: model.as_ref().to_owned(),
50 messages,
51 temperature: None,
52 top_p: None,
53 n: None,
54 stream: None,
55 stop: None,
56 max_tokens: None,
57 presence_penalty: None,
58 frequency_penalty: None,
59 user: None,
60 response_format: None,
61 image_generation: None,
62 search_parameters: None, }
64 }
65
66 pub fn with_search_parameters(mut self, params: SearchParameters) -> Self {
67 self.search_parameters = Some(params);
68 self
69 }
70}
71
72#[derive(Deserialize, Debug, Clone)]
73pub struct ChatCompletion {
74 #[serde(default)]
75 pub id: Option<String>,
76 pub created: u32,
77 #[serde(default)]
78 pub model: Option<String>,
79 #[serde(default)]
80 pub object: Option<String>,
81 pub choices: Vec<Choice>,
82 pub usage: Usage,
83}
84
85impl std::fmt::Display for ChatCompletion {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", &self.choices[0].message.content)?;
88 Ok(())
89 }
90}
91
92pub mod stream {
93 use bytes::Bytes;
94 use futures_util::Stream;
95 use serde::Deserialize;
96 use std::pin::Pin;
97 use std::str;
98 use std::task::Poll;
99
100 #[derive(Deserialize, Debug, Clone)]
101 pub struct ChatCompletionChunk {
102 pub id: String,
103 pub created: u32,
104 pub model: String,
105 pub choices: Vec<Choice>,
106 pub system_fingerprint: Option<String>,
107 }
108
109 impl std::fmt::Display for ChatCompletionChunk {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(
112 f,
113 "{}",
114 self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
115 )?;
116 Ok(())
117 }
118 }
119
120 #[derive(Deserialize, Debug, Clone)]
121 pub struct Choice {
122 pub delta: ChoiceDelta,
123 pub index: u32,
124 pub finish_reason: Option<String>,
125 }
126
127 #[derive(Deserialize, Debug, Clone)]
128 pub struct ChoiceDelta {
129 pub content: Option<String>,
130 }
131
132 pub struct ChatCompletionChunkStream {
133 byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
134 buf: String,
135 }
136
137 impl ChatCompletionChunkStream {
138 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
139 Self {
140 byte_stream: stream,
141 buf: String::new(),
142 }
143 }
144
145 fn deserialize_buf(
146 self: Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 ) -> Option<anyhow::Result<ChatCompletionChunk>> {
149 let bufclone = self.buf.clone();
150 let mut chunks = bufclone.split("\n\n").peekable();
151 let first = chunks.next();
152 let second = chunks.peek();
153
154 match first {
155 Some(first) => match first.strip_prefix("data: ") {
156 Some(chunk) => {
157 if !chunk.ends_with("}") {
158 None
159 } else {
160 if let Some(second) = second {
161 if second.ends_with("}") {
162 cx.waker().wake_by_ref();
163 }
164 }
165 self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
166 Some(
167 serde_json::from_str::<ChatCompletionChunk>(&chunk)
168 .map_err(|e| anyhow::anyhow!(e)),
169 )
170 }
171 }
172 None => None,
173 },
174 None => None,
175 }
176 }
177 }
178
179 impl Stream for ChatCompletionChunkStream {
180 type Item = anyhow::Result<ChatCompletionChunk>;
181
182 fn poll_next(
183 mut self: Pin<&mut Self>,
184 cx: &mut std::task::Context<'_>,
185 ) -> Poll<Option<Self::Item>> {
186 match self.as_mut().deserialize_buf(cx) {
187 Some(chunk) => return Poll::Ready(Some(chunk)),
188 None => {}
189 };
190
191 match self.byte_stream.as_mut().poll_next(cx) {
192 Poll::Ready(bytes_option) => match bytes_option {
193 Some(bytes_result) => match bytes_result {
194 Ok(bytes) => {
195 let data = str::from_utf8(&bytes)?.to_owned();
196 self.buf = self.buf.clone() + &data;
197 match self.deserialize_buf(cx) {
198 Some(chunk) => Poll::Ready(Some(chunk)),
199 None => {
200 cx.waker().wake_by_ref();
201 Poll::Pending
202 }
203 }
204 }
205 Err(e) => Poll::Ready(Some(Err(e.into()))),
206 },
207 None => Poll::Ready(None),
208 },
209 Poll::Pending => Poll::Pending,
210 }
211 }
212 }
213}
214
215#[derive(Deserialize, Debug, Clone)]
216pub struct Usage {
217 pub prompt_tokens: u32,
218 pub completion_tokens: u32,
219 pub total_tokens: u32,
220}
221
222#[derive(Deserialize, Debug, Clone)]
223pub struct Choice {
224 #[serde(default)]
225 pub index: Option<u32>,
226 pub message: Message,
227 pub finish_reason: String,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231pub struct Message {
232 pub role: String,
233 pub content: String,
234}
235
236pub enum Role {
237 System,
238 Assistant,
239 User,
240}
241
242#[derive(Serialize, Debug, Clone)]
244pub struct SearchParameters {
245 pub mode: SearchMode, #[serde(skip_serializing_if = "Option::is_none")]
247 pub return_citations: Option<bool>,
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub from_date: Option<String>,
251 #[serde(skip_serializing_if = "Option::is_none")]
253 pub to_date: Option<String>,
254}
255
256#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
257#[serde(rename_all = "lowercase")] pub enum SearchMode {
259 On,
260 Off,
261 Auto,
262}
263
264impl SearchParameters {
265 pub fn new(mode: SearchMode) -> Self {
266 Self {
267 mode,
268 return_citations: None,
269 from_date: None,
270 to_date: None,
271 }
272 }
273 pub fn with_citations(mut self, yes: bool) -> Self {
274 self.return_citations = Some(yes);
275 self
276 }
277 pub fn with_date_range_str(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
278 self.from_date = Some(from.into());
279 self.to_date = Some(to.into());
280 self
281 }
282}