async_gigachat/
client.rs

1use std::{
2    pin::Pin,
3    time::{SystemTime, UNIX_EPOCH},
4};
5
6use futures::{stream::StreamExt, Stream};
7use log::debug;
8use reqwest::Request;
9use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11use uuid::Uuid;
12
13use crate::{
14    api::{AccessToken, ErrorResponse},
15    config::GigaChatConfig,
16    errors::GigaChatError,
17    result::Result,
18};
19
20#[derive(Clone, Default)]
21pub struct Client {
22    http_client: reqwest::Client,
23    config: GigaChatConfig,
24    access_token: Option<AccessToken>,
25}
26
27impl Client {
28    pub fn new() -> Self {
29        Client {
30            http_client: reqwest::Client::new(),
31            ..Default::default()
32        }
33    }
34
35    pub fn with_config(config: GigaChatConfig) -> Self {
36        Client {
37            http_client: reqwest::Client::new(),
38            config,
39            ..Default::default()
40        }
41    }
42
43    async fn get_access_token(&mut self) -> Result<AccessToken> {
44        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis();
45
46        if let Some(access_token) = &self.access_token {
47            if now < access_token.expires_at as u128 {
48                return Ok(access_token.to_owned());
49            }
50        }
51
52        let new_access_token = self.retrive_access_token().await?;
53
54        self.access_token = Some(new_access_token.clone());
55
56        Ok(new_access_token)
57    }
58
59    fn _invalidate_access_token(mut self) -> Result<()> {
60        self.access_token = Default::default();
61
62        Ok(())
63    }
64
65    async fn retrive_access_token(&mut self) -> Result<AccessToken> {
66        let request_id = Uuid::new_v4();
67
68        let response = self
69            .http_client
70            .post(self.config.auth_url.clone())
71            .header("RqUID", request_id.to_string())
72            .header(
73                reqwest::header::CONTENT_TYPE,
74                "application/x-www-form-urlencoded",
75            )
76            .bearer_auth(self.config.auth_token.clone())
77            .body(format!("scope={}", self.config.scope))
78            .send()
79            .await?;
80
81        match response.error_for_status_ref() {
82            Ok(_) => (),
83            Err(error) => {
84                let error_response: ErrorResponse = response.json().await?;
85                log::error!("Error getting access token: {}", error);
86                return Err(GigaChatError::HttpError(format!(
87                    "Error getting access token: {}",
88                    error_response.message
89                )));
90            }
91        };
92
93        let access_token: AccessToken = response.json().await?;
94
95        Ok(access_token)
96    }
97
98    pub async fn get<O>(mut self, path: &str) -> Result<O>
99    where
100        O: DeserializeOwned,
101    {
102        let request = self
103            .http_client
104            .get(format!("{}{}", self.config.api_base_url, path))
105            .bearer_auth(self.get_access_token().await?.access_token)
106            .build()?;
107
108        self.execute(request).await
109    }
110
111    pub async fn post<I, O>(mut self, path: &str, body: I) -> Result<O>
112    where
113        I: Serialize,
114        O: DeserializeOwned,
115    {
116        let request = self
117            .http_client
118            .post(format!("{}{}", self.config.api_base_url, path))
119            .bearer_auth(self.get_access_token().await?.access_token)
120            .json(&body)
121            .build()?;
122
123        self.execute(request).await
124    }
125
126    pub async fn post_stream<I, O>(
127        mut self,
128        path: &str,
129        body: I,
130    ) -> Result<Pin<Box<dyn Stream<Item = Result<O>>>>>
131    where
132        I: Serialize,
133        O: DeserializeOwned + Send + 'static,
134    {
135        let request = self
136            .http_client
137            .post(format!("{}{}", self.config.api_base_url, path))
138            .bearer_auth(self.get_access_token().await?.access_token)
139            .json(&body)
140            .eventsource()
141            .unwrap();
142
143        Ok(self.stream(request).await)
144    }
145
146    pub async fn execute<R>(self, request: Request) -> Result<R>
147    where
148        R: DeserializeOwned,
149    {
150        let response = self.http_client.execute(request).await?;
151
152        match response.error_for_status_ref() {
153            Ok(_) => (),
154            Err(error) => {
155                // let error_response: ErrorResponse = response.json().await?;
156                log::error!("Error execute request: {}", error);
157                return Err(GigaChatError::HttpError(format!(
158                    "Error execute request: {}",
159                    error
160                )));
161            }
162        };
163
164        let response_text = response.text().await?;
165
166        debug!("response:\n{}", response_text);
167
168        let result: R = serde_json::from_str(&response_text)?;
169
170        Ok(result)
171    }
172
173    async fn stream<O>(
174        self,
175        mut event_source: EventSource,
176    ) -> Pin<Box<dyn Stream<Item = Result<O>>>>
177    where
178        O: DeserializeOwned + Send + 'static,
179    {
180        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
181
182        tokio::spawn(async move {
183            while let Some(event) = event_source.next().await {
184                match event {
185                    Ok(event) => match event {
186                        Event::Open => continue,
187                        Event::Message(message) => {
188                            let data = message.data;
189
190                            if data == "[DONE]" {
191                                break;
192                            }
193
194                            let result: Result<O> = serde_json::from_str(&data)
195                                .map_err(|error| GigaChatError::StreamError(error.to_string()));
196
197                            if let Err(error) = tx.send(result) {
198                                log::error!("Error sending event: {}", error);
199                                break;
200                            }
201                        }
202                    },
203                    Err(error) => {
204                        log::error!("Error getting event: {}", error);
205                        tx.send(Err(GigaChatError::StreamError(error.to_string())))
206                            .unwrap();
207                        break;
208                    }
209                }
210            }
211
212            event_source.close();
213        });
214
215        Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
216    }
217}