1use reqwest::{
2 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
3 Method, RequestBuilder, Response, StatusCode, Url,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::{config::Config, error::Error};
8
9pub struct Client {
10 api_key: String,
11 base_url: Url,
12 http_client: reqwest::Client,
13}
14
15impl Client {
16 pub fn new(config: Config) -> Result<Self, Error> {
17 let mut headers = HeaderMap::new();
18 headers.insert(
19 AUTHORIZATION,
20 HeaderValue::from_str(format!("Bearer {}", config.api_key.as_str()).as_str())?,
21 );
22 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
23
24 let http_client = reqwest::Client::builder()
25 .default_headers(headers)
26 .build()?;
27
28 let base_url =
29 Url::parse(&config.base_url).map_err(|err| Error::UrlParse(err.to_string()))?;
30
31 Ok(Self {
32 api_key: config.api_key,
33 base_url,
34 http_client,
35 })
36 }
37
38 pub fn api_key(&self) -> &str {
40 self.api_key.as_str()
41 }
42
43 pub fn base_url(&self) -> &str {
45 self.base_url.as_str()
46 }
47
48 pub async fn account(&self) -> Result<Account, Error> {
50 let response = self.request(Method::GET, "account")?.send().await?;
51 self.handle_response::<Account>(response).await
52 }
53
54 pub async fn collections(&self) -> Result<ListCollections, Error> {
56 let response = self.request(Method::GET, "collections")?.send().await?;
57 self.handle_response::<ListCollections>(response).await
58 }
59
60 pub async fn collection_models(
62 &self,
63 collection: String,
64 ) -> Result<ListCollectionModels, Error> {
65 let path = format!("collections/{}", collection);
66 let response = self.request(Method::GET, path.as_str())?.send().await?;
67 self.handle_response::<ListCollectionModels>(response).await
68 }
69
70 pub async fn deployment(&self, owner: String, name: String) -> Result<Deployment, Error> {
72 let path = format!("deployments/{}/{}", owner, name);
73 let response = self.request(Method::GET, path.as_str())?.send().await?;
74 self.handle_response::<Deployment>(response).await
75 }
76
77 pub async fn deployments(&self) -> Result<ListDeployments, Error> {
79 let response = self.request(Method::GET, "deployments")?.send().await?;
80 self.handle_response::<ListDeployments>(response).await
81 }
82
83 pub async fn create_deployment(&self, payload: CreateDeployment) -> Result<Deployment, Error> {
85 let response = self
86 .request(Method::POST, "deployments")?
87 .json(&payload)
88 .send()
89 .await?;
90 self.handle_response::<Deployment>(response).await
91 }
92
93 pub async fn update_deployment(
95 &self,
96 owner: String,
97 payload: UpdateDeployment,
98 ) -> Result<Deployment, Error> {
99 let path = format!("deployments/{}/{}", owner, payload.name);
100 let response = self
101 .request(Method::PATCH, path.as_str())?
102 .json(&payload)
103 .send()
104 .await?;
105 self.handle_response::<Deployment>(response).await
106 }
107
108 pub async fn delete_deployment(&self, owner: String, name: String) -> Result<(), Error> {
113 let path = format!("deployments/{}/{}", owner, name);
114 let response = self.request(Method::DELETE, path.as_str())?.send().await?;
115 self.handle_response(response).await
116 }
117
118 pub async fn prediction(&self, prediction_id: String) -> Result<Prediction, Error> {
120 let path = format!("predictions/{}", prediction_id);
121 let response = self.request(Method::GET, path.as_str())?.send().await?;
122 self.handle_response::<Prediction>(response).await
123 }
124
125 pub async fn predictions(&self) -> Result<ListPredictions, Error> {
127 let response = self.request(Method::GET, "predictions")?.send().await?;
128 self.handle_response::<ListPredictions>(response).await
129 }
130
131 pub async fn create_prediction(&self, payload: CreatePrediction) -> Result<Prediction, Error> {
133 let response = self
134 .request(Method::POST, "predictions")?
135 .json(&payload)
136 .send()
137 .await?;
138 self.handle_response::<Prediction>(response).await
139 }
140
141 pub async fn create_model_prediction(
143 &self,
144 payload: CreateModelPrediction,
145 ) -> Result<Prediction, Error> {
146 let path = format!("models/{}/{}/predictions", payload.owner, payload.name);
147 let response = self
148 .request(Method::POST, path.as_str())?
149 .json(&serde_json::json!({ "input": payload.input }))
150 .send()
151 .await?;
152 self.handle_response::<Prediction>(response).await
153 }
154
155 pub async fn cancel_prediction(&self, prediction_id: String) -> Result<(), Error> {
157 let path = format!("predictions/{}/cancel", prediction_id);
158 let response = self.request(Method::POST, path.as_str())?.send().await?;
159 self.handle_response(response).await
160 }
161
162 pub async fn training(&self, training_id: String) -> Result<Training, Error> {
164 let path = format!("trainings/{}", training_id);
165 let response = self.request(Method::GET, path.as_str())?.send().await?;
166 self.handle_response::<Training>(response).await
167 }
168
169 pub async fn trainings(&self) -> Result<ListTrainings, Error> {
171 let response = self.request(Method::GET, "trainings")?.send().await?;
172 self.handle_response::<ListTrainings>(response).await
173 }
174
175 pub async fn cancel_training(&self, training_id: String) -> Result<(), Error> {
177 let path = format!("trainings/{}/cancel", training_id);
178 let response = self.request(Method::POST, path.as_str())?.send().await?;
179 self.handle_response(response).await
180 }
181
182 pub async fn hardware(&self) -> Result<Vec<Hardware>, Error> {
184 let response = self.request(Method::GET, "hardware")?.send().await?;
185 self.handle_response::<Vec<Hardware>>(response).await
186 }
187
188 pub async fn public_models(&self) -> Result<ListPublicModels, Error> {
190 let response = self.request(Method::GET, "models")?.send().await?;
191 self.handle_response::<ListPublicModels>(response).await
192 }
193
194 pub async fn model(
196 &self,
197 owner: impl Into<String>,
198 name: impl Into<String>,
199 ) -> Result<Model, Error> {
200 let path = format!("models/{}/{}", owner.into(), name.into());
201 let response = self.request(Method::GET, path.as_str())?.send().await?;
202 self.handle_response::<Model>(response).await
203 }
204
205 pub async fn model_versions(
207 &self,
208 owner: impl Into<String>,
209 name: impl Into<String>,
210 ) -> Result<ListModelVersions, Error> {
211 let path = format!("models/{}/{}/versions", owner.into(), name.into());
212 let response = self.request(Method::GET, path.as_str())?.send().await?;
213 self.handle_response::<ListModelVersions>(response).await
214 }
215
216 pub async fn model_version(
218 &self,
219 owner: impl Into<String>,
220 name: impl Into<String>,
221 version_id: impl Into<String>,
222 ) -> Result<ModelVersion, Error> {
223 let path = format!(
224 "models/{}/{}/versions/{}",
225 owner.into(),
226 name.into(),
227 version_id.into()
228 );
229 let response = self.request(Method::GET, path.as_str())?.send().await?;
230 self.handle_response::<ModelVersion>(response).await
231 }
232
233 pub async fn webhook_default_secret(&self) -> Result<WebHookSecret, Error> {
235 let response = self
236 .request(Method::GET, "webhooks/default/secret")?
237 .send()
238 .await?;
239 self.handle_response::<WebHookSecret>(response).await
240 }
241
242 fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
243 let url = self
244 .base_url
245 .join(path)
246 .map_err(|err| Error::UrlParse(err.to_string()))?;
247 Ok(self.http_client.request(method, url))
248 }
249
250 async fn handle_response<T>(&self, response: Response) -> Result<T, Error>
251 where
252 T: serde::de::DeserializeOwned,
253 {
254 let status = response.status();
255 if status.is_success() | status.is_redirection() {
256 match response.json::<T>().await {
257 Ok(data) => Ok(data),
258 Err(err) => Err(Error::HttpRequest(err)),
260 }
261 } else {
262 match status {
263 StatusCode::BAD_REQUEST => {
264 let error_msg = response.text().await?;
265 Err(Error::BadRequest(error_msg))
266 }
267 StatusCode::UNAUTHORIZED => {
268 let error_msg = response.text().await?;
269 Err(Error::Unauthorized(error_msg))
270 }
271 StatusCode::FORBIDDEN => {
272 let error_msg = response.text().await?;
273 Err(Error::Forbidden(error_msg))
274 }
275 StatusCode::TOO_MANY_REQUESTS => {
276 let error_msg = response.text().await?;
277 Err(Error::RateLimited(error_msg))
278 }
279 StatusCode::INTERNAL_SERVER_ERROR => {
280 let error_msg = response.text().await?;
281 Err(Error::InternalServerError(error_msg))
282 }
283 StatusCode::SERVICE_UNAVAILABLE => {
284 let error_msg = response.text().await?;
285 Err(Error::ServiceUnavailable(error_msg))
286 }
287 status => Err(Error::UnexpectedStatus(status)),
288 }
289 }
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct Account {
295 #[serde(rename = "type")]
296 pub kind: AccountKind,
297 pub username: String,
298 pub name: String,
299 pub github_url: String,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[serde(rename_all = "lowercase")]
304pub enum AccountKind {
305 Organization,
306 User,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct Collection {
311 pub name: String,
312 pub slug: String,
313 pub description: String,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ListCollections {
318 pub next: Option<String>,
319 pub previous: Option<String>,
320 pub results: Vec<Collection>,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ListCollectionModels {
325 pub name: String,
326 pub slug: String,
327 pub description: String,
328 pub models: Vec<Model>,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct ListPublicModels {
333 pub next: Option<String>,
334 pub previous: Option<String>,
335 pub results: Vec<Model>,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct Model {
340 pub name: String,
341 pub description: Option<String>,
342 pub url: String,
343 pub owner: String,
344 pub visibility: ModelVisibility,
345 pub github_url: Option<String>,
346 pub paper_url: Option<String>,
347 pub license_url: Option<String>,
348 pub run_count: u64,
349 pub cover_image_url: Option<String>,
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
353#[serde(rename_all = "lowercase")]
354pub enum ModelVisibility {
355 Private,
356 Public,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub struct CreateDeployment {
361 pub name: String,
363
364 pub model: String,
366
367 pub version: String,
369
370 pub hardware: String,
372
373 pub min_instances: u16,
375
376 pub max_instances: u16,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct UpdateDeployment {
382 pub name: String,
384
385 pub model: String,
387
388 pub version: String,
390
391 pub hardware: String,
393
394 pub min_instances: u16,
396
397 pub max_instances: u16,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct ListDeployments {
403 pub next: Option<String>,
404 pub previous: Option<String>,
405 pub results: Vec<Deployment>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct Deployment {
410 pub owner: String,
411 pub name: String,
412 pub current_release: DeploymentRelease,
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct DeploymentRelease {
417 pub number: u64,
418 pub model: String,
419 pub version: String,
420 pub created_at: String,
421 pub created_by: Account,
422 pub configuration: DeploymentConfiguration,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct DeploymentConfiguration {
427 pub hardware: String,
428 pub min_instances: u16,
429 pub max_instances: u16,
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct CreatePrediction {
434 pub version: String,
436
437 pub input: serde_json::Value,
439
440 pub webhook: Option<String>,
448
449 pub webhook_event_filters: Option<Vec<WebHookEvent>>,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct CreateModelPrediction {
474 pub owner: String,
476
477 pub name: String,
479
480 pub input: serde_json::Value,
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
485#[serde(rename_all = "lowercase")]
486pub enum WebHookEvent {
487 Start,
488 Output,
489 Logs,
490 Completed,
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct ListPredictions {
495 pub next: Option<String>,
496 pub previous: Option<String>,
497 pub results: Vec<Prediction>,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct Prediction {
502 pub id: String,
503 pub model: String,
504 pub version: String,
505 pub input: Option<serde_json::Value>,
506 pub output: Option<serde_json::Value>,
507 pub source: Option<Source>,
508 pub metrics: Option<PredictionMetrics>,
509 pub status: PredictionStatus,
510 pub urls: PredictionUrls,
511 pub logs: Option<String>,
512 pub data_removed: Option<bool>,
513 pub created_at: String,
514 pub started_at: Option<String>,
515 pub completed_at: Option<String>,
516}
517
518#[derive(Debug, Clone, Serialize, Deserialize)]
519#[serde(rename_all = "lowercase")]
520pub enum PredictionStatus {
521 Starting,
522 Processing,
523 Succeeded,
524 Failed,
525 Canceled,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct PredictionMetrics {
530 pub predict_time: f64,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct PredictionUrls {
535 pub get: String,
536 pub cancel: String,
537}
538
539#[derive(Debug, Clone, Serialize, Deserialize)]
540#[serde(rename_all = "lowercase")]
541pub enum Source {
542 Web,
543 Api,
544}
545
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct ListTrainings {
548 pub next: Option<String>,
549 pub previous: Option<String>,
550 pub results: Vec<Training>,
551}
552
553#[derive(Debug, Clone, Serialize, Deserialize)]
554pub struct Training {
555 pub completed_at: String,
556 pub created_at: String,
557 pub id: String,
558 pub input: serde_json::Value,
559 pub metrics: TrainingMetrics,
560 pub output: TrainingOutput,
561 pub started_at: String,
562 pub source: Source,
563 pub status: String,
564 pub urls: TrainingUrls,
565 pub model: String,
566 pub version: String,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TrainingMetrics {
571 pub predict_time: f64,
572}
573
574#[derive(Debug, Clone, Serialize, Deserialize)]
575pub struct TrainingOutput {
576 pub version: String,
577 pub weights: String,
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct TrainingUrls {
582 pub get: String,
583 pub cancel: String,
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize)]
587pub struct Hardware {
588 pub name: String,
589 pub sku: String,
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
593pub struct ListModelVersions {
594 pub next: Option<String>,
595 pub previous: Option<String>,
596 pub results: Vec<ModelVersion>,
597}
598
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct ModelVersion {
601 pub id: String,
602 pub created_at: String,
603 pub cog_version: String,
604}
605
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct WebHookSecret {
608 pub key: String,
609}