1use std::{borrow::Cow, collections::BTreeMap};
5
6use error_stack::Report;
7use serde::{Deserialize, Serialize};
8use serde_with::{formats::PreferMany, serde_as, OneOrMany};
9use uuid::Uuid;
10
11use crate::providers::ProviderError;
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
15pub struct ChatResponse<CHOICE> {
16 pub created: u64,
21 pub model: Option<String>,
23 pub system_fingerprint: Option<String>,
25 pub choices: Vec<CHOICE>,
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub usage: Option<UsageResponse>,
30}
31
32pub type StreamingChatResponse = ChatResponse<ChatChoiceDelta>;
34pub type SingleChatResponse = ChatResponse<ChatChoice>;
36
37impl ChatResponse<ChatChoice> {
38 pub fn new_for_collection(num_choices: usize) -> Self {
40 SingleChatResponse {
41 created: 0,
42 model: None,
43 system_fingerprint: None,
44 choices: Vec::with_capacity(num_choices),
45 usage: Some(UsageResponse {
46 prompt_tokens: None,
47 completion_tokens: None,
48 total_tokens: None,
49 }),
50 }
51 }
52
53 pub fn merge_delta(&mut self, chunk: &ChatResponse<ChatChoiceDelta>) {
55 if self.created == 0 {
56 self.created = chunk.created;
57 }
58
59 if self.model.is_none() {
60 self.model = chunk.model.clone();
61 }
62
63 if self.system_fingerprint.is_none() {
64 self.system_fingerprint = chunk.system_fingerprint.clone();
65 }
66
67 if let Some(delta_usage) = chunk.usage.as_ref() {
68 if let Some(usage) = self.usage.as_mut() {
69 usage.merge(delta_usage);
70 } else {
71 self.usage = chunk.usage.clone();
72 }
73 }
74
75 for choice in chunk.choices.iter() {
76 if choice.index >= self.choices.len() {
77 let new_size = std::cmp::max(chunk.choices.len(), choice.index + 1);
80 self.choices.resize(new_size, ChatChoice::default());
81
82 for i in 0..self.choices.len() {
83 self.choices[i].index = i;
84 }
85 }
86
87 let c = &mut self.choices[choice.index];
88 c.message.add_delta(&choice.delta);
89
90 if let Some(finish) = choice.finish_reason.as_ref() {
91 c.finish_reason = finish.clone();
92 }
93 }
94 }
95}
96
97impl From<SingleChatResponse> for StreamingChatResponse {
99 fn from(value: SingleChatResponse) -> Self {
100 ChatResponse {
101 created: value.created,
102 model: value.model,
103 system_fingerprint: value.system_fingerprint,
104 choices: value
105 .choices
106 .into_iter()
107 .map(|c| ChatChoiceDelta {
108 index: c.index,
109 delta: c.message,
110 finish_reason: Some(c.finish_reason),
111 })
112 .collect(),
113 usage: value.usage,
114 }
115 }
116}
117
118#[derive(Serialize, Deserialize, Default, Debug, Clone)]
120pub struct ChatChoice {
121 pub index: usize,
123 pub message: ChatMessage,
125 pub finish_reason: FinishReason,
127}
128
129#[derive(Serialize, Deserialize, Default, Debug, Clone)]
131pub struct ChatChoiceDelta {
132 pub index: usize,
134 pub delta: ChatMessage,
136 pub finish_reason: Option<FinishReason>,
138}
139
140#[derive(Serialize, Deserialize, Default, Debug, Clone)]
141#[serde(rename_all = "snake_case")]
142pub enum FinishReason {
143 #[default]
144 Stop,
145 Length,
146 ContentFilter,
147 ToolCalls,
148 #[serde(untagged)]
149 Other(Cow<'static, str>),
150}
151
152impl FinishReason {
153 pub fn as_str(&self) -> &str {
154 match self {
155 FinishReason::Stop => "stop",
156 FinishReason::Length => "length",
157 FinishReason::ContentFilter => "content_filter",
158 FinishReason::ToolCalls => "tool_calls",
159 FinishReason::Other(reason) => reason.as_ref(),
160 }
161 }
162}
163
164impl std::fmt::Display for FinishReason {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 write!(f, "{}", self.as_str())
167 }
168}
169
170#[derive(Serialize, Deserialize, Default, Debug, Clone)]
172pub struct ChatMessage {
173 pub role: Option<String>,
175 #[serde(skip_serializing_if = "Option::is_none")]
178 pub name: Option<String>,
179 pub content: Option<String>,
181 #[serde(default, skip_serializing_if = "Vec::is_empty")]
183 pub tool_calls: Vec<ToolCall>,
184 #[serde(skip_serializing_if = "Option::is_none")]
186 pub tool_call_id: Option<String>,
187}
188
189impl ChatMessage {
190 pub fn add_delta(&mut self, delta: &ChatMessage) {
193 if self.role.is_none() {
194 self.role = delta.role.clone();
195 }
196 if self.name.is_none() {
197 self.name = delta.name.clone();
198 }
199
200 if self.tool_call_id.is_none() {
201 self.tool_call_id = delta.tool_call_id.clone();
202 }
203
204 match (&mut self.content, &delta.content) {
205 (Some(content), Some(new_content)) => content.push_str(new_content),
206 (None, Some(new_content)) => {
207 self.content = Some(new_content.clone());
208 }
209 _ => {}
210 }
211
212 for tool_call in &delta.tool_calls {
213 let Some(index) = tool_call.index else {
214 continue;
216 };
217 if self.tool_calls.len() <= index {
218 self.tool_calls.resize(
219 index + 1,
220 ToolCall {
221 index: None,
222 id: None,
223 typ: None,
224 function: ToolCallFunction {
225 name: None,
226 arguments: None,
227 },
228 },
229 );
230 }
231
232 self.tool_calls[index].merge_delta(tool_call);
233 }
234 }
235}
236
237#[derive(Serialize, Deserialize, Default, Debug, Clone)]
239pub struct UsageResponse {
240 pub prompt_tokens: Option<usize>,
242 pub completion_tokens: Option<usize>,
244 pub total_tokens: Option<usize>,
246}
247
248impl UsageResponse {
249 pub fn is_empty(&self) -> bool {
251 self.prompt_tokens.is_none()
252 && self.completion_tokens.is_none()
253 && self.total_tokens.is_none()
254 }
255
256 pub fn merge(&mut self, other: &UsageResponse) {
259 if other.prompt_tokens.is_some() {
260 self.prompt_tokens = other.prompt_tokens;
261 }
262
263 if other.completion_tokens.is_some() {
264 self.completion_tokens = other.completion_tokens;
265 }
266
267 if other.total_tokens.is_some() {
268 self.total_tokens = other.total_tokens;
269 }
270 }
271}
272
273#[derive(Debug, Clone, Serialize)]
275pub struct RequestInfo {
276 pub id: Uuid,
278 pub provider: String,
280 pub model: String,
282 pub num_retries: u32,
284 pub was_rate_limited: bool,
286}
287
288#[derive(Debug, Clone, Serialize)]
290pub struct ResponseInfo {
291 pub meta: Option<serde_json::Value>,
293 pub model: String,
295}
296
297#[cfg_attr(test, derive(Serialize))]
299#[derive(Debug, Clone)]
300pub enum StreamingResponse {
301 RequestInfo(RequestInfo),
304 Chunk(StreamingChatResponse),
306 Single(SingleChatResponse),
308 ResponseInfo(ResponseInfo),
310}
311
312pub type StreamingResponseSender = flume::Sender<Result<StreamingResponse, Report<ProviderError>>>;
314pub type StreamingResponseReceiver =
316 flume::Receiver<Result<StreamingResponse, Report<ProviderError>>>;
317
318#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
321#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
322pub struct ChatRequestTransformation<'a> {
323 pub supports_message_name: bool,
326 pub system_in_messages: bool,
329 pub strip_model_prefix: Option<Cow<'a, str>>,
331}
332
333impl<'a> Default for ChatRequestTransformation<'a> {
334 fn default() -> Self {
336 Self {
337 supports_message_name: true,
338 system_in_messages: true,
339 strip_model_prefix: Default::default(),
340 }
341 }
342}
343
344#[serde_as]
347#[derive(Serialize, Deserialize, Debug, Clone, Default)]
348pub struct ChatRequest {
349 pub messages: Vec<ChatMessage>,
351 #[serde(skip_serializing_if = "Option::is_none")]
354 pub system: Option<String>,
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub model: Option<String>,
358 #[serde(skip_serializing_if = "Option::is_none")]
360 pub frequency_penalty: Option<f32>,
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub logit_bias: Option<BTreeMap<usize, f32>>,
364 #[serde(skip_serializing_if = "Option::is_none")]
366 pub logprobs: Option<bool>,
367 #[serde(skip_serializing_if = "Option::is_none")]
369 pub top_logprobs: Option<u8>,
370 pub max_tokens: Option<u32>,
374 #[serde(skip_serializing_if = "Option::is_none")]
376 pub n: Option<u32>,
377 #[serde(skip_serializing_if = "Option::is_none")]
379 pub presence_penalty: Option<f32>,
380 #[serde(skip_serializing_if = "Option::is_none")]
382 pub response_format: Option<serde_json::Value>,
383 #[serde(skip_serializing_if = "Option::is_none")]
385 pub seed: Option<i64>,
386 #[serde(default, skip_serializing_if = "Vec::is_empty")]
387 #[serde_as(as = "OneOrMany<_, PreferMany>")]
389 pub stop: Vec<String>,
390 #[serde(skip_serializing_if = "Option::is_none")]
392 pub temperature: Option<f32>,
393 #[serde(skip_serializing_if = "Option::is_none")]
395 pub top_p: Option<f32>,
396 #[serde(default, skip_serializing_if = "Vec::is_empty")]
398 pub tools: Vec<Tool>,
399 #[serde(skip_serializing_if = "Option::is_none")]
401 pub tool_choice: Option<serde_json::Value>,
402 #[serde(skip_serializing_if = "Option::is_none")]
404 pub user: Option<String>,
405 #[serde(default)]
407 pub stream: bool,
408 #[serde(skip_serializing_if = "Option::is_none")]
411 pub stream_options: Option<StreamOptions>,
412}
413
414#[derive(Serialize, Deserialize, Debug, Clone, Default)]
417pub struct StreamOptions {
418 pub include_usage: bool,
420}
421
422impl ChatRequest {
423 pub fn transform(&mut self, options: &ChatRequestTransformation) {
425 let stripped = options
426 .strip_model_prefix
427 .as_deref()
428 .zip(self.model.as_deref())
429 .and_then(|(prefix, model)| model.strip_prefix(prefix));
430 if let Some(stripped) = stripped {
431 self.model = Some(stripped.to_string());
432 }
433
434 if !options.supports_message_name {
435 self.merge_message_names();
436 }
437
438 if options.system_in_messages {
439 self.move_system_to_messages();
440 } else {
441 self.move_system_message_to_top_level();
442 }
443 }
444
445 pub fn merge_message_names(&mut self) {
448 for message in self.messages.iter_mut() {
449 if let Some(name) = message.name.take() {
450 message.content = message.content.as_deref().map(|c| format!("{name}: {c}"));
451 }
452 }
453 }
454
455 pub fn move_system_to_messages(&mut self) {
457 let system = self.system.take();
458 if let Some(system) = system {
459 self.messages = std::iter::once(ChatMessage {
460 role: Some("system".to_string()),
461 content: Some(system),
462 tool_calls: Vec::new(),
463 name: None,
464 tool_call_id: None,
465 })
466 .chain(self.messages.drain(..))
467 .collect();
468 }
469 }
470
471 pub fn move_system_message_to_top_level(&mut self) {
473 if self
474 .messages
475 .get(0)
476 .map(|m| m.role.as_deref().unwrap_or_default() == "system")
477 .unwrap_or(false)
478 {
479 let system = self.messages.remove(0);
480 self.system = system.content;
481 }
482 }
483}
484
485#[derive(Serialize, Deserialize, Debug, Clone)]
487pub struct Tool {
488 #[serde(rename = "type")]
490 pub typ: String,
491 pub function: FunctionTool,
493}
494
495#[derive(Serialize, Deserialize, Debug, Clone)]
497pub struct FunctionTool {
498 pub name: String,
500 pub description: Option<String>,
502 pub parameters: Option<serde_json::Value>,
504}
505
506#[derive(Serialize, Deserialize, Debug, Clone)]
508pub struct ToolCall {
509 #[serde(skip_serializing_if = "Option::is_none")]
511 pub index: Option<usize>,
512 #[serde(skip_serializing_if = "Option::is_none")]
514 pub id: Option<String>,
515 #[serde(skip_serializing_if = "Option::is_none")]
517 #[serde(rename = "type")]
518 pub typ: Option<String>,
519 pub function: ToolCallFunction,
521}
522
523impl ToolCall {
524 fn merge_delta(&mut self, delta: &ToolCall) {
526 if self.index.is_none() {
527 self.index = delta.index;
528 }
529 if self.id.is_none() {
530 self.id = delta.id.clone();
531 }
532 if self.typ.is_none() {
533 self.typ = delta.typ.clone();
534 }
535 if self.function.name.is_none() {
536 self.function.name = delta.function.name.clone();
537 }
538
539 if self.function.arguments.is_none() {
540 self.function.arguments = delta.function.arguments.clone();
541 } else if delta.function.arguments.is_some() {
542 self.function
543 .arguments
544 .as_mut()
545 .unwrap()
546 .push_str(&delta.function.arguments.as_ref().unwrap());
547 }
548 }
549}
550
551#[derive(Serialize, Deserialize, Debug, Clone)]
553pub struct ToolCallFunction {
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub name: Option<String>,
557 #[serde(skip_serializing_if = "Option::is_none")]
559 pub arguments: Option<String>,
560}
561
562#[cfg(test)]
563mod tests {
564 use super::FinishReason;
565
566 #[test]
567 fn finish_reason_serialization() {
568 let cases = vec![
569 (FinishReason::Stop, "stop"),
570 (FinishReason::Length, "length"),
571 (FinishReason::ContentFilter, "content_filter"),
572 (FinishReason::ToolCalls, "tool_calls"),
573 (FinishReason::Other("custom_reason".into()), "custom_reason"),
574 ];
575
576 for (finish_reason, expected_str) in cases {
577 let serialized = serde_json::to_value(&finish_reason).unwrap();
578 assert_eq!(serialized, serde_json::json!(expected_str));
579 }
580 }
581
582 #[test]
583 fn finish_reason_deserialization() {
584 let cases = vec![
585 ("stop", FinishReason::Stop),
586 ("length", FinishReason::Length),
587 ("content_filter", FinishReason::ContentFilter),
588 ("tool_calls", FinishReason::ToolCalls),
589 ("custom_reason", FinishReason::Other("custom_reason".into())),
590 ];
591
592 for (json_str, expected_enum) in cases {
593 let deserialized: FinishReason =
594 serde_json::from_value(serde_json::json!(json_str)).unwrap();
595 assert_eq!(deserialized.as_str(), expected_enum.as_str());
596 }
597 }
598}