1use crate::error::MoldError;
4use anyhow::Result;
5use reqwest::{Client, StatusCode};
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9pub const DEFAULT_ENDPOINT: &str = "https://cloud.lambda.ai/api/v1";
10pub const API_KEY_ENV: &str = "LAMBDA_API_KEY";
11pub const DEFAULT_IMAGE_REPOSITORY: &str = "ghcr.io/utensils/mold";
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct LambdaSettings {
15 #[serde(default, skip_serializing_if = "Option::is_none")]
16 pub api_key: Option<String>,
17 #[serde(
18 default = "default_endpoint_opt",
19 skip_serializing_if = "Option::is_none"
20 )]
21 pub endpoint: Option<String>,
22 #[serde(
23 default = "default_image_repository_opt",
24 skip_serializing_if = "Option::is_none"
25 )]
26 pub image_repository: Option<String>,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub ssh_key_name: Option<String>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub ssh_private_key_path: Option<String>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub filesystem_prefix: Option<String>,
33 #[serde(default = "default_filesystem_mount_path")]
34 pub filesystem_mount_path: String,
35 #[serde(default = "default_confirm_hourly_usd")]
36 pub confirm_hourly_usd: f64,
37 #[serde(default = "default_local_port")]
38 pub local_port: u16,
39}
40
41impl Default for LambdaSettings {
42 fn default() -> Self {
43 Self {
44 api_key: None,
45 endpoint: default_endpoint_opt(),
46 image_repository: default_image_repository_opt(),
47 ssh_key_name: None,
48 ssh_private_key_path: None,
49 filesystem_prefix: None,
50 filesystem_mount_path: default_filesystem_mount_path(),
51 confirm_hourly_usd: default_confirm_hourly_usd(),
52 local_port: default_local_port(),
53 }
54 }
55}
56
57fn default_endpoint_opt() -> Option<String> {
58 Some(DEFAULT_ENDPOINT.to_string())
59}
60
61fn default_image_repository_opt() -> Option<String> {
62 Some(DEFAULT_IMAGE_REPOSITORY.to_string())
63}
64
65fn default_filesystem_mount_path() -> String {
66 "/data/mold".to_string()
67}
68
69fn default_confirm_hourly_usd() -> f64 {
70 5.0
71}
72
73fn default_local_port() -> u16 {
74 7680
75}
76
77impl LambdaSettings {
78 pub fn resolved_api_key(&self) -> Option<String> {
79 std::env::var(API_KEY_ENV)
80 .ok()
81 .filter(|s| !s.is_empty())
82 .or_else(|| self.api_key.clone())
83 }
84
85 pub fn endpoint(&self) -> &str {
86 self.endpoint.as_deref().unwrap_or(DEFAULT_ENDPOINT)
87 }
88
89 pub fn image_repository(&self) -> &str {
90 self.image_repository
91 .as_deref()
92 .unwrap_or(DEFAULT_IMAGE_REPOSITORY)
93 }
94
95 pub fn redacted_debug(&self) -> String {
96 format!(
97 "LambdaSettings {{ api_key: {}, endpoint: {:?}, image_repository: {:?}, \
98 ssh_key_name: {:?}, ssh_private_key_path: {:?}, filesystem_prefix: {:?}, \
99 filesystem_mount_path: {:?}, confirm_hourly_usd: {}, local_port: {} }}",
100 if self.api_key.is_some() {
101 "Some(\"<redacted>\")"
102 } else {
103 "None"
104 },
105 self.endpoint,
106 self.image_repository,
107 self.ssh_key_name,
108 self.ssh_private_key_path,
109 self.filesystem_prefix,
110 self.filesystem_mount_path,
111 self.confirm_hourly_usd,
112 self.local_port,
113 )
114 }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ApiList<T> {
119 #[serde(default)]
120 pub data: Vec<T>,
121}
122
123#[derive(Debug, Clone, Deserialize, Serialize)]
124pub struct ApiItem<T> {
125 pub data: T,
126}
127
128#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
129pub struct Region {
130 pub name: String,
131 #[serde(default)]
132 pub description: String,
133}
134
135#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
136pub struct InstanceTypeSpecs {
137 #[serde(default)]
138 pub gpus: u32,
139 #[serde(default)]
140 pub gpu_description: String,
141 #[serde(default)]
142 pub memory_gib: u32,
143 #[serde(default)]
144 pub storage_gib: u32,
145 #[serde(default)]
146 pub vcpus: u32,
147}
148
149#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
150pub struct InstanceType {
151 pub name: String,
152 #[serde(default)]
153 pub description: String,
154 #[serde(default)]
155 pub gpu_description: String,
156 #[serde(default)]
157 pub price_cents_per_hour: u32,
158 #[serde(default)]
159 pub specs: InstanceTypeSpecs,
160 #[serde(default)]
161 pub regions_with_capacity_available: Vec<Region>,
162}
163
164#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
165pub struct SshKey {
166 pub id: String,
167 pub name: String,
168 pub public_key: String,
169}
170
171#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
172pub struct Filesystem {
173 pub id: String,
174 pub name: String,
175 #[serde(default)]
176 pub mount_point: String,
177 #[serde(default)]
178 pub region: Option<Region>,
179 #[serde(default)]
180 pub bytes_used: Option<u64>,
181}
182
183#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
184pub struct Tag {
185 pub key: String,
186 pub value: String,
187}
188
189#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
190pub struct Instance {
191 pub id: String,
192 #[serde(default)]
193 pub name: Option<String>,
194 #[serde(default)]
195 pub status: String,
196 #[serde(default)]
197 pub ip: Option<String>,
198 #[serde(default)]
199 pub private_ip: Option<String>,
200 #[serde(default)]
201 pub instance_type: Option<InstanceType>,
202 #[serde(default)]
203 pub region: Option<Region>,
204 #[serde(default)]
205 pub ssh_key_names: Vec<String>,
206 #[serde(default)]
207 pub file_system_names: Vec<String>,
208 #[serde(default)]
209 pub tags: Vec<Tag>,
210}
211
212#[derive(Debug, Clone, Default, Deserialize, Serialize)]
213pub struct CreateSshKeyRequest {
214 pub name: String,
215 pub public_key: String,
216}
217
218#[derive(Debug, Clone, Default, Deserialize, Serialize)]
219pub struct CreateFilesystemRequest {
220 pub name: String,
221 pub region: String,
222}
223
224#[derive(Debug, Clone, Default, Deserialize, Serialize)]
225pub struct LaunchInstancesRequest {
226 pub region_name: String,
227 pub instance_type_name: String,
228 pub ssh_key_names: Vec<String>,
229 #[serde(skip_serializing_if = "Vec::is_empty", default)]
230 pub file_system_names: Vec<String>,
231 #[serde(skip_serializing_if = "Vec::is_empty", default)]
232 pub file_system_mounts: Vec<FilesystemMount>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 pub hostname: Option<String>,
235 pub name: String,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 pub image: Option<LaunchImage>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub user_data: Option<String>,
240 #[serde(skip_serializing_if = "Vec::is_empty", default)]
241 pub tags: Vec<Tag>,
242}
243
244#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
245pub struct InstanceLaunchResponse {
246 #[serde(default)]
247 pub instance_ids: Vec<String>,
248}
249
250#[derive(Debug, Clone, Deserialize)]
251struct InstanceTypesResponse {
252 #[serde(default)]
253 data: std::collections::BTreeMap<String, InstanceTypeOffering>,
254}
255
256#[derive(Debug, Clone, Deserialize)]
257struct InstanceTypeOffering {
258 instance_type: InstanceType,
259 #[serde(default)]
260 regions_with_capacity_available: Vec<Region>,
261}
262
263#[derive(Debug, Clone, Deserialize, Serialize)]
264pub struct FilesystemMount {
265 pub mount_point: String,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 pub file_system_name: Option<String>,
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub file_system_id: Option<String>,
270}
271
272#[derive(Debug, Clone, Deserialize, Serialize)]
273pub struct LaunchImage {
274 pub id: String,
275}
276
277pub struct LaunchRequestInput<'a> {
278 pub region_name: &'a str,
279 pub instance_type_name: &'a str,
280 pub ssh_key_name: &'a str,
281 pub filesystem_name: &'a str,
282 pub filesystem_id: Option<&'a str>,
283 pub filesystem_mount_path: &'a str,
284 pub instance_name: &'a str,
285 pub image_id: Option<&'a str>,
286 pub user_data: &'a str,
287}
288
289pub fn build_launch_request(input: LaunchRequestInput<'_>) -> LaunchInstancesRequest {
290 LaunchInstancesRequest {
291 region_name: input.region_name.to_string(),
292 instance_type_name: input.instance_type_name.to_string(),
293 ssh_key_names: vec![input.ssh_key_name.to_string()],
294 file_system_names: vec![input.filesystem_name.to_string()],
295 file_system_mounts: vec![FilesystemMount {
296 mount_point: input.filesystem_mount_path.to_string(),
297 file_system_name: input
298 .filesystem_id
299 .is_none()
300 .then(|| input.filesystem_name.to_string()),
301 file_system_id: input.filesystem_id.map(str::to_string),
302 }],
303 hostname: None,
304 name: input.instance_name.to_string(),
305 image: input.image_id.map(|id| LaunchImage { id: id.to_string() }),
306 user_data: Some(input.user_data.to_string()),
307 tags: vec![Tag {
308 key: "managed-by".to_string(),
309 value: "mold".to_string(),
310 }],
311 }
312}
313
314#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
315pub struct AvailabilityRow {
316 pub instance_type: String,
317 pub region: String,
318 pub gpu_description: String,
319 pub gpu_count: u32,
320 pub generation_slots: u32,
321 pub price_per_hour_usd: f64,
322 pub memory_gib: u32,
323 pub storage_gib: u32,
324 pub image: String,
325}
326
327impl AvailabilityRow {
328 pub fn from_instance_type(
329 instance_type: &InstanceType,
330 image_repository: &str,
331 version: &str,
332 ) -> Self {
333 let region = instance_type
334 .regions_with_capacity_available
335 .first()
336 .map(|r| r.name.clone())
337 .unwrap_or_default();
338 let image = if gpu_uses_unsupported_linux_arm64(&instance_type.specs.gpu_description)
339 || gpu_uses_unsupported_linux_arm64(&instance_type.name)
340 {
341 "unsupported: linux/arm64 host".to_string()
342 } else {
343 let tag = image_tag_for_gpu(&instance_type.specs.gpu_description, version);
344 format!("{image_repository}:{tag}")
345 };
346 Self {
347 instance_type: instance_type.name.clone(),
348 region,
349 gpu_description: instance_type.specs.gpu_description.clone(),
350 gpu_count: instance_type.specs.gpus,
351 generation_slots: instance_type.specs.gpus,
352 price_per_hour_usd: instance_type.price_cents_per_hour as f64 / 100.0,
353 memory_gib: instance_type.specs.memory_gib,
354 storage_gib: instance_type.specs.storage_gib,
355 image,
356 }
357 }
358}
359
360pub fn image_tag_for_gpu(gpu_description: &str, _version: &str) -> String {
361 let lower = gpu_description.to_ascii_lowercase();
362 if lower.contains("a100")
363 || lower.contains("a10")
364 || lower.contains("a40")
365 || lower.contains("rtx 30")
366 || lower.contains("3090")
367 {
368 "latest-sm80".to_string()
369 } else if lower.contains("h100") || lower.contains("h200") || lower.contains("gh") {
370 "latest-sm90".to_string()
371 } else if lower.contains("b200") || lower.contains("5090") || lower.contains("blackwell") {
372 "latest-sm120".to_string()
373 } else {
374 "latest".to_string()
375 }
376}
377
378pub fn gpu_uses_unsupported_linux_arm64(gpu_description: &str) -> bool {
379 gpu_description.to_ascii_lowercase().contains("gh200")
380}
381
382pub fn filesystem_name(settings: &LambdaSettings, region: &str) -> String {
383 let prefix = settings.filesystem_prefix.as_deref().unwrap_or("mold");
384 format!("{prefix}-{region}")
385}
386
387#[derive(Debug, Clone)]
388pub struct CloudInitOptions {
389 pub image: String,
390 pub mount_path: String,
391 pub env_file: String,
392}
393
394pub fn render_cloud_init(opts: &CloudInitOptions) -> String {
395 format!(
396 r#"#cloud-config
397write_files:
398 - path: /etc/systemd/system/mold-lambda.service
399 permissions: '0644'
400 content: |
401 [Unit]
402 Description=mold Lambda container
403 After=docker.service network-online.target
404 Wants=network-online.target
405
406 [Service]
407 Restart=always
408 RestartSec=10
409 ExecStartPre=-/usr/bin/docker rm -f mold
410 ExecStartPre=/usr/bin/docker pull {image}
411 ExecStart=/usr/bin/docker run --name mold --gpus all --restart unless-stopped --env-file {env_file} -e MOLD_PORT=7680 -p 127.0.0.1:7680:7680 -v {mount_path}:/workspace {image}
412 ExecStop=/usr/bin/docker stop mold
413
414 [Install]
415 WantedBy=multi-user.target
416runcmd:
417 - [ mkdir, -p, /etc/mold ]
418 - [ sh, -c, "touch {env_file} && chmod 600 {env_file}" ]
419 - [ systemctl, daemon-reload ]
420 - [ systemctl, enable, --now, mold-lambda.service ]
421"#,
422 image = opts.image,
423 mount_path = opts.mount_path,
424 env_file = opts.env_file,
425 )
426}
427
428#[derive(Clone)]
429pub struct LambdaClient {
430 client: Client,
431 endpoint: String,
432 api_key: String,
433}
434
435impl LambdaClient {
436 pub fn from_settings(settings: &LambdaSettings) -> Result<Self> {
437 let api_key = settings.resolved_api_key().ok_or_else(|| {
438 MoldError::Config("missing Lambda API key; set LAMBDA_API_KEY or lambda.api_key".into())
439 })?;
440 Ok(Self {
441 client: Client::builder().timeout(Duration::from_secs(60)).build()?,
442 endpoint: settings.endpoint().trim_end_matches('/').to_string(),
443 api_key,
444 })
445 }
446
447 pub fn new(endpoint: impl Into<String>, api_key: impl Into<String>) -> Self {
448 Self {
449 client: Client::new(),
450 endpoint: endpoint.into().trim_end_matches('/').to_string(),
451 api_key: api_key.into(),
452 }
453 }
454
455 async fn get_list<T: for<'de> Deserialize<'de> + Default>(&self, path: &str) -> Result<Vec<T>> {
456 let resp = self
457 .client
458 .get(format!("{}{}", self.endpoint, path))
459 .basic_auth(&self.api_key, Some(""))
460 .send()
461 .await?;
462 decode_list(resp).await
463 }
464
465 async fn post_item<B: Serialize, T: for<'de> Deserialize<'de>>(
466 &self,
467 path: &str,
468 body: &B,
469 ) -> Result<T> {
470 let resp = self
471 .client
472 .post(format!("{}{}", self.endpoint, path))
473 .basic_auth(&self.api_key, Some(""))
474 .json(body)
475 .send()
476 .await?;
477 decode_item(resp).await
478 }
479
480 pub async fn list_instance_types(&self) -> Result<Vec<InstanceType>> {
481 let resp = self
482 .client
483 .get(format!("{}/instance-types", self.endpoint))
484 .basic_auth(&self.api_key, Some(""))
485 .send()
486 .await?;
487 if !resp.status().is_success() {
488 return Err(lambda_error(resp).await.into());
489 }
490 decode_instance_types_body(&resp.text().await?)
491 }
492
493 pub async fn list_instances(&self) -> Result<Vec<Instance>> {
494 self.get_list("/instances").await
495 }
496
497 pub async fn get_instance(&self, id: &str) -> Result<Instance> {
498 let resp = self
499 .client
500 .get(format!("{}/instances/{id}", self.endpoint))
501 .basic_auth(&self.api_key, Some(""))
502 .send()
503 .await?;
504 decode_item(resp).await
505 }
506
507 pub async fn launch_instance(
508 &self,
509 req: &LaunchInstancesRequest,
510 ) -> Result<InstanceLaunchResponse> {
511 self.post_item("/instance-operations/launch", req).await
512 }
513
514 pub async fn terminate_instance(&self, id: &str) -> Result<()> {
515 let body = serde_json::json!({ "instance_ids": [id] });
516 let resp = self
517 .client
518 .post(format!("{}/instance-operations/terminate", self.endpoint))
519 .basic_auth(&self.api_key, Some(""))
520 .json(&body)
521 .send()
522 .await?;
523 ensure_success(resp).await
524 }
525
526 pub async fn list_ssh_keys(&self) -> Result<Vec<SshKey>> {
527 self.get_list("/ssh-keys").await
528 }
529
530 pub async fn create_ssh_key(&self, req: &CreateSshKeyRequest) -> Result<SshKey> {
531 self.post_item("/ssh-keys", req).await
532 }
533
534 pub async fn list_filesystems(&self) -> Result<Vec<Filesystem>> {
535 self.get_list("/file-systems").await
536 }
537
538 pub async fn create_filesystem(&self, req: &CreateFilesystemRequest) -> Result<Filesystem> {
539 self.post_item("/filesystems", req).await
540 }
541
542 pub async fn delete_filesystem(&self, id: &str) -> Result<()> {
543 let resp = self
544 .client
545 .delete(format!("{}/filesystems/{id}", self.endpoint))
546 .basic_auth(&self.api_key, Some(""))
547 .send()
548 .await?;
549 ensure_success(resp).await
550 }
551}
552
553async fn decode_list<T: for<'de> Deserialize<'de> + Default>(
554 resp: reqwest::Response,
555) -> Result<Vec<T>> {
556 if !resp.status().is_success() {
557 return Err(lambda_error(resp).await.into());
558 }
559 Ok(resp.json::<ApiList<T>>().await?.data)
560}
561
562async fn decode_item<T: for<'de> Deserialize<'de>>(resp: reqwest::Response) -> Result<T> {
563 if !resp.status().is_success() {
564 return Err(lambda_error(resp).await.into());
565 }
566 Ok(resp.json::<ApiItem<T>>().await?.data)
567}
568
569async fn ensure_success(resp: reqwest::Response) -> Result<()> {
570 if !resp.status().is_success() {
571 return Err(lambda_error(resp).await.into());
572 }
573 Ok(())
574}
575
576async fn lambda_error(resp: reqwest::Response) -> MoldError {
577 let status = resp.status();
578 let body = resp.text().await.unwrap_or_default();
579 let message = if status == StatusCode::UNAUTHORIZED {
580 "Lambda API authentication failed".to_string()
581 } else {
582 format!(
583 "Lambda API request failed with {status}: {}",
584 truncate(&body)
585 )
586 };
587 MoldError::Config(message)
588}
589
590fn truncate(s: &str) -> String {
591 const MAX: usize = 400;
592 if s.chars().count() <= MAX {
593 return s.to_string();
594 }
595 let mut out = s.chars().take(MAX).collect::<String>();
596 out.push('…');
597 out
598}
599
600pub fn decode_instance_types_body(body: &str) -> Result<Vec<InstanceType>> {
601 let response: InstanceTypesResponse = serde_json::from_str(body)?;
602 Ok(response
603 .data
604 .into_values()
605 .map(|offering| {
606 let mut instance_type = offering.instance_type;
607 if instance_type.specs.gpu_description.is_empty() {
608 instance_type.specs.gpu_description = instance_type.gpu_description.clone();
609 }
610 instance_type.regions_with_capacity_available =
611 offering.regions_with_capacity_available;
612 instance_type
613 })
614 .collect())
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn instance_types_decode_lambda_map_shape() {
623 let body = r#"{
624 "data": {
625 "gpu_1x_a10": {
626 "instance_type": {
627 "name": "gpu_1x_a10",
628 "description": "1x A10",
629 "gpu_description": "A10",
630 "price_cents_per_hour": 75,
631 "specs": {
632 "vcpus": 30,
633 "memory_gib": 200,
634 "storage_gib": 1400,
635 "gpus": 1
636 }
637 },
638 "regions_with_capacity_available": [
639 {"name": "us-west-1", "description": "California"}
640 ]
641 }
642 }
643 }"#;
644
645 let decoded = decode_instance_types_body(body).unwrap();
646 assert_eq!(decoded.len(), 1);
647 assert_eq!(decoded[0].name, "gpu_1x_a10");
648 assert_eq!(decoded[0].specs.gpu_description, "A10");
649 assert_eq!(
650 decoded[0].regions_with_capacity_available[0].name,
651 "us-west-1"
652 );
653 }
654
655 #[test]
656 fn lambda_settings_toml_roundtrip_and_defaults() {
657 let settings: LambdaSettings = toml::from_str("").unwrap();
658 assert_eq!(settings.endpoint.as_deref(), Some(DEFAULT_ENDPOINT));
659 assert_eq!(
660 settings.image_repository.as_deref(),
661 Some(DEFAULT_IMAGE_REPOSITORY)
662 );
663 assert_eq!(settings.filesystem_mount_path, "/data/mold");
664 assert_eq!(settings.confirm_hourly_usd, 5.0);
665 assert_eq!(settings.local_port, 7680);
666
667 let original = LambdaSettings {
668 api_key: Some("secret".into()),
669 endpoint: Some("http://localhost:9999".into()),
670 image_repository: Some("ghcr.io/example/mold".into()),
671 ssh_key_name: Some("mold-key".into()),
672 ssh_private_key_path: Some("~/.ssh/mold_lambda_ed25519".into()),
673 filesystem_prefix: Some("mold".into()),
674 filesystem_mount_path: "/mnt/mold".into(),
675 confirm_hourly_usd: 9.5,
676 local_port: 7777,
677 };
678 let encoded = toml::to_string(&original).unwrap();
679 let decoded: LambdaSettings = toml::from_str(&encoded).unwrap();
680 assert_eq!(decoded.api_key, original.api_key);
681 assert_eq!(decoded.filesystem_mount_path, "/mnt/mold");
682 assert_eq!(decoded.local_port, 7777);
683 }
684
685 #[test]
686 fn auth_prefers_lambda_api_key_env_over_config() {
687 let _guard = crate::test_support::ENV_LOCK.lock().unwrap();
688 std::env::set_var(API_KEY_ENV, "from-env");
689 let settings = LambdaSettings {
690 api_key: Some("from-config".into()),
691 ..Default::default()
692 };
693 assert_eq!(settings.resolved_api_key().as_deref(), Some("from-env"));
694 std::env::remove_var(API_KEY_ENV);
695 }
696
697 #[test]
698 fn image_tag_maps_gpu_generations() {
699 assert_eq!(
700 image_tag_for_gpu("NVIDIA A100-SXM4-80GB", "0.10.0"),
701 "latest-sm80"
702 );
703 assert_eq!(image_tag_for_gpu("NVIDIA L40S", "0.10.0"), "latest");
704 assert_eq!(
705 image_tag_for_gpu("NVIDIA H100 PCIe", "0.10.0"),
706 "latest-sm90"
707 );
708 assert_eq!(image_tag_for_gpu("NVIDIA B200", "0.10.0"), "latest-sm120");
709 }
710
711 #[test]
712 fn gh200_is_not_supported_by_published_linux_images() {
713 assert!(gpu_uses_unsupported_linux_arm64("GH200 (96 GB)"));
714 assert!(gpu_uses_unsupported_linux_arm64("gpu_1x_gh200"));
715 assert!(!gpu_uses_unsupported_linux_arm64("NVIDIA H100 PCIe"));
716 assert!(!gpu_uses_unsupported_linux_arm64("NVIDIA A100-SXM4-80GB"));
717 }
718
719 #[test]
720 fn availability_marks_gh200_as_unsupported() {
721 let ty = InstanceType {
722 name: "gpu_1x_gh200".into(),
723 description: "1x GH200".into(),
724 gpu_description: "GH200 (96 GB)".into(),
725 price_cents_per_hour: 229,
726 specs: InstanceTypeSpecs {
727 gpus: 1,
728 gpu_description: "GH200 (96 GB)".into(),
729 memory_gib: 432,
730 storage_gib: 4096,
731 ..Default::default()
732 },
733 regions_with_capacity_available: vec![Region {
734 name: "us-east-3".into(),
735 description: "Austin".into(),
736 }],
737 };
738 let row = AvailabilityRow::from_instance_type(&ty, "ghcr.io/utensils/mold", "0.10.0");
739 assert_eq!(row.image, "unsupported: linux/arm64 host");
740 }
741
742 #[test]
743 fn availability_row_uses_gpu_count_as_generation_slots() {
744 let ty = InstanceType {
745 name: "gpu_8x_h100".into(),
746 description: "8x H100".into(),
747 gpu_description: "NVIDIA H100".into(),
748 price_cents_per_hour: 15920,
749 specs: InstanceTypeSpecs {
750 gpus: 8,
751 gpu_description: "NVIDIA H100".into(),
752 memory_gib: 1800,
753 storage_gib: 200,
754 ..Default::default()
755 },
756 regions_with_capacity_available: vec![Region {
757 name: "us-east-1".into(),
758 description: "Virginia".into(),
759 }],
760 };
761 let row = AvailabilityRow::from_instance_type(&ty, "ghcr.io/utensils/mold", "0.10.0");
762 assert_eq!(row.generation_slots, 8);
763 assert_eq!(row.image, "ghcr.io/utensils/mold:latest-sm90");
764 assert_eq!(row.price_per_hour_usd, 159.20);
765 }
766
767 #[test]
768 fn filesystem_name_defaults_to_prefix_region() {
769 let settings = LambdaSettings::default();
770 assert_eq!(filesystem_name(&settings, "us-west-1"), "mold-us-west-1");
771 let custom = LambdaSettings {
772 filesystem_prefix: Some("team-mold".into()),
773 ..Default::default()
774 };
775 assert_eq!(filesystem_name(&custom, "us-east-1"), "team-mold-us-east-1");
776 }
777
778 #[test]
779 fn launch_request_contains_expected_shape() {
780 let req = build_launch_request(LaunchRequestInput {
781 region_name: "us-west-1",
782 instance_type_name: "gpu_1x_a10",
783 ssh_key_name: "mold-laptop",
784 filesystem_name: "mold-us-west-1",
785 filesystem_id: None,
786 filesystem_mount_path: "/data/mold",
787 instance_name: "mold-us-west-1",
788 image_id: None,
789 user_data: "#cloud-config\n",
790 });
791 let json = serde_json::to_value(req).unwrap();
792 assert_eq!(json["region_name"], "us-west-1");
793 assert_eq!(json["ssh_key_names"], serde_json::json!(["mold-laptop"]));
794 assert_eq!(
795 json["file_system_names"],
796 serde_json::json!(["mold-us-west-1"])
797 );
798 assert_eq!(json["file_system_mounts"][0]["mount_point"], "/data/mold");
799 assert_eq!(json["tags"][0]["key"], "managed-by");
800 assert_eq!(json["tags"][0]["value"], "mold");
801 }
802
803 #[test]
804 fn launch_request_uses_filesystem_id_when_available() {
805 let req = build_launch_request(LaunchRequestInput {
806 region_name: "us-west-1",
807 instance_type_name: "gpu_1x_a10",
808 ssh_key_name: "mold-laptop",
809 filesystem_name: "mold-us-west-1",
810 filesystem_id: Some("fs-123"),
811 filesystem_mount_path: "/data/mold",
812 instance_name: "mold-us-west-1",
813 image_id: None,
814 user_data: "#cloud-config\n",
815 });
816 let mount = &req.file_system_mounts[0];
817 assert_eq!(mount.file_system_id.as_deref(), Some("fs-123"));
818 assert!(mount.file_system_name.is_none());
819 }
820
821 #[test]
822 fn create_filesystem_request_uses_lambda_region_field() {
823 let req = CreateFilesystemRequest {
824 name: "mold-us-east-1".into(),
825 region: "us-east-1".into(),
826 };
827 let json = serde_json::to_value(req).unwrap();
828 assert_eq!(json["name"], "mold-us-east-1");
829 assert_eq!(json["region"], "us-east-1");
830 assert!(json.get("region_name").is_none());
831 }
832
833 #[test]
834 fn cloud_init_keeps_service_private_and_omits_secrets_by_default() {
835 let rendered = render_cloud_init(&CloudInitOptions {
836 image: "ghcr.io/utensils/mold:0.10.0-sm90".into(),
837 mount_path: "/data/mold".into(),
838 env_file: "/etc/mold/lambda.env".into(),
839 });
840 assert!(rendered.contains("-p 127.0.0.1:7680:7680"));
841 assert!(rendered.contains("-v /data/mold:/workspace"));
842 assert!(rendered.contains("--gpus all"));
843 assert!(rendered.contains("ghcr.io/utensils/mold:0.10.0-sm90"));
844 assert!(!rendered.contains("HF_TOKEN"));
845 assert!(!rendered.contains("CIVITAI_TOKEN"));
846 }
847}