use crate::error::{ConnectorError, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::sleep;
const MAX_RETRIES: u32 = 3;
const INITIAL_RETRY_DELAY_MS: u64 = 500;
const MAX_RETRY_DELAY_MS: u64 = 5000;
const DEFAULT_OTT_PATHS: &[&str] = &[
"/var/run/secrets/matrix/registration-token",
".matrix/registration-token",
];
const DEFAULT_PRIVATE_KEY_PATHS: &[&str] = &[
"/var/run/secrets/matrix/connector-key.pem",
"/var/run/secrets/matrix/tls.key",
];
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OttData {
token: String,
#[serde(rename = "matrix_url")]
api_url: Option<String>,
#[serde(rename = "keycloak_url")]
auth_url: Option<String>,
#[serde(rename = "expires_at")]
expires_at: Option<String>,
#[serde(rename = "connector_type")]
connector_type: Option<String>,
#[serde(rename = "tenant_id")]
tenant_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Credentials {
#[serde(rename = "client_id")]
pub client_id: String,
#[serde(rename = "keycloak_url")]
pub auth_url: String,
#[serde(rename = "tenant_id")]
pub tenant_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
}
#[derive(Debug, Clone)]
struct DirectConfig {
private_key_path: String,
client_id: String,
auth_url: String,
}
pub struct OttProvider {
api_url: Option<String>,
keys_dir: PathBuf,
credentials_dir: PathBuf,
private_key_pem: Option<String>,
credentials: Option<Credentials>,
access_token: Option<String>,
token_expires_at: Option<u64>,
connector_type: Option<String>,
instance_id: Option<String>,
direct_config: Option<DirectConfig>,
http_client: Client,
}
impl OttProvider {
pub fn new(connector_type: Option<String>, instance_id: Option<String>) -> Self {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let keys_dir =
std::env::var("STRIKE48_KEYS_DIR").unwrap_or_else(|_| format!("{home}/.strike48/keys"));
let credentials_dir = format!("{home}/.strike48/credentials");
let api_url = std::env::var("STRIKE48_API_URL").ok();
let direct_config = Self::load_direct_config();
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let http_client = if std::env::var("MATRIX_TLS_INSECURE")
.map(|v| v.eq_ignore_ascii_case("true") || v == "1")
.unwrap_or(false)
{
tracing::warn!(
"TLS certificate verification DISABLED for OTT/Keycloak HTTPS \
(MATRIX_TLS_INSECURE=true). Do NOT use in production!"
);
Client::builder()
.danger_accept_invalid_certs(true)
.build()
.unwrap_or_else(|_| Client::new())
} else {
Client::new()
};
Self {
api_url,
keys_dir: PathBuf::from(keys_dir),
credentials_dir: PathBuf::from(credentials_dir),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type,
instance_id,
direct_config,
http_client,
}
}
fn load_direct_config() -> Option<DirectConfig> {
let private_key_path = std::env::var("STRIKE48_PRIVATE_KEY_PATH")
.ok()
.or_else(Self::find_default_private_key)?;
let client_id = std::env::var("STRIKE48_CLIENT_ID").ok()?;
let auth_url = std::env::var("STRIKE48_AUTH_URL").ok()?;
if Path::new(&private_key_path).exists() {
Some(DirectConfig {
private_key_path,
client_id,
auth_url,
})
} else {
None
}
}
fn find_default_private_key() -> Option<String> {
for path in DEFAULT_PRIVATE_KEY_PATHS {
if Path::new(path).exists() {
return Some(path.to_string());
}
}
None
}
pub fn has_direct_config(&self) -> bool {
self.direct_config.is_some()
}
pub fn initialize_from_direct_config(&mut self) -> Result<Credentials> {
let config = self.direct_config.as_ref().ok_or_else(|| {
ConnectorError::InvalidConfig("Direct config not available".to_string())
})?;
self.private_key_pem = Some(Self::load_private_key_from_path(&config.private_key_path)?);
let tenant_id = std::env::var("TENANT_ID").map_err(|_| {
ConnectorError::InvalidConfig(
"TENANT_ID is required for direct-config auth".to_string(),
)
})?;
let credentials = Credentials {
client_id: config.client_id.clone(),
auth_url: config.auth_url.clone(),
tenant_id,
kid: None,
};
self.credentials = Some(credentials.clone());
Ok(credentials)
}
pub fn has_ott(&self) -> bool {
self.load_ott().is_some()
}
pub(crate) fn validate_register_origin(target: &str, configured: Option<&str>) -> Result<()> {
let configured = match configured {
Some(c) if !c.trim().is_empty() => c,
_ => {
tracing::warn!(
"STRIKE48_API_URL is not configured; skipping OTT register-URL \
allowlist check (dev only). Set STRIKE48_API_URL in production \
to enforce a same-origin policy on server-supplied register URLs."
);
return Ok(());
}
};
let target_origin = parse_origin(target).ok_or_else(|| {
ConnectorError::InvalidConfig(format!(
"OTT register URL is not a valid HTTP(S) URL: {target}"
))
})?;
let allowed_origin = parse_origin(configured).ok_or_else(|| {
ConnectorError::InvalidConfig(format!(
"STRIKE48_API_URL is not a valid HTTP(S) URL: {configured}"
))
})?;
if target_origin == allowed_origin {
Ok(())
} else {
Err(ConnectorError::InvalidConfig(format!(
"OTT register URL origin {target_origin:?} does not match \
STRIKE48_API_URL origin {allowed_origin:?}; refusing to send \
credentials to an unapproved host"
)))
}
}
fn load_ott(&self) -> Option<OttData> {
if let Ok(ott_value) = std::env::var("STRIKE48_REGISTRATION_TOKEN") {
return Self::parse_ott(&ott_value);
}
if let Ok(ott_file) = std::env::var("STRIKE48_REGISTRATION_TOKEN_FILE")
&& Path::new(&ott_file).exists()
{
return Self::load_ott_from_file(&ott_file);
}
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
for ott_path in DEFAULT_OTT_PATHS {
let full_path = if ott_path.starts_with('/') {
ott_path.to_string()
} else {
format!("{home}/{ott_path}")
};
if Path::new(&full_path).exists() {
return Self::load_ott_from_file(&full_path);
}
}
None
}
fn load_ott_from_file(file_path: &str) -> Option<OttData> {
fs::read_to_string(file_path)
.ok()
.map(|content| content.trim().to_string())
.and_then(|content| Self::parse_ott(&content))
}
fn parse_ott(value: &str) -> Option<OttData> {
if value.starts_with('{') {
return serde_json::from_str(value).ok();
}
use base64::{Engine as _, engine::general_purpose::STANDARD};
if let Ok(decoded) = STANDARD.decode(value)
&& let Ok(utf8) = String::from_utf8(decoded)
&& let Ok(ott_data) = serde_json::from_str::<OttData>(&utf8)
{
return Some(ott_data);
}
Some(OttData {
token: value.to_string(),
api_url: None,
auth_url: None,
expires_at: None,
connector_type: None,
tenant_id: None,
})
}
fn get_private_key_path(&self, connector_type: &str, instance_id: Option<&str>) -> PathBuf {
let filename = format!(
"{}_{}.pem",
connector_type,
instance_id.unwrap_or("default")
);
self.keys_dir.join(filename)
}
pub async fn register_with_ott(
&mut self,
connector_type: &str,
instance_id: Option<&str>,
) -> Result<Credentials> {
let ott_data = self
.load_ott()
.ok_or_else(|| ConnectorError::InvalidConfig("No OTT found".to_string()))?;
let public_key_pem = self
.get_or_create_keypair_for_connector(connector_type, instance_id)
.await?;
let api_url = ott_data
.api_url
.as_ref()
.or(self.api_url.as_ref())
.ok_or_else(|| ConnectorError::InvalidConfig("API URL not configured".to_string()))?;
let register_url = format!("{api_url}/api/connectors/register-with-ott");
let payload = serde_json::json!({
"token": ott_data.token,
"public_key": public_key_pem,
"connector_type": connector_type,
"instance_id": instance_id,
});
let response = self
.http_client
.post(®ister_url)
.json(&payload)
.send()
.await
.map_err(|e| ConnectorError::ConnectionError(format!("HTTP request failed: {e}")))?;
if response.status().is_success() {
let credentials: Credentials = response.json().await.map_err(|e| {
ConnectorError::SerializationError(format!("Failed to parse response: {e}"))
})?;
self.save_credentials(connector_type, instance_id, &credentials)?;
self.credentials = Some(credentials.clone());
Ok(credentials)
} else if response.status() == 401 {
Err(ConnectorError::InvalidConfig(
"Invalid or expired OTT".to_string(),
))
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(ConnectorError::RegistrationError(format!(
"OTT registration failed: {error_text}"
)))
}
}
async fn get_or_create_keypair_for_connector(
&mut self,
connector_type: &str,
instance_id: Option<&str>,
) -> Result<String> {
let key_path = self.get_private_key_path(connector_type, instance_id);
if key_path.exists() {
let pem = fs::read_to_string(&key_path)
.map_err(|e| std::io::Error::other(format!("Failed to read private key: {e}")))?;
if let Ok(public_key_pem) = Self::extract_ec_public_key_pem(&pem) {
self.private_key_pem = Some(pem);
self.connector_type = Some(connector_type.to_string());
self.instance_id = instance_id.map(|s| s.to_string());
return Ok(public_key_pem);
}
tracing::info!(
"Upgrading legacy RSA key to EC P-256 for connector {}",
connector_type
);
}
let (private_pem, public_pem) = Self::generate_ec_keypair()?;
if !self.keys_dir.exists() {
fs::create_dir_all(&self.keys_dir).map_err(|e| {
std::io::Error::other(format!("Failed to create keys directory: {e}"))
})?;
}
fs::write(&key_path, private_pem.as_bytes())
.map_err(|e| std::io::Error::other(format!("Failed to write private key: {e}")))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&key_path)?.permissions();
perms.set_mode(0o600);
fs::set_permissions(&key_path, perms)?;
}
self.private_key_pem = Some(private_pem);
self.connector_type = Some(connector_type.to_string());
self.instance_id = instance_id.map(|s| s.to_string());
Ok(public_pem)
}
fn load_private_key_from_path(key_path: &str) -> Result<String> {
let pem = fs::read_to_string(key_path)
.map_err(|e| std::io::Error::other(format!("Failed to read private key: {e}")))?;
if !pem.contains("BEGIN") || !pem.contains("PRIVATE KEY") {
return Err(ConnectorError::InvalidConfig(
"File does not contain a PEM private key".to_string(),
));
}
Ok(pem)
}
fn generate_ec_keypair() -> Result<(String, String)> {
use aws_lc_rs::signature::{ECDSA_P256_SHA256_ASN1_SIGNING, EcdsaKeyPair, KeyPair as _};
let rng = aws_lc_rs::rand::SystemRandom::new();
let pkcs8_doc = EcdsaKeyPair::generate_pkcs8(&ECDSA_P256_SHA256_ASN1_SIGNING, &rng)
.map_err(|e| ConnectorError::Other(format!("Failed to generate EC keypair: {e}")))?;
let key_pair =
EcdsaKeyPair::from_pkcs8(&ECDSA_P256_SHA256_ASN1_SIGNING, pkcs8_doc.as_ref()).map_err(
|e| ConnectorError::Other(format!("Failed to parse generated keypair: {e}")),
)?;
let private_pem = Self::der_to_pem(pkcs8_doc.as_ref(), "PRIVATE KEY");
let public_key_bytes = key_pair.public_key().as_ref();
let spki_der = Self::ec_p256_public_key_to_spki(public_key_bytes);
let public_pem = Self::der_to_pem(&spki_der, "PUBLIC KEY");
Ok((private_pem, public_pem))
}
fn extract_ec_public_key_pem(private_key_pem: &str) -> Result<String> {
use aws_lc_rs::signature::{ECDSA_P256_SHA256_ASN1_SIGNING, EcdsaKeyPair, KeyPair as _};
let der = Self::pem_to_der(private_key_pem)?;
let key_pair = EcdsaKeyPair::from_pkcs8(&ECDSA_P256_SHA256_ASN1_SIGNING, &der)
.map_err(|e| ConnectorError::Other(format!("Not a valid EC P-256 key: {e}")))?;
let public_key_bytes = key_pair.public_key().as_ref();
let spki_der = Self::ec_p256_public_key_to_spki(public_key_bytes);
Ok(Self::der_to_pem(&spki_der, "PUBLIC KEY"))
}
fn ec_p256_public_key_to_spki(public_key_bytes: &[u8]) -> Vec<u8> {
let prefix: &[u8] = &[
0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01,
0x07, 0x03, 0x42, 0x00, ];
let mut spki = Vec::with_capacity(prefix.len() + public_key_bytes.len());
spki.extend_from_slice(prefix);
spki.extend_from_slice(public_key_bytes);
spki
}
fn der_to_pem(der: &[u8], label: &str) -> String {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let b64 = STANDARD.encode(der);
let mut pem = format!("-----BEGIN {label}-----\n");
for chunk in b64.as_bytes().chunks(64) {
pem.push_str(std::str::from_utf8(chunk).unwrap());
pem.push('\n');
}
pem.push_str(&format!("-----END {label}-----\n"));
pem
}
fn pem_to_der(pem: &str) -> Result<Vec<u8>> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let b64: String = pem
.lines()
.filter(|line| !line.starts_with("-----"))
.collect();
STANDARD
.decode(&b64)
.map_err(|e| ConnectorError::Other(format!("Failed to decode PEM: {e}")))
}
pub fn load_saved_credentials(
&mut self,
connector_type: &str,
instance_id: Option<&str>,
) -> Option<Credentials> {
let filename = format!(
"{}_{}.json",
connector_type,
instance_id.unwrap_or("default")
);
let filepath = self.credentials_dir.join(&filename);
tracing::debug!("Looking for saved credentials at: {}", filepath.display());
if filepath.exists() {
tracing::debug!("Credentials file found, loading...");
if let Ok(data) = fs::read_to_string(&filepath) {
if let Ok(creds) = serde_json::from_str::<Credentials>(&data) {
tracing::debug!(
"Loaded credentials from {}: client_id={}",
filepath.display(),
creds.client_id
);
self.credentials = Some(creds.clone());
return Some(creds);
} else {
tracing::warn!(
"Failed to parse credentials JSON from {}",
filepath.display()
);
}
} else {
tracing::warn!("Failed to read credentials file: {}", filepath.display());
}
} else {
tracing::debug!("Credentials file not found: {}", filepath.display());
}
None
}
fn save_credentials(
&self,
connector_type: &str,
instance_id: Option<&str>,
credentials: &Credentials,
) -> Result<()> {
if !self.credentials_dir.exists() {
fs::create_dir_all(&self.credentials_dir).map_err(|e| {
std::io::Error::other(format!("Failed to create credentials directory: {e}"))
})?;
}
let filename = format!(
"{}_{}.json",
connector_type,
instance_id.unwrap_or("default")
);
let filepath = self.credentials_dir.join(filename);
let json = serde_json::to_string_pretty(credentials).map_err(|e| {
ConnectorError::SerializationError(format!("Failed to serialize credentials: {e}"))
})?;
fs::write(&filepath, json)
.map_err(|e| std::io::Error::other(format!("Failed to write credentials: {e}")))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&filepath)?.permissions();
perms.set_mode(0o600);
fs::set_permissions(&filepath, perms)?;
}
Ok(())
}
pub async fn register_public_key_with_ott_data(
&mut self,
ott: &str,
api_url: &str,
register_url: &str,
connector_type: &str,
instance_id: Option<&str>,
) -> Result<Credentials> {
let public_key_pem = self
.get_or_create_keypair_for_connector(connector_type, instance_id)
.await?;
let full_url = if register_url.starts_with("http") {
register_url.to_string()
} else {
format!("{}{}", api_url.trim_end_matches('/'), register_url)
};
let configured_api_url = std::env::var("STRIKE48_API_URL").ok();
Self::validate_register_origin(&full_url, configured_api_url.as_deref())?;
tracing::debug!(
"OTT registration: api_url={}, register_url={}, full_url={}",
api_url,
register_url,
full_url
);
tracing::debug!(
"OTT registration: sending token for connector_type={}",
connector_type
);
let payload = serde_json::json!({
"token": ott,
"public_key": public_key_pem,
"connector_type": connector_type,
"instance_id": instance_id,
});
tracing::debug!(
"OTT registration payload: connector_type={}, instance_id={:?}",
connector_type,
instance_id
);
const MAX_RETRIES: u32 = 4;
const INITIAL_DELAY_MS: u64 = 500;
const MAX_DELAY_MS: u64 = 3000;
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
if attempt > 0 {
let delay = std::cmp::min(INITIAL_DELAY_MS * 2_u64.pow(attempt - 1), MAX_DELAY_MS);
tracing::warn!(
"OTT registration retry {}/{} after {}ms (waiting for cluster sync)",
attempt + 1,
MAX_RETRIES,
delay
);
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
}
let response = match self.http_client.post(&full_url).json(&payload).send().await {
Ok(resp) => resp,
Err(e) => {
last_error = Some(format!("HTTP request failed: {e}"));
continue;
}
};
tracing::debug!(
"OTT registration response status: {} (attempt {})",
response.status(),
attempt + 1
);
if response.status().is_success() {
let credentials: Credentials = response.json().await.map_err(|e| {
ConnectorError::SerializationError(format!("Failed to parse response: {e}"))
})?;
self.save_credentials(connector_type, instance_id, &credentials)?;
self.credentials = Some(credentials.clone());
return Ok(credentials);
}
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 401 && error_text.contains("Invalid or expired") {
last_error = Some(error_text);
continue;
}
return Err(ConnectorError::RegistrationError(format!(
"Post-approval OTT registration failed: {error_text}"
)));
}
Err(ConnectorError::RegistrationError(format!(
"Post-approval OTT registration failed after {} retries: {}",
MAX_RETRIES,
last_error.unwrap_or_else(|| "Unknown error".to_string())
)))
}
pub async fn get_token(&mut self) -> Result<String> {
if let (Some(token), Some(expires_at)) = (&self.access_token, self.token_expires_at) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if now < expires_at - 30 {
return Ok(token.clone());
}
}
let credentials = self
.credentials
.as_ref()
.ok_or_else(|| ConnectorError::InvalidConfig("No credentials available".to_string()))?
.clone();
self.get_token_via_private_key_jwt(&credentials).await
}
pub fn clear_token_cache(&mut self) {
self.access_token = None;
self.token_expires_at = None;
tracing::debug!("Token cache cleared");
}
pub fn delete_saved_credentials(&self) {
if let (Some(connector_type), instance_id) =
(&self.connector_type, self.instance_id.as_deref())
{
let filename = format!(
"{}_{}.json",
connector_type,
instance_id.unwrap_or("default")
);
let filepath = self.credentials_dir.join(&filename);
if filepath.exists() {
match std::fs::remove_file(&filepath) {
Ok(()) => {
tracing::info!("Deleted stale credentials file: {}", filepath.display())
}
Err(e) => tracing::warn!(
"Failed to delete stale credentials file {}: {}",
filepath.display(),
e
),
}
}
}
}
pub fn reset(&mut self) {
self.credentials = None;
self.private_key_pem = None;
self.access_token = None;
self.token_expires_at = None;
tracing::debug!("OttProvider state reset");
}
#[allow(dead_code)]
pub fn has_credentials(&self) -> bool {
self.credentials.is_some()
}
async fn get_token_via_private_key_jwt(&mut self, credentials: &Credentials) -> Result<String> {
if self.private_key_pem.is_none() {
if let Some(connector_type) = &self.connector_type {
let key_path =
self.get_private_key_path(connector_type, self.instance_id.as_deref());
if let Some(key_path_str) = key_path.to_str() {
self.private_key_pem = Some(Self::load_private_key_from_path(key_path_str)?);
} else {
return Err(ConnectorError::InvalidConfig(
"Invalid key path".to_string(),
));
}
} else {
return Err(ConnectorError::InvalidConfig(
"Connector identity not set".to_string(),
));
}
}
let private_key_pem = self.private_key_pem.as_ref().ok_or_else(|| {
ConnectorError::InvalidConfig("Private key not available".to_string())
})?;
let client_assertion = self.create_client_assertion(private_key_pem, credentials)?;
let token_url = format!(
"{}/protocol/openid-connect/token",
credentials.auth_url.trim_end_matches('/')
);
let mut params = HashMap::new();
params.insert("grant_type", "client_credentials");
params.insert("client_id", &credentials.client_id);
params.insert(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
);
params.insert("client_assertion", &client_assertion);
let mut last_error = None;
let mut delay_ms = INITIAL_RETRY_DELAY_MS;
for attempt in 1..=MAX_RETRIES {
tracing::debug!(
"Token request attempt {}/{} to {}",
attempt,
MAX_RETRIES,
token_url
);
match self.http_client.post(&token_url).form(¶ms).send().await {
Ok(response) => {
let status = response.status();
if status.is_success() {
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: Option<u64>,
}
match response.json::<TokenResponse>().await {
Ok(token_data) => {
self.access_token = Some(token_data.access_token.clone());
let expires_in = token_data.expires_in.unwrap_or(300);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.token_expires_at = Some(now + expires_in);
tracing::debug!(
"Token obtained successfully, expires in {}s",
expires_in
);
return Ok(token_data.access_token);
}
Err(e) => {
return Err(ConnectorError::SerializationError(format!(
"Failed to parse token response: {e}"
)));
}
}
}
if status.is_client_error() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
tracing::error!("Token request rejected ({status}): {error_text}");
if status.as_u16() == 401 {
self.clear_token_cache();
tracing::warn!("Cleared token cache due to 401 Unauthorized");
}
return Err(ConnectorError::InvalidConfig(format!(
"Token request failed ({status}): {error_text}"
)));
}
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
last_error = Some(ConnectorError::ConnectionError(format!(
"Token request failed ({status}): {error_text}"
)));
}
Err(e) => {
use std::error::Error as _;
let mut chain = format!("{e}");
let mut src: Option<&(dyn std::error::Error + 'static)> = e.source();
while let Some(s) = src {
chain.push_str(&format!(" -> {s}"));
src = s.source();
}
tracing::debug!("Token request failed (full chain): {chain}");
last_error = Some(ConnectorError::ConnectionError(format!(
"Token request network error: {chain}"
)));
}
}
if attempt < MAX_RETRIES {
tracing::warn!(
"Token request failed (attempt {}/{}), retrying in {}ms...",
attempt,
MAX_RETRIES,
delay_ms
);
sleep(Duration::from_millis(delay_ms)).await;
delay_ms = std::cmp::min(delay_ms * 2, MAX_RETRY_DELAY_MS);
}
}
Err(last_error.unwrap_or_else(|| {
ConnectorError::ConnectionError("Token request failed after all retries".to_string())
}))
}
fn create_client_assertion(
&self,
private_key_pem: &str,
credentials: &Credentials,
) -> Result<String> {
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use serde_json::json;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = json!({
"iss": credentials.client_id,
"sub": credentials.client_id,
"aud": credentials.auth_url,
"exp": now + 60,
"iat": now,
"jti": uuid::Uuid::new_v4().to_string(),
});
let pem_bytes = private_key_pem.as_bytes();
let (algorithm, encoding_key) = if private_key_pem.contains("BEGIN RSA PRIVATE KEY") {
let key = EncodingKey::from_rsa_pem(pem_bytes)
.map_err(|e| ConnectorError::Other(format!("Failed to create RSA key: {e}")))?;
(Algorithm::RS256, key)
} else if private_key_pem.contains("BEGIN EC PRIVATE KEY") {
let key = EncodingKey::from_ec_pem(pem_bytes)
.map_err(|e| ConnectorError::Other(format!("Failed to create EC key: {e}")))?;
(Algorithm::ES256, key)
} else {
match EncodingKey::from_ec_pem(pem_bytes) {
Ok(key) => (Algorithm::ES256, key),
Err(_) => {
let key = EncodingKey::from_rsa_pem(pem_bytes).map_err(|e| {
ConnectorError::Other(format!("Failed to create encoding key: {e}"))
})?;
(Algorithm::RS256, key)
}
}
};
let mut header = Header::new(algorithm);
if let Some(ref kid) = credentials.kid {
header.kid = Some(kid.clone());
}
encode(&header, &claims, &encoding_key)
.map_err(|e| ConnectorError::Other(format!("Failed to sign JWT: {e}")))
}
#[allow(dead_code)]
pub async fn get_auth_token(&mut self) -> Option<String> {
self.get_token().await.ok()
}
}
fn parse_origin(s: &str) -> Option<(String, String, u16)> {
let url = reqwest::Url::parse(s).ok()?;
let scheme = url.scheme().to_ascii_lowercase();
if scheme != "http" && scheme != "https" {
return None;
}
let host = url.host_str()?.to_ascii_lowercase();
let port = url.port_or_known_default()?;
Some((scheme, host, port))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_register_origin_same_host_passes() {
OttProvider::validate_register_origin(
"https://api.matrix.example.com/connectors/v1/register",
Some("https://api.matrix.example.com"),
)
.expect("same origin must pass");
}
#[test]
fn validate_register_origin_default_port_normalised() {
OttProvider::validate_register_origin(
"https://api.matrix.example.com:443/x",
Some("https://api.matrix.example.com/"),
)
.expect("443 == default https port");
}
#[test]
fn validate_register_origin_cross_host_rejected() {
let err = OttProvider::validate_register_origin(
"https://attacker.example.com/connectors/v1/register",
Some("https://api.matrix.example.com"),
)
.expect_err("cross-host must be rejected");
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[test]
fn validate_register_origin_scheme_mismatch_rejected() {
let err = OttProvider::validate_register_origin(
"http://api.matrix.example.com/x",
Some("https://api.matrix.example.com"),
)
.expect_err("scheme mismatch must be rejected");
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[test]
fn validate_register_origin_port_mismatch_rejected() {
let err = OttProvider::validate_register_origin(
"https://api.matrix.example.com:8443/x",
Some("https://api.matrix.example.com"),
)
.expect_err("port mismatch must be rejected");
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[test]
fn validate_register_origin_no_allowlist_skips() {
OttProvider::validate_register_origin("https://api.example.com/x", None)
.expect("no allowlist => allow");
OttProvider::validate_register_origin("https://api.example.com/x", Some(" "))
.expect("blank allowlist => allow");
}
#[test]
fn validate_register_origin_invalid_target_rejected() {
let err =
OttProvider::validate_register_origin("not a url", Some("https://api.example.com"))
.expect_err("malformed URL must be rejected");
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[cfg(unix)]
#[test]
fn save_credentials_writes_file_with_mode_0600() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().expect("tempdir");
let provider = OttProvider {
api_url: None,
keys_dir: tmp.path().join("keys"),
credentials_dir: tmp.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("perm_test".into()),
instance_id: Some("inst_a".into()),
direct_config: None,
http_client: Client::new(),
};
let credentials = Credentials {
client_id: "ci-1".into(),
auth_url: "https://auth.example.com".into(),
tenant_id: "demo".into(),
kid: None,
};
provider
.save_credentials("perm_test", Some("inst_a"), &credentials)
.expect("save_credentials");
let path = tmp.path().join("creds").join("perm_test_inst_a.json");
let mode = std::fs::metadata(&path).unwrap().permissions().mode();
assert_eq!(
mode & 0o777,
0o600,
"credential file must be owner-only readable, got mode={mode:o}"
);
}
#[test]
fn test_parse_ott_raw_token_string() {
let result = OttProvider::parse_ott("ott_hXg1Adwu12345");
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "ott_hXg1Adwu12345");
assert!(ott.api_url.is_none());
assert!(ott.auth_url.is_none());
assert!(ott.expires_at.is_none());
}
#[test]
fn test_parse_ott_json_inline() {
let json_str = r#"{"token":"ott_abc123","matrix_url":"https://api.example.com","keycloak_url":"https://auth.example.com/realms/matrix"}"#;
let result = OttProvider::parse_ott(json_str);
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "ott_abc123");
assert_eq!(ott.api_url.as_deref(), Some("https://api.example.com"));
assert_eq!(
ott.auth_url.as_deref(),
Some("https://auth.example.com/realms/matrix")
);
}
#[test]
fn test_parse_ott_json_with_all_fields() {
let json_str = r#"{
"token": "ott_full",
"matrix_url": "https://api.example.com",
"keycloak_url": "https://auth.example.com",
"expires_at": "2026-12-31T23:59:59Z",
"connector_type": "my-connector",
"tenant_id": "tenant-1"
}"#;
let result = OttProvider::parse_ott(json_str);
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "ott_full");
assert_eq!(ott.connector_type.as_deref(), Some("my-connector"));
assert_eq!(ott.tenant_id.as_deref(), Some("tenant-1"));
assert_eq!(ott.expires_at.as_deref(), Some("2026-12-31T23:59:59Z"));
}
#[test]
fn test_parse_ott_base64_encoded_json() {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let json_str = r#"{"token":"ott_b64","matrix_url":"https://api.test.com"}"#;
let encoded = STANDARD.encode(json_str.as_bytes());
let result = OttProvider::parse_ott(&encoded);
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "ott_b64");
assert_eq!(ott.api_url.as_deref(), Some("https://api.test.com"));
}
#[test]
fn test_parse_ott_empty_string() {
let result = OttProvider::parse_ott("");
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "");
}
#[test]
fn test_parse_ott_json_missing_token_fails() {
let json_str = r#"{"matrix_url":"https://api.example.com"}"#;
let result = OttProvider::parse_ott(json_str);
assert!(result.is_none());
}
#[test]
fn test_parse_ott_json_minimal() {
let json_str = r#"{"token":"ott_min"}"#;
let result = OttProvider::parse_ott(json_str);
assert!(result.is_some());
let ott = result.unwrap();
assert_eq!(ott.token, "ott_min");
assert!(ott.api_url.is_none());
}
#[tokio::test]
async fn test_keypair_generation_and_pem_roundtrip() {
let temp_dir = tempfile::tempdir().unwrap();
let keys_dir = temp_dir.path().join("keys");
let mut provider = OttProvider {
api_url: None,
keys_dir,
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("test-connector".to_string()),
instance_id: Some("test-instance".to_string()),
direct_config: None,
http_client: Client::new(),
};
let public_key_pem = provider
.get_or_create_keypair_for_connector("test-connector", Some("test-instance"))
.await
.unwrap();
assert!(public_key_pem.contains("BEGIN PUBLIC KEY"));
assert!(public_key_pem.contains("END PUBLIC KEY"));
assert!(provider.private_key_pem.is_some());
let key_path = provider.get_private_key_path("test-connector", Some("test-instance"));
assert!(key_path.exists());
let key_data = std::fs::read_to_string(&key_path).unwrap();
assert!(key_data.contains("BEGIN PRIVATE KEY"));
let loaded_pem =
OttProvider::load_private_key_from_path(key_path.to_str().unwrap()).unwrap();
let loaded_pub_pem = OttProvider::extract_ec_public_key_pem(&loaded_pem).unwrap();
assert_eq!(public_key_pem, loaded_pub_pem);
}
#[tokio::test]
async fn test_keypair_reuse_on_second_call() {
let temp_dir = tempfile::tempdir().unwrap();
let keys_dir = temp_dir.path().join("keys");
let mut provider = OttProvider {
api_url: None,
keys_dir,
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("test-connector".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
let pub1 = provider
.get_or_create_keypair_for_connector("test-connector", Some("inst"))
.await
.unwrap();
provider.private_key_pem = None;
let pub2 = provider
.get_or_create_keypair_for_connector("test-connector", Some("inst"))
.await
.unwrap();
assert_eq!(pub1, pub2);
}
#[test]
fn test_private_key_path_format() {
let provider = OttProvider::new(Some("my-type".to_string()), Some("my-inst".to_string()));
let path = provider.get_private_key_path("my-type", Some("my-inst"));
let filename = path.file_name().unwrap().to_str().unwrap();
assert_eq!(filename, "my-type_my-inst.pem");
}
#[test]
fn test_private_key_path_default_instance() {
let provider = OttProvider::new(Some("my-type".to_string()), None);
let path = provider.get_private_key_path("my-type", None);
let filename = path.file_name().unwrap().to_str().unwrap();
assert_eq!(filename, "my-type_default.pem");
}
#[test]
fn test_save_and_load_credentials() {
let temp_dir = tempfile::tempdir().unwrap();
let creds_dir = temp_dir.path().join("credentials");
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: creds_dir.clone(),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("cred-test".to_string()),
instance_id: Some("inst-1".to_string()),
direct_config: None,
http_client: Client::new(),
};
let creds = Credentials {
client_id: "client-abc".to_string(),
auth_url: "https://auth.example.com/realms/matrix".to_string(),
tenant_id: "tenant-1".to_string(),
kid: None,
};
provider
.save_credentials("cred-test", Some("inst-1"), &creds)
.unwrap();
let filepath = creds_dir.join("cred-test_inst-1.json");
assert!(filepath.exists());
let file_content = std::fs::read_to_string(&filepath).unwrap();
let parsed: Credentials = serde_json::from_str(&file_content).unwrap();
assert_eq!(parsed.client_id, "client-abc");
assert_eq!(parsed.auth_url, "https://auth.example.com/realms/matrix");
assert_eq!(parsed.tenant_id, "tenant-1");
let loaded = provider.load_saved_credentials("cred-test", Some("inst-1"));
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.client_id, "client-abc");
assert_eq!(loaded.auth_url, "https://auth.example.com/realms/matrix");
assert_eq!(loaded.tenant_id, "tenant-1");
}
#[test]
fn test_load_credentials_not_found() {
let temp_dir = tempfile::tempdir().unwrap();
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: temp_dir.path().join("credentials"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("missing".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
let loaded = provider.load_saved_credentials("missing", Some("inst"));
assert!(loaded.is_none());
}
#[test]
fn test_load_credentials_default_instance() {
let temp_dir = tempfile::tempdir().unwrap();
let creds_dir = temp_dir.path().join("credentials");
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: creds_dir,
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("test".to_string()),
instance_id: None,
direct_config: None,
http_client: Client::new(),
};
let creds = Credentials {
client_id: "default-client".to_string(),
auth_url: "https://auth.example.com".to_string(),
tenant_id: "default".to_string(),
kid: None,
};
provider.save_credentials("test", None, &creds).unwrap();
let loaded = provider.load_saved_credentials("test", None);
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().client_id, "default-client");
}
#[test]
fn test_delete_saved_credentials() {
let temp_dir = tempfile::tempdir().unwrap();
let creds_dir = temp_dir.path().join("credentials");
let provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: creds_dir.clone(),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("del-test".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
provider
.save_credentials(
"del-test",
Some("inst"),
&Credentials {
client_id: "to-delete".to_string(),
auth_url: "https://auth.example.com".to_string(),
tenant_id: "t".to_string(),
kid: None,
},
)
.unwrap();
let filepath = creds_dir.join("del-test_inst.json");
assert!(filepath.exists());
provider.delete_saved_credentials();
assert!(!filepath.exists());
}
#[test]
fn test_delete_saved_credentials_nonexistent_is_noop() {
let temp_dir = tempfile::tempdir().unwrap();
let provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: temp_dir.path().join("credentials"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("nope".to_string()),
instance_id: Some("nope".to_string()),
direct_config: None,
http_client: Client::new(),
};
provider.delete_saved_credentials();
}
#[test]
fn test_clear_token_cache() {
let mut provider = OttProvider::new(Some("test".to_string()), Some("inst".to_string()));
provider.access_token = Some("cached-token".to_string());
provider.token_expires_at = Some(9999999999);
provider.clear_token_cache();
assert!(provider.access_token.is_none());
assert!(provider.token_expires_at.is_none());
}
#[test]
fn test_reset_clears_all_state() {
let mut provider = OttProvider::new(Some("test".to_string()), Some("inst".to_string()));
provider.access_token = Some("token".to_string());
provider.token_expires_at = Some(123);
provider.credentials = Some(Credentials {
client_id: "c".to_string(),
auth_url: "a".to_string(),
tenant_id: "t".to_string(),
kid: None,
});
provider.reset();
assert!(provider.access_token.is_none());
assert!(provider.token_expires_at.is_none());
assert!(provider.credentials.is_none());
assert!(provider.private_key_pem.is_none());
}
#[tokio::test]
async fn test_create_client_assertion_produces_valid_jwt() {
let temp_dir = tempfile::tempdir().unwrap();
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("jwt-test".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
provider
.get_or_create_keypair_for_connector("jwt-test", Some("inst"))
.await
.unwrap();
let private_key_pem = provider.private_key_pem.as_ref().unwrap();
let credentials = Credentials {
client_id: "test-client-id".to_string(),
auth_url: "https://auth.example.com/realms/matrix".to_string(),
tenant_id: "test-tenant".to_string(),
kid: None,
};
let jwt = provider
.create_client_assertion(private_key_pem, &credentials)
.unwrap();
let parts: Vec<&str> = jwt.split('.').collect();
assert_eq!(parts.len(), 3, "JWT should have 3 dot-separated parts");
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let header_bytes = URL_SAFE_NO_PAD.decode(parts[0]).unwrap();
let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
assert_eq!(header["alg"], "ES256");
assert_eq!(header["typ"], "JWT");
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
let claims: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
assert_eq!(claims["iss"], "test-client-id");
assert_eq!(claims["sub"], "test-client-id");
assert_eq!(claims["aud"], "https://auth.example.com/realms/matrix");
assert!(claims["exp"].is_number());
assert!(claims["iat"].is_number());
assert!(claims["jti"].is_string());
let iat = claims["iat"].as_u64().unwrap();
let exp = claims["exp"].as_u64().unwrap();
assert_eq!(exp - iat, 60);
let jti = claims["jti"].as_str().unwrap();
assert!(uuid::Uuid::parse_str(jti).is_ok());
}
#[tokio::test]
async fn test_create_client_assertion_signature_is_verifiable() {
let temp_dir = tempfile::tempdir().unwrap();
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("verify-test".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
let public_key_pem = provider
.get_or_create_keypair_for_connector("verify-test", Some("inst"))
.await
.unwrap();
let private_key_pem = provider.private_key_pem.as_ref().unwrap();
let credentials = Credentials {
client_id: "verify-client".to_string(),
auth_url: "https://auth.example.com".to_string(),
tenant_id: "t".to_string(),
kid: None,
};
let jwt = provider
.create_client_assertion(private_key_pem, &credentials)
.unwrap();
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
let decoding_key = DecodingKey::from_ec_pem(public_key_pem.as_bytes()).unwrap();
let mut validation = Validation::new(Algorithm::ES256);
validation.set_audience(&["https://auth.example.com"]);
validation.set_issuer(&["verify-client"]);
validation.sub = Some("verify-client".to_string());
let decoded = decode::<serde_json::Value>(&jwt, &decoding_key, &validation);
assert!(
decoded.is_ok(),
"JWT signature verification failed: {:?}",
decoded.err()
);
}
#[tokio::test]
async fn test_create_two_assertions_have_different_jti() {
let temp_dir = tempfile::tempdir().unwrap();
let mut provider = OttProvider {
api_url: None,
keys_dir: temp_dir.path().join("keys"),
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("jti-test".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
provider
.get_or_create_keypair_for_connector("jti-test", Some("inst"))
.await
.unwrap();
let private_key_pem = provider.private_key_pem.as_ref().unwrap();
let credentials = Credentials {
client_id: "c".to_string(),
auth_url: "a".to_string(),
tenant_id: "t".to_string(),
kid: None,
};
let jwt1 = provider
.create_client_assertion(private_key_pem, &credentials)
.unwrap();
let jwt2 = provider
.create_client_assertion(private_key_pem, &credentials)
.unwrap();
assert_ne!(jwt1, jwt2);
}
#[tokio::test]
async fn test_legacy_rsa_key_signs_jwt_with_rs256() {
let rsa_pem = include_str!("../../test_fixtures/legacy_rsa_key.pem");
let provider = OttProvider::new(Some("test".to_string()), None);
let credentials = Credentials {
client_id: "legacy-client".to_string(),
auth_url: "https://auth.example.com".to_string(),
tenant_id: "t".to_string(),
kid: None,
};
let jwt = provider
.create_client_assertion(rsa_pem, &credentials)
.unwrap();
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let parts: Vec<&str> = jwt.split('.').collect();
let header_bytes = URL_SAFE_NO_PAD.decode(parts[0]).unwrap();
let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
assert_eq!(header["alg"], "RS256");
}
#[tokio::test]
async fn test_legacy_rsa_key_upgraded_on_re_registration() {
let temp_dir = tempfile::tempdir().unwrap();
let keys_dir = temp_dir.path().join("keys");
fs::create_dir_all(&keys_dir).unwrap();
let rsa_pem = include_str!("../../test_fixtures/legacy_rsa_key.pem");
let key_path = keys_dir.join("upgrade-test_inst.pem");
fs::write(&key_path, rsa_pem).unwrap();
let mut provider = OttProvider {
api_url: None,
keys_dir,
credentials_dir: temp_dir.path().join("creds"),
private_key_pem: None,
credentials: None,
access_token: None,
token_expires_at: None,
connector_type: Some("upgrade-test".to_string()),
instance_id: Some("inst".to_string()),
direct_config: None,
http_client: Client::new(),
};
let public_key_pem = provider
.get_or_create_keypair_for_connector("upgrade-test", Some("inst"))
.await
.unwrap();
assert!(public_key_pem.contains("BEGIN PUBLIC KEY"));
let new_key_data = fs::read_to_string(&key_path).unwrap();
assert!(new_key_data.contains("BEGIN PRIVATE KEY"));
assert!(!new_key_data.contains("RSA"));
}
#[tokio::test]
async fn test_get_token_returns_cached_when_valid() {
let mut provider = OttProvider::new(Some("test".to_string()), Some("inst".to_string()));
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
provider.access_token = Some("cached-jwt".to_string());
provider.token_expires_at = Some(now + 120); provider.credentials = Some(Credentials {
client_id: "c".to_string(),
auth_url: "a".to_string(),
tenant_id: "t".to_string(),
kid: None,
});
let token = provider.get_token().await.unwrap();
assert_eq!(token, "cached-jwt");
}
#[tokio::test]
async fn test_get_token_expired_cache_triggers_refresh() {
let mut provider = OttProvider::new(Some("test".to_string()), Some("inst".to_string()));
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
provider.access_token = Some("expired-jwt".to_string());
provider.token_expires_at = Some(now + 10); provider.credentials = Some(Credentials {
client_id: "c".to_string(),
auth_url: "https://auth.invalid.test".to_string(),
tenant_id: "t".to_string(),
kid: None,
});
let result = provider.get_token().await;
assert!(
result.is_err(),
"Should fail trying to refresh expired token"
);
}
#[test]
fn test_has_direct_config_false_by_default() {
unsafe {
std::env::remove_var("STRIKE48_PRIVATE_KEY_PATH");
std::env::remove_var("STRIKE48_CLIENT_ID");
std::env::remove_var("STRIKE48_AUTH_URL");
}
let provider = OttProvider::new(Some("test".to_string()), None);
assert!(!provider.has_direct_config());
}
#[test]
fn test_credentials_json_field_names() {
let creds = Credentials {
client_id: "cid".to_string(),
auth_url: "https://auth.example.com".to_string(),
tenant_id: "tid".to_string(),
kid: None,
};
let json = serde_json::to_value(&creds).unwrap();
assert!(json.get("client_id").is_some());
assert!(json.get("keycloak_url").is_some()); assert!(json.get("tenant_id").is_some());
assert_eq!(json["client_id"], "cid");
assert_eq!(json["keycloak_url"], "https://auth.example.com");
assert_eq!(json["tenant_id"], "tid");
}
#[test]
fn test_credentials_deserialization_from_server_format() {
let server_json = r#"{
"client_id": "connector-client-abc",
"keycloak_url": "https://keycloak.example.com/realms/matrix",
"tenant_id": "production"
}"#;
let creds: Credentials = serde_json::from_str(server_json).unwrap();
assert_eq!(creds.client_id, "connector-client-abc");
assert_eq!(creds.auth_url, "https://keycloak.example.com/realms/matrix");
assert_eq!(creds.tenant_id, "production");
}
#[test]
fn test_ott_data_json_field_names() {
let ott = OttData {
token: "ott_test".to_string(),
api_url: Some("https://api.example.com".to_string()),
auth_url: Some("https://auth.example.com".to_string()),
expires_at: Some("2026-12-31".to_string()),
connector_type: Some("my-conn".to_string()),
tenant_id: Some("t1".to_string()),
};
let json = serde_json::to_value(&ott).unwrap();
assert_eq!(json["token"], "ott_test");
assert_eq!(json["matrix_url"], "https://api.example.com");
assert_eq!(json["keycloak_url"], "https://auth.example.com");
assert_eq!(json["expires_at"], "2026-12-31");
assert_eq!(json["connector_type"], "my-conn");
assert_eq!(json["tenant_id"], "t1");
}
#[test]
fn test_ott_data_deserialization_from_server_format() {
let server_json = r#"{
"token": "ott_from_server",
"matrix_url": "https://matrix.prod.example.com",
"keycloak_url": "https://auth.prod.example.com/realms/matrix"
}"#;
let ott: OttData = serde_json::from_str(server_json).unwrap();
assert_eq!(ott.token, "ott_from_server");
assert_eq!(
ott.api_url.as_deref(),
Some("https://matrix.prod.example.com")
);
assert_eq!(
ott.auth_url.as_deref(),
Some("https://auth.prod.example.com/realms/matrix")
);
}
}