runpod_client/
lib.rs

1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4
5const RUNPOD_ENDPOINT: &str = "https://api.runpod.io/graphql";
6
7/// Main client struct for interacting with RunPod.
8pub struct RunpodClient {
9    http_client: Client,
10    api_key: String,
11}
12
13impl RunpodClient {
14    /// Construct a new RunpodClient with your API key.
15    pub fn new(api_key: impl Into<String>) -> Self {
16        RunpodClient {
17            http_client: Client::new(),
18            api_key: api_key.into(),
19        }
20    }
21
22    /// Low-level function to execute a GraphQL query or mutation.
23    /// You normally wouldn't call this directly; instead, use helper methods.
24    async fn graphql_query<T: for<'de> Deserialize<'de>>(
25        &self,
26        graphql_body: &serde_json::Value,
27    ) -> Result<T, reqwest::Error> {
28        let url = format!("{}?api_key={}", RUNPOD_ENDPOINT, self.api_key);
29
30        let response = self
31            .http_client
32            .post(&url)
33            .json(graphql_body)
34            .send()
35            .await?
36            .error_for_status()?
37            .json::<T>()
38            .await?;
39
40        Ok(response)
41    }
42
43    // ---------------------------------------------------------------------
44    // 1) Create an On-Demand Pod
45    // ---------------------------------------------------------------------
46    pub async fn create_on_demand_pod(
47        &self,
48        req: CreateOnDemandPodRequest,
49    ) -> Result<PodCreateResponseData, reqwest::Error> {
50        // Build the GraphQL mutation string.
51        let query = format!(
52            r#"
53            mutation {{
54                podFindAndDeployOnDemand(input: {{
55                    cloudType: {cloud_type},
56                    gpuCount: {gpu_count},
57                    volumeInGb: {volume_in_gb},
58                    containerDiskInGb: {container_disk_in_gb},
59                    minVcpuCount: {min_vcpu_count},
60                    minMemoryInGb: {min_memory_in_gb},
61                    gpuTypeId: "{gpu_type_id}",
62                    name: "{name}",
63                    imageName: "{image_name}",
64                    dockerArgs: "{docker_args}",
65                    ports: "{ports}",
66                    volumeMountPath: "{volume_mount_path}",
67                    env: [{env}]
68                }}) {{
69                    id
70                    imageName
71                    env {{ key value }}
72                    machineId
73                    machine {{ podHostId }}
74                }}
75            }}
76            "#,
77            cloud_type = req.cloud_type,
78            gpu_count = req.gpu_count,
79            volume_in_gb = req.volume_in_gb,
80            container_disk_in_gb = req.container_disk_in_gb,
81            min_vcpu_count = req.min_vcpu_count,
82            min_memory_in_gb = req.min_memory_in_gb,
83            gpu_type_id = req.gpu_type_id,
84            name = req.name,
85            image_name = req.image_name,
86            docker_args = req.docker_args,
87            ports = req.ports,
88            volume_mount_path = req.volume_mount_path,
89            env = env_to_string(&req.env),
90        );
91
92        let body = json!({ "query": query });
93
94        // Execute the request
95        let resp: GraphQLResponse<PodCreateResponse> = self.graphql_query(&body).await?;
96        Ok(PodCreateResponseData {
97            data: resp.data.map(|d| d.pod_find_and_deploy_on_demand),
98            errors: resp.errors,
99        })
100    }
101
102    // ---------------------------------------------------------------------
103    // 2) Create a Spot (Interruptible) Pod
104    // ---------------------------------------------------------------------
105    pub async fn create_spot_pod(
106        &self,
107        req: CreateSpotPodRequest,
108    ) -> Result<PodCreateResponseData, reqwest::Error> {
109        let query = format!(
110            r#"
111            mutation {{
112                podRentInterruptable(input: {{
113                    bidPerGpu: {bid_per_gpu},
114                    cloudType: {cloud_type},
115                    gpuCount: {gpu_count},
116                    volumeInGb: {volume_in_gb},
117                    containerDiskInGb: {container_disk_in_gb},
118                    minVcpuCount: {min_vcpu_count},
119                    minMemoryInGb: {min_memory_in_gb},
120                    gpuTypeId: "{gpu_type_id}",
121                    name: "{name}",
122                    imageName: "{image_name}",
123                    dockerArgs: "{docker_args}",
124                    ports: "{ports}",
125                    volumeMountPath: "{volume_mount_path}",
126                    env: [{env}]
127                }}) {{
128                    id
129                    imageName
130                    env {{ key value }}
131                    machineId
132                    machine {{ podHostId }}
133                }}
134            }}
135            "#,
136            bid_per_gpu = req.bid_per_gpu,
137            cloud_type = req.cloud_type,
138            gpu_count = req.gpu_count,
139            volume_in_gb = req.volume_in_gb,
140            container_disk_in_gb = req.container_disk_in_gb,
141            min_vcpu_count = req.min_vcpu_count,
142            min_memory_in_gb = req.min_memory_in_gb,
143            gpu_type_id = req.gpu_type_id,
144            name = req.name,
145            image_name = req.image_name,
146            docker_args = req.docker_args,
147            ports = req.ports,
148            volume_mount_path = req.volume_mount_path,
149            env = env_to_string(&req.env),
150        );
151
152        let body = json!({ "query": query });
153
154        let resp: GraphQLResponse<PodCreateResponse> = self.graphql_query(&body).await?;
155        Ok(PodCreateResponseData {
156            data: resp.data.map(|d| d.pod_rent_interruptable),
157            errors: resp.errors,
158        })
159    }
160
161    // ---------------------------------------------------------------------
162    // 3) Start (Resume) a Pod (On-Demand)
163    // ---------------------------------------------------------------------
164    pub async fn start_on_demand_pod(
165        &self,
166        pod_id: &str,
167        gpu_count: i32,
168    ) -> Result<PodStartResponseData, reqwest::Error> {
169        let query = format!(
170            r#"
171            mutation {{
172                podResume(input: {{
173                    podId: "{pod_id}",
174                    gpuCount: {gpu_count}
175                }}) {{
176                    id
177                    desiredStatus
178                    imageName
179                    env {{ key value }}
180                    machineId
181                    machine {{ podHostId }}
182                }}
183            }}
184            "#,
185            pod_id = pod_id,
186            gpu_count = gpu_count
187        );
188        let body = json!({ "query": query });
189
190        let resp: GraphQLResponse<PodStartResponse> = self.graphql_query(&body).await?;
191        Ok(PodStartResponseData {
192            data: resp.data.and_then(|d| d.pod_resume),
193            errors: resp.errors,
194        })
195    }
196
197    // ---------------------------------------------------------------------
198    // 4) Start (Resume) a Pod (Spot)
199    // ---------------------------------------------------------------------
200    pub async fn start_spot_pod(
201        &self,
202        pod_id: &str,
203        bid_per_gpu: f64,
204        gpu_count: i32,
205    ) -> Result<PodStartResponseData, reqwest::Error> {
206        let query = format!(
207            r#"
208            mutation {{
209                podBidResume(input: {{
210                    podId: "{pod_id}",
211                    bidPerGpu: {bid_per_gpu},
212                    gpuCount: {gpu_count}
213                }}) {{
214                    id
215                    desiredStatus
216                    imageName
217                    env {{ key value }}
218                    machineId
219                    machine {{ podHostId }}
220                }}
221            }}
222            "#,
223            pod_id = pod_id,
224            bid_per_gpu = bid_per_gpu,
225            gpu_count = gpu_count
226        );
227        let body = json!({ "query": query });
228
229        let resp: GraphQLResponse<PodStartResponse> = self.graphql_query(&body).await?;
230        Ok(PodStartResponseData {
231            data: resp.data.and_then(|d| d.pod_bid_resume),
232            errors: resp.errors,
233        })
234    }
235
236    // ---------------------------------------------------------------------
237    // 5) Stop a Pod
238    // ---------------------------------------------------------------------
239    pub async fn stop_pod(&self, pod_id: &str) -> Result<PodStopResponseData, reqwest::Error> {
240        let query = format!(
241            r#"
242            mutation {{
243                podStop(input: {{
244                    podId: "{pod_id}"
245                }}) {{
246                    id
247                    desiredStatus
248                }}
249            }}
250            "#,
251            pod_id = pod_id
252        );
253        let body = json!({ "query": query });
254
255        let resp: GraphQLResponse<PodStopResponse> = self.graphql_query(&body).await?;
256        Ok(PodStopResponseData {
257            data: resp.data.map(|d| d.pod_stop),
258            errors: resp.errors,
259        })
260    }
261
262    // ---------------------------------------------------------------------
263    // 6) List all Pods
264    // ---------------------------------------------------------------------
265    pub async fn list_pods(&self) -> Result<PodsListResponseData, reqwest::Error> {
266        let query = r#"
267            query Pods {
268                myself {
269                    pods {
270                        id
271                        name
272                        runtime {
273                            uptimeInSeconds
274                            ports {
275                                ip
276                                isIpPublic
277                                privatePort
278                                publicPort
279                                type
280                            }
281                            gpus {
282                                id
283                                gpuUtilPercent
284                                memoryUtilPercent
285                            }
286                            container {
287                                cpuPercent
288                                memoryPercent
289                            }
290                        }
291                    }
292                }
293            }
294        "#;
295        let body = json!({ "query": query });
296
297        let resp: GraphQLResponse<PodsListResponse> = self.graphql_query(&body).await?;
298        Ok(PodsListResponseData {
299            data: resp.data.map(|d| d.myself),
300            errors: resp.errors,
301        })
302    }
303
304    // ---------------------------------------------------------------------
305    // 7) Get Pod by ID
306    // ---------------------------------------------------------------------
307    pub async fn get_pod(&self, pod_id: &str) -> Result<PodInfoResponseData, reqwest::Error> {
308        let query = format!(
309            r#"
310            query Pod {{
311                pod(input: {{
312                    podId: "{pod_id}"
313                }}) {{
314                    id
315                    name
316                    runtime {{
317                        uptimeInSeconds
318                        ports {{
319                            ip
320                            isIpPublic
321                            privatePort
322                            publicPort
323                            type
324                        }}
325                        gpus {{
326                            id
327                            gpuUtilPercent
328                            memoryUtilPercent
329                        }}
330                        container {{
331                            cpuPercent
332                            memoryPercent
333                        }}
334                    }}
335                }}
336            }}
337            "#,
338            pod_id = pod_id
339        );
340        let body = json!({ "query": query });
341
342        let resp: GraphQLResponse<PodInfoResponse> = self.graphql_query(&body).await?;
343        Ok(PodInfoResponseData {
344            data: resp.data.map(|d| d.pod),
345            errors: resp.errors,
346        })
347    }
348
349    // ---------------------------------------------------------------------
350    // 8) List GPU Types
351    // ---------------------------------------------------------------------
352    pub async fn list_gpu_types(&self) -> Result<GPUTypesListResponseData, reqwest::Error> {
353        let query = r#"
354            query GpuTypes {
355                gpuTypes {
356                    id
357                    displayName
358                    memoryInGb
359                }
360            }
361        "#;
362        let body = json!({ "query": query });
363
364        let resp: GraphQLResponse<GPUTypesListResponse> = self.graphql_query(&body).await?;
365        Ok(GPUTypesListResponseData {
366            data: resp.data.map(|d| d.gpu_types),
367            errors: resp.errors,
368        })
369    }
370
371    // ---------------------------------------------------------------------
372    // 9) Get GPU Type by ID
373    // ---------------------------------------------------------------------
374    pub async fn get_gpu_type(
375        &self,
376        gpu_type_id: &str,
377    ) -> Result<GPUTypeResponseData, reqwest::Error> {
378        let query = format!(
379            r#"
380            query GpuTypes {{
381                gpuTypes(input: {{
382                    id: "{gpu_type_id}"
383                }}) {{
384                    id
385                    displayName
386                    memoryInGb
387                    secureCloud
388                    communityCloud
389                    lowestPrice(input: {{gpuCount: 1}}) {{
390                        minimumBidPrice
391                        uninterruptablePrice
392                    }}
393                }}
394            }}
395            "#,
396            gpu_type_id = gpu_type_id
397        );
398        let body = json!({ "query": query });
399
400        let resp: GraphQLResponse<GPUTypesExtendedResponse> = self.graphql_query(&body).await?;
401        Ok(GPUTypeResponseData {
402            data: resp.data.map(|d| d.gpu_types),
403            errors: resp.errors,
404        })
405    }
406}
407
408// ---------------------------------------------------------------------
409// Helper to turn env Vec<EnvVar> into a GraphQL list string
410// ---------------------------------------------------------------------
411fn env_to_string(env: &[EnvVar]) -> String {
412    env.iter()
413        .map(|env_var| {
414            format!(
415                r#"{{ key: "{}", value: "{}" }}"#,
416                env_var.key, env_var.value
417            )
418        })
419        .collect::<Vec<String>>()
420        .join(", ")
421}
422
423// ---------------------------------------------------------------------
424// GraphQL request/response data structures
425// ---------------------------------------------------------------------
426
427#[derive(Debug, Serialize, Deserialize, Default)]
428pub struct GraphQLResponse<T> {
429    pub data: Option<T>,
430    pub errors: Option<Vec<GraphQLError>>,
431}
432
433#[derive(Debug, Serialize, Deserialize, Default)]
434pub struct GraphQLError {
435    pub message: String,
436    // You can add more fields if needed (e.g., locations, etc.)
437}
438
439// ---------------------------------------------------------------------
440// 1) Create On-Demand Pod
441// ---------------------------------------------------------------------
442#[derive(Debug, Serialize, Deserialize, Default)]
443pub struct CreateOnDemandPodRequest {
444    pub cloud_type: String,        // e.g. "ALL"
445    pub gpu_count: i32,            // e.g. 1
446    pub volume_in_gb: i32,         // e.g. 40
447    pub container_disk_in_gb: i32, // e.g. 40
448    pub min_vcpu_count: i32,       // e.g. 2
449    pub min_memory_in_gb: i32,     // e.g. 15
450    pub gpu_type_id: String,       // e.g. "NVIDIA RTX A6000"
451    pub name: String,              // e.g. "RunPod Tensorflow"
452    pub image_name: String,        // e.g. "runpod/tensorflow"
453    pub docker_args: String,       // e.g. ""
454    pub ports: String,             // e.g. "8888/http"
455    pub volume_mount_path: String, // e.g. "/workspace"
456    pub env: Vec<EnvVar>,
457}
458
459#[derive(Debug, Serialize, Deserialize, Default)]
460pub struct PodCreateResponse {
461    #[serde(rename = "podFindAndDeployOnDemand")]
462    pub pod_find_and_deploy_on_demand: PodInfoMinimal,
463    #[serde(rename = "podRentInterruptable", default)]
464    pub pod_rent_interruptable: PodInfoMinimal,
465}
466
467#[derive(Debug, Serialize, Deserialize, Default)]
468pub struct PodCreateResponseData {
469    pub data: Option<PodInfoMinimal>, // Because we can get either from OnDemand or Spot creation
470    pub errors: Option<Vec<GraphQLError>>,
471}
472
473// ---------------------------------------------------------------------
474// 2) Create Spot Pod
475// ---------------------------------------------------------------------
476#[derive(Debug, Serialize, Deserialize, Default)]
477pub struct CreateSpotPodRequest {
478    pub bid_per_gpu: f64,          // e.g. 0.2
479    pub cloud_type: String,        // e.g. "SECURE"
480    pub gpu_count: i32,            // e.g. 1
481    pub volume_in_gb: i32,         // e.g. 40
482    pub container_disk_in_gb: i32, // e.g. 40
483    pub min_vcpu_count: i32,       // e.g. 2
484    pub min_memory_in_gb: i32,     // e.g. 15
485    pub gpu_type_id: String,       // e.g. "NVIDIA RTX A6000"
486    pub name: String,              // e.g. "RunPod Pytorch"
487    pub image_name: String,        // e.g. "runpod/pytorch"
488    pub docker_args: String,       // e.g. ""
489    pub ports: String,             // e.g. "8888/http"
490    pub volume_mount_path: String, // e.g. "/workspace"
491    pub env: Vec<EnvVar>,
492}
493
494// ---------------------------------------------------------------------
495// 3) 4) Start Pod response
496// ---------------------------------------------------------------------
497#[derive(Debug, Serialize, Deserialize, Default)]
498pub struct PodStartResponse {
499    #[serde(rename = "podResume", default)]
500    pub pod_resume: Option<PodInfoMinimal>,
501
502    #[serde(rename = "podBidResume", default)]
503    pub pod_bid_resume: Option<PodInfoMinimal>,
504}
505
506#[derive(Debug, Serialize, Deserialize, Default)]
507pub struct PodStartResponseData {
508    pub data: Option<PodInfoMinimal>,
509    pub errors: Option<Vec<GraphQLError>>,
510}
511
512// ---------------------------------------------------------------------
513// 5) Stop Pod
514// ---------------------------------------------------------------------
515#[derive(Debug, Serialize, Deserialize, Default)]
516pub struct PodStopResponse {
517    #[serde(rename = "podStop")]
518    pub pod_stop: PodInfoMinimalStop,
519}
520
521#[derive(Debug, Serialize, Deserialize, Default)]
522pub struct PodStopResponseData {
523    pub data: Option<PodInfoMinimalStop>,
524    pub errors: Option<Vec<GraphQLError>>,
525}
526
527// ---------------------------------------------------------------------
528// 6) List Pods
529// ---------------------------------------------------------------------
530#[derive(Debug, Serialize, Deserialize, Default)]
531pub struct PodsListResponse {
532    pub myself: MyselfPods,
533}
534
535#[derive(Debug, Serialize, Deserialize, Default)]
536pub struct MyselfPods {
537    pub pods: Vec<PodInfoFull>,
538}
539
540#[derive(Debug, Serialize, Deserialize, Default)]
541pub struct PodsListResponseData {
542    pub data: Option<MyselfPods>,
543    pub errors: Option<Vec<GraphQLError>>,
544}
545
546// ---------------------------------------------------------------------
547// 7) Get Pod by ID
548// ---------------------------------------------------------------------
549#[derive(Debug, Serialize, Deserialize, Default)]
550pub struct PodInfoResponse {
551    pub pod: PodInfoFull,
552}
553
554#[derive(Debug, Serialize, Deserialize, Default)]
555pub struct PodInfoResponseData {
556    pub data: Option<PodInfoFull>,
557    pub errors: Option<Vec<GraphQLError>>,
558}
559
560// ---------------------------------------------------------------------
561// 8) List GPU Types
562// ---------------------------------------------------------------------
563#[derive(Debug, Serialize, Deserialize, Default)]
564pub struct GPUTypesListResponse {
565    #[serde(rename = "gpuTypes")]
566    pub gpu_types: Vec<GpuTypeMinimal>,
567}
568
569#[derive(Debug, Serialize, Deserialize, Default)]
570pub struct GPUTypesListResponseData {
571    pub data: Option<Vec<GpuTypeMinimal>>,
572    pub errors: Option<Vec<GraphQLError>>,
573}
574
575// ---------------------------------------------------------------------
576// 9) Get GPU Type by ID
577// ---------------------------------------------------------------------
578#[derive(Debug, Serialize, Deserialize, Default)]
579pub struct GPUTypesExtendedResponse {
580    #[serde(rename = "gpuTypes")]
581    pub gpu_types: Vec<GpuTypeExtended>,
582}
583
584#[derive(Debug, Serialize, Deserialize, Default)]
585pub struct GPUTypeResponseData {
586    pub data: Option<Vec<GpuTypeExtended>>,
587    pub errors: Option<Vec<GraphQLError>>,
588}
589
590// ---------------------------------------------------------------------
591// Common Data Structures
592// ---------------------------------------------------------------------
593#[derive(Debug, Serialize, Deserialize, Default)]
594pub struct PodInfoMinimal {
595    pub id: String,
596    #[serde(rename = "imageName")]
597    pub image_name: String,
598    pub env: Vec<EnvVar>,
599    #[serde(rename = "machineId")]
600    pub machine_id: String,
601    pub machine: MachineHost,
602}
603
604#[derive(Debug, Serialize, Deserialize, Default)]
605pub struct PodInfoMinimalStop {
606    pub id: String,
607    #[serde(rename = "desiredStatus")]
608    pub desired_status: Option<String>,
609}
610
611#[derive(Debug, Serialize, Deserialize, Default)]
612pub struct EnvVar {
613    pub key: String,
614    pub value: String,
615}
616
617#[derive(Debug, Serialize, Deserialize, Default)]
618pub struct MachineHost {
619    #[serde(rename = "podHostId")]
620    pub pod_host_id: Option<String>,
621}
622
623#[derive(Debug, Serialize, Deserialize, Default)]
624pub struct PodInfoFull {
625    pub id: String,
626    pub name: String,
627    pub runtime: Option<PodRuntime>,
628}
629
630#[derive(Debug, Serialize, Deserialize, Default)]
631pub struct PodRuntime {
632    #[serde(rename = "uptimeInSeconds")]
633    pub uptime_in_seconds: Option<i64>,
634    pub ports: Option<Vec<PortInfo>>,
635    pub gpus: Option<Vec<GpuInfo>>,
636    pub container: Option<ContainerInfo>,
637}
638
639#[derive(Debug, Serialize, Deserialize, Default)]
640pub struct PortInfo {
641    pub ip: Option<String>,
642    #[serde(rename = "isIpPublic")]
643    pub is_ip_public: Option<bool>,
644    #[serde(rename = "privatePort")]
645    pub private_port: Option<i32>,
646    #[serde(rename = "publicPort")]
647    pub public_port: Option<i32>,
648    #[serde(rename = "type")]
649    pub port_type: Option<String>,
650}
651
652#[derive(Debug, Serialize, Deserialize, Default)]
653pub struct GpuInfo {
654    pub id: Option<String>,
655    #[serde(rename = "gpuUtilPercent")]
656    pub gpu_util_percent: Option<f64>,
657    #[serde(rename = "memoryUtilPercent")]
658    pub memory_util_percent: Option<f64>,
659}
660
661#[derive(Debug, Serialize, Deserialize, Default)]
662pub struct ContainerInfo {
663    #[serde(rename = "cpuPercent")]
664    pub cpu_percent: Option<f64>,
665    #[serde(rename = "memoryPercent")]
666    pub memory_percent: Option<f64>,
667}
668
669#[derive(Debug, Serialize, Deserialize, Default)]
670pub struct GpuTypeMinimal {
671    pub id: String,
672    #[serde(rename = "displayName")]
673    pub display_name: String,
674    #[serde(rename = "memoryInGb")]
675    pub memory_in_gb: Option<i32>,
676}
677
678#[derive(Debug, Serialize, Deserialize, Default)]
679pub struct GpuTypeExtended {
680    pub id: String,
681    #[serde(rename = "displayName")]
682    pub display_name: String,
683    #[serde(rename = "memoryInGb")]
684    pub memory_in_gb: Option<i32>,
685    #[serde(rename = "secureCloud")]
686    pub secure_cloud: Option<bool>,
687    #[serde(rename = "communityCloud")]
688    pub community_cloud: Option<bool>,
689    pub lowest_price: Option<LowestPrice>,
690}
691
692#[derive(Debug, Serialize, Deserialize, Default)]
693pub struct LowestPrice {
694    #[serde(rename = "minimumBidPrice")]
695    pub minimum_bid_price: Option<f64>,
696    #[serde(rename = "uninterruptablePrice")]
697    pub uninterruptable_price: Option<f64>,
698}