use crate::vmm::resources::{MountSpec, SymlinkPolicy, VolumeSpec};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TsiToken {
pub hex: String,
pub bytes: [u8; 32],
}
pub fn parse_tsi_token(s: &str) -> Result<TsiToken, String> {
if s.len() != 64 || !s.bytes().all(|b| b.is_ascii_hexdigit()) {
return Err(format!(
"--tsi-token: expected 64 hex chars (32 bytes), got {} chars",
s.len()
));
}
let hex = s.to_ascii_lowercase();
let mut bytes = [0u8; 32];
for (i, b) in bytes.iter_mut().enumerate() {
*b = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
.map_err(|e| format!("--tsi-token: malformed hex at byte {i}: {e}"))?;
}
Ok(TsiToken { hex, bytes })
}
impl TsiToken {
pub fn generate() -> std::io::Result<Self> {
use std::io::Read;
let mut bytes = [0u8; 32];
std::fs::File::open("/dev/urandom")?.read_exact(&mut bytes)?;
Ok(Self {
hex: hex_lower(&bytes),
bytes,
})
}
}
pub fn hex_lower(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push(char::from_digit((b >> 4) as u32, 16).unwrap());
s.push(char::from_digit((b & 0xf) as u32, 16).unwrap());
}
s
}
pub fn append_tsi_token_cmdline(cmdline: &mut String, hex: &str) {
if !cmdline.is_empty() && !cmdline.ends_with(' ') {
cmdline.push(' ');
}
cmdline.push_str("supermachine.tsi_token=");
cmdline.push_str(hex);
}
pub fn parse_volume_spec(raw: &str) -> Result<VolumeSpec, String> {
let parts: Vec<&str> = raw.splitn(3, ':').collect();
if parts.len() < 2 {
return Err(format!(
"--volume expects HOST:GUEST[:SIZE_BYTES], got {raw:?}"
));
}
let mut spec = VolumeSpec::new(parts[0], parts[1]);
if let Some(s) = parts.get(2) {
let sz = s
.parse::<u64>()
.map_err(|_| format!("--volume SIZE_BYTES not a u64: {s:?}"))?;
spec = spec.with_size_bytes(sz);
}
Ok(spec)
}
pub fn parse_mount_spec(raw: &str) -> Result<MountSpec, String> {
let parts: Vec<&str> = raw.splitn(4, ':').collect();
let (host, tag, guest_path, policy_str) = match parts.len() {
3 => (parts[0], parts[1], parts[2], None),
4 => (parts[0], parts[1], parts[2], Some(parts[3])),
_ => {
return Err(format!(
"--mount expects HOST:TAG:GUEST_PATH[:POLICY], got {raw:?}"
))
}
};
if tag.is_empty() {
return Err(format!("--mount tag is empty: {raw:?}"));
}
if tag.len() > 35 {
return Err(format!(
"--mount tag too long (max 35 bytes, got {}): {raw:?}",
tag.len()
));
}
if guest_path.is_empty() {
return Err(format!("--mount guest_path is empty: {raw:?}"));
}
if !guest_path.starts_with('/') {
return Err(format!(
"--mount guest_path must be absolute (start with `/`), got {guest_path:?}"
));
}
let policy = match policy_str {
None => SymlinkPolicy::default(),
Some("deny") => SymlinkPolicy::Deny,
Some("opaque") => SymlinkPolicy::Opaque,
Some("follow") => SymlinkPolicy::Follow,
Some(other) => {
return Err(format!(
"--mount policy must be one of: deny, opaque, follow (got {other:?})"
))
}
};
Ok(MountSpec::new(host, tag, guest_path).with_symlinks(policy))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tsi_token_decodes_64_hex_to_32_bytes() {
let hex = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
let t = parse_tsi_token(hex).unwrap();
assert_eq!(t.hex, hex);
assert_eq!(t.bytes[0], 0x00);
assert_eq!(t.bytes[1], 0x11);
assert_eq!(t.bytes[15], 0xff);
assert_eq!(t.bytes[31], 0xff);
}
#[test]
fn tsi_token_uppercase_is_canonicalised_to_lowercase() {
let up = "AABBCCDDEEFF00112233445566778899AABBCCDDEEFF00112233445566778899";
let t = parse_tsi_token(up).unwrap();
assert_eq!(t.hex, up.to_ascii_lowercase());
assert_eq!(t.bytes[0], 0xaa);
assert_eq!(parse_tsi_token(&t.hex).unwrap().bytes, t.bytes);
}
#[test]
fn tsi_token_all_zero_is_valid() {
let z = "0".repeat(64);
assert_eq!(parse_tsi_token(&z).unwrap().bytes, [0u8; 32]);
}
#[test]
fn hex_lower_renders_each_byte_as_two_lowercase_digits() {
assert_eq!(hex_lower(&[]), "");
assert_eq!(hex_lower(&[0x00, 0x0f, 0xa5, 0xff]), "000fa5ff");
let bytes: [u8; 32] = std::array::from_fn(|i| (i * 7 + 1) as u8);
let hex = hex_lower(&bytes);
assert_eq!(hex.len(), 64);
assert_eq!(parse_tsi_token(&hex).unwrap().bytes, bytes);
}
#[test]
fn generate_yields_canonical_distinct_tokens() {
let a = TsiToken::generate().expect("urandom");
let b = TsiToken::generate().expect("urandom");
assert_eq!(a.hex.len(), 64);
assert_eq!(a.hex, hex_lower(&a.bytes));
assert_eq!(parse_tsi_token(&a.hex).unwrap().bytes, a.bytes);
assert_ne!(a.bytes, b.bytes);
}
#[test]
fn tsi_token_rejects_bad_length() {
assert!(parse_tsi_token("").is_err());
assert!(parse_tsi_token("abcd").is_err()); assert!(parse_tsi_token(&"a".repeat(63)).is_err()); assert!(parse_tsi_token(&"a".repeat(65)).is_err());
}
#[test]
fn tsi_token_rejects_non_hex() {
let mut s = "a".repeat(63);
s.push('z');
assert!(parse_tsi_token(&s).is_err());
let s2 = format!("{} ", "a".repeat(63));
assert!(parse_tsi_token(&s2).is_err());
}
#[test]
fn append_token_inserts_space_only_when_needed() {
let mut c = String::from("console=ttyAMA0 root=/dev/vda");
append_tsi_token_cmdline(&mut c, "deadbeef");
assert_eq!(
c,
"console=ttyAMA0 root=/dev/vda supermachine.tsi_token=deadbeef"
);
let mut c2 = String::from("trailing ");
append_tsi_token_cmdline(&mut c2, "ab");
assert_eq!(c2, "trailing supermachine.tsi_token=ab");
let mut c3 = String::new();
append_tsi_token_cmdline(&mut c3, "ff");
assert_eq!(c3, "supermachine.tsi_token=ff"); }
#[test]
fn volume_two_and_three_field_forms() {
let v = parse_volume_spec("/host/db:/var/lib/db").unwrap();
assert_eq!(v.host_path, "/host/db");
assert_eq!(v.guest_path, "/var/lib/db");
let v3 = parse_volume_spec("/h:/g:4096").unwrap();
assert_eq!(v3.size_bytes, 4096);
}
#[test]
fn volume_rejects_missing_guest_and_bad_size() {
assert!(parse_volume_spec("/only-host").is_err());
assert!(parse_volume_spec("").is_err());
assert!(parse_volume_spec("/h:/g:notanum").is_err());
}
#[test]
fn mount_three_field_defaults_to_opaque() {
let m = parse_mount_spec("/host/share:work:/mnt/work").unwrap();
assert_eq!(m.host_path, "/host/share");
assert_eq!(m.guest_tag, "work");
assert_eq!(m.guest_path, "/mnt/work");
assert_eq!(m.symlinks, SymlinkPolicy::Opaque);
}
#[test]
fn mount_policy_each_value_parses() {
assert_eq!(
parse_mount_spec("/h:t:/g:deny").unwrap().symlinks,
SymlinkPolicy::Deny
);
assert_eq!(
parse_mount_spec("/h:t:/g:opaque").unwrap().symlinks,
SymlinkPolicy::Opaque
);
assert_eq!(
parse_mount_spec("/h:t:/g:follow").unwrap().symlinks,
SymlinkPolicy::Follow
);
}
#[test]
fn mount_rejects_malformed() {
assert!(parse_mount_spec("/h:t").is_err(), "too few fields");
assert!(parse_mount_spec("/h::/g").is_err(), "empty tag");
assert!(
parse_mount_spec(&format!("/h:{}:/g", "x".repeat(36))).is_err(),
"tag > 35 bytes"
);
assert!(
parse_mount_spec("/h:t:relative").is_err(),
"non-absolute guest path"
);
assert!(parse_mount_spec("/h:t:/g:bogus").is_err(), "unknown policy");
}
}