1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod linode;
8pub mod oracle;
9mod proxmox;
10pub mod scaleway;
11pub mod sync;
12mod tailscale;
13mod upcloud;
14mod vultr;
15
16use std::sync::atomic::AtomicBool;
17
18use thiserror::Error;
19
20#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub struct ProviderHost {
24 pub server_id: String,
26 pub name: String,
28 pub ip: String,
30 pub tags: Vec<String>,
32 pub metadata: Vec<(String, String)>,
34}
35
36impl ProviderHost {
37 #[allow(dead_code)]
39 pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
40 Self {
41 server_id,
42 name,
43 ip,
44 tags,
45 metadata: Vec::new(),
46 }
47 }
48}
49
50#[derive(Debug, Error)]
52pub enum ProviderError {
53 #[error("HTTP error: {0}")]
54 Http(String),
55 #[error("Failed to parse response: {0}")]
56 Parse(String),
57 #[error("Authentication failed. Check your API token.")]
58 AuthFailed,
59 #[error("Rate limited. Try again in a moment.")]
60 RateLimited,
61 #[error("{0}")]
62 Execute(String),
63 #[error("Cancelled.")]
64 Cancelled,
65 #[error("Partial result: {failures} of {total} failed")]
68 PartialResult {
69 hosts: Vec<ProviderHost>,
70 failures: usize,
71 total: usize,
72 },
73}
74
75pub trait Provider {
77 fn name(&self) -> &str;
79 fn short_label(&self) -> &str;
81 fn fetch_hosts_cancellable(
83 &self,
84 token: &str,
85 cancel: &AtomicBool,
86 ) -> Result<Vec<ProviderHost>, ProviderError>;
87 #[allow(dead_code)]
89 fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
90 self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
91 }
92 fn fetch_hosts_with_progress(
94 &self,
95 token: &str,
96 cancel: &AtomicBool,
97 _progress: &dyn Fn(&str),
98 ) -> Result<Vec<ProviderHost>, ProviderError> {
99 self.fetch_hosts_cancellable(token, cancel)
100 }
101}
102
103pub const PROVIDER_NAMES: &[&str] = &[
105 "digitalocean",
106 "vultr",
107 "linode",
108 "hetzner",
109 "upcloud",
110 "proxmox",
111 "aws",
112 "scaleway",
113 "gcp",
114 "azure",
115 "tailscale",
116 "oracle",
117];
118
119pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
121 match name {
122 "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
123 "vultr" => Some(Box::new(vultr::Vultr)),
124 "linode" => Some(Box::new(linode::Linode)),
125 "hetzner" => Some(Box::new(hetzner::Hetzner)),
126 "upcloud" => Some(Box::new(upcloud::UpCloud)),
127 "proxmox" => Some(Box::new(proxmox::Proxmox {
128 base_url: String::new(),
129 verify_tls: true,
130 })),
131 "aws" => Some(Box::new(aws::Aws {
132 regions: Vec::new(),
133 profile: String::new(),
134 })),
135 "scaleway" => Some(Box::new(scaleway::Scaleway { zones: Vec::new() })),
136 "gcp" => Some(Box::new(gcp::Gcp {
137 zones: Vec::new(),
138 project: String::new(),
139 })),
140 "azure" => Some(Box::new(azure::Azure {
141 subscriptions: Vec::new(),
142 })),
143 "tailscale" => Some(Box::new(tailscale::Tailscale)),
144 "oracle" => Some(Box::new(oracle::Oracle {
145 regions: Vec::new(),
146 compartment: String::new(),
147 })),
148 _ => None,
149 }
150}
151
152pub fn get_provider_with_config(
156 name: &str,
157 section: &config::ProviderSection,
158) -> Option<Box<dyn Provider>> {
159 match name {
160 "proxmox" => Some(Box::new(proxmox::Proxmox {
161 base_url: section.url.clone(),
162 verify_tls: section.verify_tls,
163 })),
164 "aws" => Some(Box::new(aws::Aws {
165 regions: section
166 .regions
167 .split(',')
168 .map(|s| s.trim().to_string())
169 .filter(|s| !s.is_empty())
170 .collect(),
171 profile: section.profile.clone(),
172 })),
173 "scaleway" => Some(Box::new(scaleway::Scaleway {
174 zones: section
175 .regions
176 .split(',')
177 .map(|s| s.trim().to_string())
178 .filter(|s| !s.is_empty())
179 .collect(),
180 })),
181 "gcp" => Some(Box::new(gcp::Gcp {
182 zones: section
183 .regions
184 .split(',')
185 .map(|s| s.trim().to_string())
186 .filter(|s| !s.is_empty())
187 .collect(),
188 project: section.project.clone(),
189 })),
190 "azure" => Some(Box::new(azure::Azure {
191 subscriptions: section
192 .regions
193 .split(',')
194 .map(|s| s.trim().to_string())
195 .filter(|s| !s.is_empty())
196 .collect(),
197 })),
198 "oracle" => Some(Box::new(oracle::Oracle {
199 regions: section
200 .regions
201 .split(',')
202 .map(|s| s.trim().to_string())
203 .filter(|s| !s.is_empty())
204 .collect(),
205 compartment: section.compartment.clone(),
206 })),
207 _ => get_provider(name),
208 }
209}
210
211pub fn provider_display_name(name: &str) -> &str {
213 match name {
214 "digitalocean" => "DigitalOcean",
215 "vultr" => "Vultr",
216 "linode" => "Linode",
217 "hetzner" => "Hetzner",
218 "upcloud" => "UpCloud",
219 "proxmox" => "Proxmox VE",
220 "aws" => "AWS EC2",
221 "scaleway" => "Scaleway",
222 "gcp" => "GCP",
223 "azure" => "Azure",
224 "tailscale" => "Tailscale",
225 "oracle" => "Oracle Cloud",
226 other => other,
227 }
228}
229
230pub(crate) fn http_agent() -> ureq::Agent {
232 ureq::Agent::config_builder()
233 .timeout_global(Some(std::time::Duration::from_secs(30)))
234 .max_redirects(0)
235 .build()
236 .new_agent()
237}
238
239pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
241 Ok(ureq::Agent::config_builder()
242 .timeout_global(Some(std::time::Duration::from_secs(30)))
243 .max_redirects(0)
244 .tls_config(
245 ureq::tls::TlsConfig::builder()
246 .provider(ureq::tls::TlsProvider::NativeTls)
247 .disable_verification(true)
248 .build(),
249 )
250 .build()
251 .new_agent())
252}
253
254pub(crate) fn strip_cidr(ip: &str) -> &str {
258 if let Some(pos) = ip.rfind('/') {
260 if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
261 return &ip[..pos];
262 }
263 }
264 ip
265}
266
267fn map_ureq_error(err: ureq::Error) -> ProviderError {
269 match err {
270 ureq::Error::StatusCode(code) => match code {
271 401 | 403 => ProviderError::AuthFailed,
272 429 => ProviderError::RateLimited,
273 _ => ProviderError::Http(format!("HTTP {}", code)),
274 },
275 other => ProviderError::Http(other.to_string()),
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
288 fn test_strip_cidr_ipv6_with_prefix() {
289 assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
290 assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
291 }
292
293 #[test]
294 fn test_strip_cidr_bare_ipv6() {
295 assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
296 }
297
298 #[test]
299 fn test_strip_cidr_ipv4_passthrough() {
300 assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
301 assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
302 }
303
304 #[test]
305 fn test_strip_cidr_empty() {
306 assert_eq!(strip_cidr(""), "");
307 }
308
309 #[test]
310 fn test_strip_cidr_slash_without_digits() {
311 assert_eq!(strip_cidr("path/to/something"), "path/to/something");
313 }
314
315 #[test]
316 fn test_strip_cidr_trailing_slash() {
317 assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
319 }
320
321 #[test]
326 fn test_get_provider_digitalocean() {
327 let p = get_provider("digitalocean").unwrap();
328 assert_eq!(p.name(), "digitalocean");
329 assert_eq!(p.short_label(), "do");
330 }
331
332 #[test]
333 fn test_get_provider_vultr() {
334 let p = get_provider("vultr").unwrap();
335 assert_eq!(p.name(), "vultr");
336 assert_eq!(p.short_label(), "vultr");
337 }
338
339 #[test]
340 fn test_get_provider_linode() {
341 let p = get_provider("linode").unwrap();
342 assert_eq!(p.name(), "linode");
343 assert_eq!(p.short_label(), "linode");
344 }
345
346 #[test]
347 fn test_get_provider_hetzner() {
348 let p = get_provider("hetzner").unwrap();
349 assert_eq!(p.name(), "hetzner");
350 assert_eq!(p.short_label(), "hetzner");
351 }
352
353 #[test]
354 fn test_get_provider_upcloud() {
355 let p = get_provider("upcloud").unwrap();
356 assert_eq!(p.name(), "upcloud");
357 assert_eq!(p.short_label(), "uc");
358 }
359
360 #[test]
361 fn test_get_provider_proxmox() {
362 let p = get_provider("proxmox").unwrap();
363 assert_eq!(p.name(), "proxmox");
364 assert_eq!(p.short_label(), "pve");
365 }
366
367 #[test]
368 fn test_get_provider_unknown_returns_none() {
369 assert!(get_provider("unknown_provider").is_none());
370 assert!(get_provider("").is_none());
371 assert!(get_provider("DigitalOcean").is_none()); }
373
374 #[test]
375 fn test_get_provider_all_names_resolve() {
376 for name in PROVIDER_NAMES {
377 assert!(
378 get_provider(name).is_some(),
379 "Provider '{}' should resolve",
380 name
381 );
382 }
383 }
384
385 #[test]
390 fn test_get_provider_with_config_proxmox_uses_url() {
391 let section = config::ProviderSection {
392 provider: "proxmox".to_string(),
393 token: "user@pam!token=secret".to_string(),
394 alias_prefix: "pve-".to_string(),
395 user: String::new(),
396 identity_file: String::new(),
397 url: "https://pve.example.com:8006".to_string(),
398 verify_tls: false,
399 auto_sync: false,
400 profile: String::new(),
401 regions: String::new(),
402 project: String::new(),
403 compartment: String::new(),
404 };
405 let p = get_provider_with_config("proxmox", §ion).unwrap();
406 assert_eq!(p.name(), "proxmox");
407 }
408
409 #[test]
410 fn test_get_provider_with_config_non_proxmox_delegates() {
411 let section = config::ProviderSection {
412 provider: "digitalocean".to_string(),
413 token: "do-token".to_string(),
414 alias_prefix: "do-".to_string(),
415 user: String::new(),
416 identity_file: String::new(),
417 url: String::new(),
418 verify_tls: true,
419 auto_sync: true,
420 profile: String::new(),
421 regions: String::new(),
422 project: String::new(),
423 compartment: String::new(),
424 };
425 let p = get_provider_with_config("digitalocean", §ion).unwrap();
426 assert_eq!(p.name(), "digitalocean");
427 }
428
429 #[test]
430 fn test_get_provider_with_config_gcp_uses_project_and_zones() {
431 let section = config::ProviderSection {
432 provider: "gcp".to_string(),
433 token: "sa.json".to_string(),
434 alias_prefix: "gcp".to_string(),
435 user: String::new(),
436 identity_file: String::new(),
437 url: String::new(),
438 verify_tls: true,
439 auto_sync: true,
440 profile: String::new(),
441 regions: "us-central1-a, europe-west1-b".to_string(),
442 project: "my-project".to_string(),
443 compartment: String::new(),
444 };
445 let p = get_provider_with_config("gcp", §ion).unwrap();
446 assert_eq!(p.name(), "gcp");
447 }
448
449 #[test]
450 fn test_get_provider_with_config_unknown_returns_none() {
451 let section = config::ProviderSection {
452 provider: "unknown_provider".to_string(),
453 token: String::new(),
454 alias_prefix: String::new(),
455 user: String::new(),
456 identity_file: String::new(),
457 url: String::new(),
458 verify_tls: true,
459 auto_sync: true,
460 profile: String::new(),
461 regions: String::new(),
462 project: String::new(),
463 compartment: String::new(),
464 };
465 assert!(get_provider_with_config("unknown_provider", §ion).is_none());
466 }
467
468 #[test]
473 fn test_display_name_all_providers() {
474 assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
475 assert_eq!(provider_display_name("vultr"), "Vultr");
476 assert_eq!(provider_display_name("linode"), "Linode");
477 assert_eq!(provider_display_name("hetzner"), "Hetzner");
478 assert_eq!(provider_display_name("upcloud"), "UpCloud");
479 assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
480 assert_eq!(provider_display_name("aws"), "AWS EC2");
481 assert_eq!(provider_display_name("scaleway"), "Scaleway");
482 assert_eq!(provider_display_name("gcp"), "GCP");
483 assert_eq!(provider_display_name("azure"), "Azure");
484 assert_eq!(provider_display_name("tailscale"), "Tailscale");
485 assert_eq!(provider_display_name("oracle"), "Oracle Cloud");
486 }
487
488 #[test]
489 fn test_display_name_unknown_returns_input() {
490 assert_eq!(
491 provider_display_name("unknown_provider"),
492 "unknown_provider"
493 );
494 assert_eq!(provider_display_name(""), "");
495 }
496
497 #[test]
502 fn test_provider_names_count() {
503 assert_eq!(PROVIDER_NAMES.len(), 12);
504 }
505
506 #[test]
507 fn test_provider_names_contains_all() {
508 assert!(PROVIDER_NAMES.contains(&"digitalocean"));
509 assert!(PROVIDER_NAMES.contains(&"vultr"));
510 assert!(PROVIDER_NAMES.contains(&"linode"));
511 assert!(PROVIDER_NAMES.contains(&"hetzner"));
512 assert!(PROVIDER_NAMES.contains(&"upcloud"));
513 assert!(PROVIDER_NAMES.contains(&"proxmox"));
514 assert!(PROVIDER_NAMES.contains(&"aws"));
515 assert!(PROVIDER_NAMES.contains(&"scaleway"));
516 assert!(PROVIDER_NAMES.contains(&"gcp"));
517 assert!(PROVIDER_NAMES.contains(&"azure"));
518 assert!(PROVIDER_NAMES.contains(&"tailscale"));
519 assert!(PROVIDER_NAMES.contains(&"oracle"));
520 }
521
522 #[test]
527 fn test_provider_error_display_http() {
528 let err = ProviderError::Http("connection refused".to_string());
529 assert_eq!(format!("{}", err), "HTTP error: connection refused");
530 }
531
532 #[test]
533 fn test_provider_error_display_parse() {
534 let err = ProviderError::Parse("invalid JSON".to_string());
535 assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
536 }
537
538 #[test]
539 fn test_provider_error_display_auth() {
540 let err = ProviderError::AuthFailed;
541 assert!(format!("{}", err).contains("Authentication failed"));
542 }
543
544 #[test]
545 fn test_provider_error_display_rate_limited() {
546 let err = ProviderError::RateLimited;
547 assert!(format!("{}", err).contains("Rate limited"));
548 }
549
550 #[test]
551 fn test_provider_error_display_cancelled() {
552 let err = ProviderError::Cancelled;
553 assert_eq!(format!("{}", err), "Cancelled.");
554 }
555
556 #[test]
557 fn test_provider_error_display_partial_result() {
558 let err = ProviderError::PartialResult {
559 hosts: vec![],
560 failures: 3,
561 total: 10,
562 };
563 assert!(format!("{}", err).contains("3 of 10 failed"));
564 }
565
566 #[test]
571 fn test_provider_host_construction() {
572 let host = ProviderHost::new(
573 "12345".to_string(),
574 "web-01".to_string(),
575 "1.2.3.4".to_string(),
576 vec!["prod".to_string(), "web".to_string()],
577 );
578 assert_eq!(host.server_id, "12345");
579 assert_eq!(host.name, "web-01");
580 assert_eq!(host.ip, "1.2.3.4");
581 assert_eq!(host.tags.len(), 2);
582 }
583
584 #[test]
585 fn test_provider_host_clone() {
586 let host = ProviderHost::new(
587 "1".to_string(),
588 "a".to_string(),
589 "1.1.1.1".to_string(),
590 vec![],
591 );
592 let cloned = host.clone();
593 assert_eq!(cloned.server_id, host.server_id);
594 assert_eq!(cloned.name, host.name);
595 }
596
597 #[test]
602 fn test_strip_cidr_ipv6_with_64() {
603 assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
604 }
605
606 #[test]
607 fn test_strip_cidr_ipv4_with_32() {
608 assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
609 }
610
611 #[test]
612 fn test_strip_cidr_ipv4_with_8() {
613 assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
614 }
615
616 #[test]
617 fn test_strip_cidr_just_slash() {
618 assert_eq!(strip_cidr("/"), "/");
620 }
621
622 #[test]
623 fn test_strip_cidr_slash_with_letters() {
624 assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
625 }
626
627 #[test]
628 fn test_strip_cidr_multiple_slashes() {
629 assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
631 }
632
633 #[test]
634 fn test_strip_cidr_ipv6_full_notation() {
635 assert_eq!(
636 strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
637 "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
638 );
639 }
640
641 #[test]
646 fn test_provider_error_debug_http() {
647 let err = ProviderError::Http("timeout".to_string());
648 let debug = format!("{:?}", err);
649 assert!(debug.contains("Http"));
650 assert!(debug.contains("timeout"));
651 }
652
653 #[test]
654 fn test_provider_error_debug_partial_result() {
655 let err = ProviderError::PartialResult {
656 hosts: vec![ProviderHost::new(
657 "1".to_string(),
658 "web".to_string(),
659 "1.2.3.4".to_string(),
660 vec![],
661 )],
662 failures: 2,
663 total: 5,
664 };
665 let debug = format!("{:?}", err);
666 assert!(debug.contains("PartialResult"));
667 assert!(debug.contains("failures: 2"));
668 }
669
670 #[test]
675 fn test_provider_host_empty_fields() {
676 let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
677 assert!(host.server_id.is_empty());
678 assert!(host.name.is_empty());
679 assert!(host.ip.is_empty());
680 }
681
682 #[test]
687 fn test_get_provider_with_config_all_providers() {
688 for &name in PROVIDER_NAMES {
689 let section = config::ProviderSection {
690 provider: name.to_string(),
691 token: "tok".to_string(),
692 alias_prefix: "test".to_string(),
693 user: String::new(),
694 identity_file: String::new(),
695 url: if name == "proxmox" {
696 "https://pve:8006".to_string()
697 } else {
698 String::new()
699 },
700 verify_tls: true,
701 auto_sync: true,
702 profile: String::new(),
703 regions: String::new(),
704 project: String::new(),
705 compartment: String::new(),
706 };
707 let p = get_provider_with_config(name, §ion);
708 assert!(
709 p.is_some(),
710 "get_provider_with_config({}) should return Some",
711 name
712 );
713 assert_eq!(p.unwrap().name(), name);
714 }
715 }
716
717 #[test]
722 fn test_provider_fetch_hosts_delegates_to_cancellable() {
723 let provider = get_provider("digitalocean").unwrap();
724 let result = provider.fetch_hosts("fake-token");
728 assert!(result.is_err()); }
730
731 #[test]
736 fn test_strip_cidr_digit_then_letters_not_stripped() {
737 assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
738 }
739
740 #[test]
745 fn test_provider_display_name_all() {
746 assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
747 assert_eq!(provider_display_name("vultr"), "Vultr");
748 assert_eq!(provider_display_name("linode"), "Linode");
749 assert_eq!(provider_display_name("hetzner"), "Hetzner");
750 assert_eq!(provider_display_name("upcloud"), "UpCloud");
751 assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
752 assert_eq!(provider_display_name("aws"), "AWS EC2");
753 assert_eq!(provider_display_name("scaleway"), "Scaleway");
754 assert_eq!(provider_display_name("gcp"), "GCP");
755 assert_eq!(provider_display_name("azure"), "Azure");
756 assert_eq!(provider_display_name("tailscale"), "Tailscale");
757 assert_eq!(provider_display_name("oracle"), "Oracle Cloud");
758 }
759
760 #[test]
761 fn test_provider_display_name_unknown() {
762 assert_eq!(
763 provider_display_name("unknown_provider"),
764 "unknown_provider"
765 );
766 }
767
768 #[test]
773 fn test_get_provider_all_known() {
774 for name in PROVIDER_NAMES {
775 assert!(
776 get_provider(name).is_some(),
777 "get_provider({}) should return Some",
778 name
779 );
780 }
781 }
782
783 #[test]
784 fn test_get_provider_case_sensitive_and_unknown() {
785 assert!(get_provider("unknown_provider").is_none());
786 assert!(get_provider("DigitalOcean").is_none()); assert!(get_provider("VULTR").is_none());
788 assert!(get_provider("").is_none());
789 }
790
791 #[test]
796 fn test_provider_names_has_all_twelve() {
797 assert_eq!(PROVIDER_NAMES.len(), 12);
798 assert!(PROVIDER_NAMES.contains(&"digitalocean"));
799 assert!(PROVIDER_NAMES.contains(&"proxmox"));
800 assert!(PROVIDER_NAMES.contains(&"aws"));
801 assert!(PROVIDER_NAMES.contains(&"scaleway"));
802 assert!(PROVIDER_NAMES.contains(&"azure"));
803 assert!(PROVIDER_NAMES.contains(&"tailscale"));
804 assert!(PROVIDER_NAMES.contains(&"oracle"));
805 }
806
807 #[test]
812 fn test_provider_short_labels() {
813 let cases = [
814 ("digitalocean", "do"),
815 ("vultr", "vultr"),
816 ("linode", "linode"),
817 ("hetzner", "hetzner"),
818 ("upcloud", "uc"),
819 ("proxmox", "pve"),
820 ("aws", "aws"),
821 ("scaleway", "scw"),
822 ("gcp", "gcp"),
823 ("azure", "az"),
824 ("tailscale", "ts"),
825 ];
826 for (name, expected_label) in &cases {
827 let p = get_provider(name).unwrap();
828 assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
829 }
830 }
831
832 #[test]
837 fn test_http_agent_creates_agent() {
838 let _agent = http_agent();
840 }
841
842 #[test]
843 fn test_http_agent_insecure_creates_agent() {
844 let agent = http_agent_insecure();
846 assert!(agent.is_ok());
847 }
848
849 #[test]
854 fn test_map_ureq_error_401_is_auth_failed() {
855 let err = map_ureq_error(ureq::Error::StatusCode(401));
856 assert!(matches!(err, ProviderError::AuthFailed));
857 }
858
859 #[test]
860 fn test_map_ureq_error_403_is_auth_failed() {
861 let err = map_ureq_error(ureq::Error::StatusCode(403));
862 assert!(matches!(err, ProviderError::AuthFailed));
863 }
864
865 #[test]
866 fn test_map_ureq_error_429_is_rate_limited() {
867 let err = map_ureq_error(ureq::Error::StatusCode(429));
868 assert!(matches!(err, ProviderError::RateLimited));
869 }
870
871 #[test]
872 fn test_map_ureq_error_500_is_http() {
873 let err = map_ureq_error(ureq::Error::StatusCode(500));
874 match err {
875 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 500"),
876 other => panic!("expected Http, got {:?}", other),
877 }
878 }
879
880 #[test]
881 fn test_map_ureq_error_404_is_http() {
882 let err = map_ureq_error(ureq::Error::StatusCode(404));
883 match err {
884 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 404"),
885 other => panic!("expected Http, got {:?}", other),
886 }
887 }
888
889 #[test]
890 fn test_map_ureq_error_502_is_http() {
891 let err = map_ureq_error(ureq::Error::StatusCode(502));
892 match err {
893 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 502"),
894 other => panic!("expected Http, got {:?}", other),
895 }
896 }
897
898 #[test]
899 fn test_map_ureq_error_503_is_http() {
900 let err = map_ureq_error(ureq::Error::StatusCode(503));
901 match err {
902 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 503"),
903 other => panic!("expected Http, got {:?}", other),
904 }
905 }
906
907 #[test]
908 fn test_map_ureq_error_200_is_http() {
909 let err = map_ureq_error(ureq::Error::StatusCode(200));
911 match err {
912 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 200"),
913 other => panic!("expected Http, got {:?}", other),
914 }
915 }
916
917 #[test]
918 fn test_map_ureq_error_non_status_is_http() {
919 let err = map_ureq_error(ureq::Error::HostNotFound);
921 match err {
922 ProviderError::Http(msg) => assert!(!msg.is_empty()),
923 other => panic!("expected Http, got {:?}", other),
924 }
925 }
926
927 #[test]
928 fn test_map_ureq_error_all_auth_codes_covered() {
929 for code in [400, 402, 405, 406, 407, 408, 409, 410] {
931 let err = map_ureq_error(ureq::Error::StatusCode(code));
932 assert!(
933 matches!(err, ProviderError::Http(_)),
934 "status {} should be Http, not AuthFailed",
935 code
936 );
937 }
938 }
939
940 #[test]
941 fn test_map_ureq_error_only_429_is_rate_limited() {
942 for code in [428, 430, 431] {
944 let err = map_ureq_error(ureq::Error::StatusCode(code));
945 assert!(
946 !matches!(err, ProviderError::RateLimited),
947 "status {} should not be RateLimited",
948 code
949 );
950 }
951 }
952
953 #[test]
954 fn test_map_ureq_error_io_error() {
955 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
956 let err = map_ureq_error(ureq::Error::Io(io_err));
957 match err {
958 ProviderError::Http(msg) => assert!(msg.contains("refused"), "got: {}", msg),
959 other => panic!("expected Http, got {:?}", other),
960 }
961 }
962
963 #[test]
964 fn test_map_ureq_error_timeout() {
965 let err = map_ureq_error(ureq::Error::Timeout(ureq::Timeout::Global));
966 match err {
967 ProviderError::Http(msg) => assert!(!msg.is_empty()),
968 other => panic!("expected Http, got {:?}", other),
969 }
970 }
971
972 #[test]
973 fn test_map_ureq_error_connection_failed() {
974 let err = map_ureq_error(ureq::Error::ConnectionFailed);
975 match err {
976 ProviderError::Http(msg) => assert!(!msg.is_empty()),
977 other => panic!("expected Http, got {:?}", other),
978 }
979 }
980
981 #[test]
982 fn test_map_ureq_error_bad_uri() {
983 let err = map_ureq_error(ureq::Error::BadUri("no scheme".to_string()));
984 match err {
985 ProviderError::Http(msg) => assert!(msg.contains("no scheme"), "got: {}", msg),
986 other => panic!("expected Http, got {:?}", other),
987 }
988 }
989
990 #[test]
991 fn test_map_ureq_error_too_many_redirects() {
992 let err = map_ureq_error(ureq::Error::TooManyRedirects);
993 match err {
994 ProviderError::Http(msg) => assert!(!msg.is_empty()),
995 other => panic!("expected Http, got {:?}", other),
996 }
997 }
998
999 #[test]
1000 fn test_map_ureq_error_redirect_failed() {
1001 let err = map_ureq_error(ureq::Error::RedirectFailed);
1002 match err {
1003 ProviderError::Http(msg) => assert!(!msg.is_empty()),
1004 other => panic!("expected Http, got {:?}", other),
1005 }
1006 }
1007
1008 #[test]
1009 fn test_map_ureq_error_all_status_codes_1xx_to_5xx() {
1010 for code in [
1012 100, 200, 201, 301, 302, 400, 401, 403, 404, 429, 500, 502, 503, 504,
1013 ] {
1014 let err = map_ureq_error(ureq::Error::StatusCode(code));
1015 match code {
1016 401 | 403 => assert!(
1017 matches!(err, ProviderError::AuthFailed),
1018 "status {} should be AuthFailed",
1019 code
1020 ),
1021 429 => assert!(
1022 matches!(err, ProviderError::RateLimited),
1023 "status {} should be RateLimited",
1024 code
1025 ),
1026 _ => assert!(
1027 matches!(err, ProviderError::Http(_)),
1028 "status {} should be Http",
1029 code
1030 ),
1031 }
1032 }
1033 }
1034
1035 #[test]
1041 fn test_http_get_json_response() {
1042 let mut server = mockito::Server::new();
1043 let mock = server
1044 .mock("GET", "/api/test")
1045 .with_status(200)
1046 .with_header("content-type", "application/json")
1047 .with_body(r#"{"name": "test-server", "id": 42}"#)
1048 .create();
1049
1050 let agent = http_agent();
1051 let mut resp = agent
1052 .get(&format!("{}/api/test", server.url()))
1053 .call()
1054 .unwrap();
1055
1056 #[derive(serde::Deserialize)]
1057 struct TestResp {
1058 name: String,
1059 id: u32,
1060 }
1061
1062 let body: TestResp = resp.body_mut().read_json().unwrap();
1063 assert_eq!(body.name, "test-server");
1064 assert_eq!(body.id, 42);
1065 mock.assert();
1066 }
1067
1068 #[test]
1069 fn test_http_get_with_bearer_header() {
1070 let mut server = mockito::Server::new();
1071 let mock = server
1072 .mock("GET", "/api/hosts")
1073 .match_header("Authorization", "Bearer my-secret-token")
1074 .with_status(200)
1075 .with_header("content-type", "application/json")
1076 .with_body(r#"{"hosts": []}"#)
1077 .create();
1078
1079 let agent = http_agent();
1080 let resp = agent
1081 .get(&format!("{}/api/hosts", server.url()))
1082 .header("Authorization", "Bearer my-secret-token")
1083 .call();
1084
1085 assert!(resp.is_ok());
1086 mock.assert();
1087 }
1088
1089 #[test]
1090 fn test_http_get_with_custom_header() {
1091 let mut server = mockito::Server::new();
1092 let mock = server
1093 .mock("GET", "/api/servers")
1094 .match_header("X-Auth-Token", "scw-token-123")
1095 .with_status(200)
1096 .with_header("content-type", "application/json")
1097 .with_body(r#"{"servers": []}"#)
1098 .create();
1099
1100 let agent = http_agent();
1101 let resp = agent
1102 .get(&format!("{}/api/servers", server.url()))
1103 .header("X-Auth-Token", "scw-token-123")
1104 .call();
1105
1106 assert!(resp.is_ok());
1107 mock.assert();
1108 }
1109
1110 #[test]
1111 fn test_http_401_maps_to_auth_failed() {
1112 let mut server = mockito::Server::new();
1113 let mock = server
1114 .mock("GET", "/api/test")
1115 .with_status(401)
1116 .with_body("Unauthorized")
1117 .create();
1118
1119 let agent = http_agent();
1120 let err = agent
1121 .get(&format!("{}/api/test", server.url()))
1122 .call()
1123 .unwrap_err();
1124
1125 let provider_err = map_ureq_error(err);
1126 assert!(matches!(provider_err, ProviderError::AuthFailed));
1127 mock.assert();
1128 }
1129
1130 #[test]
1131 fn test_http_403_maps_to_auth_failed() {
1132 let mut server = mockito::Server::new();
1133 let mock = server
1134 .mock("GET", "/api/test")
1135 .with_status(403)
1136 .with_body("Forbidden")
1137 .create();
1138
1139 let agent = http_agent();
1140 let err = agent
1141 .get(&format!("{}/api/test", server.url()))
1142 .call()
1143 .unwrap_err();
1144
1145 let provider_err = map_ureq_error(err);
1146 assert!(matches!(provider_err, ProviderError::AuthFailed));
1147 mock.assert();
1148 }
1149
1150 #[test]
1151 fn test_http_429_maps_to_rate_limited() {
1152 let mut server = mockito::Server::new();
1153 let mock = server
1154 .mock("GET", "/api/test")
1155 .with_status(429)
1156 .with_body("Too Many Requests")
1157 .create();
1158
1159 let agent = http_agent();
1160 let err = agent
1161 .get(&format!("{}/api/test", server.url()))
1162 .call()
1163 .unwrap_err();
1164
1165 let provider_err = map_ureq_error(err);
1166 assert!(matches!(provider_err, ProviderError::RateLimited));
1167 mock.assert();
1168 }
1169
1170 #[test]
1171 fn test_http_500_maps_to_http_error() {
1172 let mut server = mockito::Server::new();
1173 let mock = server
1174 .mock("GET", "/api/test")
1175 .with_status(500)
1176 .with_body("Internal Server Error")
1177 .create();
1178
1179 let agent = http_agent();
1180 let err = agent
1181 .get(&format!("{}/api/test", server.url()))
1182 .call()
1183 .unwrap_err();
1184
1185 let provider_err = map_ureq_error(err);
1186 match provider_err {
1187 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 500"),
1188 other => panic!("expected Http, got {:?}", other),
1189 }
1190 mock.assert();
1191 }
1192
1193 #[test]
1194 fn test_http_post_form_encoding() {
1195 let mut server = mockito::Server::new();
1196 let mock = server
1197 .mock("POST", "/oauth/token")
1198 .match_header("content-type", "application/x-www-form-urlencoded")
1199 .match_body(
1200 "grant_type=client_credentials&client_id=my-app&client_secret=secret123&scope=api",
1201 )
1202 .with_status(200)
1203 .with_header("content-type", "application/json")
1204 .with_body(r#"{"access_token": "eyJ.abc.def"}"#)
1205 .create();
1206
1207 let agent = http_agent();
1208 let client_id = "my-app".to_string();
1209 let client_secret = "secret123".to_string();
1210 let mut resp = agent
1211 .post(&format!("{}/oauth/token", server.url()))
1212 .send_form([
1213 ("grant_type", "client_credentials"),
1214 ("client_id", client_id.as_str()),
1215 ("client_secret", client_secret.as_str()),
1216 ("scope", "api"),
1217 ])
1218 .unwrap();
1219
1220 #[derive(serde::Deserialize)]
1221 struct TokenResp {
1222 access_token: String,
1223 }
1224
1225 let body: TokenResp = resp.body_mut().read_json().unwrap();
1226 assert_eq!(body.access_token, "eyJ.abc.def");
1227 mock.assert();
1228 }
1229
1230 #[test]
1231 fn test_http_read_to_string() {
1232 let mut server = mockito::Server::new();
1233 let mock = server
1234 .mock("GET", "/api/xml")
1235 .with_status(200)
1236 .with_header("content-type", "text/xml")
1237 .with_body("<root><item>hello</item></root>")
1238 .create();
1239
1240 let agent = http_agent();
1241 let mut resp = agent
1242 .get(&format!("{}/api/xml", server.url()))
1243 .call()
1244 .unwrap();
1245
1246 let body = resp.body_mut().read_to_string().unwrap();
1247 assert_eq!(body, "<root><item>hello</item></root>");
1248 mock.assert();
1249 }
1250
1251 #[test]
1252 fn test_http_body_reader_with_take() {
1253 use std::io::Read;
1255
1256 let mut server = mockito::Server::new();
1257 let mock = server
1258 .mock("GET", "/download")
1259 .with_status(200)
1260 .with_body("binary-content-here-12345")
1261 .create();
1262
1263 let agent = http_agent();
1264 let mut resp = agent
1265 .get(&format!("{}/download", server.url()))
1266 .call()
1267 .unwrap();
1268
1269 let mut bytes = Vec::new();
1270 resp.body_mut()
1271 .as_reader()
1272 .take(1_048_576)
1273 .read_to_end(&mut bytes)
1274 .unwrap();
1275
1276 assert_eq!(bytes, b"binary-content-here-12345");
1277 mock.assert();
1278 }
1279
1280 #[test]
1281 fn test_http_body_reader_take_truncates() {
1282 use std::io::Read;
1284
1285 let mut server = mockito::Server::new();
1286 let mock = server
1287 .mock("GET", "/large")
1288 .with_status(200)
1289 .with_body("abcdefghijklmnopqrstuvwxyz")
1290 .create();
1291
1292 let agent = http_agent();
1293 let mut resp = agent
1294 .get(&format!("{}/large", server.url()))
1295 .call()
1296 .unwrap();
1297
1298 let mut bytes = Vec::new();
1299 resp.body_mut()
1300 .as_reader()
1301 .take(10) .read_to_end(&mut bytes)
1303 .unwrap();
1304
1305 assert_eq!(bytes, b"abcdefghij");
1306 mock.assert();
1307 }
1308
1309 #[test]
1310 fn test_http_no_redirects() {
1311 let mut server = mockito::Server::new();
1315 let redirect_mock = server
1316 .mock("GET", "/redirect")
1317 .with_status(302)
1318 .with_header("Location", "/target")
1319 .create();
1320 let target_mock = server.mock("GET", "/target").with_status(200).create();
1321
1322 let agent = http_agent();
1323 let resp = agent
1324 .get(&format!("{}/redirect", server.url()))
1325 .call()
1326 .unwrap();
1327
1328 assert_eq!(resp.status(), 302);
1329 redirect_mock.assert();
1330 target_mock.expect(0); }
1332
1333 #[test]
1334 fn test_http_invalid_json_returns_parse_error() {
1335 let mut server = mockito::Server::new();
1336 let mock = server
1337 .mock("GET", "/api/bad")
1338 .with_status(200)
1339 .with_header("content-type", "application/json")
1340 .with_body("this is not json")
1341 .create();
1342
1343 let agent = http_agent();
1344 let mut resp = agent
1345 .get(&format!("{}/api/bad", server.url()))
1346 .call()
1347 .unwrap();
1348
1349 #[derive(serde::Deserialize)]
1350 #[allow(dead_code)]
1351 struct Expected {
1352 name: String,
1353 }
1354
1355 let result: Result<Expected, _> = resp.body_mut().read_json();
1356 assert!(result.is_err());
1357 mock.assert();
1358 }
1359
1360 #[test]
1361 fn test_http_empty_json_body_returns_parse_error() {
1362 let mut server = mockito::Server::new();
1363 let mock = server
1364 .mock("GET", "/api/empty")
1365 .with_status(200)
1366 .with_header("content-type", "application/json")
1367 .with_body("")
1368 .create();
1369
1370 let agent = http_agent();
1371 let mut resp = agent
1372 .get(&format!("{}/api/empty", server.url()))
1373 .call()
1374 .unwrap();
1375
1376 #[derive(serde::Deserialize)]
1377 #[allow(dead_code)]
1378 struct Expected {
1379 name: String,
1380 }
1381
1382 let result: Result<Expected, _> = resp.body_mut().read_json();
1383 assert!(result.is_err());
1384 mock.assert();
1385 }
1386
1387 #[test]
1388 fn test_http_multiple_headers() {
1389 let mut server = mockito::Server::new();
1391 let mock = server
1392 .mock("GET", "/api/aws")
1393 .match_header("Authorization", "AWS4-HMAC-SHA256 cred=test")
1394 .match_header("x-amz-date", "20260324T120000Z")
1395 .with_status(200)
1396 .with_header("content-type", "text/xml")
1397 .with_body("<result/>")
1398 .create();
1399
1400 let agent = http_agent();
1401 let mut resp = agent
1402 .get(&format!("{}/api/aws", server.url()))
1403 .header("Authorization", "AWS4-HMAC-SHA256 cred=test")
1404 .header("x-amz-date", "20260324T120000Z")
1405 .call()
1406 .unwrap();
1407
1408 let body = resp.body_mut().read_to_string().unwrap();
1409 assert_eq!(body, "<result/>");
1410 mock.assert();
1411 }
1412
1413 #[test]
1414 fn test_http_connection_refused_maps_to_http_error() {
1415 let agent = http_agent();
1417 let err = agent.get("http://127.0.0.1:1").call().unwrap_err();
1418
1419 let provider_err = map_ureq_error(err);
1420 match provider_err {
1421 ProviderError::Http(msg) => assert!(!msg.is_empty()),
1422 other => panic!("expected Http, got {:?}", other),
1423 }
1424 }
1425
1426 #[test]
1427 fn test_http_nested_json_deserialization() {
1428 let mut server = mockito::Server::new();
1430 let mock = server
1431 .mock("GET", "/api/droplets")
1432 .with_status(200)
1433 .with_header("content-type", "application/json")
1434 .with_body(
1435 r#"{
1436 "data": [
1437 {"id": "1", "name": "web-01", "ip": "1.2.3.4"},
1438 {"id": "2", "name": "web-02", "ip": "5.6.7.8"}
1439 ],
1440 "meta": {"total": 2}
1441 }"#,
1442 )
1443 .create();
1444
1445 #[derive(serde::Deserialize)]
1446 #[allow(dead_code)]
1447 struct Host {
1448 id: String,
1449 name: String,
1450 ip: String,
1451 }
1452 #[derive(serde::Deserialize)]
1453 #[allow(dead_code)]
1454 struct Meta {
1455 total: u32,
1456 }
1457 #[derive(serde::Deserialize)]
1458 #[allow(dead_code)]
1459 struct Resp {
1460 data: Vec<Host>,
1461 meta: Meta,
1462 }
1463
1464 let agent = http_agent();
1465 let mut resp = agent
1466 .get(&format!("{}/api/droplets", server.url()))
1467 .call()
1468 .unwrap();
1469
1470 let body: Resp = resp.body_mut().read_json().unwrap();
1471 assert_eq!(body.data.len(), 2);
1472 assert_eq!(body.data[0].name, "web-01");
1473 assert_eq!(body.data[1].ip, "5.6.7.8");
1474 assert_eq!(body.meta.total, 2);
1475 mock.assert();
1476 }
1477
1478 #[test]
1479 fn test_http_xml_deserialization_with_quick_xml() {
1480 let mut server = mockito::Server::new();
1482 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1483 <DescribeInstancesResponse>
1484 <reservationSet>
1485 <item>
1486 <instancesSet>
1487 <item>
1488 <instanceId>i-abc123</instanceId>
1489 <instanceState><name>running</name></instanceState>
1490 </item>
1491 </instancesSet>
1492 </item>
1493 </reservationSet>
1494 </DescribeInstancesResponse>"#;
1495
1496 let mock = server
1497 .mock("GET", "/ec2")
1498 .with_status(200)
1499 .with_header("content-type", "text/xml")
1500 .with_body(xml)
1501 .create();
1502
1503 let agent = http_agent();
1504 let mut resp = agent.get(&format!("{}/ec2", server.url())).call().unwrap();
1505
1506 let body = resp.body_mut().read_to_string().unwrap();
1507 #[derive(serde::Deserialize)]
1509 struct InstanceState {
1510 name: String,
1511 }
1512 #[derive(serde::Deserialize)]
1513 struct Instance {
1514 #[serde(rename = "instanceId")]
1515 instance_id: String,
1516 #[serde(rename = "instanceState")]
1517 instance_state: InstanceState,
1518 }
1519 #[derive(serde::Deserialize)]
1520 struct InstanceSet {
1521 item: Vec<Instance>,
1522 }
1523 #[derive(serde::Deserialize)]
1524 struct Reservation {
1525 #[serde(rename = "instancesSet")]
1526 instances_set: InstanceSet,
1527 }
1528 #[derive(serde::Deserialize)]
1529 struct ReservationSet {
1530 item: Vec<Reservation>,
1531 }
1532 #[derive(serde::Deserialize)]
1533 struct DescribeResp {
1534 #[serde(rename = "reservationSet")]
1535 reservation_set: ReservationSet,
1536 }
1537
1538 let parsed: DescribeResp = quick_xml::de::from_str(&body).unwrap();
1539 assert_eq!(
1540 parsed.reservation_set.item[0].instances_set.item[0].instance_id,
1541 "i-abc123"
1542 );
1543 assert_eq!(
1544 parsed.reservation_set.item[0].instances_set.item[0]
1545 .instance_state
1546 .name,
1547 "running"
1548 );
1549 mock.assert();
1550 }
1551}