use clap::{Parser, Subcommand};
use freta::{argparse::parse_key_val, Client, Error, Image, ImageFormat, Result};
use powershell_script::PsScriptBuilder;
use serde::Deserialize;
use std::{io::stderr, path::PathBuf};
use tracing::{info, level_filters::LevelFilter};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
#[derive(Parser)]
struct Args {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
List,
ImageVm(ImageOpt),
}
#[derive(Parser)]
struct ImageOpt {
vm_name: String,
#[clap(long, value_name = "KEY=VALUE", value_parser = parse_key_val::<String, String>, action = clap::ArgAction::Append)]
tags: Option<Vec<(String, String)>>,
#[arg(long)]
monitor: bool,
}
#[derive(Deserialize)]
struct Snapshot {
#[serde(alias = "Id")]
id: String,
#[serde(alias = "Path")]
path: PathBuf,
}
fn run<Q>(query: Q) -> Result<String>
where
Q: AsRef<str>,
{
let ps = PsScriptBuilder::new()
.no_profile(true)
.non_interactive(true)
.hidden(true)
.print_commands(false)
.build();
let output = ps
.run(query.as_ref())
.map_err(|e| Error::Other("launching powershell failed", format!("{e:?}")))?;
if !output.success() {
return Err(Error::Other(
"command failed",
output
.stderr()
.or_else(|| output.stdout())
.unwrap_or_else(|| "unknown error".to_string()),
));
}
Ok(output.stdout().unwrap_or_default())
}
#[derive(Deserialize, Debug)]
struct Entry {
#[serde(alias = "VMName")]
name: String,
#[serde(alias = "VMId")]
id: Uuid,
}
#[derive(Deserialize, Debug)]
struct Entries(Vec<Entry>);
fn list_vms() -> Result<Entries> {
let out = run("get-vm | select vmname, vmid, state | where state -eq 'running' | select vmname,vmid | convertto-json")?;
let entries = if let Ok(entry) = serde_json::from_str::<Entry>(&out) {
Entries(vec![entry])
} else {
serde_json::from_str::<Entries>(&out)?
};
Ok(entries)
}
fn get_vm_id(vm_name: &str) -> Result<Uuid> {
for vm in list_vms()?.0 {
if vm.name == vm_name {
return Ok(vm.id);
}
}
Err(Error::Other(
"unable to find running VM",
vm_name.to_string(),
))
}
async fn create_snapshot(
vm_name: String,
mut tags: Vec<(String, String)>,
client: &Client,
) -> Result<Image> {
let vm_id = get_vm_id(&vm_name)?;
let snapshot_id = Uuid::new_v4();
info!("creating hyperv snapshot id: {}", snapshot_id);
run(format!(
"get-vm -id {vm_id} | checkpoint-vm -snapshotname {snapshot_id}"
))?;
let output = run(format!(
"get-vm -id {vm_id} | get-vmsnapshot -name {snapshot_id} | select id, path | convertto-json"
))?;
let snapshot: Snapshot = serde_json::from_str(&output)?;
let path = snapshot
.path
.join("Snapshots")
.join(format!("{}.VMRS", snapshot.id));
tags.push(("name".to_string(), vm_name.clone()));
let image = client.images_upload(ImageFormat::Vmrs, tags, path).await?;
info!("image_id: {}", image.image_id);
run(format!(
"get-vm -id {vm_id} | get-vmsnapshot -name {snapshot_id} | remove-vmsnapshot"
))?;
Ok(image)
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env()
.map_err(|e| Error::Other("invalid env filter", e.to_string()))?,
)
.with_writer(stderr)
.init();
let cmd = Args::parse();
let client = Client::new().await?;
match cmd.command {
Commands::List => {
let vms = list_vms()?;
for vm in vms.0 {
info!("{}", vm.name);
}
}
Commands::ImageVm(opts) => {
let image =
create_snapshot(opts.vm_name, opts.tags.unwrap_or_default(), &client).await?;
if opts.monitor {
client.images_monitor(image.image_id).await?;
}
}
}
Ok(())
}