1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    logprobs::{Logprob, Logprobs},
8    Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13    pub role: Cow<'a, str>,
14    pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18    pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19        Self {
20            role: role.into(),
21            content: content.into(),
22        }
23    }
24    pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25        Self::new("user", content)
26    }
27    pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28        Self::new("assistant", content)
29    }
30    pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31        Self::new("system", content)
32    }
33}
34
35pub struct TaskChat<'a> {
36    pub messages: Vec<Message<'a>>,
38    pub stopping: Stopping<'a>,
40    pub sampling: ChatSampling,
42    pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48    pub fn with_message(message: Message<'a>) -> Self {
51        Self::with_messages(vec![message])
52    }
53
54    pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57        TaskChat {
58            messages,
59            sampling: ChatSampling::default(),
60            stopping: Stopping::default(),
61            logprobs: Logprobs::No,
62        }
63    }
64
65    pub fn push_message(mut self, message: Message<'a>) -> Self {
67        self.messages.push(message);
68        self
69    }
70
71    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73        self.stopping.maximum_tokens = Some(maximum_tokens);
74        self
75    }
76
77    pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
79        self.logprobs = logprobs;
80        self
81    }
82}
83
84pub struct ChatSampling {
87    pub temperature: Option<f64>,
91    pub top_p: Option<f64>,
95    pub frequency_penalty: Option<f64>,
100    pub presence_penalty: Option<f64>,
111}
112
113impl ChatSampling {
114    pub const MOST_LIKELY: Self = ChatSampling {
117        temperature: None,
118        top_p: None,
119        frequency_penalty: None,
120        presence_penalty: None,
121    };
122}
123
124impl Default for ChatSampling {
125    fn default() -> Self {
126        Self::MOST_LIKELY
127    }
128}
129
130#[derive(Debug, PartialEq, Deserialize)]
131pub struct Usage {
132    pub prompt_tokens: u32,
133    pub completion_tokens: u32,
134}
135
136#[derive(Debug, PartialEq)]
137pub struct ChatOutput {
138    pub message: Message<'static>,
139    pub finish_reason: String,
140    pub logprobs: Vec<Distribution>,
143    pub usage: Usage,
144}
145
146impl ChatOutput {
147    pub fn new(
148        message: Message<'static>,
149        finish_reason: String,
150        logprobs: Vec<Distribution>,
151        usage: Usage,
152    ) -> Self {
153        Self {
154            message,
155            finish_reason,
156            logprobs,
157            usage,
158        }
159    }
160}
161
162#[derive(Deserialize, Debug, PartialEq)]
163pub struct ResponseChoice {
164    pub message: Message<'static>,
165    pub finish_reason: String,
166    pub logprobs: Option<LogprobContent>,
167}
168
169#[derive(Deserialize, Debug, PartialEq, Default)]
170pub struct LogprobContent {
171    content: Vec<Distribution>,
172}
173
174#[derive(Deserialize, Debug, PartialEq)]
176pub struct Distribution {
177    #[serde(flatten)]
179    pub sampled: Logprob,
180    #[serde(rename = "top_logprobs")]
182    pub top: Vec<Logprob>,
183}
184
185#[derive(Deserialize, Debug, PartialEq)]
186pub struct ChatResponse {
187    choices: Vec<ResponseChoice>,
188    usage: Usage,
189}
190
191#[derive(Serialize)]
193struct StreamOptions {
194    include_usage: bool,
198}
199
200#[derive(Serialize)]
201struct ChatBody<'a> {
202    pub model: &'a str,
204    messages: &'a [Message<'a>],
206    #[serde(skip_serializing_if = "Option::is_none")]
208    pub max_tokens: Option<u32>,
209    #[serde(skip_serializing_if = "<[_]>::is_empty")]
210    pub stop: &'a [&'a str],
211    #[serde(skip_serializing_if = "Option::is_none")]
215    pub temperature: Option<f64>,
216    #[serde(skip_serializing_if = "Option::is_none")]
219    pub top_p: Option<f64>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub frequency_penalty: Option<f64>,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub presence_penalty: Option<f64>,
224    #[serde(skip_serializing_if = "std::ops::Not::not")]
226    pub stream: bool,
227    #[serde(skip_serializing_if = "std::ops::Not::not")]
228    pub logprobs: bool,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub top_logprobs: Option<u8>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub stream_options: Option<StreamOptions>,
233}
234
235impl<'a> ChatBody<'a> {
236    pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
237        let TaskChat {
238            messages,
239            stopping:
240                Stopping {
241                    maximum_tokens,
242                    stop_sequences,
243                },
244            sampling:
245                ChatSampling {
246                    temperature,
247                    top_p,
248                    frequency_penalty,
249                    presence_penalty,
250                },
251            logprobs,
252        } = task;
253
254        Self {
255            model,
256            messages,
257            max_tokens: *maximum_tokens,
258            stop: stop_sequences,
259            temperature: *temperature,
260            top_p: *top_p,
261            frequency_penalty: *frequency_penalty,
262            presence_penalty: *presence_penalty,
263            stream: false,
264            logprobs: logprobs.logprobs(),
265            top_logprobs: logprobs.top_logprobs(),
266            stream_options: None,
267        }
268    }
269
270    pub fn with_streaming(mut self) -> Self {
271        self.stream = true;
272        self.stream_options = Some(StreamOptions {
275            include_usage: true,
276        });
277        self
278    }
279}
280
281impl Task for TaskChat<'_> {
282    type Output = ChatOutput;
283
284    type ResponseBody = ChatResponse;
285
286    fn build_request(
287        &self,
288        client: &reqwest::Client,
289        base: &str,
290        model: &str,
291    ) -> reqwest::RequestBuilder {
292        let body = ChatBody::new(model, self);
293        client.post(format!("{base}/chat/completions")).json(&body)
294    }
295
296    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
297        let ResponseChoice {
298            message,
299            finish_reason,
300            logprobs,
301        } = response.choices.pop().unwrap();
302        ChatOutput::new(
303            message,
304            finish_reason,
305            logprobs.unwrap_or_default().content,
306            response.usage,
307        )
308    }
309}
310
311#[derive(Debug, Deserialize)]
312pub struct StreamMessage {
313    pub role: Option<String>,
316    pub content: String,
319}
320
321#[derive(Debug, Deserialize)]
323#[serde(untagged)]
324pub enum DeserializedChatChunk {
325    Delta {
326        delta: StreamMessage,
328        logprobs: Option<LogprobContent>,
329    },
330    Finished {
332        finish_reason: String,
334    },
335}
336
337#[derive(Deserialize)]
343pub struct StreamChatResponse {
344    pub choices: Vec<DeserializedChatChunk>,
345    pub usage: Option<Usage>,
346}
347
348#[derive(Debug)]
349pub enum ChatEvent {
350    Delta {
351        content: String,
354        logprobs: Vec<Distribution>,
356    },
357    Finished {
359        reason: String,
361    },
362    Summary { usage: Usage },
364}
365
366impl StreamTask for TaskChat<'_> {
367    type Output = ChatEvent;
368
369    type ResponseBody = StreamChatResponse;
370
371    fn build_request(
372        &self,
373        client: &reqwest::Client,
374        base: &str,
375        model: &str,
376    ) -> reqwest::RequestBuilder {
377        let body = ChatBody::new(model, self).with_streaming();
378        client.post(format!("{base}/chat/completions")).json(&body)
379    }
380
381    fn body_to_output(&self, mut response: Self::ResponseBody) -> Option<Self::Output> {
382        if let Some(usage) = response.usage {
383            Some(ChatEvent::Summary { usage })
384        } else {
385            let chunk = response
388                .choices
389                .pop()
390                .expect("There must always be at least one choice");
391
392            match chunk {
393                DeserializedChatChunk::Delta {
395                    delta: StreamMessage { role: Some(_), .. },
396                    ..
397                } => None,
398                DeserializedChatChunk::Delta {
399                    delta:
400                        StreamMessage {
401                            role: None,
402                            content,
403                        },
404                    logprobs,
405                } => Some(ChatEvent::Delta {
406                    content,
407                    logprobs: logprobs.unwrap_or_default().content,
408                }),
409                DeserializedChatChunk::Finished { finish_reason } => Some(ChatEvent::Finished {
410                    reason: finish_reason,
411                }),
412            }
413        }
414    }
415}
416
417impl Logprobs {
418    pub fn logprobs(self) -> bool {
420        match self {
421            Logprobs::No => false,
422            Logprobs::Sampled | Logprobs::Top(_) => true,
423        }
424    }
425
426    pub fn top_logprobs(self) -> Option<u8> {
428        match self {
429            Logprobs::No | Logprobs::Sampled => None,
430            Logprobs::Top(n) => Some(n),
431        }
432    }
433}