Skip to main content

systemprompt_sync/
api_client.rs

1use std::time::Duration;
2
3use reqwest::{Client, StatusCode};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use systemprompt_models::net::{HTTP_CONNECT_TIMEOUT, HTTP_SYNC_DEPLOY_TIMEOUT};
7use tokio::time::sleep;
8
9use crate::error::{SyncError, SyncResult};
10
11#[derive(Debug, Clone, Copy)]
12pub struct RetryConfig {
13    pub max_attempts: u32,
14    pub initial_delay: Duration,
15    pub max_delay: Duration,
16    pub exponential_base: u32,
17}
18
19impl Default for RetryConfig {
20    fn default() -> Self {
21        Self {
22            max_attempts: 5,
23            initial_delay: Duration::from_secs(2),
24            max_delay: Duration::from_secs(30),
25            exponential_base: 2,
26        }
27    }
28}
29
30#[derive(Clone, Debug)]
31pub struct SyncApiClient {
32    client: Client,
33    api_url: String,
34    token: String,
35    hostname: Option<String>,
36    sync_token: Option<String>,
37    retry_config: RetryConfig,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct RegistryToken {
42    pub registry: String,
43    pub username: String,
44    pub token: String,
45}
46
47#[derive(Debug, Clone, Copy, Deserialize)]
48pub struct UploadResponse {
49    pub files_uploaded: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct DeployResponse {
54    pub status: String,
55    pub app_url: Option<String>,
56}
57
58impl SyncApiClient {
59    pub fn new(api_url: &str, token: &str) -> SyncResult<Self> {
60        Ok(Self {
61            client: Client::builder()
62                .connect_timeout(HTTP_CONNECT_TIMEOUT)
63                .timeout(HTTP_SYNC_DEPLOY_TIMEOUT)
64                .build()?,
65            api_url: api_url.to_string(),
66            token: token.to_string(),
67            hostname: None,
68            sync_token: None,
69            retry_config: RetryConfig::default(),
70        })
71    }
72
73    pub fn with_direct_sync(
74        mut self,
75        hostname: Option<String>,
76        sync_token: Option<String>,
77    ) -> Self {
78        self.hostname = hostname;
79        self.sync_token = sync_token;
80        self
81    }
82
83    fn direct_sync_credentials(&self) -> Option<(String, String)> {
84        match (&self.hostname, &self.sync_token) {
85            (Some(hostname), Some(token)) => {
86                let url = format!("https://{}/api/v1/sync/files", hostname);
87                Some((url, token.clone()))
88            },
89            _ => None,
90        }
91    }
92
93    fn calculate_next_delay(&self, current: Duration) -> Duration {
94        current
95            .saturating_mul(self.retry_config.exponential_base)
96            .min(self.retry_config.max_delay)
97    }
98
99    pub async fn upload_files(
100        &self,
101        tenant_id: &systemprompt_identifiers::TenantId,
102        data: Vec<u8>,
103    ) -> SyncResult<UploadResponse> {
104        let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
105            (
106                format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
107                self.token.clone(),
108            )
109        });
110
111        let mut current_delay = self.retry_config.initial_delay;
112
113        for attempt in 1..=self.retry_config.max_attempts {
114            let response = self
115                .client
116                .post(&url)
117                .header("Authorization", format!("Bearer {}", token))
118                .header("Content-Type", "application/octet-stream")
119                .body(data.clone())
120                .send()
121                .await?;
122
123            match self.handle_json_response::<UploadResponse>(response).await {
124                Ok(upload) => return Ok(upload),
125                Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
126                    tracing::warn!(
127                        attempt = attempt,
128                        max_attempts = self.retry_config.max_attempts,
129                        delay_ms = current_delay.as_millis() as u64,
130                        error = %error,
131                        "Retryable sync error, waiting before retry"
132                    );
133                    sleep(current_delay).await;
134                    current_delay = self.calculate_next_delay(current_delay);
135                },
136                Err(error) => return Err(error),
137            }
138        }
139
140        Err(SyncError::ApiError {
141            status: 503,
142            message: "Max retry attempts exceeded".to_string(),
143        })
144    }
145
146    pub async fn download_files(
147        &self,
148        tenant_id: &systemprompt_identifiers::TenantId,
149    ) -> SyncResult<Vec<u8>> {
150        let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
151            (
152                format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
153                self.token.clone(),
154            )
155        });
156
157        let mut current_delay = self.retry_config.initial_delay;
158
159        for attempt in 1..=self.retry_config.max_attempts {
160            let response = self
161                .client
162                .get(&url)
163                .header("Authorization", format!("Bearer {}", token))
164                .send()
165                .await?;
166
167            match self.handle_binary_response(response).await {
168                Ok(data) => return Ok(data),
169                Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
170                    tracing::warn!(
171                        attempt = attempt,
172                        max_attempts = self.retry_config.max_attempts,
173                        delay_ms = current_delay.as_millis() as u64,
174                        error = %error,
175                        "Retryable sync error, waiting before retry"
176                    );
177                    sleep(current_delay).await;
178                    current_delay = self.calculate_next_delay(current_delay);
179                },
180                Err(error) => return Err(error),
181            }
182        }
183
184        Err(SyncError::ApiError {
185            status: 503,
186            message: "Max retry attempts exceeded".to_string(),
187        })
188    }
189
190    pub async fn get_registry_token(
191        &self,
192        tenant_id: &systemprompt_identifiers::TenantId,
193    ) -> SyncResult<RegistryToken> {
194        let url = format!(
195            "{}/api/v1/cloud/tenants/{}/registry-token",
196            self.api_url, tenant_id
197        );
198        self.get(&url).await
199    }
200
201    pub async fn deploy(
202        &self,
203        tenant_id: &systemprompt_identifiers::TenantId,
204        image: &str,
205    ) -> SyncResult<DeployResponse> {
206        let url = format!("{}/api/v1/cloud/tenants/{}/deploy", self.api_url, tenant_id);
207        self.post(&url, &serde_json::json!({ "image": image }))
208            .await
209    }
210
211    pub async fn get_tenant_app_id(
212        &self,
213        tenant_id: &systemprompt_identifiers::TenantId,
214    ) -> SyncResult<String> {
215        #[derive(Deserialize)]
216        struct TenantInfo {
217            fly_app_name: Option<String>,
218        }
219        let url = format!("{}/api/v1/cloud/tenants/{}", self.api_url, tenant_id);
220        let info: TenantInfo = self.get(&url).await?;
221        info.fly_app_name.ok_or(SyncError::TenantNoApp)
222    }
223
224    pub async fn get_database_url(
225        &self,
226        tenant_id: &systemprompt_identifiers::TenantId,
227    ) -> SyncResult<String> {
228        #[derive(Deserialize)]
229        struct DatabaseInfo {
230            database_url: Option<String>,
231        }
232        let url = format!(
233            "{}/api/v1/cloud/tenants/{}/database",
234            self.api_url, tenant_id
235        );
236        let info: DatabaseInfo = self.get(&url).await?;
237        info.database_url.ok_or_else(|| SyncError::ApiError {
238            status: 404,
239            message: "Database URL not available for tenant".to_string(),
240        })
241    }
242
243    async fn get<T: DeserializeOwned>(&self, url: &str) -> SyncResult<T> {
244        let response = self
245            .client
246            .get(url)
247            .header("Authorization", format!("Bearer {}", self.token))
248            .send()
249            .await?;
250
251        self.handle_json_response(response).await
252    }
253
254    async fn post<T: DeserializeOwned, B: Serialize + Sync>(
255        &self,
256        url: &str,
257        body: &B,
258    ) -> SyncResult<T> {
259        let response = self
260            .client
261            .post(url)
262            .header("Authorization", format!("Bearer {}", self.token))
263            .json(body)
264            .send()
265            .await?;
266
267        self.handle_json_response(response).await
268    }
269
270    async fn handle_json_response<T: DeserializeOwned>(
271        &self,
272        response: reqwest::Response,
273    ) -> SyncResult<T> {
274        let status = response.status();
275        if status == StatusCode::UNAUTHORIZED {
276            return Err(SyncError::Unauthorized);
277        }
278        if !status.is_success() {
279            let message = response.text().await?;
280            return Err(SyncError::ApiError {
281                status: status.as_u16(),
282                message,
283            });
284        }
285        Ok(response.json().await?)
286    }
287
288    async fn handle_binary_response(&self, response: reqwest::Response) -> SyncResult<Vec<u8>> {
289        let status = response.status();
290        if !status.is_success() {
291            let message = response
292                .text()
293                .await
294                .unwrap_or_else(|e| format!("(body unreadable: {})", e));
295            return Err(SyncError::ApiError {
296                status: status.as_u16(),
297                message,
298            });
299        }
300        Ok(response.bytes().await?.to_vec())
301    }
302}