1use crate::types::Message;
2use crate::types::{FunctionCallPolicy, Tool};
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ChatCompletionRequest {
8 pub model: String,
9 pub messages: Vec<Message>,
10 pub temperature: Option<f32>,
11 pub max_tokens: Option<u32>,
12 pub stream: Option<bool>,
13 pub top_p: Option<f32>,
14 pub frequency_penalty: Option<f32>,
15 pub presence_penalty: Option<f32>,
16 pub functions: Option<Vec<Tool>>,
18 pub function_call: Option<FunctionCallPolicy>,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub extensions: Option<Map<String, Value>>,
23}
24
25impl ChatCompletionRequest {
26 pub fn new(model: String, messages: Vec<Message>) -> Self {
27 Self {
28 model,
29 messages,
30 temperature: None,
31 max_tokens: None,
32 stream: None,
33 top_p: None,
34 frequency_penalty: None,
35 presence_penalty: None,
36 functions: None,
37 function_call: None,
38 extensions: None,
39 }
40 }
41
42 pub fn with_temperature(mut self, temperature: f32) -> Self {
43 self.temperature = Some(temperature);
44 self
45 }
46
47 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
48 self.max_tokens = Some(max_tokens);
49 self
50 }
51
52 pub fn with_functions(mut self, functions: Vec<Tool>) -> Self {
53 self.functions = Some(functions);
54 self
55 }
56
57 pub fn with_function_call(mut self, function_call: FunctionCallPolicy) -> Self {
58 self.function_call = Some(function_call);
59 self
60 }
61
62 pub fn with_extension(mut self, key: &str, value: serde_json::Value) -> Self {
64 let map = self.extensions.get_or_insert_with(Map::new);
65 map.insert(key.to_string(), value);
66 self
67 }
68
69 pub fn apply_extensions(&self, target: &mut serde_json::Value) {
70 if let (Some(ext), Some(obj)) = (&self.extensions, target.as_object_mut()) {
71 for (k, v) in ext {
72 obj.insert(k.clone(), v.clone());
73 }
74 }
75 }
76
77 pub fn ignore_previous(mut self) -> Self {
80 use crate::types::Role;
81 let mut new_msgs: Vec<Message> = self
82 .messages
83 .iter()
84 .filter(|m| matches!(m.role, Role::System))
85 .cloned()
86 .collect();
87 if let Some(last) = self
88 .messages
89 .iter()
90 .rev()
91 .find(|m| !matches!(m.role, Role::System))
92 {
93 new_msgs.push(last.clone());
94 }
95 self.messages = new_msgs;
96 self
97 }
98}