use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SasPermissions {
pub read: bool,
pub write: bool,
pub delete: bool,
pub list: bool,
pub add: bool,
pub create: bool,
}
impl SasPermissions {
pub fn read_only() -> Self {
Self {
read: true,
write: false,
delete: false,
list: false,
add: false,
create: false,
}
}
pub fn read_write() -> Self {
Self {
read: true,
write: true,
delete: false,
list: false,
add: false,
create: false,
}
}
pub fn full() -> Self {
Self {
read: true,
write: true,
delete: true,
list: true,
add: true,
create: true,
}
}
pub fn as_permission_string(&self) -> String {
let mut s = String::with_capacity(6);
if self.read {
s.push('r');
}
if self.write {
s.push('w');
}
if self.delete {
s.push('d');
}
if self.list {
s.push('l');
}
if self.add {
s.push('a');
}
if self.create {
s.push('c');
}
s
}
pub fn from_permission_string(s: &str) -> Self {
Self {
read: s.contains('r'),
write: s.contains('w'),
delete: s.contains('d'),
list: s.contains('l'),
add: s.contains('a'),
create: s.contains('c'),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SasResource {
Blob,
Container,
Queue,
Table,
}
impl SasResource {
pub fn as_code(&self) -> &'static str {
match self {
SasResource::Blob => "b",
SasResource::Container => "c",
SasResource::Queue => "q",
SasResource::Table => "t",
}
}
}
#[derive(Debug, Clone)]
pub struct AzureSasParams {
pub account_name: String,
pub container: String,
pub blob: Option<String>,
pub permissions: SasPermissions,
pub expiry: u64,
pub start: Option<u64>,
pub resource: SasResource,
}
#[derive(Debug, thiserror::Error)]
pub enum AzureError {
#[error("parse error: {0}")]
Parse(String),
#[error("missing field: {0}")]
MissingField(String),
#[error("expired token")]
Expired,
}
const SAS_VERSION: &str = "2021-06-08";
pub fn generate_sas_token(params: &AzureSasParams, account_key: &[u8]) -> String {
let expiry_str = unix_to_iso8601(params.expiry);
let start_str = params.start.map(unix_to_iso8601);
let perm_str = params.permissions.as_permission_string();
let resource_code = params.resource.as_code();
let signed_resource = match ¶ms.blob {
Some(blob) => format!("{}/{}/{}", params.account_name, params.container, blob),
None => format!("{}/{}", params.account_name, params.container),
};
let string_to_sign = format!(
"{account}\n{permissions}\n{expiry}\n{resource}\n{version}\n{resource_code}",
account = params.account_name,
permissions = perm_str,
expiry = expiry_str,
resource = signed_resource,
version = SAS_VERSION,
resource_code = resource_code,
);
let sig_bytes = mock_sign(string_to_sign.as_bytes(), account_key);
let sig_hex = hex::encode(sig_bytes);
let mut parts: Vec<String> = Vec::new();
parts.push(format!("sv={}", urlencoding::encode(SAS_VERSION)));
parts.push(format!("ss={}", resource_code));
parts.push("srt=o".to_owned()); parts.push(format!("sp={}", urlencoding::encode(&perm_str)));
if let Some(ref s) = start_str {
parts.push(format!("st={}", urlencoding::encode(s)));
}
parts.push(format!("se={}", urlencoding::encode(&expiry_str)));
parts.push(format!("sig={}", urlencoding::encode(&sig_hex)));
parts.join("&")
}
pub fn build_sas_url(params: &AzureSasParams, account_key: &[u8]) -> String {
let token = generate_sas_token(params, account_key);
let blob_path = match ¶ms.blob {
Some(b) => format!("/{}", urlencoding::encode(b)),
None => String::new(),
};
format!(
"https://{}.blob.core.windows.net/{}{blob_path}?{token}",
params.account_name,
urlencoding::encode(¶ms.container),
)
}
pub fn parse_sas_token(token: &str) -> Result<HashMap<String, String>, AzureError> {
let mut map = HashMap::new();
for part in token.split('&') {
if part.is_empty() {
continue;
}
let eq_pos = part
.find('=')
.ok_or_else(|| AzureError::Parse(format!("missing '=' in token segment: {part}")))?;
let key = &part[..eq_pos];
let raw_value = &part[eq_pos + 1..];
let value = urlencoding::decode(raw_value)
.map(|s| s.into_owned())
.unwrap_or_else(|_| raw_value.to_owned());
map.insert(key.to_owned(), value);
}
Ok(map)
}
pub fn is_sas_valid(token_params: &HashMap<String, String>, current_time: u64) -> bool {
match token_params.get("se") {
Some(se) => {
let expiry = iso8601_to_unix(se).unwrap_or(0);
current_time < expiry
}
None => false,
}
}
fn unix_to_iso8601(ts: u64) -> String {
let secs = ts;
let days = secs / 86400;
let time_of_day = secs % 86400;
let hh = time_of_day / 3600;
let mm = (time_of_day % 3600) / 60;
let ss = time_of_day % 60;
let (year, month, day) = days_to_ymd(days);
format!("{year:04}-{month:02}-{day:02}T{hh:02}:{mm:02}:{ss:02}Z")
}
fn iso8601_to_unix(s: &str) -> Option<u64> {
if s.len() < 19 {
return None;
}
let year: u64 = s[0..4].parse().ok()?;
let month: u64 = s[5..7].parse().ok()?;
let day: u64 = s[8..10].parse().ok()?;
let hh: u64 = s[11..13].parse().ok()?;
let mm: u64 = s[14..16].parse().ok()?;
let ss: u64 = s[17..19].parse().ok()?;
let days = ymd_to_days(year, month, day);
Some(days * 86400 + hh * 3600 + mm * 60 + ss)
}
fn days_to_ymd(days: u64) -> (u64, u64, u64) {
let z = days + 719468;
let era = z / 146097;
let doe = z % 146097;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
let y = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
(y, m, d)
}
fn ymd_to_days(y: u64, m: u64, d: u64) -> u64 {
let y = if m <= 2 { y - 1 } else { y };
let era = y / 400;
let yoe = y % 400;
let doy = (153 * (if m > 2 { m - 3 } else { m + 9 }) + 2) / 5 + d - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
era * 146097 + doe - 719468
}
fn mock_sign(data: &[u8], key: &[u8]) -> [u8; 32] {
let mut out = [0u8; 32];
if key.is_empty() {
for (i, &b) in data.iter().enumerate() {
out[i % 32] ^= b;
}
return out;
}
for (i, &b) in data.iter().enumerate() {
let k = key[i % key.len()];
out[i % 32] ^= b ^ k;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_only_permission_string() {
let perm = SasPermissions::read_only();
let s = perm.as_permission_string();
assert!(s.contains('r'), "read_only must contain 'r'");
assert!(!s.contains('w'), "read_only must not contain 'w'");
assert!(!s.contains('d'), "read_only must not contain 'd'");
assert!(!s.contains('l'), "read_only must not contain 'l'");
}
#[test]
fn test_full_permission_string() {
let s = SasPermissions::full().as_permission_string();
for ch in ['r', 'w', 'd', 'l', 'a', 'c'] {
assert!(s.contains(ch), "full must contain '{ch}'");
}
}
#[test]
fn test_generate_sas_token_contains_required_fields() {
let params = AzureSasParams {
account_name: "account".into(),
container: "container".into(),
blob: Some("file.bin".into()),
permissions: SasPermissions::read_write(),
expiry: 9_999_999_999,
start: None,
resource: SasResource::Blob,
};
let token = generate_sas_token(¶ms, b"fake-key");
assert!(token.contains("sv="), "token must contain sv=");
assert!(token.contains("sp="), "token must contain sp=");
assert!(token.contains("se="), "token must contain se=");
assert!(token.contains("sig="), "token must contain sig=");
}
#[test]
fn test_parse_sas_token_round_trip() {
let params = AzureSasParams {
account_name: "roundtrip".into(),
container: "c".into(),
blob: None,
permissions: SasPermissions::full(),
expiry: 9_000_000_000,
start: None,
resource: SasResource::Container,
};
let token = generate_sas_token(¶ms, b"key");
let map = parse_sas_token(&token).expect("parse");
assert_eq!(map.get("sv").map(|s| s.as_str()), Some(SAS_VERSION));
let perm_encoded = SasPermissions::full().as_permission_string();
assert_eq!(
map.get("sp").map(|s| s.as_str()),
Some(perm_encoded.as_str())
);
}
#[test]
fn test_is_sas_valid_expiry() {
let future_params = AzureSasParams {
account_name: "a".into(),
container: "c".into(),
blob: None,
permissions: SasPermissions::read_only(),
expiry: 9_999_999_999,
start: None,
resource: SasResource::Blob,
};
let future_token = generate_sas_token(&future_params, b"k");
let future_map = parse_sas_token(&future_token).expect("parse future");
let past_params = AzureSasParams {
account_name: "a".into(),
container: "c".into(),
blob: None,
permissions: SasPermissions::read_only(),
expiry: 1_000_000, start: None,
resource: SasResource::Blob,
};
let past_token = generate_sas_token(&past_params, b"k");
let past_map = parse_sas_token(&past_token).expect("parse past");
let now: u64 = 1_744_000_000;
assert!(
is_sas_valid(&future_map, now),
"future token should be valid"
);
assert!(
!is_sas_valid(&past_map, now),
"past token should be expired"
);
}
#[test]
fn test_is_sas_valid_missing_se() {
let map: HashMap<String, String> = [("sv".into(), SAS_VERSION.into())].into();
assert!(!is_sas_valid(&map, 1_000_000));
}
#[test]
fn test_permission_round_trip() {
for perm in [
SasPermissions::read_only(),
SasPermissions::read_write(),
SasPermissions::full(),
] {
let s = perm.as_permission_string();
let decoded = SasPermissions::from_permission_string(&s);
assert_eq!(decoded, perm, "round-trip failed for '{s}'");
}
}
#[test]
fn test_build_sas_url_format() {
let params = AzureSasParams {
account_name: "myaccount".into(),
container: "mycontainer".into(),
blob: Some("blob.bin".into()),
permissions: SasPermissions::read_only(),
expiry: 9_999_999_999,
start: None,
resource: SasResource::Blob,
};
let url = build_sas_url(¶ms, b"key");
assert!(url.starts_with("https://myaccount.blob.core.windows.net/"));
assert!(url.contains('?'), "URL must contain a query string");
assert!(url.contains("sig="), "URL query must contain sig=");
}
#[test]
fn test_different_keys_produce_different_signatures() {
let params = AzureSasParams {
account_name: "a".into(),
container: "c".into(),
blob: None,
permissions: SasPermissions::read_only(),
expiry: 9_000_000_000,
start: None,
resource: SasResource::Blob,
};
let t1 = generate_sas_token(¶ms, b"key-one");
let t2 = generate_sas_token(¶ms, b"key-two");
let m1 = parse_sas_token(&t1).expect("parse 1");
let m2 = parse_sas_token(&t2).expect("parse 2");
assert_ne!(
m1.get("sig"),
m2.get("sig"),
"different keys must produce different signatures"
);
}
}