async_gigachat/
chat.rs

1use std::{fmt, pin::Pin};
2
3use crate::{errors::GigaChatError, result::Result};
4use derive_builder::Builder;
5use futures::Stream;
6use log::debug;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10
11#[derive(Clone, Serialize, Default, Debug, Builder)]
12#[builder(setter(into, strip_option), default)]
13pub struct ChatCompletionRequest {
14    pub model: String,
15    pub messages: Vec<ChatMessage>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub temperature: Option<f32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub top_p: Option<f32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub n: Option<i64>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub stream: Option<bool>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub max_tokens: Option<i64>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub repetition_penalty: Option<f32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub update_interval: Option<f32>,
30}
31
32impl From<ChatCompletionRequestBuilderError> for GigaChatError {
33    fn from(error: ChatCompletionRequestBuilderError) -> Self {
34        GigaChatError::SystemError(error.to_string())
35    }
36}
37
38#[derive(Builder, Debug, Clone, Serialize, Deserialize)]
39#[builder(setter(into, strip_option))]
40pub struct ChatMessage {
41    pub role: Option<Role>,
42    pub content: String,
43}
44
45impl From<ChatMessageBuilderError> for GigaChatError {
46    fn from(error: ChatMessageBuilderError) -> Self {
47        GigaChatError::SystemError(error.to_string())
48    }
49}
50
51#[derive(Clone, Serialize, Debug, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum Role {
54    System,
55    Assistant,
56    User,
57}
58
59impl fmt::Display for Role {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        write!(f, "{:?}", self)
62    }
63}
64
65#[derive(Clone, Deserialize)]
66pub struct ChatCompletionResponse {
67    pub choices: Vec<ChatChoice>,
68    pub created: i64,
69    pub model: String,
70    pub usage: Usage,
71    pub object: String,
72}
73
74#[derive(Clone, Deserialize)]
75pub struct ChatCompletionStreamResponse {
76    pub choices: Vec<ChatStreamChoice>,
77    pub created: i64,
78    pub model: String,
79    pub object: String,
80}
81
82#[derive(Clone, Deserialize)]
83pub struct ChatChoice {
84    pub message: ChatMessage,
85    pub index: u32,
86    pub finish_reason: String,
87}
88
89#[derive(Clone, Deserialize)]
90pub struct ChatStreamChoice {
91    pub delta: ChatMessage,
92    pub index: u32,
93    pub finish_reason: Option<String>,
94}
95
96#[derive(Clone, Deserialize)]
97pub struct Usage {
98    pub prompt_tokens: i32,
99    pub completion_tokens: i32,
100    pub total_tokens: i32,
101}
102
103pub struct Chat {
104    client: Client,
105}
106
107impl Chat {
108    pub fn new(client: Client) -> Self {
109        Chat { client }
110    }
111
112    pub async fn completion(
113        self,
114        request: ChatCompletionRequest,
115    ) -> Result<ChatCompletionResponse> {
116        debug!("request:\n{}", serde_json::to_string_pretty(&request)?);
117
118        let response = self.client.post("/chat/completions", request).await?;
119
120        Ok(response)
121    }
122
123    pub async fn completion_stream(
124        self,
125        request: ChatCompletionRequest,
126    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionStreamResponse>>>>> {
127        debug!("request:\n{}", serde_json::to_string_pretty(&request)?);
128
129        match request.stream {
130            Some(true) => (),
131            _ => {
132                return Err(GigaChatError::InvalidArgument(
133                    "When stream is false, use Chat::completion".to_owned(),
134                ))
135            }
136        }
137
138        let response = self
139            .client
140            .post_stream("/chat/completions", request)
141            .await?;
142
143        Ok(response)
144    }
145}