1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 logprobs::{Logprob, Logprobs},
8 Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13 pub role: Cow<'a, str>,
14 pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18 pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19 Self {
20 role: role.into(),
21 content: content.into(),
22 }
23 }
24 pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25 Self::new("user", content)
26 }
27 pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28 Self::new("assistant", content)
29 }
30 pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31 Self::new("system", content)
32 }
33}
34
35pub struct TaskChat<'a> {
36 pub messages: Vec<Message<'a>>,
38 pub stopping: Stopping<'a>,
40 pub sampling: ChatSampling,
42 pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48 pub fn with_message(message: Message<'a>) -> Self {
51 Self::with_messages(vec![message])
52 }
53
54 pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57 TaskChat {
58 messages,
59 sampling: ChatSampling::default(),
60 stopping: Stopping::default(),
61 logprobs: Logprobs::No,
62 }
63 }
64
65 pub fn push_message(mut self, message: Message<'a>) -> Self {
67 self.messages.push(message);
68 self
69 }
70
71 pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73 self.stopping.maximum_tokens = Some(maximum_tokens);
74 self
75 }
76}
77
78pub struct ChatSampling {
81 pub temperature: Option<f64>,
85 pub top_p: Option<f64>,
89 pub frequency_penalty: Option<f64>,
94 pub presence_penalty: Option<f64>,
105}
106
107impl ChatSampling {
108 pub const MOST_LIKELY: Self = ChatSampling {
111 temperature: None,
112 top_p: None,
113 frequency_penalty: None,
114 presence_penalty: None,
115 };
116}
117
118impl Default for ChatSampling {
119 fn default() -> Self {
120 Self::MOST_LIKELY
121 }
122}
123
124#[derive(Debug, PartialEq, Deserialize)]
125pub struct Usage {
126 pub prompt_tokens: u32,
127 pub completion_tokens: u32,
128}
129
130#[derive(Debug, PartialEq)]
131pub struct ChatOutput {
132 pub message: Message<'static>,
133 pub finish_reason: String,
134 pub logprobs: Vec<Distribution>,
137 pub usage: Usage,
138}
139
140impl ChatOutput {
141 pub fn new(
142 message: Message<'static>,
143 finish_reason: String,
144 logprobs: Vec<Distribution>,
145 usage: Usage,
146 ) -> Self {
147 Self {
148 message,
149 finish_reason,
150 logprobs,
151 usage,
152 }
153 }
154}
155
156#[derive(Deserialize, Debug, PartialEq)]
157pub struct ResponseChoice {
158 pub message: Message<'static>,
159 pub finish_reason: String,
160 pub logprobs: Option<LogprobContent>,
161}
162
163#[derive(Deserialize, Debug, PartialEq, Default)]
164pub struct LogprobContent {
165 content: Vec<Distribution>,
166}
167
168#[derive(Deserialize, Debug, PartialEq)]
170pub struct Distribution {
171 #[serde(flatten)]
173 pub sampled: Logprob,
174 #[serde(rename = "top_logprobs")]
176 pub top: Vec<Logprob>,
177}
178
179#[derive(Deserialize, Debug, PartialEq)]
180pub struct ResponseChat {
181 choices: Vec<ResponseChoice>,
182 usage: Usage,
183}
184
185#[derive(Serialize)]
186struct ChatBody<'a> {
187 pub model: &'a str,
189 messages: &'a [Message<'a>],
191 #[serde(skip_serializing_if = "Option::is_none")]
193 pub max_tokens: Option<u32>,
194 #[serde(skip_serializing_if = "<[_]>::is_empty")]
195 pub stop: &'a [&'a str],
196 #[serde(skip_serializing_if = "Option::is_none")]
200 pub temperature: Option<f64>,
201 #[serde(skip_serializing_if = "Option::is_none")]
204 pub top_p: Option<f64>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub frequency_penalty: Option<f64>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub presence_penalty: Option<f64>,
209 #[serde(skip_serializing_if = "std::ops::Not::not")]
211 pub stream: bool,
212 #[serde(skip_serializing_if = "std::ops::Not::not")]
213 pub logprobs: bool,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub top_logprobs: Option<u8>,
216}
217
218impl<'a> ChatBody<'a> {
219 pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
220 let TaskChat {
221 messages,
222 stopping:
223 Stopping {
224 maximum_tokens,
225 stop_sequences,
226 },
227 sampling:
228 ChatSampling {
229 temperature,
230 top_p,
231 frequency_penalty,
232 presence_penalty,
233 },
234 logprobs,
235 } = task;
236
237 Self {
238 model,
239 messages,
240 max_tokens: *maximum_tokens,
241 stop: stop_sequences,
242 temperature: *temperature,
243 top_p: *top_p,
244 frequency_penalty: *frequency_penalty,
245 presence_penalty: *presence_penalty,
246 stream: false,
247 logprobs: logprobs.logprobs(),
248 top_logprobs: logprobs.top_logprobs(),
249 }
250 }
251
252 pub fn with_streaming(mut self) -> Self {
253 self.stream = true;
254 self
255 }
256}
257
258impl Task for TaskChat<'_> {
259 type Output = ChatOutput;
260
261 type ResponseBody = ResponseChat;
262
263 fn build_request(
264 &self,
265 client: &reqwest::Client,
266 base: &str,
267 model: &str,
268 ) -> reqwest::RequestBuilder {
269 let body = ChatBody::new(model, self);
270 client.post(format!("{base}/chat/completions")).json(&body)
271 }
272
273 fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
274 let ResponseChoice {
275 message,
276 finish_reason,
277 logprobs,
278 } = response.choices.pop().unwrap();
279 ChatOutput::new(
280 message,
281 finish_reason,
282 logprobs.unwrap_or_default().content,
283 response.usage,
284 )
285 }
286}
287
288#[derive(Deserialize)]
289pub struct StreamMessage {
290 pub role: Option<String>,
293 pub content: String,
296}
297
298#[derive(Deserialize)]
300pub struct ChatStreamChunk {
301 pub finish_reason: Option<String>,
304 pub delta: StreamMessage,
306}
307
308#[derive(Deserialize)]
311pub struct ChatEvent {
312 pub choices: Vec<ChatStreamChunk>,
313}
314
315impl StreamTask for TaskChat<'_> {
316 type Output = ChatStreamChunk;
317
318 type ResponseBody = ChatEvent;
319
320 fn build_request(
321 &self,
322 client: &reqwest::Client,
323 base: &str,
324 model: &str,
325 ) -> reqwest::RequestBuilder {
326 let body = ChatBody::new(model, self).with_streaming();
327 client.post(format!("{base}/chat/completions")).json(&body)
328 }
329
330 fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
331 response
334 .choices
335 .pop()
336 .expect("There must always be at least one choice")
337 }
338}
339
340impl Logprobs {
341 pub fn logprobs(self) -> bool {
343 match self {
344 Logprobs::No => false,
345 Logprobs::Sampled | Logprobs::Top(_) => true,
346 }
347 }
348
349 pub fn top_logprobs(self) -> Option<u8> {
351 match self {
352 Logprobs::No | Logprobs::Sampled => None,
353 Logprobs::Top(n) => Some(n),
354 }
355 }
356}