alith_interface/llms/api/
client.rs

1use super::error::map_serialization_error;
2use super::{
3    config::ApiConfigTrait,
4    error::{ClientError, WrappedError, map_deserialization_error},
5};
6use bytes::Bytes;
7use serde::{Serialize, de::DeserializeOwned};
8
9#[derive(Debug, Clone)]
10pub struct ApiClient<C: ApiConfigTrait> {
11    http_client: reqwest::Client,
12    pub config: C,
13    pub backoff: backoff::ExponentialBackoff,
14}
15
16impl<C: ApiConfigTrait> ApiClient<C> {
17    pub fn new(config: C) -> Self {
18        Self {
19            http_client: reqwest::Client::new(),
20            config,
21            backoff: backoff::ExponentialBackoffBuilder::new()
22                .with_max_elapsed_time(Some(std::time::Duration::from_secs(60)))
23                .build(),
24        }
25    }
26
27    /// Make a POST request to {path} and deserialize the response body
28    pub async fn post<I, O>(&self, path: &str, request: I) -> Result<O, ClientError>
29    where
30        I: Serialize + std::fmt::Debug,
31        O: DeserializeOwned,
32    {
33        // Log the serialized request
34        let request_maker = || async {
35            let serialized_request =
36                serde_json::to_string(&request).map_err(map_serialization_error)?;
37            crate::trace!("Serialized post request: {}", serialized_request);
38            let request_builder = self
39                .http_client
40                .post(self.config.url(path))
41                // .query(&self.config.query())
42                .headers(self.config.headers())
43                .header(reqwest::header::CONTENT_TYPE, "application/json")
44                .body(serialized_request);
45            // crate::trace!("Serialized post request: {:?}", request_builder); // This will log API keys!
46            Ok(request_builder.build()?)
47        };
48        self.execute(request_maker).await
49    }
50
51    /// Make a GET request to {path} and deserialize the response body
52    pub async fn get<O>(&self, path: &str) -> Result<O, ClientError>
53    where
54        O: DeserializeOwned,
55    {
56        let request_maker = || async {
57            crate::trace!("Get request: {}", path);
58            let request_builder = self
59                .http_client
60                .get(self.config.url(path))
61                .headers(self.config.headers());
62
63            // crate::trace!("Serialized post request: {:?}", request_builder); // This will log API keys!
64            Ok(request_builder.build()?)
65        };
66        self.execute(request_maker).await
67    }
68
69    /// Execute a HTTP request and retry on rate limit
70    ///
71    /// request_maker serves one purpose: to be able to create request again
72    /// to retry API call after getting rate limited. request_maker is async because
73    /// reqwest::multipart::Form is created by async calls to read files for uploads.
74    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, ClientError>
75    where
76        M: Fn() -> Fut,
77        Fut: core::future::Future<Output = Result<reqwest::Request, ClientError>>,
78    {
79        let client = self.http_client.clone();
80
81        backoff::future::retry(self.backoff.clone(), || async {
82            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
83            let response = client
84                .execute(request)
85                .await
86                .map_err(ClientError::Reqwest)
87                .map_err(backoff::Error::Permanent)?;
88
89            let status = response.status();
90            let bytes = response
91                .bytes()
92                .await
93                .map_err(ClientError::Reqwest)
94                .map_err(backoff::Error::Permanent)?;
95
96            // Deserialize response body from either error object or actual response object
97            if !status.is_success() {
98                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
99                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
100                    .map_err(backoff::Error::Permanent)?;
101
102                if status.as_u16() == 429
103                    // API returns 429 also when:
104                    // "You exceeded your current quota, please check your plan and billing details."
105                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
106                {
107                    // Rate limited retry...
108                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
109                    return Err(backoff::Error::Transient {
110                        err: ClientError::ApiError(wrapped_error.error),
111                        retry_after: None,
112                    });
113                } else if status.as_u16() == 503 {
114                    return Err(backoff::Error::Transient {
115                        err: ClientError::ServiceUnavailable {
116                            message: wrapped_error.error.message,
117                        },
118                        retry_after: None,
119                    });
120                } else {
121                    return Err(backoff::Error::Permanent(ClientError::ApiError(
122                        wrapped_error.error,
123                    )));
124                }
125            }
126
127            Ok(bytes)
128        })
129        .await
130    }
131
132    /// Execute a HTTP request and retry on rate limit
133    ///
134    /// request_maker serves one purpose: to be able to create request again
135    /// to retry API call after getting rate limited. request_maker is async because
136    /// reqwest::multipart::Form is created by async calls to read files for uploads.
137    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, ClientError>
138    where
139        O: DeserializeOwned,
140        M: Fn() -> Fut,
141        Fut: core::future::Future<Output = Result<reqwest::Request, ClientError>>,
142    {
143        let bytes = self.execute_raw(request_maker).await?;
144
145        // Deserialize once into a generic Value
146        let value: serde_json::Value =
147            serde_json::from_slice(&bytes).map_err(|e| map_deserialization_error(e, &bytes))?;
148
149        // Log the pretty-printed JSON
150        let pretty_json = serde_json::to_string_pretty(&value).map_err(map_serialization_error)?;
151        crate::trace!("Serialized response: {}", pretty_json);
152
153        // Convert the Value into the target type
154        let response: O =
155            serde_json::from_value(value).map_err(|e| map_deserialization_error(e, &bytes))?;
156
157        Ok(response)
158    }
159}