openai_gpt_rs/
args.rs

1use crate::{
2    chat::Message,
3    models::{ChatModels, CompletionModels, EditModels},
4};
5
6#[derive(Debug)]
7pub struct CompletionArgs {
8    pub prompt: String,
9    pub model: String,
10    pub suffix: String,
11    pub max_tokens: u32,
12    pub n: usize,
13    pub temperature: f64,
14}
15
16impl Default for CompletionArgs {
17    /// Create a new CompletionArgs struct with default values
18    fn default() -> Self {
19        Self {
20            prompt: "".to_string(),
21            model: "text-davinci-003".to_string(),
22            suffix: "".to_string(),
23            max_tokens: 32,
24            n: 1,
25            temperature: 1.0,
26        }
27    }
28}
29
30impl CompletionArgs {
31    /// Set the prompt for the completion
32    pub fn prompt<T>(&mut self, prompt: T) -> &mut Self
33    where
34        T: ToString,
35    {
36        self.prompt = prompt.to_string();
37        self
38    }
39
40    /// Set the model to use for the completion
41    pub fn model(&mut self, model: CompletionModels) -> &mut Self {
42        self.model = model.to_string();
43        self
44    }
45
46    /// Set the suffix for the completion
47    pub fn suffix<T>(&mut self, suffix: T) -> &mut Self
48    where
49        T: ToString,
50    {
51        self.suffix = suffix.to_string();
52        self
53    }
54
55    /// Set the maximum number of tokens for the completion
56    pub fn max_tokens(&mut self, max_tokens: u32) -> &mut Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60
61    /// Set the number of completions to return
62    pub fn n(&mut self, n: usize) -> &mut Self {
63        self.n = n;
64        self
65    }
66
67    /// Set the temperature for the completion
68    pub fn temperature(&mut self, temperature: f64) -> &mut Self {
69        self.temperature = temperature;
70        self
71    }
72}
73
74#[derive(Debug)]
75pub struct EditArgs {
76    pub model: String,
77    pub input: String,
78    pub instruction: String,
79    pub n: usize,
80    pub temperature: f64,
81    pub top_p: f64,
82}
83
84impl Default for EditArgs {
85    /// Create a new EditArgs struct with default values
86    fn default() -> Self {
87        Self {
88            model: "text-davinci-edit-001".to_string(),
89            input: "".to_string(),
90            instruction: "".to_string(),
91            n: 1,
92            temperature: 1.0,
93            top_p: 1.0,
94        }
95    }
96}
97
98impl EditArgs {
99    /// Set the model to use for the edit
100    pub fn model(&mut self, model: EditModels) -> &mut Self {
101        self.model = model.to_string();
102        self
103    }
104
105    /// Set the input for the edit
106    pub fn input<T>(&mut self, input: T) -> &mut Self
107    where
108        T: ToString,
109    {
110        self.input = input.to_string();
111        self
112    }
113
114    /// Set the instruction for the edit
115    pub fn instruction<T>(&mut self, instruction: T) -> &mut Self
116    where
117        T: ToString,
118    {
119        self.instruction = instruction.to_string();
120        self
121    }
122
123    /// Set the number of edits to return
124    pub fn n(&mut self, n: usize) -> &mut Self {
125        self.n = n;
126        self
127    }
128
129    /// Set the temperature for the edit
130    pub fn temperature(&mut self, temperature: f64) -> &mut Self {
131        self.temperature = temperature;
132        self
133    }
134
135    /// Set the top_p for the edit
136    pub fn top_p(&mut self, top_p: f64) -> &mut Self {
137        self.top_p = top_p;
138        self
139    }
140}
141
142pub enum ImageSize {
143    Small,
144    Medium,
145    Big,
146}
147
148impl std::fmt::Display for ImageSize {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        let str = match self {
151            Self::Small => "256x256",
152            Self::Medium => "512x512",
153            Self::Big => "1024x1024",
154        };
155
156        write!(f, "{}", str)
157    }
158}
159
160pub enum ImageResponseFormat {
161    Url,
162    B64Json,
163}
164
165impl std::fmt::Display for ImageResponseFormat {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        let str = match self {
168            Self::Url => "url",
169            Self::B64Json => "b64json",
170        };
171
172        write!(f, "{}", str)
173    }
174}
175
176#[derive(Debug)]
177pub struct ImageArgs {
178    pub prompt: String,
179    pub n: usize,
180    pub size: String,
181    pub response_format: String,
182}
183
184impl Default for ImageArgs {
185    /// Create a new ImageArgs struct with default values
186    fn default() -> Self {
187        Self {
188            prompt: "".to_string(),
189            n: 1,
190            size: ImageSize::Medium.to_string(),
191            response_format: ImageResponseFormat::Url.to_string(),
192        }
193    }
194}
195
196impl ImageArgs {
197    /// Set the prompt for the image
198    pub fn prompt<T>(&mut self, prompt: T) -> &mut Self
199    where
200        T: ToString,
201    {
202        self.prompt = prompt.to_string();
203        self
204    }
205
206    /// Set the number of images to return
207    pub fn n(&mut self, n: usize) -> &mut Self {
208        self.n = n;
209        self
210    }
211
212    /// Set the size of the images to return
213    pub fn size(&mut self, size: ImageSize) -> &mut Self {
214        self.size = size.to_string();
215        self
216    }
217
218    /// Set the response format for the images to return
219    pub fn response_format(&mut self, response_format: ImageResponseFormat) -> &mut Self {
220        self.response_format = response_format.to_string();
221        self
222    }
223}
224
225#[derive(Debug)]
226pub struct ChatArgs {
227    pub model: String,
228    pub messages: Vec<Message>,
229    pub n: i32,
230    pub temperature: f64,
231    pub top_p: f64,
232    pub max_tokens: u32,
233    pub presence_penalty: f64,
234    pub frequency_penalty: f64,
235}
236
237impl Default for ChatArgs {
238    /// Default chat arguments
239    fn default() -> Self {
240        Self {
241            model: "gpt-3.5-turbo".to_string(),
242            messages: vec![],
243            n: 1,
244            temperature: 1.0,
245            top_p: 1.0,
246            max_tokens: 32,
247            presence_penalty: 0.0,
248            frequency_penalty: 0.0,
249        }
250    }
251}
252
253impl ChatArgs {
254    /// Set the model to use
255    pub fn model(&mut self, model: ChatModels) -> &mut Self {
256        self.model = model.to_string();
257        self
258    }
259
260    /// Set the messages to use
261    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
262        self.messages = messages;
263        self
264    }
265
266    /// Set the number of messages to return
267    pub fn n(&mut self, n: i32) -> &mut Self {
268        self.n = n;
269        self
270    }
271
272    /// Set the temperature to use
273    pub fn temperature(&mut self, temperature: f64) -> &mut Self {
274        self.temperature = temperature;
275        self
276    }
277
278    /// Set the top_p to use
279    pub fn top_p(&mut self, top_p: f64) -> &mut Self {
280        self.top_p = top_p;
281        self
282    }
283
284    /// Set the max_tokens to use
285    pub fn max_tokens(&mut self, max_tokens: u32) -> &mut Self {
286        self.max_tokens = max_tokens;
287        self
288    }
289
290    /// Set the presence_penalty to use
291    pub fn presence_penalty(&mut self, presence_penalty: f64) -> &mut Self {
292        self.presence_penalty = presence_penalty;
293        self
294    }
295
296    /// Set the frequency_penalty to use
297    pub fn frequency_penalty(&mut self, frequency_penalty: f64) -> &mut Self {
298        self.frequency_penalty = frequency_penalty;
299        self
300    }
301}