use crate::error::{Error, Result};
use crate::workflows::{WorkflowContext, WorkflowResult};
use futures::future::join_all;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::Instant;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Certificate {
pub id: String,
pub issuer: String,
pub subject: String,
pub public_key: Vec<u8>,
pub signature: Vec<u8>,
pub is_root: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrustChainResult {
pub valid: bool,
pub chain_length: usize,
pub root_id: Option<String>,
pub errors: Vec<String>,
}
pub struct AutonomousTrustChainWorkflow {
trusted_roots: HashSet<String>,
max_chain_length: usize,
}
impl AutonomousTrustChainWorkflow {
pub fn new(trusted_roots: HashSet<String>, max_chain_length: usize) -> Self {
Self {
trusted_roots,
max_chain_length,
}
}
pub async fn execute(
&self,
certificates: Vec<Certificate>,
context: WorkflowContext,
) -> Result<WorkflowResult<TrustChainResult>> {
let start = Instant::now();
info!("Starting trust chain validation workflow {}", context.id);
let cert_map: HashMap<String, Certificate> = certificates
.iter()
.map(|cert| (cert.id.clone(), cert.clone()))
.collect();
let issuers: HashSet<String> = certificates.iter().map(|c| c.issuer.clone()).collect();
let leaf = certificates
.iter()
.find(|cert| !issuers.contains(&cert.id))
.ok_or_else(|| Error::TrustChainFailed("No leaf certificate found".to_string()))?;
let result = self.validate_chain(leaf, &cert_map).await?;
let execution_time_ms = start.elapsed().as_millis() as u64;
info!(
"Trust chain validation workflow {} completed: {}",
context.id,
if result.valid { "VALID" } else { "INVALID" }
);
Ok(WorkflowResult::success(context, result, execution_time_ms))
}
async fn validate_chain(
&self,
leaf: &Certificate,
cert_map: &HashMap<String, Certificate>,
) -> Result<TrustChainResult> {
let mut chain = Vec::new();
let mut errors = Vec::new();
let mut current = leaf.clone();
let mut visited = HashSet::new();
loop {
if visited.contains(¤t.id) {
errors.push(format!("Cycle detected at certificate {}", current.id));
return Ok(TrustChainResult {
valid: false,
chain_length: chain.len(),
root_id: None,
errors,
});
}
visited.insert(current.id.clone());
chain.push(current.clone());
if chain.len() > self.max_chain_length {
errors.push(format!("Chain too long: {} > {}", chain.len(), self.max_chain_length));
return Ok(TrustChainResult {
valid: false,
chain_length: chain.len(),
root_id: None,
errors,
});
}
if current.is_root {
if self.trusted_roots.contains(¤t.id) {
debug!("Reached trusted root: {}", current.id);
return Ok(TrustChainResult {
valid: true,
chain_length: chain.len(),
root_id: Some(current.id.clone()),
errors: Vec::new(),
});
} else {
errors.push(format!("Root {} is not trusted", current.id));
return Ok(TrustChainResult {
valid: false,
chain_length: chain.len(),
root_id: Some(current.id.clone()),
errors,
});
}
}
match cert_map.get(¤t.issuer) {
Some(issuer) => current = issuer.clone(),
None => {
errors.push(format!("Issuer {} not found", current.issuer));
return Ok(TrustChainResult {
valid: false,
chain_length: chain.len(),
root_id: None,
errors,
});
}
}
}
}
pub async fn validate_multiple(
&self,
certificate_chains: Vec<Vec<Certificate>>,
context: WorkflowContext,
) -> Result<Vec<WorkflowResult<TrustChainResult>>> {
info!("Validating {} certificate chains in parallel", certificate_chains.len());
let futures = certificate_chains.into_iter().map(|certs| {
let ctx = context.clone();
async move { self.execute(certs, ctx).await }
});
let results = join_all(futures).await;
Ok(results.into_iter().filter_map(Result::ok).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_cert(id: &str, issuer: &str, is_root: bool) -> Certificate {
Certificate {
id: id.to_string(),
issuer: issuer.to_string(),
subject: id.to_string(),
public_key: vec![1, 2, 3],
signature: vec![4, 5, 6],
is_root,
}
}
#[tokio::test]
async fn test_valid_chain() {
let mut roots = HashSet::new();
roots.insert("root".to_string());
let workflow = AutonomousTrustChainWorkflow::new(roots, 10);
let certs = vec![
create_cert("leaf", "intermediate", false),
create_cert("intermediate", "root", false),
create_cert("root", "root", true),
];
let result = workflow
.execute(certs, WorkflowContext::default())
.await
.unwrap();
assert!(result.success);
assert!(result.data.valid);
assert_eq!(result.data.chain_length, 3);
}
#[tokio::test]
async fn test_untrusted_root() {
let roots = HashSet::new();
let workflow = AutonomousTrustChainWorkflow::new(roots, 10);
let certs = vec![
create_cert("leaf", "root", false),
create_cert("root", "root", true),
];
let result = workflow
.execute(certs, WorkflowContext::default())
.await
.unwrap();
assert!(!result.data.valid);
}
}