1use crate::{IntoRequest, ToSchema};
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumMessage, EnumString, EnumVariantNames};
6
7#[derive(Debug, Clone, Serialize, Builder)]
8pub struct ChatCompletionRequest {
9 #[builder(setter(into))]
11 messages: Vec<ChatCompletionMessage>,
12 #[builder(default)]
14 model: ChatCompleteModel,
15 #[builder(default, setter(strip_option))]
17 #[serde(skip_serializing_if = "Option::is_none")]
18 frequency_penalty: Option<f32>,
19
20 #[builder(default, setter(strip_option))]
26 #[serde(skip_serializing_if = "Option::is_none")]
27 max_tokens: Option<usize>,
28 #[builder(default, setter(strip_option))]
30 #[serde(skip_serializing_if = "Option::is_none")]
31 n: Option<usize>,
32 #[builder(default, setter(strip_option))]
34 #[serde(skip_serializing_if = "Option::is_none")]
35 presence_penalty: Option<f32>,
36 #[builder(default, setter(strip_option))]
38 #[serde(skip_serializing_if = "Option::is_none")]
39 response_format: Option<ChatResponseFormatObject>,
40 #[builder(default, setter(strip_option))]
42 #[serde(skip_serializing_if = "Option::is_none")]
43 seed: Option<usize>,
44 #[builder(default, setter(strip_option))]
47 #[serde(skip_serializing_if = "Option::is_none")]
48 stop: Option<String>,
49 #[builder(default, setter(strip_option))]
51 #[serde(skip_serializing_if = "Option::is_none")]
52 stream: Option<bool>,
53 #[builder(default, setter(strip_option))]
55 #[serde(skip_serializing_if = "Option::is_none")]
56 temperature: Option<f32>,
57 #[builder(default, setter(strip_option))]
59 #[serde(skip_serializing_if = "Option::is_none")]
60 top_p: Option<f32>,
61 #[builder(default, setter(into))]
63 #[serde(skip_serializing_if = "Vec::is_empty")]
64 tools: Vec<Tool>,
65 #[builder(default, setter(strip_option))]
67 #[serde(skip_serializing_if = "Option::is_none")]
68 tool_choice: Option<ToolChoice>,
69 #[builder(default, setter(strip_option, into))]
71 #[serde(skip_serializing_if = "Option::is_none")]
72 user: Option<String>,
73}
74
75#[derive(
76 Debug, Clone, Default, PartialEq, Eq, Serialize, EnumString, Display, EnumVariantNames,
77)]
78#[serde(rename_all = "snake_case")]
79pub enum ToolChoice {
80 #[default]
81 None,
82 Auto,
83 Function {
85 name: String,
86 },
87}
88
89#[derive(Debug, Clone, Serialize)]
90pub struct Tool {
91 r#type: ToolType,
93 function: FunctionInfo,
95}
96
97#[derive(Debug, Clone, Serialize)]
98pub struct FunctionInfo {
99 description: String,
101 name: String,
103 parameters: serde_json::Value,
105}
106
107#[derive(Debug, Clone, Serialize)]
108pub struct ChatResponseFormatObject {
109 r#type: ChatResponseFormat,
110}
111
112#[derive(
113 Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, EnumString, Display, EnumVariantNames,
114)]
115#[serde(rename_all = "snake_case")]
116pub enum ChatResponseFormat {
117 Text,
118 #[default]
119 Json,
120}
121
122#[derive(Debug, Clone, Serialize, Display, EnumVariantNames, EnumMessage)]
123#[serde(rename_all = "snake_case", tag = "role")]
124pub enum ChatCompletionMessage {
125 System(SystemMessage),
127 User(UserMessage),
129 Assistant(AssistantMessage),
131 Tool(ToolMessage),
133}
134
135#[derive(
136 Debug,
137 Clone,
138 Copy,
139 Default,
140 PartialEq,
141 Eq,
142 Serialize,
143 Deserialize,
144 EnumString,
145 EnumIter,
146 Display,
147 EnumVariantNames,
148 EnumMessage,
149)]
150
151pub enum ChatCompleteModel {
152 #[default]
154 #[serde(rename = "gpt-3.5-turbo-1106")]
155 #[strum(serialize = "gpt-3.5-turbo")]
156 Gpt3Turbo,
157 #[serde(rename = "gpt-3.5-turbo-instruct")]
159 #[strum(serialize = "gpt-3.5-turbo-instruct")]
160 Gpt3TurboInstruct,
161 #[serde(rename = "gpt-4-1106-preview")]
163 #[strum(serialize = "gpt-4-turbo")]
164 Gpt4Turbo,
165 #[serde(rename = "gpt-4-1106-vision-preview")]
167 #[strum(serialize = "gpt-4-turbo-vision")]
168 Gpt4TurboVision,
169}
170
171#[derive(Debug, Clone, Serialize)]
172pub struct SystemMessage {
173 content: String,
175 #[serde(skip_serializing_if = "Option::is_none")]
177 name: Option<String>,
178}
179
180#[derive(Debug, Clone, Serialize)]
181pub struct UserMessage {
182 content: String,
184 #[serde(skip_serializing_if = "Option::is_none")]
186 name: Option<String>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct AssistantMessage {
191 #[serde(default)]
193 pub content: Option<String>,
194 #[serde(skip_serializing_if = "Option::is_none", default)]
196 pub name: Option<String>,
197 #[serde(skip_serializing_if = "Vec::is_empty", default)]
199 pub tool_calls: Vec<ToolCall>,
200}
201
202#[derive(Debug, Clone, Serialize)]
203pub struct ToolMessage {
204 content: String,
206 tool_call_id: String,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ToolCall {
212 pub id: String,
214 pub r#type: ToolType,
216 pub function: FunctionCall,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct FunctionCall {
222 pub name: String,
224 pub arguments: String,
226}
227
228#[derive(
229 Debug,
230 Clone,
231 Copy,
232 PartialEq,
233 Eq,
234 Default,
235 Serialize,
236 Deserialize,
237 EnumString,
238 Display,
239 EnumVariantNames,
240)]
241#[serde(rename_all = "snake_case")]
242pub enum ToolType {
243 #[default]
244 Function,
245}
246
247#[derive(Debug, Clone, Deserialize)]
248pub struct ChatCompletionResponse {
249 pub id: String,
251 pub choices: Vec<ChatCompletionChoice>,
253 pub created: usize,
255 pub model: ChatCompleteModel,
257 pub system_fingerprint: String,
259 pub object: String,
261 pub usage: ChatCompleteUsage,
263}
264
265#[derive(Debug, Clone, Deserialize)]
266pub struct ChatCompletionChoice {
267 pub finish_reason: FinishReason,
269 pub index: usize,
271 pub message: AssistantMessage,
273}
274
275#[derive(Debug, Clone, Deserialize)]
276pub struct ChatCompleteUsage {
277 pub completion_tokens: usize,
279 pub prompt_tokens: usize,
281 pub total_tokens: usize,
283}
284
285#[derive(
286 Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, EnumString, Display, EnumVariantNames,
287)]
288#[serde(rename_all = "snake_case")]
289pub enum FinishReason {
290 #[default]
291 Stop,
292 Length,
293 ContentFilter,
294 ToolCalls,
295}
296
297impl IntoRequest for ChatCompletionRequest {
298 fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
299 let url = format!("{}/chat/completions", base_url);
300 client.post(url).json(&self)
301 }
302}
303
304impl ChatCompletionRequest {
305 pub fn new(model: ChatCompleteModel, messages: impl Into<Vec<ChatCompletionMessage>>) -> Self {
306 ChatCompletionRequestBuilder::default()
307 .model(model)
308 .messages(messages)
309 .build()
310 .unwrap()
311 }
312
313 pub fn new_with_tools(
314 model: ChatCompleteModel,
315 messages: impl Into<Vec<ChatCompletionMessage>>,
316 tools: impl Into<Vec<Tool>>,
317 ) -> Self {
318 ChatCompletionRequestBuilder::default()
319 .model(model)
320 .messages(messages)
321 .tools(tools)
322 .build()
323 .unwrap()
324 }
325}
326
327impl ChatCompletionMessage {
328 pub fn new_system(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
329 ChatCompletionMessage::System(SystemMessage {
330 content: content.into(),
331 name: Self::get_name(name),
332 })
333 }
334
335 pub fn new_user(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
336 ChatCompletionMessage::User(UserMessage {
337 content: content.into(),
338 name: Self::get_name(name),
339 })
340 }
341
342 fn get_name(name: &str) -> Option<String> {
343 if name.is_empty() {
344 None
345 } else {
346 Some(name.into())
347 }
348 }
349}
350
351impl Tool {
352 pub fn new_function<T: ToSchema>(
353 name: impl Into<String>,
354 description: impl Into<String>,
355 ) -> Self {
356 let parameters = T::to_schema();
357 Self {
358 r#type: ToolType::Function,
359 function: FunctionInfo {
360 name: name.into(),
361 description: description.into(),
362 parameters,
363 },
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::{ToSchema, SDK};
372 use anyhow::Result;
373 use schemars::JsonSchema;
374
375 #[allow(dead_code)]
376 #[derive(Debug, Clone, Deserialize, JsonSchema)]
377 struct GetWeatherArgs {
378 pub city: String,
380 pub unit: TemperatureUnit,
382 }
383
384 #[allow(dead_code)]
385 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, JsonSchema)]
386 enum TemperatureUnit {
387 #[default]
389 Celsius,
390 Fahrenheit,
392 }
393
394 #[derive(Debug, Clone)]
395 struct GetWeatherResponse {
396 temperature: f32,
397 unit: TemperatureUnit,
398 }
399
400 #[allow(dead_code)]
401 #[derive(Debug, Deserialize, JsonSchema)]
402 struct ExplainMoodArgs {
403 pub name: String,
405 }
406
407 fn get_weather_forecast(args: GetWeatherArgs) -> GetWeatherResponse {
408 match args.unit {
409 TemperatureUnit::Celsius => GetWeatherResponse {
410 temperature: 22.2,
411 unit: TemperatureUnit::Celsius,
412 },
413 TemperatureUnit::Fahrenheit => GetWeatherResponse {
414 temperature: 72.0,
415 unit: TemperatureUnit::Fahrenheit,
416 },
417 }
418 }
419
420 #[test]
421 #[ignore]
422 fn chat_completion_request_tool_choice_function_serialize_should_work() {
423 let req = ChatCompletionRequestBuilder::default()
424 .tool_choice(ToolChoice::Function {
425 name: "my_function".to_string(),
426 })
427 .messages(vec![])
428 .build()
429 .unwrap();
430 let json = serde_json::to_value(req).unwrap();
431 assert_eq!(
432 json,
433 serde_json::json!({
434 "tool_choice": {
435 "type": "function",
436 "function": {
437 "name": "my_function"
438 }
439 },
440 "messages": []
441 })
442 );
443 }
444
445 #[test]
446 fn chat_completion_request_serialize_should_work() {
447 let mut req = get_simple_completion_request();
448 req.tool_choice = Some(ToolChoice::Auto);
449 let json = serde_json::to_value(req).unwrap();
450 assert_eq!(
451 json,
452 serde_json::json!({
453 "tool_choice": "auto",
454 "model": "gpt-3.5-turbo-1106",
455 "messages": [{
456 "role": "system",
457 "content": "I can answer any question you ask me."
458 }, {
459 "role": "user",
460 "content": "What is human life expectancy in the world?",
461 "name": "user1"
462 }]
463 })
464 );
465 }
466
467 #[test]
468 fn chat_completion_request_with_tools_serialize_should_work() {
469 let req = get_tool_completion_request();
470 let json = serde_json::to_value(req).unwrap();
471 assert_eq!(
472 json,
473 serde_json::json!({
474 "model": "gpt-3.5-turbo-1106",
475 "messages": [{
476 "role": "system",
477 "content": "I can choose the right function for you."
478 }, {
479 "role": "user",
480 "content": "What is the weather like in Boston?",
481 "name": "user1"
482 }],
483 "tools": [
484 {
485 "type": "function",
486 "function": {
487 "description": "Get the weather forecast for a city.",
488 "name": "get_weather_forecast",
489 "parameters": GetWeatherArgs::to_schema()
490 }
491 },
492 {
493 "type": "function",
494 "function": {
495 "description": "Explain the meaning of the given mood.",
496 "name": "explain_mood",
497 "parameters": ExplainMoodArgs::to_schema()
498 }
499 }
500 ]
501 })
502 );
503 }
504
505 #[tokio::test]
506 async fn simple_chat_completion_should_work() -> Result<()> {
507 let req = get_simple_completion_request();
508 let res = SDK.chat_completion(req).await?;
509 assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
510 assert_eq!(res.object, "chat.completion");
511 assert_eq!(res.choices.len(), 1);
512 let choice = &res.choices[0];
513 assert_eq!(choice.finish_reason, FinishReason::Stop);
514 assert_eq!(choice.index, 0);
515 assert_eq!(choice.message.tool_calls.len(), 0);
516 Ok(())
517 }
518
519 #[tokio::test]
520 async fn chat_completion_with_tools_should_work() -> Result<()> {
521 let req = get_tool_completion_request();
522 let res = SDK.chat_completion(req).await?;
523 assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
524 assert_eq!(res.object, "chat.completion");
525 assert_eq!(res.choices.len(), 1);
526 let choice = &res.choices[0];
527 assert_eq!(choice.finish_reason, FinishReason::ToolCalls);
528 assert_eq!(choice.index, 0);
529 assert_eq!(choice.message.content, None);
530 assert_eq!(choice.message.tool_calls.len(), 1);
531 let tool_call = &choice.message.tool_calls[0];
532 assert_eq!(tool_call.function.name, "get_weather_forecast");
533 let ret = get_weather_forecast(serde_json::from_str(&tool_call.function.arguments)?);
534 assert_eq!(ret.unit, TemperatureUnit::Celsius);
535 assert_eq!(ret.temperature, 22.2);
536 Ok(())
537 }
538
539 fn get_simple_completion_request() -> ChatCompletionRequest {
540 let messages = vec![
541 ChatCompletionMessage::new_system("I can answer any question you ask me.", ""),
542 ChatCompletionMessage::new_user("What is human life expectancy in the world?", "user1"),
543 ];
544 ChatCompletionRequest::new(ChatCompleteModel::Gpt3Turbo, messages)
545 }
546
547 fn get_tool_completion_request() -> ChatCompletionRequest {
548 let messages = vec![
549 ChatCompletionMessage::new_system("I can choose the right function for you.", ""),
550 ChatCompletionMessage::new_user("What is the weather like in Boston?", "user1"),
551 ];
552 let tools = vec![
553 Tool::new_function::<GetWeatherArgs>(
554 "get_weather_forecast",
555 "Get the weather forecast for a city.",
556 ),
557 Tool::new_function::<ExplainMoodArgs>(
558 "explain_mood",
559 "Explain the meaning of the given mood.",
560 ),
561 ];
562 ChatCompletionRequest::new_with_tools(ChatCompleteModel::Gpt3Turbo, messages, tools)
563 }
564}