1use color_eyre::{
95 eyre::{self, eyre, WrapErr},
96 Report,
97};
98use educe::Educe;
99use itertools::Itertools;
100use rusoto_core::credential::{DefaultCredentialsProvider, ProvideAwsCredentials};
101use rusoto_core::request::HttpClient;
102pub use rusoto_core::Region;
103use rusoto_ec2::Ec2;
104use std::collections::HashMap;
105use std::future::Future;
106use std::io::Write;
107use std::pin::Pin;
108use std::sync::Arc;
109use std::time;
110use tracing::instrument;
111use tracing_futures::Instrument;
112
113#[derive(Debug, Clone)]
115#[allow(missing_copy_implementations)]
116#[non_exhaustive]
117pub enum LaunchMode {
118 DefinedDuration {
121 hours: usize,
124 },
125 TrySpot {
128 hours: usize,
131 },
132 OnDemand,
134}
135
136impl LaunchMode {
137 pub fn duration_spot(hours: usize) -> Self {
142 let hours = std::cmp::min(hours, 6);
143 let hours = std::cmp::max(hours, 1);
144 Self::DefinedDuration { hours }
145 }
146
147 pub fn try_duration_spot(hours: usize) -> Self {
154 match Self::duration_spot(hours) {
155 Self::DefinedDuration { hours } => Self::TrySpot { hours },
156 _ => unreachable!(),
157 }
158 }
159
160 pub fn on_demand() -> Self {
162 Self::OnDemand
163 }
164}
165
166#[derive(Debug, Clone)]
170pub enum AvailabilityZoneSpec {
171 Any,
173 Cluster(usize),
177 Specify(String),
181}
182
183impl Default for AvailabilityZoneSpec {
184 fn default() -> Self {
185 Self::Any
186 }
187}
188
189impl std::fmt::Display for AvailabilityZoneSpec {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 match *self {
192 AvailabilityZoneSpec::Any => write!(f, "any"),
193 AvailabilityZoneSpec::Cluster(ref i) => write!(f, "cluster({})", i),
194 AvailabilityZoneSpec::Specify(ref s) => write!(f, "{}", s),
195 }
196 }
197}
198
199#[derive(Clone, Educe)]
208#[educe(Debug)]
209pub struct Setup {
210 region: Region,
211 availability_zone: AvailabilityZoneSpec,
212 instance_type: String,
213 ami: String,
214 username: String,
215 #[educe(Debug(ignore))]
216 setup_fn: Option<
217 Arc<
218 dyn for<'r> Fn(
219 &'r crate::Machine<'_>,
220 )
221 -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'r>>
222 + Send
223 + Sync
224 + 'static,
225 >,
226 >,
227}
228
229impl super::MachineSetup for Setup {
230 type Region = String;
231
232 fn region(&self) -> Self::Region {
233 match self.availability_zone {
234 AvailabilityZoneSpec::Specify(ref id) => format!("{}-{}", self.region.name(), id),
235 AvailabilityZoneSpec::Cluster(id) => format!("{}-{}", self.region.name(), id),
236 AvailabilityZoneSpec::Any => self.region.name().to_string(),
237 }
238 }
239}
240
241impl Default for Setup {
242 fn default() -> Self {
243 Setup {
244 region: Region::UsEast1,
245 availability_zone: AvailabilityZoneSpec::Any,
246 instance_type: "t3.small".into(),
247 ami: String::from("ami-085925f297f89fce1"),
248 username: "ubuntu".into(),
249 setup_fn: None,
250 }
251 }
252}
253
254impl Setup {
255 pub async fn region_with_ubuntu_ami(mut self, region: Region) -> Result<Self, Report> {
266 self.region = region.clone();
267 let ami: String = UbuntuAmi::new(region).await?.into();
268 Ok(self.ami(ami, "ubuntu"))
269 }
270
271 pub fn username(self, username: impl ToString) -> Self {
276 Self {
277 username: username.to_string(),
278 ..self
279 }
280 }
281
282 pub fn ami(self, ami: impl ToString, username: impl ToString) -> Self {
285 Self {
286 ami: ami.to_string(),
287 username: username.to_string(),
288 ..self
289 }
290 }
291
292 pub fn instance_type(mut self, typ: impl ToString) -> Self {
297 self.instance_type = typ.to_string();
298 self
299 }
300
301 pub fn setup(
326 mut self,
327 setup: impl for<'r> Fn(
328 &'r crate::Machine<'_>,
329 ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'r>>
330 + Send
331 + Sync
332 + 'static,
333 ) -> Self {
334 self.setup_fn = Some(Arc::new(setup));
335 self
336 }
337
338 pub fn region(mut self, region: Region, ami: impl ToString, username: impl ToString) -> Self {
348 self.region = region;
349 self.ami(ami, username)
350 }
351
352 pub fn availability_zone(self, az: AvailabilityZoneSpec) -> Self {
357 Self {
358 availability_zone: az,
359 ..self
360 }
361 }
362}
363
364#[derive(Educe)]
377#[educe(Debug)]
378pub struct Launcher<P = DefaultCredentialsProvider> {
379 #[educe(Debug(ignore))]
380 credential_provider: Box<dyn Fn() -> Result<P, Report> + Send + Sync>,
381 mode: LaunchMode,
382 use_open_ports: bool,
383 regions: HashMap<<Setup as super::MachineSetup>::Region, RegionLauncher>,
384}
385
386impl Default for Launcher {
387 fn default() -> Self {
388 Launcher {
389 credential_provider: Box::new(|| Ok(DefaultCredentialsProvider::new()?)),
390 mode: LaunchMode::DefinedDuration { hours: 6 },
391 use_open_ports: false,
392 regions: Default::default(),
393 }
394 }
395}
396
397impl<P> Launcher<P> {
398 #[deprecated(note = "prefer set_mode")]
407 pub fn set_max_instance_duration(&mut self, t: usize) -> &mut Self {
408 self.mode = LaunchMode::duration_spot(t);
409 self
410 }
411
412 pub fn set_mode(&mut self, mode: LaunchMode) -> &mut Self {
416 self.mode = mode;
417 self
418 }
419
420 pub fn open_ports(&mut self) -> &mut Self {
423 self.use_open_ports = true;
424 self
425 }
426
427 pub fn with_credentials<P2>(
433 self,
434 f: impl Fn() -> Result<P2, Report> + Send + Sync + 'static,
435 ) -> Launcher<P2> {
436 Launcher {
437 credential_provider: Box::new(f),
438 mode: self.mode,
439 use_open_ports: self.use_open_ports,
440 regions: self.regions,
441 }
442 }
443}
444
445impl<P> super::Launcher for Launcher<P>
446where
447 P: ProvideAwsCredentials + Send + Sync + 'static,
448{
449 type MachineDescriptor = Setup;
450
451 #[instrument(level = "debug", skip(self))]
452 fn launch<'l>(
453 &'l mut self,
454 l: super::LaunchDescriptor<Self::MachineDescriptor>,
455 ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'l>> {
456 Box::pin(async move {
457 let prov = (*self.credential_provider)()?;
458 let Self {
459 use_open_ports,
460 mode,
461 ref mut regions,
462 ..
463 } = self;
464
465 if !regions.contains_key(&l.region) {
466 let region_span = tracing::debug_span!("new_region", name = %l.region, az = %l.machines[0].1.availability_zone);
467 let awsregion = RegionLauncher::new(
468 l.machines[0].1.region.name(),
471 l.machines[0].1.availability_zone.clone(),
472 prov,
473 *use_open_ports,
474 )
475 .instrument(region_span)
476 .await?;
477 regions.insert(l.region.clone(), awsregion);
478 }
479
480 let region_span = tracing::debug_span!("region", name = %l.region);
481 regions
482 .get_mut(&l.region)
483 .unwrap()
484 .launch(mode.clone(), l.max_wait, l.machines)
485 .instrument(region_span)
486 .await?;
487 Ok(())
488 }.in_current_span())
489 }
490
491 #[instrument(level = "debug", skip(self, max_wait))]
492 fn spawn<'l, I>(
493 &'l mut self,
494 descriptors: I,
495 max_wait: Option<std::time::Duration>,
496 ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'l>>
497 where
498 I: IntoIterator<Item = (String, Self::MachineDescriptor)> + Send + 'static,
499 I: std::fmt::Debug,
500 I::IntoIter: Send,
501 {
502 use super::MachineSetup;
503 Box::pin(
504 async move {
505 tracing::info!("spinning up tsunami");
506
507 let names_to_setups = descriptors
509 .into_iter()
510 .map(|(name, setup)| (MachineSetup::region(&setup), (name, setup)))
511 .into_group_map();
512
513 let (mut haves, have_nots): (Vec<_>, Vec<_>) = names_to_setups
517 .into_iter()
518 .partition(|(region_name, _)| self.regions.contains_key(region_name));
519
520 let _prov = (*self.credential_provider)()?;
522 let use_open_ports = self.use_open_ports;
523
524 let newly_initialized: Vec<Result<_, _>> =
525 futures_util::future::join_all(have_nots.iter().map(|(region_name, s)| {
526 let region_span = tracing::debug_span!("new_region", region = %region_name);
527 let prov = (*self.credential_provider)().unwrap();
528 async move {
529 let awsregion = RegionLauncher::new(
530 s[0].1.region.name(),
533 s[0].1.availability_zone.clone(),
534 prov,
535 use_open_ports,
536 )
537 .await?;
538 Ok::<_, Report>((region_name.clone(), awsregion))
539 }
540 .instrument(region_span)
541 }))
542 .await;
543 self.regions.extend(
544 newly_initialized
545 .into_iter()
546 .collect::<Result<Vec<_>, _>>()?,
547 );
548
549 haves.extend(have_nots);
551
552 let max_wait = max_wait;
559 let regions = futures_util::future::join_all(haves.into_iter().map(
560 |(region_name, machines)| {
561 let mut region_launcher = self.regions.remove(®ion_name).unwrap();
563 let region_span = tracing::debug_span!("region", region = %region_name);
564 let mode = self.mode.clone();
565 async move {
566 if let Err(e) = region_launcher.launch(mode, max_wait, machines).await {
567 Err((region_name, region_launcher, e))
568 } else {
569 Ok((region_name, region_launcher))
570 }
571 }
572 .instrument(region_span)
573 },
574 ))
575 .await;
576
577 let (regions, res) =
579 regions
580 .into_iter()
581 .fold((vec![], None), |acc, r| match (acc, r) {
582 ((mut rs, x), Ok((name, rl))) => {
583 rs.push((name, rl));
584 (rs, x)
585 }
586 ((mut rs, None), Err((name, rl, e))) => {
587 rs.push((name, rl));
588 (rs, Some(e))
589 }
590 ((mut rs, x @ Some(_)), Err((name, rl, _))) => {
591 rs.push((name, rl));
592 (rs, x)
593 }
594 });
595 self.regions.extend(regions.into_iter());
596
597 if let Some(e) = res {
598 Err(e)
599 } else {
600 Ok(())
601 }
602 }
603 .in_current_span(),
604 )
605 }
606
607 #[instrument(level = "debug")]
608 fn connect_all<'l>(
609 &'l self,
610 ) -> Pin<
611 Box<dyn Future<Output = Result<HashMap<String, crate::Machine<'l>>, Report>> + Send + 'l>,
612 > {
613 Box::pin(async move { collect!(self.regions) }.in_current_span())
614 }
615
616 #[instrument(level = "debug")]
617 fn terminate_all(mut self) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send>> {
618 Box::pin(
619 async move {
620 if self.regions.is_empty() {
621 return Ok(());
622 }
623
624 let res =
625 futures_util::future::join_all(self.regions.drain().map(|(region, mut rl)| {
626 let region_span = tracing::debug_span!("region", %region);
627 async move { rl.terminate_all().await }.instrument(region_span)
628 }))
629 .await;
630 res.into_iter().fold(Ok(()), |acc, x| match acc {
631 Ok(_) => x,
632 Err(a) => match x {
633 Ok(_) => Err(a),
634 Err(e) => Err(a.wrap_err(e)),
635 },
636 })
637 }
638 .in_current_span(),
639 )
640 }
641}
642
643#[derive(Debug, Clone)]
644struct IpInfo {
645 public_dns: String,
646 public_ip: String,
647 private_ip: String,
648}
649
650#[derive(Debug, Clone)]
654struct TaggedSetup {
655 name: String,
656 setup: Setup,
657 ip_info: Option<IpInfo>,
658}
659
660#[derive(Educe, Default)]
673#[educe(Debug)]
674pub struct RegionLauncher {
675 pub region: rusoto_core::region::Region,
677 availability_zone: AvailabilityZoneSpec,
678 security_group_id: String,
679 ssh_key_name: String,
680 private_key_path: Option<tempfile::NamedTempFile>,
681 #[educe(Debug(ignore))]
682 client: Option<rusoto_ec2::Ec2Client>,
683 spot_requests: HashMap<String, TaggedSetup>,
684 instances: HashMap<String, TaggedSetup>,
685}
686
687impl RegionLauncher {
688 pub async fn new<P>(
694 region: &str,
695 availability_zone: AvailabilityZoneSpec,
696 provider: P,
697 use_open_ports: bool,
698 ) -> Result<Self, Report>
699 where
700 P: ProvideAwsCredentials + Send + Sync + 'static,
701 {
702 let region = region.parse()?;
703 let ec2 = RegionLauncher::connect(region, availability_zone, provider)
704 .wrap_err("failed to connect to region")?
705 .make_security_group(use_open_ports)
706 .await
707 .wrap_err("failed to make security groups")?
708 .make_ssh_key()
709 .await
710 .wrap_err("failed to make ssh key")?;
711
712 Ok(ec2)
713 }
714
715 #[instrument(level = "debug", skip(provider))]
716 fn connect<P>(
717 region: rusoto_core::region::Region,
718 availability_zone: AvailabilityZoneSpec,
719 provider: P,
720 ) -> Result<Self, Report>
721 where
722 P: ProvideAwsCredentials + Send + Sync + 'static,
723 {
724 tracing::debug!("connecting to ec2");
725 let ec2 = rusoto_ec2::Ec2Client::new_with(
726 HttpClient::new().wrap_err("failed to construct new http client")?,
727 provider,
728 region.clone(),
729 );
730
731 Ok(Self {
732 region,
733 availability_zone,
734 security_group_id: Default::default(),
735 ssh_key_name: Default::default(),
736 private_key_path: Some(
737 tempfile::NamedTempFile::new()
738 .wrap_err("failed to create temporary file for keypair")?,
739 ),
740 spot_requests: Default::default(),
741 instances: Default::default(),
742 client: Some(ec2),
743 })
744 }
745
746 #[instrument(level = "debug", skip(self, max_wait))]
751 pub async fn launch<M>(
752 &mut self,
753 mode: LaunchMode,
754 mut max_wait: Option<time::Duration>,
755 machines: M,
756 ) -> Result<(), Report>
757 where
758 M: IntoIterator<Item = (String, Setup)> + std::fmt::Debug,
759 {
760 let machines: Vec<_> = machines.into_iter().collect();
761 let machines = machines.clone();
762 let mut do_ondemand = false;
763 match mode {
764 LaunchMode::TrySpot {
765 hours: max_instance_duration_hours,
766 }
767 | LaunchMode::DefinedDuration {
768 hours: max_instance_duration_hours,
769 } => {
770 let machines = machines.clone();
771
772 self.make_spot_instance_requests(
775 max_instance_duration_hours * 60, machines,
777 )
778 .await
779 .wrap_err("failed to make spot instance requests")?;
780
781 let start = time::Instant::now();
782 if let Err(e) = self
783 .wait_for_spot_instance_requests(max_wait)
784 .await
785 .wrap_err(eyre!(
786 "failed while waiting for spot instances fulfilment in {}",
787 self.region.name()
788 ))
789 {
790 if let LaunchMode::TrySpot { .. } = mode {
793 tracing::debug!(err = ?e, "re-trying with OnDemand instace");
794 do_ondemand = true;
795 } else {
796 return Err(e);
797 }
798 } else {
799 if let Some(ref mut d) = max_wait {
800 *d -= time::Instant::now().duration_since(start);
801 }
802 }
803 }
804 LaunchMode::OnDemand => {
805 do_ondemand = true;
806 }
807 }
808
809 if do_ondemand {
810 self.make_on_demand_requests(machines)
811 .await
812 .wrap_err(eyre!(
813 "failed to start on demand instances in {}",
814 self.region.name()
815 ))?;
816
817 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
819 }
820
821 self.wait_for_instances(max_wait)
822 .await
823 .wrap_err("failed while waiting for instances to come up")?;
824 Ok(())
825 }
826
827 #[instrument(level = "trace", skip(self))]
828 async fn make_security_group(mut self, use_open_ports: bool) -> Result<Self, Report> {
829 let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
830
831 let group_name = super::rand_name("security");
833 tracing::debug!(name = %group_name, "creating security group");
834 let req = rusoto_ec2::CreateSecurityGroupRequest {
835 group_name,
836 description: "temporary access group for tsunami VMs".to_string(),
837 ..Default::default()
838 };
839 let res = ec2
840 .create_security_group(req)
841 .await
842 .wrap_err("failed to create security group for new machines")?;
843 let group_id = res
844 .group_id
845 .expect("aws created security group with no group id");
846 tracing::trace!(id = %group_id, "security group created");
847
848 let mut req = rusoto_ec2::AuthorizeSecurityGroupIngressRequest {
849 group_id: Some(group_id.clone()),
850 ip_protocol: Some("icmp".to_string()),
852 from_port: Some(-1),
853 to_port: Some(-1),
854 cidr_ip: Some("0.0.0.0/0".to_string()),
855 ..Default::default()
856 };
857 tracing::trace!("adding icmp access");
858 ec2.authorize_security_group_ingress(req.clone())
859 .await
860 .wrap_err("failed to fill in security group for new machines")?;
861
862 req.ip_protocol = Some("tcp".to_string());
864 req.from_port = Some(22);
865 req.to_port = Some(22);
866 req.cidr_ip = Some("0.0.0.0/0".to_string());
867 tracing::trace!("adding ssh access");
868 ec2.authorize_security_group_ingress(req.clone())
869 .await
870 .wrap_err("failed to fill in security group for new machines")?;
871
872 req.ip_protocol = Some("tcp".to_string());
876 req.from_port = Some(0);
877 req.to_port = Some(65535);
878 if use_open_ports {
879 req.cidr_ip = Some("0.0.0.0/0".to_string());
880 } else {
881 req.cidr_ip = Some("172.31.0.0/16".to_string());
882 }
883
884 tracing::trace!("adding intra-vm tcp access");
885 ec2.authorize_security_group_ingress(req.clone())
886 .await
887 .wrap_err("failed to fill in security group for new machines")?;
888
889 req.ip_protocol = Some("udp".to_string());
890 req.from_port = Some(0);
891 req.to_port = Some(65535);
892 if use_open_ports {
893 req.cidr_ip = Some("0.0.0.0/0".to_string());
894 } else {
895 req.cidr_ip = Some("172.31.0.0/16".to_string());
896 }
897
898 tracing::trace!("adding intra-vm udp access");
899 ec2.authorize_security_group_ingress(req)
900 .await
901 .wrap_err("failed to fill in security group for new machines")?;
902
903 self.security_group_id = group_id;
904 Ok(self)
905 }
906
907 #[instrument(level = "trace", skip(self))]
908 async fn make_ssh_key(mut self) -> Result<Self, Report> {
909 let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
910 let private_key_path = self
911 .private_key_path
912 .as_mut()
913 .expect("RegionLauncher unconnected");
914
915 tracing::debug!("creating keypair");
917 let key_name = super::rand_name("key");
918 let req = rusoto_ec2::CreateKeyPairRequest {
919 key_name: key_name.clone(),
920 ..Default::default()
921 };
922 let res = ec2
923 .create_key_pair(req)
924 .await
925 .context("failed to generate new key pair")?;
926 tracing::trace!(fingerprint = ?res.key_fingerprint, "created keypair");
927
928 let private_key = res
930 .key_material
931 .expect("aws did not generate key material for new key");
932 private_key_path
933 .write_all(private_key.as_bytes())
934 .context("could not write private key to file")?;
935 tracing::debug!(
936 filename = %private_key_path.path().display(),
937 "wrote keypair to file"
938 );
939
940 self.ssh_key_name = key_name;
941 Ok(self)
942 }
943
944 #[instrument(level = "trace", skip(self, mk))]
950 async fn make_placement<R>(
951 &mut self,
952 mk: impl FnOnce(String, Option<String>) -> R,
953 ) -> Result<Option<R>, Report> {
954 if let AvailabilityZoneSpec::Any = self.availability_zone {
955 Ok(None)
956 } else {
957 let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
958 tracing::trace!("creating placement group");
959 let placement_name = super::rand_name("placement");
960 let req = rusoto_ec2::CreatePlacementGroupRequest {
961 group_name: Some(placement_name.clone()),
962 strategy: Some(String::from("cluster")),
963 ..Default::default()
964 };
965 ec2.create_placement_group(req).await?;
966 tracing::trace!("created placement group");
967
968 Ok(Some(mk(
969 placement_name,
970 match self.availability_zone {
971 AvailabilityZoneSpec::Cluster(_) => None,
972 AvailabilityZoneSpec::Specify(ref av) => Some(av.clone()),
973 _ => unreachable!(),
974 },
975 )))
976 }
977 }
978
979 fn for_each_machine_group<M>(
980 machines: M,
981 ) -> impl Iterator<Item = ((String, String), Vec<(String, Setup)>)> + Send
982 where
983 M: IntoIterator<Item = (String, Setup)>,
984 M: std::fmt::Debug,
985 {
986 machines
988 .into_iter()
989 .map(|(name, m)| {
990 ((m.ami.clone(), m.instance_type.clone()), (name, m))
993 })
994 .into_group_map()
995 .into_iter()
996 }
997
998 #[instrument(level = "trace", skip(self))]
999 async fn make_on_demand_requests<M>(&mut self, machines: M) -> Result<(), Report>
1000 where
1001 M: IntoIterator<Item = (String, Setup)>,
1002 M: std::fmt::Debug,
1003 {
1004 tracing::info!("launching on demand instances");
1005
1006 for ((ami, instance_type), reqs) in Self::for_each_machine_group(machines) {
1008 let inst_span = tracing::debug_span!("run_instance", ?ami, ?instance_type);
1009 async {
1010 let placement = self
1012 .make_placement(|group_name, az| rusoto_ec2::Placement {
1013 group_name: Some(group_name),
1014 availability_zone: az,
1015 ..Default::default()
1016 })
1017 .await
1018 .wrap_err("create new placement group")?;
1019 let req = rusoto_ec2::RunInstancesRequest {
1020 image_id: Some(ami),
1021 instance_type: Some(instance_type),
1022 placement,
1023 security_group_ids: Some(vec![self.security_group_id.clone()]),
1024 key_name: Some(self.ssh_key_name.clone()),
1025 min_count: reqs.len() as i64,
1026 max_count: reqs.len() as i64,
1027 instance_initiated_shutdown_behavior: Some("terminate".to_string()),
1028 ..Default::default()
1029 };
1030
1031 tracing::trace!("issuing request");
1034 let res = self
1035 .client
1036 .as_mut()
1037 .unwrap()
1038 .run_instances(req)
1039 .await
1040 .wrap_err("failed to request on demand instances")?;
1041
1042 let instances: Vec<String> = res
1044 .instances
1045 .expect("run_instances should always return instances")
1046 .into_iter()
1047 .filter_map(|i| i.instance_id)
1048 .inspect(|instance_id| {
1049 tracing::trace!(id = %instance_id, "launched on-demand instance");
1050 })
1051 .collect();
1052
1053 eyre::ensure!(
1055 instances.len() == reqs.len(),
1056 "Got {} instances but expected {}",
1057 instances.len(),
1058 reqs.len(),
1059 );
1060
1061 self.instances
1062 .extend(instances.into_iter().zip_eq(reqs.into_iter()).map(
1063 |(instance_id, (name, setup))| {
1064 let setup = TaggedSetup {
1065 name,
1066 setup,
1067 ip_info: None,
1068 };
1069 (instance_id, setup)
1070 },
1071 ));
1072
1073 Ok(())
1074 }
1075 .instrument(inst_span)
1076 .await?;
1077 }
1078
1079 Ok(())
1080 }
1081
1082 #[instrument(level = "trace", skip(self, max_duration))]
1093 async fn make_spot_instance_requests<M>(
1094 &mut self,
1095 max_duration: usize,
1096 machines: M,
1097 ) -> Result<(), Report>
1098 where
1099 M: IntoIterator<Item = (String, Setup)>,
1100 M: std::fmt::Debug,
1101 {
1102 tracing::info!("launching spot requests");
1103
1104 for ((ami, instance_type), reqs) in Self::for_each_machine_group(machines) {
1106 let spot_span = tracing::debug_span!("spot_request", ?ami, ?instance_type);
1107 async {
1108 let placement = self
1110 .make_placement(|group_name, az| rusoto_ec2::SpotPlacement {
1111 group_name: Some(group_name),
1112 availability_zone: az,
1113 ..Default::default()
1114 })
1115 .await
1116 .wrap_err("create new placement group")?;
1117 let launch = rusoto_ec2::RequestSpotLaunchSpecification {
1118 image_id: Some(ami),
1119 instance_type: Some(instance_type),
1120 placement,
1121 security_group_ids: Some(vec![self.security_group_id.clone()]),
1122 key_name: Some(self.ssh_key_name.clone()),
1123 ..Default::default()
1124 };
1125
1126 let req = rusoto_ec2::RequestSpotInstancesRequest {
1129 instance_count: Some(reqs.len() as i64),
1130 block_duration_minutes: Some(max_duration as i64),
1131 launch_specification: Some(launch),
1132 type_: Some("one-time".into()),
1135 ..Default::default()
1136 };
1137
1138 tracing::trace!("issuing spot request");
1139 let res = self
1140 .client
1141 .as_mut()
1142 .unwrap()
1143 .request_spot_instances(req)
1144 .await
1145 .wrap_err("failed to request spot instance")?;
1146
1147 let spot_instance_requests: Vec<String> = res
1149 .spot_instance_requests
1150 .expect("request_spot_instances should always return spot instance requests")
1151 .into_iter()
1152 .filter_map(|sir| sir.spot_instance_request_id)
1153 .inspect(|request_id| {
1154 tracing::trace!(id = %request_id, "activated spot request");
1155 })
1156 .collect();
1157
1158 eyre::ensure!(
1160 spot_instance_requests.len() == reqs.len(),
1161 "Got {} spot instance requests but expected {}",
1162 spot_instance_requests.len(),
1163 reqs.len(),
1164 );
1165
1166 for (request_id, (name, setup)) in
1167 spot_instance_requests.into_iter().zip_eq(reqs.into_iter())
1168 {
1169 self.spot_requests.insert(
1170 request_id,
1171 TaggedSetup {
1172 name,
1173 setup,
1174 ip_info: None,
1175 },
1176 );
1177 }
1178
1179 Ok(())
1180 }
1181 .instrument(spot_span)
1182 .await?;
1183 }
1184
1185 Ok(())
1186 }
1187
1188 #[instrument(level = "trace", skip(self, max_wait))]
1197 async fn wait_for_spot_instance_requests(
1198 &mut self,
1199 max_wait: Option<time::Duration>,
1200 ) -> Result<(), Report> {
1201 tracing::info!("waiting for instances to spawn");
1202
1203 let start = time::Instant::now();
1204
1205 loop {
1206 tracing::trace!("checking spot request status");
1207 let instances = self.describe_spot_instance_requests().await?;
1208
1209 let mut any_pending = false;
1210 for (request_id, state, status, instance_id) in &instances {
1211 match &**state {
1212 "active" if instance_id.is_some() => {
1213 tracing::trace!(%request_id, %state, ?instance_id, "spot instance request ready");
1214 }
1215 "active" | "open" => {
1216 any_pending = true;
1217 }
1218 s => {
1219 let _ = self.cancel_spot_instance_requests().await;
1221 eyre::bail!("spot request unexpectedly {}: {}", s, status);
1222 }
1223 }
1224 }
1225 let all_active = !any_pending;
1226
1227 if all_active {
1228 self.instances = instances
1230 .into_iter()
1231 .map(|(request_id, state, _, instance_id)| {
1232 assert_eq!(state, "active");
1233 let instance_id = instance_id.unwrap();
1234 let setup = self.spot_requests[&request_id].clone();
1235 (instance_id, setup)
1236 })
1237 .collect();
1238 break;
1239 }
1240
1241 tokio::time::sleep(time::Duration::from_secs(1)).await;
1243
1244 if let Some(wait_limit) = max_wait {
1245 if start.elapsed() <= wait_limit {
1246 continue;
1247 }
1248 self.cancel_spot_instance_requests().await?;
1249 eyre::bail!("wait limit reached");
1250 }
1251 }
1252
1253 Ok(())
1254 }
1255
1256 #[instrument(level = "trace", skip(self, max_wait))]
1258 async fn wait_for_instances(&mut self, max_wait: Option<time::Duration>) -> Result<(), Report> {
1259 let start = time::Instant::now();
1260 let desc_req = rusoto_ec2::DescribeInstancesRequest {
1261 instance_ids: Some(self.instances.keys().cloned().collect()),
1262 ..Default::default()
1263 };
1264 let client = self.client.as_ref().unwrap();
1265 let private_key_path = self.private_key_path.as_ref().unwrap();
1266 let mut all_ready = self.instances.is_empty();
1267 while !all_ready {
1268 all_ready = true;
1269
1270 for reservation in client
1271 .describe_instances(desc_req.clone())
1272 .await
1273 .wrap_err("could not query AWS for instance state")?
1274 .reservations
1275 .unwrap_or_else(Vec::new)
1276 {
1277 for instance in reservation.instances.unwrap_or_else(Vec::new) {
1278 match instance {
1279 rusoto_ec2::Instance {
1282 state: Some(rusoto_ec2::InstanceState { code: Some(16), .. }),
1283 instance_id: Some(instance_id),
1284 public_dns_name: Some(public_dns),
1285 public_ip_address: Some(public_ip),
1286 private_ip_address: Some(private_ip),
1287 ..
1288 } => {
1289 let instance_span =
1290 tracing::debug_span!("instance", %instance_id, ip = %public_ip);
1291 let instances = &mut self.instances;
1292 async {
1293 tracing::trace!("instance running");
1294
1295 let tag_setup = instances.get_mut(&instance_id).unwrap();
1297
1298 let m = crate::MachineDescriptor {
1300 nickname: Default::default(),
1301 public_dns: Default::default(),
1302 public_ip: public_ip.to_string(),
1303 private_ip: Default::default(),
1304 _tsunami: Default::default(),
1305 };
1306
1307 if let Err(e) = m
1308 .connect_ssh(
1309 &tag_setup.setup.username,
1310 Some(private_key_path.path()),
1311 max_wait,
1312 22,
1313 )
1314 .await
1315 {
1316 tracing::trace!("ssh failed: {}", e);
1317 all_ready = false;
1318 } else {
1319 tracing::debug!("instance ready");
1320
1321 tag_setup.ip_info = Some(IpInfo {
1322 public_dns: public_dns.clone(),
1323 public_ip: public_ip.clone(),
1324 private_ip: private_ip.clone(),
1325 });
1326 }
1327 }
1328 .instrument(instance_span)
1329 .await
1330 }
1331 _ => {
1332 all_ready = false;
1333 }
1334 }
1335 }
1336 }
1337
1338 tokio::time::sleep(time::Duration::from_secs(1)).await;
1340
1341 if let Some(wait_limit) = max_wait {
1342 if start.elapsed() <= wait_limit {
1343 continue;
1344 }
1345 self.cancel_spot_instance_requests().await?;
1346 eyre::bail!("wait limit reached");
1347 }
1348 }
1349
1350 futures_util::future::join_all(self.instances.iter().map(
1351 |(
1352 instance_id,
1353 TaggedSetup {
1354 ip_info,
1355 name,
1356 setup,
1357 },
1358 )| {
1359 let IpInfo {
1360 public_dns,
1361 public_ip,
1362 private_ip,
1363 } = ip_info.as_ref().unwrap();
1364 let instance_span = tracing::debug_span!("instance", %instance_id, ip = %public_ip);
1365 async move {
1366 if let Setup {
1367 username,
1368 setup_fn: Some(f),
1369 ..
1370 } = setup
1371 {
1372 super::setup_machine(
1373 &name,
1374 Some(&public_dns),
1375 &public_ip,
1376 Some(&private_ip),
1377 &username,
1378 max_wait,
1379 Some(private_key_path.path()),
1380 f.as_ref(),
1381 )
1382 .await?;
1383 }
1384
1385 Ok(())
1386 }
1387 .instrument(instance_span)
1388 },
1389 ))
1390 .await
1391 .into_iter()
1392 .collect()
1393 }
1394
1395 #[instrument(level = "debug")]
1398 pub async fn connect_all<'l>(&'l self) -> Result<HashMap<String, crate::Machine<'l>>, Report> {
1399 let private_key_path = self.private_key_path.as_ref().unwrap();
1400 futures_util::future::join_all(self.instances.values().map(|info| {
1401 let instance_span = tracing::trace_span!("instance", name = %info.name);
1402 async move {
1403 match info {
1404 TaggedSetup {
1405 name,
1406 setup: Setup { username, .. },
1407 ip_info:
1408 Some(IpInfo {
1409 public_dns,
1410 public_ip,
1411 private_ip,
1412 }),
1413 } => {
1414 let m = crate::MachineDescriptor {
1415 public_dns: Some(public_dns.clone()),
1416 public_ip: public_ip.clone(),
1417 private_ip: Some(private_ip.clone()),
1418 nickname: name.clone(),
1419 _tsunami: Default::default(),
1420 };
1421
1422 let m = m
1423 .connect_ssh(&username, Some(private_key_path.path()), None, 22)
1424 .await?;
1425 Ok((name.clone(), m))
1426 }
1427 _ => eyre::bail!("machine has no ip information"),
1428 }
1429 }
1430 .instrument(instance_span)
1431 }))
1432 .await
1433 .into_iter()
1434 .collect()
1435 }
1436
1437 #[instrument(level = "debug")]
1447 pub async fn terminate_all(&mut self) -> Result<(), Report> {
1448 let client = self.client.as_ref().unwrap();
1449
1450 if !self.ssh_key_name.trim().is_empty() {
1451 let key_span = tracing::trace_span!("key", name = %self.ssh_key_name);
1452 async {
1453 tracing::trace!("removing keypair");
1454 let req = rusoto_ec2::DeleteKeyPairRequest {
1455 key_name: Some(self.ssh_key_name.clone()),
1456 ..Default::default()
1457 };
1458 if let Err(e) = client.delete_key_pair(req).await {
1459 tracing::warn!("failed to clean up temporary SSH key: {}", e);
1460 }
1461 }
1462 .instrument(key_span)
1463 .await;
1464 }
1465
1466 if !self.instances.is_empty() {
1468 tracing::info!("terminating instances");
1469 let instance_ids = self.instances.keys().cloned().collect();
1470 self.instances.clear();
1471 self.terminate_instances(instance_ids).await?;
1477 }
1478
1479 use rusoto_core::RusotoError;
1480 if !self.security_group_id.trim().is_empty() {
1481 let group_span =
1482 tracing::trace_span!("removing security group", id = %self.security_group_id);
1483 async {
1484 tracing::trace!("removing security group.");
1485 let start = tokio::time::Instant::now();
1487 loop {
1488 if start.elapsed() > tokio::time::Duration::from_secs(5 * 60) {
1489 return Err(Report::msg(
1490 "failed to clean up temporary security group after 5 minutes.",
1491 ));
1492 }
1493
1494 let req = rusoto_ec2::DeleteSecurityGroupRequest {
1495 group_id: Some(self.security_group_id.clone()),
1496 ..Default::default()
1497 };
1498 match client.delete_security_group(req).await {
1499 Ok(_) => break,
1500 Err(RusotoError::Unknown(r)) => {
1501 let err = r.body_as_str();
1502 if err.contains("<Code>DependencyViolation</Code>") {
1503 tracing::trace!("instances not yet shut down -- retrying");
1504 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
1505 } else {
1506 Err(Report::new(RusotoError::<
1507 rusoto_ec2::DeleteSecurityGroupError,
1508 >::Unknown(r)))
1509 .wrap_err("failed to clean up temporary security group")?;
1510 unreachable!();
1511 }
1512 }
1513 Err(e) => {
1514 return Err(Report::new(e)
1515 .wrap_err("failed to clean up temporary security group"));
1516 }
1517 }
1518 }
1519
1520 tracing::trace!("cleaned up temporary security group");
1521 Ok::<_, Report>(())
1522 }
1523 .instrument(group_span)
1524 .await?;
1525 }
1526
1527 Ok(())
1528 }
1529
1530 #[instrument(level = "debug")]
1531 async fn describe_spot_instance_requests(
1532 &self,
1533 ) -> Result<Vec<(String, String, String, Option<String>)>, Report> {
1534 let client = self.client.as_ref().unwrap();
1535 let request_ids = self.spot_requests.keys().cloned().collect();
1536 let req = rusoto_ec2::DescribeSpotInstanceRequestsRequest {
1537 spot_instance_request_ids: Some(request_ids),
1538 ..Default::default()
1539 };
1540 loop {
1541 let res = client.describe_spot_instance_requests(req.clone()).await;
1542 if let Err(ref e) = res {
1543 let msg = e.to_string();
1544 if msg.contains("The spot instance request ID") && msg.contains("does not exist") {
1545 tracing::trace!("spot instance requests not yet ready");
1546
1547 tokio::time::sleep(time::Duration::from_secs(1)).await;
1549 continue;
1550 } else {
1551 res.wrap_err("failed to describe spot instances")?;
1552 unreachable!();
1553 }
1554 }
1555
1556 let res = res.expect("Err checked above");
1557 let instances = res
1558 .spot_instance_requests
1559 .expect("describe always returns at least one spot instance")
1560 .into_iter()
1561 .map(|sir| {
1562 let request_id = sir
1563 .spot_instance_request_id
1564 .expect("spot request did not have id specified");
1565 let state = sir
1566 .state
1567 .expect("spot request did not have state specified");
1568 let status = sir
1569 .status
1570 .expect("spot request did not have status specified")
1571 .code
1572 .expect("spot request status did not have status code");
1573 let instance_id = sir.instance_id;
1574 (request_id, state, status, instance_id)
1575 })
1576 .collect();
1577 break Ok(instances);
1578 }
1579 }
1580
1581 #[instrument(level = "debug")]
1582 async fn cancel_spot_instance_requests(&self) -> Result<(), Report> {
1583 tracing::warn!("wait time exceeded for -- cancelling run");
1584 if self.spot_requests.is_empty() {
1585 return Ok(());
1586 }
1587 let request_ids = self.spot_requests.keys().cloned().collect();
1588 let cancel = rusoto_ec2::CancelSpotInstanceRequestsRequest {
1589 spot_instance_request_ids: request_ids,
1590 ..Default::default()
1591 };
1592 self.client
1593 .as_ref()
1594 .unwrap()
1595 .cancel_spot_instance_requests(cancel)
1596 .await
1597 .wrap_err("failed to cancel spot instances")?;
1598
1599 tracing::trace!("spot instances cancelled -- waiting for cancellation");
1600 loop {
1601 tracing::trace!("checking spot request status");
1602 let instances = self.describe_spot_instance_requests().await?;
1603
1604 let all_cancelled = instances.iter().all(|(request_id, state, _, instance_id)| {
1605 if state == "closed"
1606 || state == "cancelled"
1607 || state == "failed"
1608 || state == "completed"
1609 {
1610 tracing::trace!(%request_id, ?instance_id, "spot instance request {}", state);
1611 true
1612 } else {
1613 false
1614 }
1615 });
1616
1617 if all_cancelled {
1618 tracing::trace!("deleting spot instances");
1619 let instance_ids = instances
1621 .into_iter()
1622 .filter_map(|(_, _, _, instance_id)| instance_id)
1623 .collect();
1624 self.terminate_instances(instance_ids).await?;
1625 break;
1626 }
1627
1628 tokio::time::sleep(time::Duration::from_secs(1)).await;
1630 }
1631 Ok(())
1632 }
1633
1634 #[instrument(level = "debug")]
1635 async fn terminate_instances(&self, instance_ids: Vec<String>) -> Result<(), Report> {
1636 if instance_ids.is_empty() {
1637 return Ok(());
1638 }
1639 let client = self.client.as_ref().unwrap();
1640 let termination_req = rusoto_ec2::TerminateInstancesRequest {
1641 instance_ids,
1642 ..Default::default()
1643 };
1644 while let Err(e) = client.terminate_instances(termination_req.clone()).await {
1645 let msg = e.to_string();
1646 if msg.contains("Pooled stream disconnected") || msg.contains("broken pipe") {
1647 tracing::trace!("retrying instance termination");
1648 continue;
1649 } else {
1650 Err(e).wrap_err("failed to terminate tsunami instances")?;
1651 unreachable!();
1652 }
1653 }
1654 Ok(())
1655 }
1656}
1657
1658struct UbuntuAmi(String);
1659
1660impl UbuntuAmi {
1661 async fn new(r: Region) -> Result<Self, Report> {
1662 Ok(UbuntuAmi(
1663 ubuntu_ami::get_latest(
1664 &r.name(),
1665 Some("bionic"),
1666 None,
1667 Some("hvm:ebs-ssd"),
1668 Some("amd64"),
1669 )
1670 .await
1671 .map_err(|e| eyre!(e))?,
1672 ))
1673 }
1674}
1675
1676impl From<UbuntuAmi> for String {
1677 fn from(s: UbuntuAmi) -> String {
1678 s.0
1679 }
1680}
1681
1682#[cfg(test)]
1683mod test {
1684 use super::RegionLauncher;
1685 use super::*;
1686 use rusoto_core::credential::DefaultCredentialsProvider;
1687 use rusoto_core::region::Region;
1688 use rusoto_ec2::Ec2;
1689 use std::future::Future;
1690
1691 fn do_make_machine_and_ssh_setupfn<'l>(
1692 l: &'l mut super::Launcher,
1693 ) -> impl Future<Output = Result<(), Report>> + 'l {
1694 use crate::providers::Launcher;
1695 async move {
1696 l.spawn(
1697 vec![(
1698 String::from("my machine"),
1699 super::Setup::default().setup(|vm| {
1700 Box::pin(async move {
1701 if vm.ssh.command("whoami").status().await?.success() {
1702 Ok(())
1703 } else {
1704 Err(eyre!("failed"))
1705 }
1706 })
1707 }),
1708 )],
1709 None,
1710 )
1711 .await?;
1712 let vms = l.connect_all().await?;
1713 let my_machine = vms
1714 .get("my machine")
1715 .ok_or_else(|| eyre!("machine not found"))?;
1716 my_machine
1717 .ssh
1718 .command("echo")
1719 .arg("\"Hello, EC2\"")
1720 .status()
1721 .await?;
1722
1723 Ok(())
1724 }
1725 }
1726
1727 #[test]
1728 #[ignore]
1729 fn make_machine_and_ssh_setupfn() {
1730 use crate::providers::Launcher;
1731 tracing_subscriber::fmt::init();
1732 let rt = tokio::runtime::Runtime::new().unwrap();
1733 let mut l = super::Launcher::default();
1734 l.set_mode(LaunchMode::duration_spot(1));
1736 rt.block_on(async move {
1737 if let Err(e) = do_make_machine_and_ssh_setupfn(&mut l).await {
1738 l.terminate_all().await.unwrap();
1740 panic!(e);
1741 } else {
1742 l.terminate_all().await.unwrap();
1743 }
1744 })
1745 }
1746
1747 #[test]
1748 #[ignore]
1749 fn make_key() -> Result<(), Report> {
1750 let rt = tokio::runtime::Runtime::new().unwrap();
1751 let region = Region::UsEast1;
1752 let provider = DefaultCredentialsProvider::new()?;
1753 let ec2 = RegionLauncher::connect(region, super::AvailabilityZoneSpec::Any, provider)?;
1754 rt.block_on(async {
1755 let mut ec2 = ec2.make_ssh_key().await?;
1756 tracing::debug!(
1757 name = %ec2.ssh_key_name,
1758 path = %ec2.private_key_path.as_ref().unwrap().path().display(),
1759 "made key"
1760 );
1761 assert!(!ec2.ssh_key_name.is_empty());
1762 assert!(ec2.private_key_path.as_ref().unwrap().path().exists());
1763
1764 let mut req = rusoto_ec2::DeleteKeyPairRequest::default();
1765 req.key_name = Some(ec2.ssh_key_name.clone());
1766 ec2.client
1767 .as_mut()
1768 .unwrap()
1769 .delete_key_pair(req)
1770 .await
1771 .context(format!(
1772 "Could not delete ssh key pair {:?}",
1773 ec2.ssh_key_name
1774 ))?;
1775
1776 Ok(())
1777 })
1778 }
1779
1780 fn do_multi_instance_spot_request<'l>(
1781 ec2: &'l mut super::RegionLauncher,
1782 ) -> impl Future<Output = Result<(), Report>> + 'l {
1783 async move {
1784 let names = (1..).map(|x| format!("{}", x));
1785 let setup = Setup::default();
1786 let ms: Vec<(String, Setup)> = names.zip(itertools::repeat_n(setup, 5)).collect();
1787
1788 tracing::debug!(num = %ms.len(), "make spot instance requests");
1789 ec2.make_spot_instance_requests(60 as _, ms).await?;
1790 assert_eq!(ec2.spot_requests.len(), 5);
1791 tracing::debug!("wait for spot instance requests");
1792 ec2.wait_for_spot_instance_requests(None).await?;
1793
1794 Ok(())
1795 }
1796 }
1797
1798 #[test]
1799 #[ignore]
1800 fn multi_instance_spot_request() -> Result<(), Report> {
1801 let region = "us-east-1";
1802 let provider = DefaultCredentialsProvider::new()?;
1803
1804 let rt = tokio::runtime::Runtime::new().unwrap();
1805 rt.block_on(async {
1806 let mut ec2 =
1807 RegionLauncher::new(region, super::AvailabilityZoneSpec::Any, provider, false)
1808 .await?;
1809
1810 if let Err(e) = do_multi_instance_spot_request(&mut ec2).await {
1811 ec2.terminate_all().await.unwrap();
1812 panic!(e);
1813 } else {
1814 ec2.terminate_all().await.unwrap();
1815 }
1816
1817 Ok(())
1818 })
1819 }
1820}