1use crate::{IntoRequest, ToSchema};
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumMessage, EnumString, VariantNames};
6
7#[derive(Debug, Clone, Serialize, Builder)]
8pub struct ChatCompletionRequest {
9 #[builder(setter(into))]
11 messages: Vec<ChatCompletionMessage>,
12 #[builder(default)]
14 model: ChatCompleteModel,
15 #[builder(default, setter(strip_option))]
17 #[serde(skip_serializing_if = "Option::is_none")]
18 frequency_penalty: Option<f32>,
19
20 #[builder(default, setter(strip_option))]
26 #[serde(skip_serializing_if = "Option::is_none")]
27 max_tokens: Option<usize>,
28 #[builder(default, setter(strip_option))]
30 #[serde(skip_serializing_if = "Option::is_none")]
31 n: Option<usize>,
32 #[builder(default, setter(strip_option))]
34 #[serde(skip_serializing_if = "Option::is_none")]
35 presence_penalty: Option<f32>,
36 #[builder(default, setter(strip_option))]
38 #[serde(skip_serializing_if = "Option::is_none")]
39 response_format: Option<ChatResponseFormatObject>,
40 #[builder(default, setter(strip_option))]
42 #[serde(skip_serializing_if = "Option::is_none")]
43 seed: Option<usize>,
44 #[builder(default, setter(strip_option))]
47 #[serde(skip_serializing_if = "Option::is_none")]
48 stop: Option<String>,
49 #[builder(default, setter(strip_option))]
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub stream: Option<bool>,
53 #[builder(default, setter(strip_option))]
55 #[serde(skip_serializing_if = "Option::is_none")]
56 temperature: Option<f32>,
57 #[builder(default, setter(strip_option))]
59 #[serde(skip_serializing_if = "Option::is_none")]
60 top_p: Option<f32>,
61 #[builder(default, setter(into))]
63 #[serde(skip_serializing_if = "Vec::is_empty")]
64 tools: Vec<Tool>,
65 #[builder(default, setter(strip_option))]
67 #[serde(skip_serializing_if = "Option::is_none")]
68 tool_choice: Option<ToolChoice>,
69 #[builder(default, setter(strip_option, into))]
71 #[serde(skip_serializing_if = "Option::is_none")]
72 user: Option<String>,
73}
74
75#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, EnumString, Display, VariantNames)]
76#[serde(rename_all = "snake_case")]
77pub enum ToolChoice {
78 #[default]
79 None,
80 Auto,
81 Function {
83 name: String,
84 },
85}
86
87#[derive(Debug, Clone, Serialize)]
88pub struct Tool {
89 r#type: ToolType,
91 function: FunctionInfo,
93}
94
95#[derive(Debug, Clone, Serialize)]
96pub struct FunctionInfo {
97 description: String,
99 name: String,
101 parameters: serde_json::Value,
103}
104
105#[derive(Debug, Clone, Serialize)]
106pub struct ChatResponseFormatObject {
107 r#type: ChatResponseFormat,
108}
109
110#[derive(
111 Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, EnumString, Display, VariantNames,
112)]
113#[serde(rename_all = "snake_case")]
114pub enum ChatResponseFormat {
115 Text,
116 #[default]
117 Json,
118}
119
120#[derive(Debug, Clone, Serialize, Display, VariantNames, EnumMessage)]
121#[serde(rename_all = "snake_case", tag = "role")]
122pub enum ChatCompletionMessage {
123 System(SystemMessage),
125 User(UserMessage),
127 Assistant(AssistantMessage),
129 Tool(ToolMessage),
131}
132
133#[derive(
134 Debug,
135 Clone,
136 Default,
137 PartialEq,
138 Eq,
139 Serialize,
140 Deserialize,
141 EnumString,
142 EnumIter,
143 Display,
144 VariantNames,
145 EnumMessage,
146)]
147pub enum ChatCompleteModel {
148 #[default]
150 #[serde(rename = "gpt-3.5-turbo-1106")]
151 #[strum(serialize = "gpt-3.5-turbo")]
152 Gpt3Turbo,
153 #[serde(rename = "gpt-3.5-turbo-instruct")]
155 #[strum(serialize = "gpt-3.5-turbo-instruct")]
156 Gpt3TurboInstruct,
157 #[serde(rename = "gpt-4-1106-preview")]
159 #[strum(serialize = "gpt-4-turbo")]
160 Gpt4Turbo,
161 #[serde(rename = "gpt-4-1106-vision-preview")]
163 #[strum(serialize = "gpt-4-turbo-vision")]
164 Gpt4TurboVision,
165
166 #[serde(rename = "deepseek-chat")]
167 #[strum(serialize = "deepseek-chat")]
168 DeepSeekChat,
169
170 #[serde(rename = "deepseek-reasoner")]
171 #[strum(serialize = "deepseek-reasoner")]
172 DeepSeekReasoner,
173
174 #[serde(untagged)]
175 Other(String),
176}
177
178#[derive(Debug, Clone, Serialize)]
179pub struct SystemMessage {
180 content: String,
182 #[serde(skip_serializing_if = "Option::is_none")]
184 name: Option<String>,
185}
186
187#[derive(Debug, Clone, Serialize)]
188pub struct UserMessage {
189 content: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
193 name: Option<String>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct AssistantMessage {
198 #[serde(default)]
200 pub content: Option<String>,
201 #[serde(skip_serializing_if = "Option::is_none", default)]
203 pub name: Option<String>,
204 #[serde(skip_serializing_if = "Vec::is_empty", default)]
206 pub tool_calls: Vec<ToolCall>,
207}
208
209#[derive(Debug, Clone, Serialize)]
210pub struct ToolMessage {
211 content: String,
213 tool_call_id: String,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ToolCall {
219 pub id: String,
221 pub r#type: ToolType,
223 pub function: FunctionCall,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct FunctionCall {
229 pub name: String,
231 pub arguments: String,
233}
234
235#[derive(
236 Debug,
237 Clone,
238 Copy,
239 PartialEq,
240 Eq,
241 Default,
242 Serialize,
243 Deserialize,
244 EnumString,
245 Display,
246 VariantNames,
247)]
248#[serde(rename_all = "snake_case")]
249pub enum ToolType {
250 #[default]
251 Function,
252}
253
254#[derive(Debug, Clone, Deserialize)]
255pub struct ChatCompletionResponse {
256 pub id: String,
258 pub choices: Vec<ChatCompletionChoice>,
260 pub created: usize,
262 pub model: ChatCompleteModel,
264 pub system_fingerprint: Option<String>,
266 pub object: String,
268 pub usage: ChatCompleteUsage,
270}
271
272#[derive(Debug, Clone, Deserialize)]
273pub struct ChatCompletionChoice {
274 pub finish_reason: FinishReason,
276 pub index: usize,
278 pub message: AssistantMessage,
280}
281
282#[derive(Debug, Clone, Deserialize)]
283pub struct ChatCompleteUsage {
284 pub completion_tokens: usize,
286 pub prompt_tokens: usize,
288 pub total_tokens: usize,
290}
291
292#[derive(Deserialize, Clone, Debug)]
293pub struct Delta {
294 pub content: Option<String>,
295 pub reasoning_content: Option<String>,
296 pub role: Option<String>,
297}
298
299#[derive(Deserialize, Clone, Debug)]
300pub struct ChatStreamChoice {
301 pub delta: Delta,
302 pub finish_reason: Option<String>,
303 pub index: usize,
304 pub logprobs: Option<String>,
305}
306
307#[derive(Deserialize, Clone, Debug)]
308pub struct ChatStreamResponse {
309 pub choices: Vec<ChatStreamChoice>,
310 pub created: usize,
311 pub id: String,
312 pub model: String,
313 pub object: String,
314 pub system_fingerprint: Option<String>,
315}
316
317#[derive(
318 Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, EnumString, Display, VariantNames,
319)]
320#[serde(rename_all = "snake_case")]
321pub enum FinishReason {
322 #[default]
323 Stop,
324 Length,
325 ContentFilter,
326 ToolCalls,
327}
328
329impl IntoRequest for ChatCompletionRequest {
330 fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
331 let url = format!("{}/chat/completions", base_url);
332 client.post(url).json(&self)
333 }
334}
335
336impl ChatCompletionRequest {
337 pub fn new(model: ChatCompleteModel, messages: impl Into<Vec<ChatCompletionMessage>>) -> Self {
338 ChatCompletionRequestBuilder::default()
339 .model(model)
340 .messages(messages)
341 .build()
342 .unwrap()
343 }
344
345 pub fn new_with_tools(
346 model: ChatCompleteModel,
347 messages: impl Into<Vec<ChatCompletionMessage>>,
348 tools: impl Into<Vec<Tool>>,
349 ) -> Self {
350 ChatCompletionRequestBuilder::default()
351 .model(model)
352 .messages(messages)
353 .tools(tools)
354 .build()
355 .unwrap()
356 }
357}
358
359impl ChatCompletionMessage {
360 pub fn new_system(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
361 ChatCompletionMessage::System(SystemMessage {
362 content: content.into(),
363 name: Self::get_name(name),
364 })
365 }
366
367 pub fn new_user(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
368 ChatCompletionMessage::User(UserMessage {
369 content: content.into(),
370 name: Self::get_name(name),
371 })
372 }
373
374 fn get_name(name: &str) -> Option<String> {
375 if name.is_empty() {
376 None
377 } else {
378 Some(name.into())
379 }
380 }
381}
382
383impl Tool {
384 pub fn new_function<T: ToSchema>(
385 name: impl Into<String>,
386 description: impl Into<String>,
387 ) -> Self {
388 let parameters = T::to_schema();
389 Self {
390 r#type: ToolType::Function,
391 function: FunctionInfo {
392 name: name.into(),
393 description: description.into(),
394 parameters,
395 },
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::{SDK, ToSchema};
404 use anyhow::Result;
405 use schemars::JsonSchema;
406
407 #[allow(dead_code)]
408 #[derive(Debug, Clone, Deserialize, JsonSchema)]
409 struct GetWeatherArgs {
410 pub city: String,
412 pub unit: TemperatureUnit,
414 }
415
416 #[allow(dead_code)]
417 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, JsonSchema)]
418 enum TemperatureUnit {
419 #[default]
421 Celsius,
422 Fahrenheit,
424 }
425
426 #[derive(Debug, Clone)]
427 struct GetWeatherResponse {
428 temperature: f32,
429 unit: TemperatureUnit,
430 }
431
432 #[allow(dead_code)]
433 #[derive(Debug, Deserialize, JsonSchema)]
434 struct ExplainMoodArgs {
435 pub name: String,
437 }
438
439 fn get_weather_forecast(args: GetWeatherArgs) -> GetWeatherResponse {
440 match args.unit {
441 TemperatureUnit::Celsius => GetWeatherResponse {
442 temperature: 22.2,
443 unit: TemperatureUnit::Celsius,
444 },
445 TemperatureUnit::Fahrenheit => GetWeatherResponse {
446 temperature: 72.0,
447 unit: TemperatureUnit::Fahrenheit,
448 },
449 }
450 }
451
452 #[test]
453 #[ignore]
454 fn chat_completion_request_tool_choice_function_serialize_should_work() {
455 let req = ChatCompletionRequestBuilder::default()
456 .tool_choice(ToolChoice::Function {
457 name: "my_function".to_string(),
458 })
459 .messages(vec![])
460 .build()
461 .unwrap();
462 let json = serde_json::to_value(req).unwrap();
463 assert_eq!(
464 json,
465 serde_json::json!({
466 "tool_choice": {
467 "type": "function",
468 "function": {
469 "name": "my_function"
470 }
471 },
472 "messages": []
473 })
474 );
475 }
476
477 #[test]
478 fn chat_completion_request_serialize_should_work() {
479 let mut req = get_simple_completion_request();
480 req.tool_choice = Some(ToolChoice::Auto);
481 let json = serde_json::to_value(req).unwrap();
482 assert_eq!(
483 json,
484 serde_json::json!({
485 "tool_choice": "auto",
486 "model": "gpt-3.5-turbo-1106",
487 "messages": [{
488 "role": "system",
489 "content": "I can answer any question you ask me."
490 }, {
491 "role": "user",
492 "content": "What is human life expectancy in the world?",
493 "name": "user1"
494 }]
495 })
496 );
497 }
498
499 #[test]
500 fn chat_completion_request_with_tools_serialize_should_work() {
501 let req = get_tool_completion_request();
502 let json = serde_json::to_value(req).unwrap();
503 assert_eq!(
504 json,
505 serde_json::json!({
506 "model": "gpt-3.5-turbo-1106",
507 "messages": [{
508 "role": "system",
509 "content": "I can choose the right function for you."
510 }, {
511 "role": "user",
512 "content": "What is the weather like in Boston?",
513 "name": "user1"
514 }],
515 "tools": [
516 {
517 "type": "function",
518 "function": {
519 "description": "Get the weather forecast for a city.",
520 "name": "get_weather_forecast",
521 "parameters": GetWeatherArgs::to_schema()
522 }
523 },
524 {
525 "type": "function",
526 "function": {
527 "description": "Explain the meaning of the given mood.",
528 "name": "explain_mood",
529 "parameters": ExplainMoodArgs::to_schema()
530 }
531 }
532 ]
533 })
534 );
535 }
536
537 #[tokio::test]
538 #[ignore]
539 async fn simple_chat_completion_should_work() -> Result<()> {
540 let req = get_simple_completion_request();
541 let res = SDK.chat_completion(req).await?;
542 assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
543 assert_eq!(res.object, "chat.completion");
544 assert_eq!(res.choices.len(), 1);
545 let choice = &res.choices[0];
546 assert_eq!(choice.finish_reason, FinishReason::Stop);
547 assert_eq!(choice.index, 0);
548 assert_eq!(choice.message.tool_calls.len(), 0);
549 Ok(())
550 }
551
552 #[tokio::test]
553 #[ignore]
554 async fn chat_completion_with_tools_should_work() -> Result<()> {
555 let req = get_tool_completion_request();
556 let res = SDK.chat_completion(req).await?;
557 assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
558 assert_eq!(res.object, "chat.completion");
559 assert_eq!(res.choices.len(), 1);
560 let choice = &res.choices[0];
561 assert_eq!(choice.finish_reason, FinishReason::ToolCalls);
562 assert_eq!(choice.index, 0);
563 assert_eq!(choice.message.content, None);
564 assert_eq!(choice.message.tool_calls.len(), 1);
565 let tool_call = &choice.message.tool_calls[0];
566 assert_eq!(tool_call.function.name, "get_weather_forecast");
567 let ret = get_weather_forecast(serde_json::from_str(&tool_call.function.arguments)?);
568 assert_eq!(ret.unit, TemperatureUnit::Celsius);
569 assert_eq!(ret.temperature, 22.2);
570 Ok(())
571 }
572
573 fn get_simple_completion_request() -> ChatCompletionRequest {
574 let messages = vec![
575 ChatCompletionMessage::new_system("I can answer any question you ask me.", ""),
576 ChatCompletionMessage::new_user("What is human life expectancy in the world?", "user1"),
577 ];
578 ChatCompletionRequest::new(ChatCompleteModel::Gpt3Turbo, messages)
579 }
580
581 fn get_tool_completion_request() -> ChatCompletionRequest {
582 let messages = vec![
583 ChatCompletionMessage::new_system("I can choose the right function for you.", ""),
584 ChatCompletionMessage::new_user("What is the weather like in Boston?", "user1"),
585 ];
586 let tools = vec![
587 Tool::new_function::<GetWeatherArgs>(
588 "get_weather_forecast",
589 "Get the weather forecast for a city.",
590 ),
591 Tool::new_function::<ExplainMoodArgs>(
592 "explain_mood",
593 "Explain the meaning of the given mood.",
594 ),
595 ];
596 ChatCompletionRequest::new_with_tools(ChatCompleteModel::Gpt3Turbo, messages, tools)
597 }
598}