ffs_cli/providers/aws/
mod.rs

1use std::io::prelude::*;
2use std::net::TcpStream;
3use std::path::Path;
4
5use async_trait::async_trait;
6use aws_config::meta::region::RegionProviderChain;
7use aws_sdk_ec2::client::Waiters;
8use aws_sdk_ec2::config::Region;
9use aws_sdk_ec2::types::{InstanceType, ResourceType, Tag, TagSpecification};
10use aws_sdk_ec2::Client;
11use ssh2::Session;
12
13use super::Provider;
14use crate::config::Config;
15use crate::jobs::Job;
16
17#[derive(Clone)]
18pub struct AWSProvider {}
19
20impl Default for AWSProvider {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl AWSProvider {
27    #[must_use]
28    pub const fn new() -> Self {
29        Self {}
30    }
31}
32
33#[async_trait]
34impl Provider for AWSProvider {
35    async fn start_job(&self, name: &str) -> Result<Job, Box<dyn std::error::Error + Send + Sync>> {
36        let cfg = config();
37
38        let region_provider =
39            RegionProviderChain::first_try(Some(Region::new(cfg.location.clone())))
40                .or_default_provider();
41        let shared_config = aws_config::from_env().region(region_provider).load().await;
42        let client = Client::new(&shared_config);
43
44        let tag_spec = TagSpecification::builder()
45            .resource_type(ResourceType::Instance)
46            .tags(Tag::builder().key("Name").value(name).build())
47            .build();
48
49        let run_out = client
50            .run_instances()
51            .image_id(cfg.image.clone())
52            .instance_type(InstanceType::from(cfg.server_type.as_str()))
53            .min_count(1)
54            .max_count(1)
55            .key_name(cfg.ssh_key_name.clone())
56            .tag_specifications(tag_spec)
57            .send()
58            .await?;
59
60        let instance = run_out.instances().first().ok_or("no instance created")?;
61
62        let instance_id = instance
63            .instance_id()
64            .ok_or("missing instance id")?
65            .to_string();
66
67        // Wait until instance is running
68        let _ = client
69            .wait_until_instance_running()
70            .instance_ids(instance_id.clone())
71            .wait(std::time::Duration::from_secs(120))
72            .await;
73
74        let desc = client
75            .describe_instances()
76            .instance_ids(instance_id.clone())
77            .send()
78            .await?;
79
80        let ipv4 = desc
81            .reservations()
82            .first()
83            .and_then(|res| res.instances().first())
84            .and_then(|inst| inst.public_ip_address())
85            .unwrap_or("")
86            .to_string();
87
88        let job = Job {
89            id: instance_id.clone(),
90            ipv4: ipv4.clone(),
91            name: Some(name.to_string()),
92        };
93
94        let key_path = cfg.ssh_key_path.clone();
95        tokio::spawn(async move {
96            let _ = tokio::task::spawn_blocking(move || super::install_over_ssh(&ipv4, &key_path))
97                .await;
98        });
99
100        Ok(job)
101    }
102
103    async fn get_job(
104        &self,
105        job_id: &str,
106    ) -> Result<Option<Job>, Box<dyn std::error::Error + Send + Sync>> {
107        let cfg = config();
108        let region_provider =
109            RegionProviderChain::first_try(Some(Region::new(cfg.location.clone())))
110                .or_default_provider();
111        let shared_config = aws_config::from_env().region(region_provider).load().await;
112        let client = Client::new(&shared_config);
113
114        let desc = client
115            .describe_instances()
116            .instance_ids(job_id)
117            .send()
118            .await?;
119
120        if let Some(reservation) = desc.reservations().first() {
121            if let Some(instance) = reservation.instances().first() {
122                let name_tag = instance
123                    .tags()
124                    .iter()
125                    .find(|t| t.key() == Some("Name"))
126                    .and_then(|t| t.value())
127                    .map(ToString::to_string);
128
129                return Ok(Some(Job {
130                    id: instance.instance_id().unwrap_or_default().to_string(),
131                    ipv4: instance.public_ip_address().unwrap_or_default().to_string(),
132                    name: name_tag,
133                }));
134            }
135        }
136
137        Ok(None)
138    }
139
140    async fn stop_job(
141        &self,
142        job_id: &str,
143    ) -> Result<Job, Box<dyn std::error::Error + Send + Sync>> {
144        let cfg = config();
145        let region_provider =
146            RegionProviderChain::first_try(Some(Region::new(cfg.location.clone())))
147                .or_default_provider();
148        let shared_config = aws_config::from_env().region(region_provider).load().await;
149        let client = Client::new(&shared_config);
150
151        client
152            .terminate_instances()
153            .instance_ids(job_id)
154            .send()
155            .await?;
156
157        Ok(Job {
158            id: job_id.to_string(),
159            ipv4: String::new(),
160            name: None,
161        })
162    }
163
164    async fn list_jobs(&self) -> Result<Vec<Job>, Box<dyn std::error::Error + Send + Sync>> {
165        let cfg = config();
166        let region_provider =
167            RegionProviderChain::first_try(Some(Region::new(cfg.location.clone())))
168                .or_default_provider();
169        let shared_config = aws_config::from_env().region(region_provider).load().await;
170        let client = Client::new(&shared_config);
171
172        let desc = client.describe_instances().send().await?;
173
174        let mut jobs = Vec::new();
175        for reservation in desc.reservations() {
176            for instance in reservation.instances() {
177                let name_tag = instance
178                    .tags()
179                    .iter()
180                    .find(|t| t.key() == Some("Name"))
181                    .and_then(|t| t.value())
182                    .map(ToString::to_string);
183
184                jobs.push(Job {
185                    id: instance.instance_id().unwrap_or_default().to_string(),
186                    ipv4: instance.public_ip_address().unwrap_or_default().to_string(),
187                    name: name_tag,
188                });
189            }
190        }
191
192        Ok(jobs)
193    }
194
195    async fn tail(
196        &self,
197        job_id: &str,
198        filename: &str,
199    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
200        let cfg = config();
201
202        if let Some(job) = self.get_job(job_id).await? {
203            let tcp = TcpStream::connect((job.ipv4.as_str(), 22))?;
204            let mut sess = Session::new()?;
205            sess.set_tcp_stream(tcp);
206            sess.handshake()?;
207            sess.userauth_pubkey_file("root", None, Path::new(&cfg.ssh_key_path), None)?;
208
209            let mut channel = sess.channel_session()?;
210            channel.exec(&format!("cat {filename}"))?;
211
212            let mut s = String::new();
213            channel.read_to_string(&mut s)?;
214            println!("{s}");
215            channel.wait_close()?;
216            println!("{}", channel.exit_status()?);
217        }
218
219        Ok(())
220    }
221
222    async fn scp(
223        &self,
224        _job_id: &str,
225        _filename: &str,
226        _destination: &str,
227    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
228        Ok(())
229    }
230}
231
232fn config() -> Config {
233    Config::new()
234}