replic/
client.rs

1use reqwest::{
2    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
3    Method, RequestBuilder, Response, StatusCode, Url,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::{config::Config, error::Error};
8
9pub struct Client {
10    api_key: String,
11    base_url: Url,
12    http_client: reqwest::Client,
13}
14
15impl Client {
16    pub fn new(config: Config) -> Result<Self, Error> {
17        let mut headers = HeaderMap::new();
18        headers.insert(
19            AUTHORIZATION,
20            HeaderValue::from_str(format!("Bearer {}", config.api_key.as_str()).as_str())?,
21        );
22        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
23
24        let http_client = reqwest::Client::builder()
25            .default_headers(headers)
26            .build()?;
27
28        let base_url =
29            Url::parse(&config.base_url).map_err(|err| Error::UrlParse(err.to_string()))?;
30
31        Ok(Self {
32            api_key: config.api_key,
33            base_url,
34            http_client,
35        })
36    }
37
38    /// Get the api key.
39    pub fn api_key(&self) -> &str {
40        self.api_key.as_str()
41    }
42
43    /// Get the base URL.
44    pub fn base_url(&self) -> &str {
45        self.base_url.as_str()
46    }
47
48    /// Get the authenticated account.
49    pub async fn account(&self) -> Result<Account, Error> {
50        let response = self.request(Method::GET, "account")?.send().await?;
51        self.handle_response::<Account>(response).await
52    }
53
54    /// List collections of models.
55    pub async fn collections(&self) -> Result<ListCollections, Error> {
56        let response = self.request(Method::GET, "collections")?.send().await?;
57        self.handle_response::<ListCollections>(response).await
58    }
59
60    /// List collection of models.
61    pub async fn collection_models(
62        &self,
63        collection: String,
64    ) -> Result<ListCollectionModels, Error> {
65        let path = format!("collections/{}", collection);
66        let response = self.request(Method::GET, path.as_str())?.send().await?;
67        self.handle_response::<ListCollectionModels>(response).await
68    }
69
70    /// Get information about a deployment by name including the current release.
71    pub async fn deployment(&self, owner: String, name: String) -> Result<Deployment, Error> {
72        let path = format!("deployments/{}/{}", owner, name);
73        let response = self.request(Method::GET, path.as_str())?.send().await?;
74        self.handle_response::<Deployment>(response).await
75    }
76
77    /// List deployments associated with the current account, including the latest release configuration for each deployment.
78    pub async fn deployments(&self) -> Result<ListDeployments, Error> {
79        let response = self.request(Method::GET, "deployments")?.send().await?;
80        self.handle_response::<ListDeployments>(response).await
81    }
82
83    /// Create a new deployment.
84    pub async fn create_deployment(&self, payload: CreateDeployment) -> Result<Deployment, Error> {
85        let response = self
86            .request(Method::POST, "deployments")?
87            .json(&payload)
88            .send()
89            .await?;
90        self.handle_response::<Deployment>(response).await
91    }
92
93    /// Update a deployment.
94    pub async fn update_deployment(
95        &self,
96        owner: String,
97        payload: UpdateDeployment,
98    ) -> Result<Deployment, Error> {
99        let path = format!("deployments/{}/{}", owner, payload.name);
100        let response = self
101            .request(Method::PATCH, path.as_str())?
102            .json(&payload)
103            .send()
104            .await?;
105        self.handle_response::<Deployment>(response).await
106    }
107
108    /// Delete a deployment.
109    ///
110    /// Deployment deletion has some restrictions:
111    ///     - You can only delete deployments that have been offline and unused for at least 15 minutes.
112    pub async fn delete_deployment(&self, owner: String, name: String) -> Result<(), Error> {
113        let path = format!("deployments/{}/{}", owner, name);
114        let response = self.request(Method::DELETE, path.as_str())?.send().await?;
115        self.handle_response(response).await
116    }
117
118    /// Get a prediction.
119    pub async fn prediction(&self, prediction_id: String) -> Result<Prediction, Error> {
120        let path = format!("predictions/{}", prediction_id);
121        let response = self.request(Method::GET, path.as_str())?.send().await?;
122        self.handle_response::<Prediction>(response).await
123    }
124
125    /// List predictions.
126    pub async fn predictions(&self) -> Result<ListPredictions, Error> {
127        let response = self.request(Method::GET, "predictions")?.send().await?;
128        self.handle_response::<ListPredictions>(response).await
129    }
130
131    /// Create a prediction.
132    pub async fn create_prediction(&self, payload: CreatePrediction) -> Result<Prediction, Error> {
133        let response = self
134            .request(Method::POST, "predictions")?
135            .json(&payload)
136            .send()
137            .await?;
138        self.handle_response::<Prediction>(response).await
139    }
140
141    /// Create a prediction from an official model
142    pub async fn create_model_prediction(
143        &self,
144        payload: CreateModelPrediction,
145    ) -> Result<Prediction, Error> {
146        let path = format!("models/{}/{}/predictions", payload.owner, payload.name);
147        let response = self
148            .request(Method::POST, path.as_str())?
149            .json(&serde_json::json!({ "input": payload.input }))
150            .send()
151            .await?;
152        self.handle_response::<Prediction>(response).await
153    }
154
155    /// Cancel a prediction.
156    pub async fn cancel_prediction(&self, prediction_id: String) -> Result<(), Error> {
157        let path = format!("predictions/{}/cancel", prediction_id);
158        let response = self.request(Method::POST, path.as_str())?.send().await?;
159        self.handle_response(response).await
160    }
161
162    /// Get a training.
163    pub async fn training(&self, training_id: String) -> Result<Training, Error> {
164        let path = format!("trainings/{}", training_id);
165        let response = self.request(Method::GET, path.as_str())?.send().await?;
166        self.handle_response::<Training>(response).await
167    }
168
169    /// List trainings.
170    pub async fn trainings(&self) -> Result<ListTrainings, Error> {
171        let response = self.request(Method::GET, "trainings")?.send().await?;
172        self.handle_response::<ListTrainings>(response).await
173    }
174
175    /// Cancel a training.
176    pub async fn cancel_training(&self, training_id: String) -> Result<(), Error> {
177        let path = format!("trainings/{}/cancel", training_id);
178        let response = self.request(Method::POST, path.as_str())?.send().await?;
179        self.handle_response(response).await
180    }
181
182    /// List available hardware for models.
183    pub async fn hardware(&self) -> Result<Vec<Hardware>, Error> {
184        let response = self.request(Method::GET, "hardware")?.send().await?;
185        self.handle_response::<Vec<Hardware>>(response).await
186    }
187
188    /// List public models.
189    pub async fn public_models(&self) -> Result<ListPublicModels, Error> {
190        let response = self.request(Method::GET, "models")?.send().await?;
191        self.handle_response::<ListPublicModels>(response).await
192    }
193
194    /// Get model.
195    pub async fn model(
196        &self,
197        owner: impl Into<String>,
198        name: impl Into<String>,
199    ) -> Result<Model, Error> {
200        let path = format!("models/{}/{}", owner.into(), name.into());
201        let response = self.request(Method::GET, path.as_str())?.send().await?;
202        self.handle_response::<Model>(response).await
203    }
204
205    /// List model versions.
206    pub async fn model_versions(
207        &self,
208        owner: impl Into<String>,
209        name: impl Into<String>,
210    ) -> Result<ListModelVersions, Error> {
211        let path = format!("models/{}/{}/versions", owner.into(), name.into());
212        let response = self.request(Method::GET, path.as_str())?.send().await?;
213        self.handle_response::<ListModelVersions>(response).await
214    }
215
216    /// Get model version.
217    pub async fn model_version(
218        &self,
219        owner: impl Into<String>,
220        name: impl Into<String>,
221        version_id: impl Into<String>,
222    ) -> Result<ModelVersion, Error> {
223        let path = format!(
224            "models/{}/{}/versions/{}",
225            owner.into(),
226            name.into(),
227            version_id.into()
228        );
229        let response = self.request(Method::GET, path.as_str())?.send().await?;
230        self.handle_response::<ModelVersion>(response).await
231    }
232
233    /// Get WebHook default secret
234    pub async fn webhook_default_secret(&self) -> Result<WebHookSecret, Error> {
235        let response = self
236            .request(Method::GET, "webhooks/default/secret")?
237            .send()
238            .await?;
239        self.handle_response::<WebHookSecret>(response).await
240    }
241
242    fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
243        let url = self
244            .base_url
245            .join(path)
246            .map_err(|err| Error::UrlParse(err.to_string()))?;
247        Ok(self.http_client.request(method, url))
248    }
249
250    async fn handle_response<T>(&self, response: Response) -> Result<T, Error>
251    where
252        T: serde::de::DeserializeOwned,
253    {
254        let status = response.status();
255        if status.is_success() | status.is_redirection() {
256            match response.json::<T>().await {
257                Ok(data) => Ok(data),
258                // TODO: this should be a serde error
259                Err(err) => Err(Error::HttpRequest(err)),
260            }
261        } else {
262            match status {
263                StatusCode::BAD_REQUEST => {
264                    let error_msg = response.text().await?;
265                    Err(Error::BadRequest(error_msg))
266                }
267                StatusCode::UNAUTHORIZED => {
268                    let error_msg = response.text().await?;
269                    Err(Error::Unauthorized(error_msg))
270                }
271                StatusCode::FORBIDDEN => {
272                    let error_msg = response.text().await?;
273                    Err(Error::Forbidden(error_msg))
274                }
275                StatusCode::TOO_MANY_REQUESTS => {
276                    let error_msg = response.text().await?;
277                    Err(Error::RateLimited(error_msg))
278                }
279                StatusCode::INTERNAL_SERVER_ERROR => {
280                    let error_msg = response.text().await?;
281                    Err(Error::InternalServerError(error_msg))
282                }
283                StatusCode::SERVICE_UNAVAILABLE => {
284                    let error_msg = response.text().await?;
285                    Err(Error::ServiceUnavailable(error_msg))
286                }
287                status => Err(Error::UnexpectedStatus(status)),
288            }
289        }
290    }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct Account {
295    #[serde(rename = "type")]
296    pub kind: AccountKind,
297    pub username: String,
298    pub name: String,
299    pub github_url: String,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[serde(rename_all = "lowercase")]
304pub enum AccountKind {
305    Organization,
306    User,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct Collection {
311    pub name: String,
312    pub slug: String,
313    pub description: String,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ListCollections {
318    pub next: Option<String>,
319    pub previous: Option<String>,
320    pub results: Vec<Collection>,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ListCollectionModels {
325    pub name: String,
326    pub slug: String,
327    pub description: String,
328    pub models: Vec<Model>,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct ListPublicModels {
333    pub next: Option<String>,
334    pub previous: Option<String>,
335    pub results: Vec<Model>,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct Model {
340    pub name: String,
341    pub description: Option<String>,
342    pub url: String,
343    pub owner: String,
344    pub visibility: ModelVisibility,
345    pub github_url: Option<String>,
346    pub paper_url: Option<String>,
347    pub license_url: Option<String>,
348    pub run_count: u64,
349    pub cover_image_url: Option<String>,
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
353#[serde(rename_all = "lowercase")]
354pub enum ModelVisibility {
355    Private,
356    Public,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub struct CreateDeployment {
361    /// The name of the deployment.
362    pub name: String,
363
364    /// The full name of the model that you want to deploy e.g. stability-ai/sdxl.
365    pub model: String,
366
367    /// The 64-character string ID of the model version that you want to deploy.
368    pub version: String,
369
370    /// The SKU for the hardware used to run the model.
371    pub hardware: String,
372
373    /// The maximum number of instances for scaling.
374    pub min_instances: u16,
375
376    /// The minimum number of instances for scaling.
377    pub max_instances: u16,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct UpdateDeployment {
382    /// The name of the deployment.
383    pub name: String,
384
385    /// The full name of the model that you want to deploy e.g. stability-ai/sdxl.
386    pub model: String,
387
388    /// The 64-character string ID of the model version that you want to deploy.
389    pub version: String,
390
391    /// The SKU for the hardware used to run the model.
392    pub hardware: String,
393
394    /// The maximum number of instances for scaling.
395    pub min_instances: u16,
396
397    /// The minimum number of instances for scaling.
398    pub max_instances: u16,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct ListDeployments {
403    pub next: Option<String>,
404    pub previous: Option<String>,
405    pub results: Vec<Deployment>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct Deployment {
410    pub owner: String,
411    pub name: String,
412    pub current_release: DeploymentRelease,
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct DeploymentRelease {
417    pub number: u64,
418    pub model: String,
419    pub version: String,
420    pub created_at: String,
421    pub created_by: Account,
422    pub configuration: DeploymentConfiguration,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct DeploymentConfiguration {
427    pub hardware: String,
428    pub min_instances: u16,
429    pub max_instances: u16,
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct CreatePrediction {
434    /// The ID of the model version to run.
435    pub version: String,
436
437    /// The model's input as a JSON object.
438    pub input: serde_json::Value,
439
440    /// An HTTPS URL for receiving a webhook when the prediction has new output.
441    ///
442    /// The webhook will be a POST request where the request body is the same as the response body of the get prediction operation.
443    ///
444    /// **Notes**:
445    ///     - Retries a few times in case of network problems.
446    ///     - It doesn't follow redirects.
447    pub webhook: Option<String>,
448
449    /// Events triggering webhook requests.
450    ///
451    /// **start**: immediately on prediction start
452    /// **output**: each time a prediction generates an output (note that predictions can generate multiple outputs)
453    /// **logs**: each time log output is generated by a prediction
454    /// **completed**: when the prediction reaches a terminal state (succeeded/canceled/failed)
455    ///
456    /// For example, if you only wanted requests to be sent at the start and end of the prediction, you would provide:
457    ///
458    /// ```json
459    /// {
460    ///     "version":
461    ///     "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
462    ///     "input": {
463    ///         "text": "Alice"
464    ///     },
465    ///     "webhook": "https://example.com/my-webhook",
466    ///     "webhook_events_filter": ["start", "completed"]
467    /// }
468    /// ```
469    pub webhook_event_filters: Option<Vec<WebHookEvent>>,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct CreateModelPrediction {
474    /// Model owner
475    pub owner: String,
476
477    /// Model name
478    pub name: String,
479
480    /// The model's input as a JSON object.
481    pub input: serde_json::Value,
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
485#[serde(rename_all = "lowercase")]
486pub enum WebHookEvent {
487    Start,
488    Output,
489    Logs,
490    Completed,
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct ListPredictions {
495    pub next: Option<String>,
496    pub previous: Option<String>,
497    pub results: Vec<Prediction>,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct Prediction {
502    pub id: String,
503    pub model: String,
504    pub version: String,
505    pub input: Option<serde_json::Value>,
506    pub output: Option<serde_json::Value>,
507    pub source: Option<Source>,
508    pub metrics: Option<PredictionMetrics>,
509    pub status: PredictionStatus,
510    pub urls: PredictionUrls,
511    pub logs: Option<String>,
512    pub data_removed: Option<bool>,
513    pub created_at: String,
514    pub started_at: Option<String>,
515    pub completed_at: Option<String>,
516}
517
518#[derive(Debug, Clone, Serialize, Deserialize)]
519#[serde(rename_all = "lowercase")]
520pub enum PredictionStatus {
521    Starting,
522    Processing,
523    Succeeded,
524    Failed,
525    Canceled,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct PredictionMetrics {
530    pub predict_time: f64,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct PredictionUrls {
535    pub get: String,
536    pub cancel: String,
537}
538
539#[derive(Debug, Clone, Serialize, Deserialize)]
540#[serde(rename_all = "lowercase")]
541pub enum Source {
542    Web,
543    Api,
544}
545
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct ListTrainings {
548    pub next: Option<String>,
549    pub previous: Option<String>,
550    pub results: Vec<Training>,
551}
552
553#[derive(Debug, Clone, Serialize, Deserialize)]
554pub struct Training {
555    pub completed_at: String,
556    pub created_at: String,
557    pub id: String,
558    pub input: serde_json::Value,
559    pub metrics: TrainingMetrics,
560    pub output: TrainingOutput,
561    pub started_at: String,
562    pub source: Source,
563    pub status: String,
564    pub urls: TrainingUrls,
565    pub model: String,
566    pub version: String,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TrainingMetrics {
571    pub predict_time: f64,
572}
573
574#[derive(Debug, Clone, Serialize, Deserialize)]
575pub struct TrainingOutput {
576    pub version: String,
577    pub weights: String,
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct TrainingUrls {
582    pub get: String,
583    pub cancel: String,
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize)]
587pub struct Hardware {
588    pub name: String,
589    pub sku: String,
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
593pub struct ListModelVersions {
594    pub next: Option<String>,
595    pub previous: Option<String>,
596    pub results: Vec<ModelVersion>,
597}
598
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct ModelVersion {
601    pub id: String,
602    pub created_at: String,
603    pub cog_version: String,
604}
605
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct WebHookSecret {
608    pub key: String,
609}