use crate::auth::federation::{generate_session_keypair, request_security_token};
use crate::auth::imds::ImdsClient;
use crate::auth::provider::AuthProvider;
use crate::auth::x509_utils::extract_tenant_id;
use crate::core::region::Region;
use chrono::{DateTime, Duration, Utc};
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
struct SessionState {
security_token: String,
session_private_key_pem: String,
expires_at: DateTime<Utc>,
}
pub struct InstancePrincipalsAuthProvider {
runtime_handle: tokio::runtime::Handle,
region: Region,
tenancy_id: String,
leaf_certificate: String,
leaf_private_key_pem: String,
intermediate_certificates: Vec<String>,
session_state: Arc<RwLock<SessionState>>,
}
impl InstancePrincipalsAuthProvider {
pub async fn new() -> crate::core::Result<Self> {
let runtime_handle = tokio::runtime::Handle::current();
let imds = ImdsClient::new()?;
let region_str = imds.get_region().await?;
let region = Region::from_str(®ion_str).map_err(|e| {
crate::core::OciError::ConfigError(format!("Invalid region from IMDS: {}", e))
})?;
let leaf_cert = imds.get_leaf_certificate().await?;
let leaf_key = imds.get_leaf_private_key().await?;
let intermediate_certs = imds.get_intermediate_certificates().await?;
let tenancy_id = extract_tenant_id(&leaf_cert)?;
let session_keypair = generate_session_keypair()?;
let security_token = request_security_token(
®ion,
&tenancy_id,
&leaf_cert,
&leaf_key,
&intermediate_certs,
&session_keypair.public_key_pem,
)
.await?;
let session_state = Arc::new(RwLock::new(SessionState {
security_token: security_token.token,
session_private_key_pem: session_keypair.private_key_pem,
expires_at: security_token.expires_at,
}));
Ok(Self {
runtime_handle,
region,
tenancy_id,
leaf_certificate: leaf_cert,
leaf_private_key_pem: leaf_key,
intermediate_certificates: intermediate_certs,
session_state,
})
}
pub fn region(&self) -> Region {
self.region
}
async fn ensure_token_valid(&self) -> crate::core::Result<()> {
{
let state = self.session_state.read().await;
if !Self::is_expired(&state.expires_at) {
return Ok(());
}
}
{
let mut state = self.session_state.write().await;
if !Self::is_expired(&state.expires_at) {
return Ok(());
}
let session_keypair = generate_session_keypair()?;
let security_token = request_security_token(
&self.region,
&self.tenancy_id,
&self.leaf_certificate,
&self.leaf_private_key_pem,
&self.intermediate_certificates,
&session_keypair.public_key_pem,
)
.await?;
state.security_token = security_token.token;
state.session_private_key_pem = session_keypair.private_key_pem;
state.expires_at = security_token.expires_at;
}
Ok(())
}
fn is_expired(expires_at: &DateTime<Utc>) -> bool {
let now = Utc::now();
let buffer = Duration::minutes(5);
now + buffer >= *expires_at
}
}
impl AuthProvider for InstancePrincipalsAuthProvider {
fn get_key_id(&self) -> String {
self.runtime_handle
.block_on(self.ensure_token_valid())
.unwrap_or_else(|e| {
eprintln!("Warning: Failed to refresh token: {}", e);
});
let state = self.session_state.blocking_read();
format!("ST${}", state.security_token)
}
fn get_private_key(&self) -> &str {
self.runtime_handle
.block_on(self.ensure_token_valid())
.unwrap_or_else(|e| {
eprintln!("Warning: Failed to refresh token: {}", e);
});
let pem = self
.session_state
.blocking_read()
.session_private_key_pem
.clone();
Box::leak(pem.into_boxed_str())
}
fn get_passphrase(&self) -> Option<&str> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn test_is_expired() {
let expires_at = Utc::now() + Duration::minutes(10);
assert!(!InstancePrincipalsAuthProvider::is_expired(&expires_at));
let expires_at = Utc::now() + Duration::minutes(4);
assert!(InstancePrincipalsAuthProvider::is_expired(&expires_at));
let expires_at = Utc::now() - Duration::minutes(10);
assert!(InstancePrincipalsAuthProvider::is_expired(&expires_at));
}
#[test]
fn test_key_id_format() {
let token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test.sig";
let key_id = format!("ST${}", token);
assert!(key_id.starts_with("ST$"));
assert!(key_id.contains("eyJ"));
}
}