1pub mod aws;
2pub mod azure;
3pub mod config;
4mod digitalocean;
5pub mod gcp;
6mod hetzner;
7mod linode;
8mod proxmox;
9pub mod scaleway;
10pub mod sync;
11mod tailscale;
12mod upcloud;
13mod vultr;
14
15use std::sync::atomic::AtomicBool;
16
17use thiserror::Error;
18
19#[derive(Debug, Clone)]
21#[allow(dead_code)]
22pub struct ProviderHost {
23 pub server_id: String,
25 pub name: String,
27 pub ip: String,
29 pub tags: Vec<String>,
31 pub metadata: Vec<(String, String)>,
33}
34
35impl ProviderHost {
36 #[allow(dead_code)]
38 pub fn new(server_id: String, name: String, ip: String, tags: Vec<String>) -> Self {
39 Self {
40 server_id,
41 name,
42 ip,
43 tags,
44 metadata: Vec::new(),
45 }
46 }
47}
48
49#[derive(Debug, Error)]
51pub enum ProviderError {
52 #[error("HTTP error: {0}")]
53 Http(String),
54 #[error("Failed to parse response: {0}")]
55 Parse(String),
56 #[error("Authentication failed. Check your API token.")]
57 AuthFailed,
58 #[error("Rate limited. Try again in a moment.")]
59 RateLimited,
60 #[error("{0}")]
61 Execute(String),
62 #[error("Cancelled.")]
63 Cancelled,
64 #[error("Partial result: {failures} of {total} failed")]
67 PartialResult {
68 hosts: Vec<ProviderHost>,
69 failures: usize,
70 total: usize,
71 },
72}
73
74pub trait Provider {
76 fn name(&self) -> &str;
78 fn short_label(&self) -> &str;
80 fn fetch_hosts_cancellable(
82 &self,
83 token: &str,
84 cancel: &AtomicBool,
85 ) -> Result<Vec<ProviderHost>, ProviderError>;
86 #[allow(dead_code)]
88 fn fetch_hosts(&self, token: &str) -> Result<Vec<ProviderHost>, ProviderError> {
89 self.fetch_hosts_cancellable(token, &AtomicBool::new(false))
90 }
91 fn fetch_hosts_with_progress(
93 &self,
94 token: &str,
95 cancel: &AtomicBool,
96 _progress: &dyn Fn(&str),
97 ) -> Result<Vec<ProviderHost>, ProviderError> {
98 self.fetch_hosts_cancellable(token, cancel)
99 }
100}
101
102pub const PROVIDER_NAMES: &[&str] = &[
104 "digitalocean",
105 "vultr",
106 "linode",
107 "hetzner",
108 "upcloud",
109 "proxmox",
110 "aws",
111 "scaleway",
112 "gcp",
113 "azure",
114 "tailscale",
115];
116
117pub fn get_provider(name: &str) -> Option<Box<dyn Provider>> {
119 match name {
120 "digitalocean" => Some(Box::new(digitalocean::DigitalOcean)),
121 "vultr" => Some(Box::new(vultr::Vultr)),
122 "linode" => Some(Box::new(linode::Linode)),
123 "hetzner" => Some(Box::new(hetzner::Hetzner)),
124 "upcloud" => Some(Box::new(upcloud::UpCloud)),
125 "proxmox" => Some(Box::new(proxmox::Proxmox {
126 base_url: String::new(),
127 verify_tls: true,
128 })),
129 "aws" => Some(Box::new(aws::Aws {
130 regions: Vec::new(),
131 profile: String::new(),
132 })),
133 "scaleway" => Some(Box::new(scaleway::Scaleway { zones: Vec::new() })),
134 "gcp" => Some(Box::new(gcp::Gcp {
135 zones: Vec::new(),
136 project: String::new(),
137 })),
138 "azure" => Some(Box::new(azure::Azure {
139 subscriptions: Vec::new(),
140 })),
141 "tailscale" => Some(Box::new(tailscale::Tailscale)),
142 _ => None,
143 }
144}
145
146pub fn get_provider_with_config(
150 name: &str,
151 section: &config::ProviderSection,
152) -> Option<Box<dyn Provider>> {
153 match name {
154 "proxmox" => Some(Box::new(proxmox::Proxmox {
155 base_url: section.url.clone(),
156 verify_tls: section.verify_tls,
157 })),
158 "aws" => Some(Box::new(aws::Aws {
159 regions: section
160 .regions
161 .split(',')
162 .map(|s| s.trim().to_string())
163 .filter(|s| !s.is_empty())
164 .collect(),
165 profile: section.profile.clone(),
166 })),
167 "scaleway" => Some(Box::new(scaleway::Scaleway {
168 zones: section
169 .regions
170 .split(',')
171 .map(|s| s.trim().to_string())
172 .filter(|s| !s.is_empty())
173 .collect(),
174 })),
175 "gcp" => Some(Box::new(gcp::Gcp {
176 zones: section
177 .regions
178 .split(',')
179 .map(|s| s.trim().to_string())
180 .filter(|s| !s.is_empty())
181 .collect(),
182 project: section.project.clone(),
183 })),
184 "azure" => Some(Box::new(azure::Azure {
185 subscriptions: section
186 .regions
187 .split(',')
188 .map(|s| s.trim().to_string())
189 .filter(|s| !s.is_empty())
190 .collect(),
191 })),
192 _ => get_provider(name),
193 }
194}
195
196pub fn provider_display_name(name: &str) -> &str {
198 match name {
199 "digitalocean" => "DigitalOcean",
200 "vultr" => "Vultr",
201 "linode" => "Linode",
202 "hetzner" => "Hetzner",
203 "upcloud" => "UpCloud",
204 "proxmox" => "Proxmox VE",
205 "aws" => "AWS EC2",
206 "scaleway" => "Scaleway",
207 "gcp" => "GCP",
208 "azure" => "Azure",
209 "tailscale" => "Tailscale",
210 other => other,
211 }
212}
213
214pub(crate) fn http_agent() -> ureq::Agent {
216 ureq::Agent::config_builder()
217 .timeout_global(Some(std::time::Duration::from_secs(30)))
218 .max_redirects(0)
219 .build()
220 .new_agent()
221}
222
223pub(crate) fn http_agent_insecure() -> Result<ureq::Agent, ProviderError> {
225 Ok(ureq::Agent::config_builder()
226 .timeout_global(Some(std::time::Duration::from_secs(30)))
227 .max_redirects(0)
228 .tls_config(
229 ureq::tls::TlsConfig::builder()
230 .provider(ureq::tls::TlsProvider::NativeTls)
231 .disable_verification(true)
232 .build(),
233 )
234 .build()
235 .new_agent())
236}
237
238pub(crate) fn strip_cidr(ip: &str) -> &str {
242 if let Some(pos) = ip.rfind('/') {
244 if ip[pos + 1..].bytes().all(|b| b.is_ascii_digit()) && pos + 1 < ip.len() {
245 return &ip[..pos];
246 }
247 }
248 ip
249}
250
251fn map_ureq_error(err: ureq::Error) -> ProviderError {
253 match err {
254 ureq::Error::StatusCode(code) => match code {
255 401 | 403 => ProviderError::AuthFailed,
256 429 => ProviderError::RateLimited,
257 _ => ProviderError::Http(format!("HTTP {}", code)),
258 },
259 other => ProviderError::Http(other.to_string()),
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
272 fn test_strip_cidr_ipv6_with_prefix() {
273 assert_eq!(strip_cidr("2600:3c00::1/128"), "2600:3c00::1");
274 assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
275 }
276
277 #[test]
278 fn test_strip_cidr_bare_ipv6() {
279 assert_eq!(strip_cidr("2600:3c00::1"), "2600:3c00::1");
280 }
281
282 #[test]
283 fn test_strip_cidr_ipv4_passthrough() {
284 assert_eq!(strip_cidr("1.2.3.4"), "1.2.3.4");
285 assert_eq!(strip_cidr("10.0.0.1/24"), "10.0.0.1");
286 }
287
288 #[test]
289 fn test_strip_cidr_empty() {
290 assert_eq!(strip_cidr(""), "");
291 }
292
293 #[test]
294 fn test_strip_cidr_slash_without_digits() {
295 assert_eq!(strip_cidr("path/to/something"), "path/to/something");
297 }
298
299 #[test]
300 fn test_strip_cidr_trailing_slash() {
301 assert_eq!(strip_cidr("1.2.3.4/"), "1.2.3.4/");
303 }
304
305 #[test]
310 fn test_get_provider_digitalocean() {
311 let p = get_provider("digitalocean").unwrap();
312 assert_eq!(p.name(), "digitalocean");
313 assert_eq!(p.short_label(), "do");
314 }
315
316 #[test]
317 fn test_get_provider_vultr() {
318 let p = get_provider("vultr").unwrap();
319 assert_eq!(p.name(), "vultr");
320 assert_eq!(p.short_label(), "vultr");
321 }
322
323 #[test]
324 fn test_get_provider_linode() {
325 let p = get_provider("linode").unwrap();
326 assert_eq!(p.name(), "linode");
327 assert_eq!(p.short_label(), "linode");
328 }
329
330 #[test]
331 fn test_get_provider_hetzner() {
332 let p = get_provider("hetzner").unwrap();
333 assert_eq!(p.name(), "hetzner");
334 assert_eq!(p.short_label(), "hetzner");
335 }
336
337 #[test]
338 fn test_get_provider_upcloud() {
339 let p = get_provider("upcloud").unwrap();
340 assert_eq!(p.name(), "upcloud");
341 assert_eq!(p.short_label(), "uc");
342 }
343
344 #[test]
345 fn test_get_provider_proxmox() {
346 let p = get_provider("proxmox").unwrap();
347 assert_eq!(p.name(), "proxmox");
348 assert_eq!(p.short_label(), "pve");
349 }
350
351 #[test]
352 fn test_get_provider_unknown_returns_none() {
353 assert!(get_provider("oracle").is_none());
354 assert!(get_provider("").is_none());
355 assert!(get_provider("DigitalOcean").is_none()); }
357
358 #[test]
359 fn test_get_provider_all_names_resolve() {
360 for name in PROVIDER_NAMES {
361 assert!(
362 get_provider(name).is_some(),
363 "Provider '{}' should resolve",
364 name
365 );
366 }
367 }
368
369 #[test]
374 fn test_get_provider_with_config_proxmox_uses_url() {
375 let section = config::ProviderSection {
376 provider: "proxmox".to_string(),
377 token: "user@pam!token=secret".to_string(),
378 alias_prefix: "pve-".to_string(),
379 user: String::new(),
380 identity_file: String::new(),
381 url: "https://pve.example.com:8006".to_string(),
382 verify_tls: false,
383 auto_sync: false,
384 profile: String::new(),
385 regions: String::new(),
386 project: String::new(),
387 };
388 let p = get_provider_with_config("proxmox", §ion).unwrap();
389 assert_eq!(p.name(), "proxmox");
390 }
391
392 #[test]
393 fn test_get_provider_with_config_non_proxmox_delegates() {
394 let section = config::ProviderSection {
395 provider: "digitalocean".to_string(),
396 token: "do-token".to_string(),
397 alias_prefix: "do-".to_string(),
398 user: String::new(),
399 identity_file: String::new(),
400 url: String::new(),
401 verify_tls: true,
402 auto_sync: true,
403 profile: String::new(),
404 regions: String::new(),
405 project: String::new(),
406 };
407 let p = get_provider_with_config("digitalocean", §ion).unwrap();
408 assert_eq!(p.name(), "digitalocean");
409 }
410
411 #[test]
412 fn test_get_provider_with_config_gcp_uses_project_and_zones() {
413 let section = config::ProviderSection {
414 provider: "gcp".to_string(),
415 token: "sa.json".to_string(),
416 alias_prefix: "gcp".to_string(),
417 user: String::new(),
418 identity_file: String::new(),
419 url: String::new(),
420 verify_tls: true,
421 auto_sync: true,
422 profile: String::new(),
423 regions: "us-central1-a, europe-west1-b".to_string(),
424 project: "my-project".to_string(),
425 };
426 let p = get_provider_with_config("gcp", §ion).unwrap();
427 assert_eq!(p.name(), "gcp");
428 }
429
430 #[test]
431 fn test_get_provider_with_config_unknown_returns_none() {
432 let section = config::ProviderSection {
433 provider: "oracle".to_string(),
434 token: String::new(),
435 alias_prefix: String::new(),
436 user: String::new(),
437 identity_file: String::new(),
438 url: String::new(),
439 verify_tls: true,
440 auto_sync: true,
441 profile: String::new(),
442 regions: String::new(),
443 project: String::new(),
444 };
445 assert!(get_provider_with_config("oracle", §ion).is_none());
446 }
447
448 #[test]
453 fn test_display_name_all_providers() {
454 assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
455 assert_eq!(provider_display_name("vultr"), "Vultr");
456 assert_eq!(provider_display_name("linode"), "Linode");
457 assert_eq!(provider_display_name("hetzner"), "Hetzner");
458 assert_eq!(provider_display_name("upcloud"), "UpCloud");
459 assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
460 assert_eq!(provider_display_name("aws"), "AWS EC2");
461 assert_eq!(provider_display_name("scaleway"), "Scaleway");
462 assert_eq!(provider_display_name("gcp"), "GCP");
463 assert_eq!(provider_display_name("azure"), "Azure");
464 assert_eq!(provider_display_name("tailscale"), "Tailscale");
465 }
466
467 #[test]
468 fn test_display_name_unknown_returns_input() {
469 assert_eq!(provider_display_name("oracle"), "oracle");
470 assert_eq!(provider_display_name(""), "");
471 }
472
473 #[test]
478 fn test_provider_names_count() {
479 assert_eq!(PROVIDER_NAMES.len(), 11);
480 }
481
482 #[test]
483 fn test_provider_names_contains_all() {
484 assert!(PROVIDER_NAMES.contains(&"digitalocean"));
485 assert!(PROVIDER_NAMES.contains(&"vultr"));
486 assert!(PROVIDER_NAMES.contains(&"linode"));
487 assert!(PROVIDER_NAMES.contains(&"hetzner"));
488 assert!(PROVIDER_NAMES.contains(&"upcloud"));
489 assert!(PROVIDER_NAMES.contains(&"proxmox"));
490 assert!(PROVIDER_NAMES.contains(&"aws"));
491 assert!(PROVIDER_NAMES.contains(&"scaleway"));
492 assert!(PROVIDER_NAMES.contains(&"gcp"));
493 assert!(PROVIDER_NAMES.contains(&"azure"));
494 assert!(PROVIDER_NAMES.contains(&"tailscale"));
495 }
496
497 #[test]
502 fn test_provider_error_display_http() {
503 let err = ProviderError::Http("connection refused".to_string());
504 assert_eq!(format!("{}", err), "HTTP error: connection refused");
505 }
506
507 #[test]
508 fn test_provider_error_display_parse() {
509 let err = ProviderError::Parse("invalid JSON".to_string());
510 assert_eq!(format!("{}", err), "Failed to parse response: invalid JSON");
511 }
512
513 #[test]
514 fn test_provider_error_display_auth() {
515 let err = ProviderError::AuthFailed;
516 assert!(format!("{}", err).contains("Authentication failed"));
517 }
518
519 #[test]
520 fn test_provider_error_display_rate_limited() {
521 let err = ProviderError::RateLimited;
522 assert!(format!("{}", err).contains("Rate limited"));
523 }
524
525 #[test]
526 fn test_provider_error_display_cancelled() {
527 let err = ProviderError::Cancelled;
528 assert_eq!(format!("{}", err), "Cancelled.");
529 }
530
531 #[test]
532 fn test_provider_error_display_partial_result() {
533 let err = ProviderError::PartialResult {
534 hosts: vec![],
535 failures: 3,
536 total: 10,
537 };
538 assert!(format!("{}", err).contains("3 of 10 failed"));
539 }
540
541 #[test]
546 fn test_provider_host_construction() {
547 let host = ProviderHost::new(
548 "12345".to_string(),
549 "web-01".to_string(),
550 "1.2.3.4".to_string(),
551 vec!["prod".to_string(), "web".to_string()],
552 );
553 assert_eq!(host.server_id, "12345");
554 assert_eq!(host.name, "web-01");
555 assert_eq!(host.ip, "1.2.3.4");
556 assert_eq!(host.tags.len(), 2);
557 }
558
559 #[test]
560 fn test_provider_host_clone() {
561 let host = ProviderHost::new(
562 "1".to_string(),
563 "a".to_string(),
564 "1.1.1.1".to_string(),
565 vec![],
566 );
567 let cloned = host.clone();
568 assert_eq!(cloned.server_id, host.server_id);
569 assert_eq!(cloned.name, host.name);
570 }
571
572 #[test]
577 fn test_strip_cidr_ipv6_with_64() {
578 assert_eq!(strip_cidr("2a01:4f8::1/64"), "2a01:4f8::1");
579 }
580
581 #[test]
582 fn test_strip_cidr_ipv4_with_32() {
583 assert_eq!(strip_cidr("1.2.3.4/32"), "1.2.3.4");
584 }
585
586 #[test]
587 fn test_strip_cidr_ipv4_with_8() {
588 assert_eq!(strip_cidr("10.0.0.1/8"), "10.0.0.1");
589 }
590
591 #[test]
592 fn test_strip_cidr_just_slash() {
593 assert_eq!(strip_cidr("/"), "/");
595 }
596
597 #[test]
598 fn test_strip_cidr_slash_with_letters() {
599 assert_eq!(strip_cidr("10.0.0.1/abc"), "10.0.0.1/abc");
600 }
601
602 #[test]
603 fn test_strip_cidr_multiple_slashes() {
604 assert_eq!(strip_cidr("10.0.0.1/24/48"), "10.0.0.1/24");
606 }
607
608 #[test]
609 fn test_strip_cidr_ipv6_full_notation() {
610 assert_eq!(
611 strip_cidr("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128"),
612 "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
613 );
614 }
615
616 #[test]
621 fn test_provider_error_debug_http() {
622 let err = ProviderError::Http("timeout".to_string());
623 let debug = format!("{:?}", err);
624 assert!(debug.contains("Http"));
625 assert!(debug.contains("timeout"));
626 }
627
628 #[test]
629 fn test_provider_error_debug_partial_result() {
630 let err = ProviderError::PartialResult {
631 hosts: vec![ProviderHost::new(
632 "1".to_string(),
633 "web".to_string(),
634 "1.2.3.4".to_string(),
635 vec![],
636 )],
637 failures: 2,
638 total: 5,
639 };
640 let debug = format!("{:?}", err);
641 assert!(debug.contains("PartialResult"));
642 assert!(debug.contains("failures: 2"));
643 }
644
645 #[test]
650 fn test_provider_host_empty_fields() {
651 let host = ProviderHost::new(String::new(), String::new(), String::new(), vec![]);
652 assert!(host.server_id.is_empty());
653 assert!(host.name.is_empty());
654 assert!(host.ip.is_empty());
655 }
656
657 #[test]
662 fn test_get_provider_with_config_all_providers() {
663 for &name in PROVIDER_NAMES {
664 let section = config::ProviderSection {
665 provider: name.to_string(),
666 token: "tok".to_string(),
667 alias_prefix: "test".to_string(),
668 user: String::new(),
669 identity_file: String::new(),
670 url: if name == "proxmox" {
671 "https://pve:8006".to_string()
672 } else {
673 String::new()
674 },
675 verify_tls: true,
676 auto_sync: true,
677 profile: String::new(),
678 regions: String::new(),
679 project: String::new(),
680 };
681 let p = get_provider_with_config(name, §ion);
682 assert!(
683 p.is_some(),
684 "get_provider_with_config({}) should return Some",
685 name
686 );
687 assert_eq!(p.unwrap().name(), name);
688 }
689 }
690
691 #[test]
696 fn test_provider_fetch_hosts_delegates_to_cancellable() {
697 let provider = get_provider("digitalocean").unwrap();
698 let result = provider.fetch_hosts("fake-token");
702 assert!(result.is_err()); }
704
705 #[test]
710 fn test_strip_cidr_digit_then_letters_not_stripped() {
711 assert_eq!(strip_cidr("10.0.0.1/24abc"), "10.0.0.1/24abc");
712 }
713
714 #[test]
719 fn test_provider_display_name_all() {
720 assert_eq!(provider_display_name("digitalocean"), "DigitalOcean");
721 assert_eq!(provider_display_name("vultr"), "Vultr");
722 assert_eq!(provider_display_name("linode"), "Linode");
723 assert_eq!(provider_display_name("hetzner"), "Hetzner");
724 assert_eq!(provider_display_name("upcloud"), "UpCloud");
725 assert_eq!(provider_display_name("proxmox"), "Proxmox VE");
726 assert_eq!(provider_display_name("aws"), "AWS EC2");
727 assert_eq!(provider_display_name("scaleway"), "Scaleway");
728 assert_eq!(provider_display_name("gcp"), "GCP");
729 assert_eq!(provider_display_name("azure"), "Azure");
730 assert_eq!(provider_display_name("tailscale"), "Tailscale");
731 }
732
733 #[test]
734 fn test_provider_display_name_unknown() {
735 assert_eq!(provider_display_name("oracle"), "oracle");
736 }
737
738 #[test]
743 fn test_get_provider_all_known() {
744 for name in PROVIDER_NAMES {
745 assert!(
746 get_provider(name).is_some(),
747 "get_provider({}) should return Some",
748 name
749 );
750 }
751 }
752
753 #[test]
754 fn test_get_provider_case_sensitive_and_unknown() {
755 assert!(get_provider("oracle").is_none());
756 assert!(get_provider("DigitalOcean").is_none()); assert!(get_provider("VULTR").is_none());
758 assert!(get_provider("").is_none());
759 }
760
761 #[test]
766 fn test_provider_names_has_all_eleven() {
767 assert_eq!(PROVIDER_NAMES.len(), 11);
768 assert!(PROVIDER_NAMES.contains(&"digitalocean"));
769 assert!(PROVIDER_NAMES.contains(&"proxmox"));
770 assert!(PROVIDER_NAMES.contains(&"aws"));
771 assert!(PROVIDER_NAMES.contains(&"scaleway"));
772 assert!(PROVIDER_NAMES.contains(&"azure"));
773 assert!(PROVIDER_NAMES.contains(&"tailscale"));
774 }
775
776 #[test]
781 fn test_provider_short_labels() {
782 let cases = [
783 ("digitalocean", "do"),
784 ("vultr", "vultr"),
785 ("linode", "linode"),
786 ("hetzner", "hetzner"),
787 ("upcloud", "uc"),
788 ("proxmox", "pve"),
789 ("aws", "aws"),
790 ("scaleway", "scw"),
791 ("gcp", "gcp"),
792 ("azure", "az"),
793 ("tailscale", "ts"),
794 ];
795 for (name, expected_label) in &cases {
796 let p = get_provider(name).unwrap();
797 assert_eq!(p.short_label(), *expected_label, "short_label for {}", name);
798 }
799 }
800
801 #[test]
806 fn test_http_agent_creates_agent() {
807 let _agent = http_agent();
809 }
810
811 #[test]
812 fn test_http_agent_insecure_creates_agent() {
813 let agent = http_agent_insecure();
815 assert!(agent.is_ok());
816 }
817
818 #[test]
823 fn test_map_ureq_error_401_is_auth_failed() {
824 let err = map_ureq_error(ureq::Error::StatusCode(401));
825 assert!(matches!(err, ProviderError::AuthFailed));
826 }
827
828 #[test]
829 fn test_map_ureq_error_403_is_auth_failed() {
830 let err = map_ureq_error(ureq::Error::StatusCode(403));
831 assert!(matches!(err, ProviderError::AuthFailed));
832 }
833
834 #[test]
835 fn test_map_ureq_error_429_is_rate_limited() {
836 let err = map_ureq_error(ureq::Error::StatusCode(429));
837 assert!(matches!(err, ProviderError::RateLimited));
838 }
839
840 #[test]
841 fn test_map_ureq_error_500_is_http() {
842 let err = map_ureq_error(ureq::Error::StatusCode(500));
843 match err {
844 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 500"),
845 other => panic!("expected Http, got {:?}", other),
846 }
847 }
848
849 #[test]
850 fn test_map_ureq_error_404_is_http() {
851 let err = map_ureq_error(ureq::Error::StatusCode(404));
852 match err {
853 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 404"),
854 other => panic!("expected Http, got {:?}", other),
855 }
856 }
857
858 #[test]
859 fn test_map_ureq_error_502_is_http() {
860 let err = map_ureq_error(ureq::Error::StatusCode(502));
861 match err {
862 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 502"),
863 other => panic!("expected Http, got {:?}", other),
864 }
865 }
866
867 #[test]
868 fn test_map_ureq_error_503_is_http() {
869 let err = map_ureq_error(ureq::Error::StatusCode(503));
870 match err {
871 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 503"),
872 other => panic!("expected Http, got {:?}", other),
873 }
874 }
875
876 #[test]
877 fn test_map_ureq_error_200_is_http() {
878 let err = map_ureq_error(ureq::Error::StatusCode(200));
880 match err {
881 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 200"),
882 other => panic!("expected Http, got {:?}", other),
883 }
884 }
885
886 #[test]
887 fn test_map_ureq_error_non_status_is_http() {
888 let err = map_ureq_error(ureq::Error::HostNotFound);
890 match err {
891 ProviderError::Http(msg) => assert!(!msg.is_empty()),
892 other => panic!("expected Http, got {:?}", other),
893 }
894 }
895
896 #[test]
897 fn test_map_ureq_error_all_auth_codes_covered() {
898 for code in [400, 402, 405, 406, 407, 408, 409, 410] {
900 let err = map_ureq_error(ureq::Error::StatusCode(code));
901 assert!(
902 matches!(err, ProviderError::Http(_)),
903 "status {} should be Http, not AuthFailed",
904 code
905 );
906 }
907 }
908
909 #[test]
910 fn test_map_ureq_error_only_429_is_rate_limited() {
911 for code in [428, 430, 431] {
913 let err = map_ureq_error(ureq::Error::StatusCode(code));
914 assert!(
915 !matches!(err, ProviderError::RateLimited),
916 "status {} should not be RateLimited",
917 code
918 );
919 }
920 }
921
922 #[test]
923 fn test_map_ureq_error_io_error() {
924 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
925 let err = map_ureq_error(ureq::Error::Io(io_err));
926 match err {
927 ProviderError::Http(msg) => assert!(msg.contains("refused"), "got: {}", msg),
928 other => panic!("expected Http, got {:?}", other),
929 }
930 }
931
932 #[test]
933 fn test_map_ureq_error_timeout() {
934 let err = map_ureq_error(ureq::Error::Timeout(ureq::Timeout::Global));
935 match err {
936 ProviderError::Http(msg) => assert!(!msg.is_empty()),
937 other => panic!("expected Http, got {:?}", other),
938 }
939 }
940
941 #[test]
942 fn test_map_ureq_error_connection_failed() {
943 let err = map_ureq_error(ureq::Error::ConnectionFailed);
944 match err {
945 ProviderError::Http(msg) => assert!(!msg.is_empty()),
946 other => panic!("expected Http, got {:?}", other),
947 }
948 }
949
950 #[test]
951 fn test_map_ureq_error_bad_uri() {
952 let err = map_ureq_error(ureq::Error::BadUri("no scheme".to_string()));
953 match err {
954 ProviderError::Http(msg) => assert!(msg.contains("no scheme"), "got: {}", msg),
955 other => panic!("expected Http, got {:?}", other),
956 }
957 }
958
959 #[test]
960 fn test_map_ureq_error_too_many_redirects() {
961 let err = map_ureq_error(ureq::Error::TooManyRedirects);
962 match err {
963 ProviderError::Http(msg) => assert!(!msg.is_empty()),
964 other => panic!("expected Http, got {:?}", other),
965 }
966 }
967
968 #[test]
969 fn test_map_ureq_error_redirect_failed() {
970 let err = map_ureq_error(ureq::Error::RedirectFailed);
971 match err {
972 ProviderError::Http(msg) => assert!(!msg.is_empty()),
973 other => panic!("expected Http, got {:?}", other),
974 }
975 }
976
977 #[test]
978 fn test_map_ureq_error_all_status_codes_1xx_to_5xx() {
979 for code in [
981 100, 200, 201, 301, 302, 400, 401, 403, 404, 429, 500, 502, 503, 504,
982 ] {
983 let err = map_ureq_error(ureq::Error::StatusCode(code));
984 match code {
985 401 | 403 => assert!(
986 matches!(err, ProviderError::AuthFailed),
987 "status {} should be AuthFailed",
988 code
989 ),
990 429 => assert!(
991 matches!(err, ProviderError::RateLimited),
992 "status {} should be RateLimited",
993 code
994 ),
995 _ => assert!(
996 matches!(err, ProviderError::Http(_)),
997 "status {} should be Http",
998 code
999 ),
1000 }
1001 }
1002 }
1003
1004 #[test]
1010 fn test_http_get_json_response() {
1011 let mut server = mockito::Server::new();
1012 let mock = server
1013 .mock("GET", "/api/test")
1014 .with_status(200)
1015 .with_header("content-type", "application/json")
1016 .with_body(r#"{"name": "test-server", "id": 42}"#)
1017 .create();
1018
1019 let agent = http_agent();
1020 let mut resp = agent
1021 .get(&format!("{}/api/test", server.url()))
1022 .call()
1023 .unwrap();
1024
1025 #[derive(serde::Deserialize)]
1026 struct TestResp {
1027 name: String,
1028 id: u32,
1029 }
1030
1031 let body: TestResp = resp.body_mut().read_json().unwrap();
1032 assert_eq!(body.name, "test-server");
1033 assert_eq!(body.id, 42);
1034 mock.assert();
1035 }
1036
1037 #[test]
1038 fn test_http_get_with_bearer_header() {
1039 let mut server = mockito::Server::new();
1040 let mock = server
1041 .mock("GET", "/api/hosts")
1042 .match_header("Authorization", "Bearer my-secret-token")
1043 .with_status(200)
1044 .with_header("content-type", "application/json")
1045 .with_body(r#"{"hosts": []}"#)
1046 .create();
1047
1048 let agent = http_agent();
1049 let resp = agent
1050 .get(&format!("{}/api/hosts", server.url()))
1051 .header("Authorization", "Bearer my-secret-token")
1052 .call();
1053
1054 assert!(resp.is_ok());
1055 mock.assert();
1056 }
1057
1058 #[test]
1059 fn test_http_get_with_custom_header() {
1060 let mut server = mockito::Server::new();
1061 let mock = server
1062 .mock("GET", "/api/servers")
1063 .match_header("X-Auth-Token", "scw-token-123")
1064 .with_status(200)
1065 .with_header("content-type", "application/json")
1066 .with_body(r#"{"servers": []}"#)
1067 .create();
1068
1069 let agent = http_agent();
1070 let resp = agent
1071 .get(&format!("{}/api/servers", server.url()))
1072 .header("X-Auth-Token", "scw-token-123")
1073 .call();
1074
1075 assert!(resp.is_ok());
1076 mock.assert();
1077 }
1078
1079 #[test]
1080 fn test_http_401_maps_to_auth_failed() {
1081 let mut server = mockito::Server::new();
1082 let mock = server
1083 .mock("GET", "/api/test")
1084 .with_status(401)
1085 .with_body("Unauthorized")
1086 .create();
1087
1088 let agent = http_agent();
1089 let err = agent
1090 .get(&format!("{}/api/test", server.url()))
1091 .call()
1092 .unwrap_err();
1093
1094 let provider_err = map_ureq_error(err);
1095 assert!(matches!(provider_err, ProviderError::AuthFailed));
1096 mock.assert();
1097 }
1098
1099 #[test]
1100 fn test_http_403_maps_to_auth_failed() {
1101 let mut server = mockito::Server::new();
1102 let mock = server
1103 .mock("GET", "/api/test")
1104 .with_status(403)
1105 .with_body("Forbidden")
1106 .create();
1107
1108 let agent = http_agent();
1109 let err = agent
1110 .get(&format!("{}/api/test", server.url()))
1111 .call()
1112 .unwrap_err();
1113
1114 let provider_err = map_ureq_error(err);
1115 assert!(matches!(provider_err, ProviderError::AuthFailed));
1116 mock.assert();
1117 }
1118
1119 #[test]
1120 fn test_http_429_maps_to_rate_limited() {
1121 let mut server = mockito::Server::new();
1122 let mock = server
1123 .mock("GET", "/api/test")
1124 .with_status(429)
1125 .with_body("Too Many Requests")
1126 .create();
1127
1128 let agent = http_agent();
1129 let err = agent
1130 .get(&format!("{}/api/test", server.url()))
1131 .call()
1132 .unwrap_err();
1133
1134 let provider_err = map_ureq_error(err);
1135 assert!(matches!(provider_err, ProviderError::RateLimited));
1136 mock.assert();
1137 }
1138
1139 #[test]
1140 fn test_http_500_maps_to_http_error() {
1141 let mut server = mockito::Server::new();
1142 let mock = server
1143 .mock("GET", "/api/test")
1144 .with_status(500)
1145 .with_body("Internal Server Error")
1146 .create();
1147
1148 let agent = http_agent();
1149 let err = agent
1150 .get(&format!("{}/api/test", server.url()))
1151 .call()
1152 .unwrap_err();
1153
1154 let provider_err = map_ureq_error(err);
1155 match provider_err {
1156 ProviderError::Http(msg) => assert_eq!(msg, "HTTP 500"),
1157 other => panic!("expected Http, got {:?}", other),
1158 }
1159 mock.assert();
1160 }
1161
1162 #[test]
1163 fn test_http_post_form_encoding() {
1164 let mut server = mockito::Server::new();
1165 let mock = server
1166 .mock("POST", "/oauth/token")
1167 .match_header("content-type", "application/x-www-form-urlencoded")
1168 .match_body(
1169 "grant_type=client_credentials&client_id=my-app&client_secret=secret123&scope=api",
1170 )
1171 .with_status(200)
1172 .with_header("content-type", "application/json")
1173 .with_body(r#"{"access_token": "eyJ.abc.def"}"#)
1174 .create();
1175
1176 let agent = http_agent();
1177 let client_id = "my-app".to_string();
1178 let client_secret = "secret123".to_string();
1179 let mut resp = agent
1180 .post(&format!("{}/oauth/token", server.url()))
1181 .send_form([
1182 ("grant_type", "client_credentials"),
1183 ("client_id", client_id.as_str()),
1184 ("client_secret", client_secret.as_str()),
1185 ("scope", "api"),
1186 ])
1187 .unwrap();
1188
1189 #[derive(serde::Deserialize)]
1190 struct TokenResp {
1191 access_token: String,
1192 }
1193
1194 let body: TokenResp = resp.body_mut().read_json().unwrap();
1195 assert_eq!(body.access_token, "eyJ.abc.def");
1196 mock.assert();
1197 }
1198
1199 #[test]
1200 fn test_http_read_to_string() {
1201 let mut server = mockito::Server::new();
1202 let mock = server
1203 .mock("GET", "/api/xml")
1204 .with_status(200)
1205 .with_header("content-type", "text/xml")
1206 .with_body("<root><item>hello</item></root>")
1207 .create();
1208
1209 let agent = http_agent();
1210 let mut resp = agent
1211 .get(&format!("{}/api/xml", server.url()))
1212 .call()
1213 .unwrap();
1214
1215 let body = resp.body_mut().read_to_string().unwrap();
1216 assert_eq!(body, "<root><item>hello</item></root>");
1217 mock.assert();
1218 }
1219
1220 #[test]
1221 fn test_http_body_reader_with_take() {
1222 use std::io::Read;
1224
1225 let mut server = mockito::Server::new();
1226 let mock = server
1227 .mock("GET", "/download")
1228 .with_status(200)
1229 .with_body("binary-content-here-12345")
1230 .create();
1231
1232 let agent = http_agent();
1233 let mut resp = agent
1234 .get(&format!("{}/download", server.url()))
1235 .call()
1236 .unwrap();
1237
1238 let mut bytes = Vec::new();
1239 resp.body_mut()
1240 .as_reader()
1241 .take(1_048_576)
1242 .read_to_end(&mut bytes)
1243 .unwrap();
1244
1245 assert_eq!(bytes, b"binary-content-here-12345");
1246 mock.assert();
1247 }
1248
1249 #[test]
1250 fn test_http_body_reader_take_truncates() {
1251 use std::io::Read;
1253
1254 let mut server = mockito::Server::new();
1255 let mock = server
1256 .mock("GET", "/large")
1257 .with_status(200)
1258 .with_body("abcdefghijklmnopqrstuvwxyz")
1259 .create();
1260
1261 let agent = http_agent();
1262 let mut resp = agent
1263 .get(&format!("{}/large", server.url()))
1264 .call()
1265 .unwrap();
1266
1267 let mut bytes = Vec::new();
1268 resp.body_mut()
1269 .as_reader()
1270 .take(10) .read_to_end(&mut bytes)
1272 .unwrap();
1273
1274 assert_eq!(bytes, b"abcdefghij");
1275 mock.assert();
1276 }
1277
1278 #[test]
1279 fn test_http_no_redirects() {
1280 let mut server = mockito::Server::new();
1284 let redirect_mock = server
1285 .mock("GET", "/redirect")
1286 .with_status(302)
1287 .with_header("Location", "/target")
1288 .create();
1289 let target_mock = server.mock("GET", "/target").with_status(200).create();
1290
1291 let agent = http_agent();
1292 let resp = agent
1293 .get(&format!("{}/redirect", server.url()))
1294 .call()
1295 .unwrap();
1296
1297 assert_eq!(resp.status(), 302);
1298 redirect_mock.assert();
1299 target_mock.expect(0); }
1301
1302 #[test]
1303 fn test_http_invalid_json_returns_parse_error() {
1304 let mut server = mockito::Server::new();
1305 let mock = server
1306 .mock("GET", "/api/bad")
1307 .with_status(200)
1308 .with_header("content-type", "application/json")
1309 .with_body("this is not json")
1310 .create();
1311
1312 let agent = http_agent();
1313 let mut resp = agent
1314 .get(&format!("{}/api/bad", server.url()))
1315 .call()
1316 .unwrap();
1317
1318 #[derive(serde::Deserialize)]
1319 #[allow(dead_code)]
1320 struct Expected {
1321 name: String,
1322 }
1323
1324 let result: Result<Expected, _> = resp.body_mut().read_json();
1325 assert!(result.is_err());
1326 mock.assert();
1327 }
1328
1329 #[test]
1330 fn test_http_empty_json_body_returns_parse_error() {
1331 let mut server = mockito::Server::new();
1332 let mock = server
1333 .mock("GET", "/api/empty")
1334 .with_status(200)
1335 .with_header("content-type", "application/json")
1336 .with_body("")
1337 .create();
1338
1339 let agent = http_agent();
1340 let mut resp = agent
1341 .get(&format!("{}/api/empty", server.url()))
1342 .call()
1343 .unwrap();
1344
1345 #[derive(serde::Deserialize)]
1346 #[allow(dead_code)]
1347 struct Expected {
1348 name: String,
1349 }
1350
1351 let result: Result<Expected, _> = resp.body_mut().read_json();
1352 assert!(result.is_err());
1353 mock.assert();
1354 }
1355
1356 #[test]
1357 fn test_http_multiple_headers() {
1358 let mut server = mockito::Server::new();
1360 let mock = server
1361 .mock("GET", "/api/aws")
1362 .match_header("Authorization", "AWS4-HMAC-SHA256 cred=test")
1363 .match_header("x-amz-date", "20260324T120000Z")
1364 .with_status(200)
1365 .with_header("content-type", "text/xml")
1366 .with_body("<result/>")
1367 .create();
1368
1369 let agent = http_agent();
1370 let mut resp = agent
1371 .get(&format!("{}/api/aws", server.url()))
1372 .header("Authorization", "AWS4-HMAC-SHA256 cred=test")
1373 .header("x-amz-date", "20260324T120000Z")
1374 .call()
1375 .unwrap();
1376
1377 let body = resp.body_mut().read_to_string().unwrap();
1378 assert_eq!(body, "<result/>");
1379 mock.assert();
1380 }
1381
1382 #[test]
1383 fn test_http_connection_refused_maps_to_http_error() {
1384 let agent = http_agent();
1386 let err = agent.get("http://127.0.0.1:1").call().unwrap_err();
1387
1388 let provider_err = map_ureq_error(err);
1389 match provider_err {
1390 ProviderError::Http(msg) => assert!(!msg.is_empty()),
1391 other => panic!("expected Http, got {:?}", other),
1392 }
1393 }
1394
1395 #[test]
1396 fn test_http_nested_json_deserialization() {
1397 let mut server = mockito::Server::new();
1399 let mock = server
1400 .mock("GET", "/api/droplets")
1401 .with_status(200)
1402 .with_header("content-type", "application/json")
1403 .with_body(
1404 r#"{
1405 "data": [
1406 {"id": "1", "name": "web-01", "ip": "1.2.3.4"},
1407 {"id": "2", "name": "web-02", "ip": "5.6.7.8"}
1408 ],
1409 "meta": {"total": 2}
1410 }"#,
1411 )
1412 .create();
1413
1414 #[derive(serde::Deserialize)]
1415 #[allow(dead_code)]
1416 struct Host {
1417 id: String,
1418 name: String,
1419 ip: String,
1420 }
1421 #[derive(serde::Deserialize)]
1422 #[allow(dead_code)]
1423 struct Meta {
1424 total: u32,
1425 }
1426 #[derive(serde::Deserialize)]
1427 #[allow(dead_code)]
1428 struct Resp {
1429 data: Vec<Host>,
1430 meta: Meta,
1431 }
1432
1433 let agent = http_agent();
1434 let mut resp = agent
1435 .get(&format!("{}/api/droplets", server.url()))
1436 .call()
1437 .unwrap();
1438
1439 let body: Resp = resp.body_mut().read_json().unwrap();
1440 assert_eq!(body.data.len(), 2);
1441 assert_eq!(body.data[0].name, "web-01");
1442 assert_eq!(body.data[1].ip, "5.6.7.8");
1443 assert_eq!(body.meta.total, 2);
1444 mock.assert();
1445 }
1446
1447 #[test]
1448 fn test_http_xml_deserialization_with_quick_xml() {
1449 let mut server = mockito::Server::new();
1451 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1452 <DescribeInstancesResponse>
1453 <reservationSet>
1454 <item>
1455 <instancesSet>
1456 <item>
1457 <instanceId>i-abc123</instanceId>
1458 <instanceState><name>running</name></instanceState>
1459 </item>
1460 </instancesSet>
1461 </item>
1462 </reservationSet>
1463 </DescribeInstancesResponse>"#;
1464
1465 let mock = server
1466 .mock("GET", "/ec2")
1467 .with_status(200)
1468 .with_header("content-type", "text/xml")
1469 .with_body(xml)
1470 .create();
1471
1472 let agent = http_agent();
1473 let mut resp = agent.get(&format!("{}/ec2", server.url())).call().unwrap();
1474
1475 let body = resp.body_mut().read_to_string().unwrap();
1476 #[derive(serde::Deserialize)]
1478 struct InstanceState {
1479 name: String,
1480 }
1481 #[derive(serde::Deserialize)]
1482 struct Instance {
1483 #[serde(rename = "instanceId")]
1484 instance_id: String,
1485 #[serde(rename = "instanceState")]
1486 instance_state: InstanceState,
1487 }
1488 #[derive(serde::Deserialize)]
1489 struct InstanceSet {
1490 item: Vec<Instance>,
1491 }
1492 #[derive(serde::Deserialize)]
1493 struct Reservation {
1494 #[serde(rename = "instancesSet")]
1495 instances_set: InstanceSet,
1496 }
1497 #[derive(serde::Deserialize)]
1498 struct ReservationSet {
1499 item: Vec<Reservation>,
1500 }
1501 #[derive(serde::Deserialize)]
1502 struct DescribeResp {
1503 #[serde(rename = "reservationSet")]
1504 reservation_set: ReservationSet,
1505 }
1506
1507 let parsed: DescribeResp = quick_xml::de::from_str(&body).unwrap();
1508 assert_eq!(
1509 parsed.reservation_set.item[0].instances_set.item[0].instance_id,
1510 "i-abc123"
1511 );
1512 assert_eq!(
1513 parsed.reservation_set.item[0].instances_set.item[0]
1514 .instance_state
1515 .name,
1516 "running"
1517 );
1518 mock.assert();
1519 }
1520}