use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use thiserror::Error;
use tokio::time::sleep;
use super::client::{IdentityCheckRequest, IdentityCheckResponse, OrgApi};
use crate::IdentityRequirement;
const POLL_EVERY: Duration = Duration::from_secs(2);
#[derive(Debug)]
pub enum ResolveOutcome {
Verified(SmartflowProof),
HoldExpired { verify_url: String, challenge_id: String },
ProviderUnready { provider: String, message: String },
Error(SmartflowError),
}
#[derive(Debug, Clone)]
pub struct SmartflowProof {
pub provider: String,
pub subject: String,
pub loa: u8,
pub expires_at: DateTime<Utc>,
pub signature: Option<String>,
}
#[derive(Debug, Error)]
pub enum SmartflowError {
#[error("network/http: {0}")]
Net(String),
#[error("decode: {0}")]
Decode(String),
#[error("server response: {0}")]
Server(String),
}
pub struct SmartflowProvider {
api: Arc<OrgApi>,
pub hold_seconds: u64,
}
impl SmartflowProvider {
pub fn new(api: Arc<OrgApi>) -> Self {
Self {
api,
hold_seconds: 120,
}
}
pub fn with_hold_seconds(mut self, secs: u64) -> Self {
self.hold_seconds = secs;
self
}
fn requirement_to_request(req: &IdentityRequirement) -> IdentityCheckRequest {
IdentityCheckRequest {
provider: req.provider.clone(),
scope: req.scope.clone(),
allowed_subjects: req.allowed_subjects.iter().cloned().collect(),
min_loa: if req.loa == 0 { None } else { Some(req.loa) },
max_age_seconds: req.max_proof_age_seconds,
}
}
pub async fn resolve(&self, req: &IdentityRequirement) -> ResolveOutcome {
let wire = Self::requirement_to_request(req);
let first = match self.api.identity_check(&wire).await {
Ok(r) => r,
Err(super::client::OrgApiError::Http { status: 503, body }) => {
return ResolveOutcome::ProviderUnready {
provider: req.provider.clone(),
message: body,
};
}
Err(e) => return ResolveOutcome::Error(SmartflowError::Net(e.to_string())),
};
if first.verified {
return ResolveOutcome::Verified(proof_from_response(req, &first));
}
let (verify_url, challenge_id) = match (first.verify_url, first.challenge_id) {
(Some(u), Some(c)) => (u, c),
_ => {
return ResolveOutcome::Error(SmartflowError::Server(
"server reported unverified but did not return verify_url + challenge_id"
.into(),
));
}
};
eprintln!(
"[shield] identity verification required for scope='{}' provider='{}'",
req.scope, req.provider
);
eprintln!("[shield] open: {}", verify_url);
eprintln!(
"[shield] holding tool call for {}s (challenge={})",
self.hold_seconds, challenge_id
);
let deadline = std::time::Instant::now() + Duration::from_secs(self.hold_seconds);
while std::time::Instant::now() < deadline {
sleep(POLL_EVERY).await;
match self.api.identity_result(&challenge_id).await {
Ok(r) if r.verified => {
return ResolveOutcome::Verified(proof_from_response(req, &r));
}
Ok(_) => continue,
Err(super::client::OrgApiError::Http { status: 404, .. }) => {
break;
}
Err(e) => {
log::warn!("[shield] identity_result poll error: {}", e);
}
}
}
ResolveOutcome::HoldExpired {
verify_url,
challenge_id,
}
}
}
fn proof_from_response(
req: &IdentityRequirement,
resp: &IdentityCheckResponse,
) -> SmartflowProof {
SmartflowProof {
provider: req.provider.clone(),
subject: resp.subject.clone().unwrap_or_default(),
loa: resp.loa.unwrap_or(0),
expires_at: resp.expires_at.unwrap_or_else(Utc::now),
signature: resp.signature.clone(),
}
}