1use std::{
2 pin::Pin,
3 time::{SystemTime, UNIX_EPOCH},
4};
5
6use futures::{stream::StreamExt, Stream};
7use log::debug;
8use reqwest::Request;
9use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11use uuid::Uuid;
12
13use crate::{
14 api::{AccessToken, ErrorResponse},
15 config::GigaChatConfig,
16 errors::GigaChatError,
17 result::Result,
18};
19
20#[derive(Clone, Default)]
21pub struct Client {
22 http_client: reqwest::Client,
23 config: GigaChatConfig,
24 access_token: Option<AccessToken>,
25}
26
27impl Client {
28 pub fn new() -> Self {
29 Client {
30 http_client: reqwest::Client::new(),
31 ..Default::default()
32 }
33 }
34
35 pub fn with_config(config: GigaChatConfig) -> Self {
36 Client {
37 http_client: reqwest::Client::new(),
38 config,
39 ..Default::default()
40 }
41 }
42
43 async fn get_access_token(&mut self) -> Result<AccessToken> {
44 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis();
45
46 if let Some(access_token) = &self.access_token {
47 if now < access_token.expires_at as u128 {
48 return Ok(access_token.to_owned());
49 }
50 }
51
52 let new_access_token = self.retrive_access_token().await?;
53
54 self.access_token = Some(new_access_token.clone());
55
56 Ok(new_access_token)
57 }
58
59 fn _invalidate_access_token(mut self) -> Result<()> {
60 self.access_token = Default::default();
61
62 Ok(())
63 }
64
65 async fn retrive_access_token(&mut self) -> Result<AccessToken> {
66 let request_id = Uuid::new_v4();
67
68 let response = self
69 .http_client
70 .post(self.config.auth_url.clone())
71 .header("RqUID", request_id.to_string())
72 .header(
73 reqwest::header::CONTENT_TYPE,
74 "application/x-www-form-urlencoded",
75 )
76 .bearer_auth(self.config.auth_token.clone())
77 .body(format!("scope={}", self.config.scope))
78 .send()
79 .await?;
80
81 match response.error_for_status_ref() {
82 Ok(_) => (),
83 Err(error) => {
84 let error_response: ErrorResponse = response.json().await?;
85 log::error!("Error getting access token: {}", error);
86 return Err(GigaChatError::HttpError(format!(
87 "Error getting access token: {}",
88 error_response.message
89 )));
90 }
91 };
92
93 let access_token: AccessToken = response.json().await?;
94
95 Ok(access_token)
96 }
97
98 pub async fn get<O>(mut self, path: &str) -> Result<O>
99 where
100 O: DeserializeOwned,
101 {
102 let request = self
103 .http_client
104 .get(format!("{}{}", self.config.api_base_url, path))
105 .bearer_auth(self.get_access_token().await?.access_token)
106 .build()?;
107
108 self.execute(request).await
109 }
110
111 pub async fn post<I, O>(mut self, path: &str, body: I) -> Result<O>
112 where
113 I: Serialize,
114 O: DeserializeOwned,
115 {
116 let request = self
117 .http_client
118 .post(format!("{}{}", self.config.api_base_url, path))
119 .bearer_auth(self.get_access_token().await?.access_token)
120 .json(&body)
121 .build()?;
122
123 self.execute(request).await
124 }
125
126 pub async fn post_stream<I, O>(
127 mut self,
128 path: &str,
129 body: I,
130 ) -> Result<Pin<Box<dyn Stream<Item = Result<O>>>>>
131 where
132 I: Serialize,
133 O: DeserializeOwned + Send + 'static,
134 {
135 let request = self
136 .http_client
137 .post(format!("{}{}", self.config.api_base_url, path))
138 .bearer_auth(self.get_access_token().await?.access_token)
139 .json(&body)
140 .eventsource()
141 .unwrap();
142
143 Ok(self.stream(request).await)
144 }
145
146 pub async fn execute<R>(self, request: Request) -> Result<R>
147 where
148 R: DeserializeOwned,
149 {
150 let response = self.http_client.execute(request).await?;
151
152 match response.error_for_status_ref() {
153 Ok(_) => (),
154 Err(error) => {
155 log::error!("Error execute request: {}", error);
157 return Err(GigaChatError::HttpError(format!(
158 "Error execute request: {}",
159 error
160 )));
161 }
162 };
163
164 let response_text = response.text().await?;
165
166 debug!("response:\n{}", response_text);
167
168 let result: R = serde_json::from_str(&response_text)?;
169
170 Ok(result)
171 }
172
173 async fn stream<O>(
174 self,
175 mut event_source: EventSource,
176 ) -> Pin<Box<dyn Stream<Item = Result<O>>>>
177 where
178 O: DeserializeOwned + Send + 'static,
179 {
180 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
181
182 tokio::spawn(async move {
183 while let Some(event) = event_source.next().await {
184 match event {
185 Ok(event) => match event {
186 Event::Open => continue,
187 Event::Message(message) => {
188 let data = message.data;
189
190 if data == "[DONE]" {
191 break;
192 }
193
194 let result: Result<O> = serde_json::from_str(&data)
195 .map_err(|error| GigaChatError::StreamError(error.to_string()));
196
197 if let Err(error) = tx.send(result) {
198 log::error!("Error sending event: {}", error);
199 break;
200 }
201 }
202 },
203 Err(error) => {
204 log::error!("Error getting event: {}", error);
205 tx.send(Err(GigaChatError::StreamError(error.to_string())))
206 .unwrap();
207 break;
208 }
209 }
210 }
211
212 event_source.close();
213 });
214
215 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
216 }
217}