1use std::{fmt, pin::Pin};
2
3use crate::{errors::GigaChatError, result::Result};
4use derive_builder::Builder;
5use futures::Stream;
6use log::debug;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10
11#[derive(Clone, Serialize, Default, Debug, Builder)]
12#[builder(setter(into, strip_option), default)]
13pub struct ChatCompletionRequest {
14 pub model: String,
15 pub messages: Vec<ChatMessage>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub top_p: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub n: Option<i64>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub stream: Option<bool>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub max_tokens: Option<i64>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub repetition_penalty: Option<f32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub update_interval: Option<f32>,
30}
31
32impl From<ChatCompletionRequestBuilderError> for GigaChatError {
33 fn from(error: ChatCompletionRequestBuilderError) -> Self {
34 GigaChatError::SystemError(error.to_string())
35 }
36}
37
38#[derive(Builder, Debug, Clone, Serialize, Deserialize)]
39#[builder(setter(into, strip_option))]
40pub struct ChatMessage {
41 pub role: Option<Role>,
42 pub content: String,
43}
44
45impl From<ChatMessageBuilderError> for GigaChatError {
46 fn from(error: ChatMessageBuilderError) -> Self {
47 GigaChatError::SystemError(error.to_string())
48 }
49}
50
51#[derive(Clone, Serialize, Debug, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum Role {
54 System,
55 Assistant,
56 User,
57}
58
59impl fmt::Display for Role {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{:?}", self)
62 }
63}
64
65#[derive(Clone, Deserialize)]
66pub struct ChatCompletionResponse {
67 pub choices: Vec<ChatChoice>,
68 pub created: i64,
69 pub model: String,
70 pub usage: Usage,
71 pub object: String,
72}
73
74#[derive(Clone, Deserialize)]
75pub struct ChatCompletionStreamResponse {
76 pub choices: Vec<ChatStreamChoice>,
77 pub created: i64,
78 pub model: String,
79 pub object: String,
80}
81
82#[derive(Clone, Deserialize)]
83pub struct ChatChoice {
84 pub message: ChatMessage,
85 pub index: u32,
86 pub finish_reason: String,
87}
88
89#[derive(Clone, Deserialize)]
90pub struct ChatStreamChoice {
91 pub delta: ChatMessage,
92 pub index: u32,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Clone, Deserialize)]
97pub struct Usage {
98 pub prompt_tokens: i32,
99 pub completion_tokens: i32,
100 pub total_tokens: i32,
101}
102
103pub struct Chat {
104 client: Client,
105}
106
107impl Chat {
108 pub fn new(client: Client) -> Self {
109 Chat { client }
110 }
111
112 pub async fn completion(
113 self,
114 request: ChatCompletionRequest,
115 ) -> Result<ChatCompletionResponse> {
116 debug!("request:\n{}", serde_json::to_string_pretty(&request)?);
117
118 let response = self.client.post("/chat/completions", request).await?;
119
120 Ok(response)
121 }
122
123 pub async fn completion_stream(
124 self,
125 request: ChatCompletionRequest,
126 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionStreamResponse>>>>> {
127 debug!("request:\n{}", serde_json::to_string_pretty(&request)?);
128
129 match request.stream {
130 Some(true) => (),
131 _ => {
132 return Err(GigaChatError::InvalidArgument(
133 "When stream is false, use Chat::completion".to_owned(),
134 ))
135 }
136 }
137
138 let response = self
139 .client
140 .post_stream("/chat/completions", request)
141 .await?;
142
143 Ok(response)
144 }
145}