alith_interface/llms/api/
client.rs1use super::error::map_serialization_error;
2use super::{
3 config::ApiConfigTrait,
4 error::{ClientError, WrappedError, map_deserialization_error},
5};
6use bytes::Bytes;
7use serde::{Serialize, de::DeserializeOwned};
8
9#[derive(Debug, Clone)]
10pub struct ApiClient<C: ApiConfigTrait> {
11 http_client: reqwest::Client,
12 pub config: C,
13 pub backoff: backoff::ExponentialBackoff,
14}
15
16impl<C: ApiConfigTrait> ApiClient<C> {
17 pub fn new(config: C) -> Self {
18 Self {
19 http_client: reqwest::Client::new(),
20 config,
21 backoff: backoff::ExponentialBackoffBuilder::new()
22 .with_max_elapsed_time(Some(std::time::Duration::from_secs(60)))
23 .build(),
24 }
25 }
26
27 pub async fn post<I, O>(&self, path: &str, request: I) -> Result<O, ClientError>
29 where
30 I: Serialize + std::fmt::Debug,
31 O: DeserializeOwned,
32 {
33 let request_maker = || async {
35 let serialized_request =
36 serde_json::to_string(&request).map_err(map_serialization_error)?;
37 crate::trace!("Serialized post request: {}", serialized_request);
38 let request_builder = self
39 .http_client
40 .post(self.config.url(path))
41 .headers(self.config.headers())
43 .header(reqwest::header::CONTENT_TYPE, "application/json")
44 .body(serialized_request);
45 Ok(request_builder.build()?)
47 };
48 self.execute(request_maker).await
49 }
50
51 pub async fn get<O>(&self, path: &str) -> Result<O, ClientError>
53 where
54 O: DeserializeOwned,
55 {
56 let request_maker = || async {
57 crate::trace!("Get request: {}", path);
58 let request_builder = self
59 .http_client
60 .get(self.config.url(path))
61 .headers(self.config.headers());
62
63 Ok(request_builder.build()?)
65 };
66 self.execute(request_maker).await
67 }
68
69 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, ClientError>
75 where
76 M: Fn() -> Fut,
77 Fut: core::future::Future<Output = Result<reqwest::Request, ClientError>>,
78 {
79 let client = self.http_client.clone();
80
81 backoff::future::retry(self.backoff.clone(), || async {
82 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
83 let response = client
84 .execute(request)
85 .await
86 .map_err(ClientError::Reqwest)
87 .map_err(backoff::Error::Permanent)?;
88
89 let status = response.status();
90 let bytes = response
91 .bytes()
92 .await
93 .map_err(ClientError::Reqwest)
94 .map_err(backoff::Error::Permanent)?;
95
96 if !status.is_success() {
98 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
99 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
100 .map_err(backoff::Error::Permanent)?;
101
102 if status.as_u16() == 429
103 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
106 {
107 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
109 return Err(backoff::Error::Transient {
110 err: ClientError::ApiError(wrapped_error.error),
111 retry_after: None,
112 });
113 } else if status.as_u16() == 503 {
114 return Err(backoff::Error::Transient {
115 err: ClientError::ServiceUnavailable {
116 message: wrapped_error.error.message,
117 },
118 retry_after: None,
119 });
120 } else {
121 return Err(backoff::Error::Permanent(ClientError::ApiError(
122 wrapped_error.error,
123 )));
124 }
125 }
126
127 Ok(bytes)
128 })
129 .await
130 }
131
132 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, ClientError>
138 where
139 O: DeserializeOwned,
140 M: Fn() -> Fut,
141 Fut: core::future::Future<Output = Result<reqwest::Request, ClientError>>,
142 {
143 let bytes = self.execute_raw(request_maker).await?;
144
145 let value: serde_json::Value =
147 serde_json::from_slice(&bytes).map_err(|e| map_deserialization_error(e, &bytes))?;
148
149 let pretty_json = serde_json::to_string_pretty(&value).map_err(map_serialization_error)?;
151 crate::trace!("Serialized response: {}", pretty_json);
152
153 let response: O =
155 serde_json::from_value(value).map_err(|e| map_deserialization_error(e, &bytes))?;
156
157 Ok(response)
158 }
159}