1use std::collections::HashMap;
31use std::convert::Infallible;
32use std::net::SocketAddr;
33use std::sync::Arc;
34
35use bytes::Bytes;
36use http_body_util::Full;
37use hyper::body::Incoming;
38use hyper::server::conn::http1;
39use hyper::service::service_fn;
40use hyper::{Method, Request, Response, StatusCode};
41use hyper_util::rt::TokioIo;
42use log::{debug, error, info, warn};
43use tokio::sync::{oneshot, RwLock};
44
45type ResponseBody = Full<Bytes>;
50
51use super::{Challenge, IdentityProvider, Requirement};
52use crate::identity::cache::ProofCache;
53use crate::identity::proof::ProofSigner;
54
55#[derive(Debug, Clone)]
57pub struct Inflight {
58 pub rule_id: String,
59 pub provider: String,
60 pub requirement: Requirement,
61}
62
63#[derive(Default)]
64struct State {
65 challenges: HashMap<String, (Challenge, Inflight)>,
67}
68
69pub struct ServerHandle {
72 pub bound: SocketAddr,
73 pub base_url: String,
74 state: Arc<RwLock<State>>,
75 #[allow(dead_code)] shutdown_tx: Option<oneshot::Sender<()>>,
77}
78
79impl ServerHandle {
80 pub fn base_url(&self) -> String {
81 self.base_url.clone()
82 }
83
84 pub async fn register(&self, challenge: Challenge, inflight: Inflight) {
87 let cid = challenge.challenge_id.clone();
88 let entry = (challenge, inflight);
89 let mut g = self.state.write().await;
90 g.challenges.insert(cid, entry);
91 }
92
93 #[cfg(test)]
96 pub async fn inflight_count(&self) -> usize {
97 self.state.read().await.challenges.len()
98 }
99}
100
101pub struct CallbackServer;
102
103impl CallbackServer {
104 pub async fn spawn(
107 host: &str,
108 port: u16,
109 providers: Vec<Arc<dyn IdentityProvider>>,
110 cache: Arc<ProofCache>,
111 signer: Arc<ProofSigner>,
112 ) -> anyhow::Result<ServerHandle> {
113 if host != "127.0.0.1" && host != "localhost" && host != "::1" {
114 warn!(
115 "[shield-identity] callback_host='{}' is NOT loopback -- this exposes the \
116 OAuth callback to the network. Only override this if you know what you're doing.",
117 host,
118 );
119 }
120 let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
121 let listener = std::net::TcpListener::bind(addr)?;
124 listener.set_nonblocking(true)?;
125 let bound = listener.local_addr()?;
126 let base_url = format!("http://{}:{}", bound.ip(), bound.port());
127
128 let state: Arc<RwLock<State>> = Arc::new(RwLock::new(State::default()));
129 let providers: Arc<Vec<Arc<dyn IdentityProvider>>> = Arc::new(providers);
130
131 let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
132 let tokio_listener = tokio::net::TcpListener::from_std(listener)?;
138 let accept_state = state.clone();
139 let accept_providers = providers.clone();
140 let accept_cache = cache.clone();
141 let accept_signer = signer.clone();
142 tokio::spawn(async move {
143 loop {
144 tokio::select! {
145 _ = &mut shutdown_rx => {
146 debug!("[shield-identity] callback server shutting down");
147 return;
148 }
149 accept = tokio_listener.accept() => {
150 let (stream, peer) = match accept {
151 Ok(p) => p,
152 Err(e) => {
153 error!("[shield-identity] accept failed: {}", e);
154 continue;
155 }
156 };
157 let io = TokioIo::new(stream);
158 let svc_state = accept_state.clone();
159 let svc_providers = accept_providers.clone();
160 let svc_cache = accept_cache.clone();
161 let svc_signer = accept_signer.clone();
162 let svc = service_fn(move |req: Request<Incoming>| {
163 let state = svc_state.clone();
164 let providers = svc_providers.clone();
165 let cache = svc_cache.clone();
166 let signer = svc_signer.clone();
167 async move {
168 Ok::<_, Infallible>(handle(req, state, providers, cache, signer).await)
169 }
170 });
171 tokio::spawn(async move {
172 if let Err(e) = http1::Builder::new()
173 .serve_connection(io, svc)
174 .await
175 {
176 debug!(
177 "[shield-identity] connection from {} ended: {}",
178 peer, e
179 );
180 }
181 });
182 }
183 }
184 }
185 });
186
187 info!(
188 "[shield-identity] callback server listening on {} (base_url={})",
189 bound, base_url
190 );
191
192 Ok(ServerHandle {
193 bound,
194 base_url,
195 state,
196 shutdown_tx: Some(shutdown_tx),
197 })
198 }
199}
200
201async fn handle(
206 req: Request<Incoming>,
207 state: Arc<RwLock<State>>,
208 providers: Arc<Vec<Arc<dyn IdentityProvider>>>,
209 cache: Arc<ProofCache>,
210 signer: Arc<ProofSigner>,
211) -> Response<ResponseBody> {
212 let method = req.method().clone();
213 let path = req.uri().path().to_string();
214 debug!("[shield-identity] {} {}", method, path);
215
216 if method != Method::GET {
222 return text(StatusCode::METHOD_NOT_ALLOWED, "GET only");
223 }
224
225 if path == "/healthz" {
226 return text(StatusCode::OK, "ok\n");
227 }
228 if path == "/" {
229 return html(StatusCode::OK, INDEX_HTML);
230 }
231
232 if let Some(rest) = path.strip_prefix("/verify/") {
233 let challenge_id = rest.trim_end_matches('/').to_string();
234 let query = req.uri().query().unwrap_or("").to_string();
235 return handle_verify(&challenge_id, &query, &state, &providers).await;
236 }
237
238 if path == "/callback" {
239 let query = req.uri().query().unwrap_or("").to_string();
240 return handle_callback(&query, &state, &providers, &cache, &signer).await;
241 }
242
243 text(StatusCode::NOT_FOUND, "not found")
244}
245
246async fn handle_verify(
247 challenge_id: &str,
248 query: &str,
249 state: &Arc<RwLock<State>>,
250 providers: &Arc<Vec<Arc<dyn IdentityProvider>>>,
251) -> Response<ResponseBody> {
252 let mock_flow = query_param(query, "mock").is_some();
253 let (challenge, inflight) = {
254 let g = state.read().await;
255 match g.challenges.get(challenge_id) {
256 Some((c, i)) => (c.clone(), i.clone()),
257 None => return html(StatusCode::NOT_FOUND, EXPIRED_HTML),
258 }
259 };
260
261 if mock_flow {
262 let provider = providers.iter().find(|p| p.id() == inflight.provider).cloned();
266 let Some(p) = provider else {
267 return html(StatusCode::INTERNAL_SERVER_ERROR, "no matching provider");
268 };
269 match p.exchange(challenge_id, "synthetic-code", challenge_id, challenge.pkce_verifier.as_deref()).await {
270 Ok(_vi) => {
271 let location = format!("/callback?code=synthetic-code&state={}&mock=1", challenge_id);
275 let mut resp = Response::new(empty_body());
276 *resp.status_mut() = StatusCode::FOUND;
277 resp.headers_mut().insert(
278 hyper::header::LOCATION,
279 hyper::header::HeaderValue::from_str(&location).unwrap(),
280 );
281 resp
282 }
283 Err(e) => html(
284 StatusCode::INTERNAL_SERVER_ERROR,
285 &format!("<h1>Mock exchange failed</h1><pre>{}</pre>", html_escape(&e.to_string())),
286 ),
287 }
288 } else {
289 let mut resp = Response::new(empty_body());
291 *resp.status_mut() = StatusCode::FOUND;
292 if let Ok(hv) = hyper::header::HeaderValue::from_str(&challenge.verify_url) {
293 resp.headers_mut().insert(hyper::header::LOCATION, hv);
294 }
295 resp
296 }
297}
298
299async fn handle_callback(
300 query: &str,
301 state: &Arc<RwLock<State>>,
302 providers: &Arc<Vec<Arc<dyn IdentityProvider>>>,
303 cache: &Arc<ProofCache>,
304 signer: &Arc<ProofSigner>,
305) -> Response<ResponseBody> {
306 let code = match query_param(query, "code") {
307 Some(c) => c,
308 None => return html(StatusCode::BAD_REQUEST, "<h1>Missing OAuth code</h1>"),
309 };
310 let cb_state = match query_param(query, "state") {
311 Some(s) => s,
312 None => return html(StatusCode::BAD_REQUEST, "<h1>Missing OAuth state</h1>"),
313 };
314
315 let (challenge, inflight) = {
316 let g = state.read().await;
317 match g.challenges.get(&cb_state) {
318 Some((c, i)) => (c.clone(), i.clone()),
319 None => return html(StatusCode::NOT_FOUND, EXPIRED_HTML),
320 }
321 };
322
323 let provider = providers.iter().find(|p| p.id() == inflight.provider).cloned();
324 let Some(p) = provider else {
325 return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Provider gone</h1>");
326 };
327
328 let vi = match p
329 .exchange(&cb_state, &code, &cb_state, challenge.pkce_verifier.as_deref())
330 .await
331 {
332 Ok(v) => v,
333 Err(e) => {
334 warn!("[shield-identity] exchange failed: {}", e);
335 return html(
336 StatusCode::INTERNAL_SERVER_ERROR,
337 &format!("<h1>Verification failed</h1><pre>{}</pre>", html_escape(&e.to_string())),
338 );
339 }
340 };
341
342 if !inflight.requirement.allows(&vi) {
343 return html(
344 StatusCode::FORBIDDEN,
345 &format!(
346 "<h1>Verified, but not authorised</h1>\
347 <p>Identity \"{}\" verified successfully, but is not on the allow-list for this rule.</p>",
348 html_escape(&vi.email.clone().unwrap_or_else(|| vi.subject.clone())),
349 ),
350 );
351 }
352 if vi.loa < inflight.requirement.loa {
353 return html(
354 StatusCode::FORBIDDEN,
355 &format!(
356 "<h1>Verified, but LOA too low</h1>\
357 <p>This rule requires LOA ≥ {} but the provider reported LOA {}.</p>",
358 inflight.requirement.loa, vi.loa,
359 ),
360 );
361 }
362
363 let now = super::unix_now();
366 let proof = super::Proof {
367 v: 1,
368 provider: vi.provider.clone(),
369 subject: vi.subject.clone(),
370 email: vi.email.clone(),
371 loa: vi.loa,
372 scope: inflight.requirement.scope.clone(),
373 verified_at: now,
374 expires_at: now.saturating_add(inflight.requirement.max_proof_age_seconds),
375 nonce: hex::encode(rand::random::<[u8; 16]>()),
376 sig: String::new(),
377 };
378 let signed = match signer.sign(proof) {
379 Ok(p) => p,
380 Err(e) => {
381 error!("[shield-identity] sign failed: {}", e);
382 return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Internal signing error</h1>");
383 }
384 };
385 if let Err(e) = cache.insert(signed) {
386 error!("[shield-identity] cache insert failed: {}", e);
387 return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Internal cache error</h1>");
388 }
389
390 {
392 let mut g = state.write().await;
393 g.challenges.remove(&cb_state);
394 }
395
396 info!(
397 "[shield-identity] verified provider={} subject={} loa={} scope={} rule={}",
398 vi.provider, vi.subject, vi.loa, inflight.requirement.scope, inflight.rule_id
399 );
400
401 let display = vi.email.unwrap_or_else(|| vi.subject.clone());
402 html(StatusCode::OK, &success_page(&display, &inflight.rule_id, &inflight.requirement.scope))
403}
404
405fn empty_body() -> ResponseBody {
410 Full::new(Bytes::new())
411}
412
413fn body_from_string(s: String) -> ResponseBody {
414 Full::new(Bytes::from(s))
415}
416
417fn text(status: StatusCode, body: &str) -> Response<ResponseBody> {
418 let mut resp = Response::new(body_from_string(body.to_string()));
419 *resp.status_mut() = status;
420 resp.headers_mut().insert(
421 hyper::header::CONTENT_TYPE,
422 hyper::header::HeaderValue::from_static("text/plain; charset=utf-8"),
423 );
424 resp
425}
426
427fn html(status: StatusCode, body: &str) -> Response<ResponseBody> {
428 let full = format!(
429 "<!doctype html><html><head><meta charset=\"utf-8\">\
430 <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">\
431 <title>Aperion Shield</title>\
432 <style>{css}</style></head><body><main class=\"card\">{body}</main></body></html>",
433 css = PAGE_CSS,
434 body = body,
435 );
436 let mut resp = Response::new(body_from_string(full));
437 *resp.status_mut() = status;
438 resp.headers_mut().insert(
439 hyper::header::CONTENT_TYPE,
440 hyper::header::HeaderValue::from_static("text/html; charset=utf-8"),
441 );
442 resp.headers_mut().insert(
443 hyper::header::CACHE_CONTROL,
444 hyper::header::HeaderValue::from_static("no-store"),
445 );
446 resp.headers_mut().insert(
447 hyper::header::HeaderName::from_static("content-security-policy"),
448 hyper::header::HeaderValue::from_static("default-src 'self'; script-src 'none'; style-src 'unsafe-inline'"),
449 );
450 resp
451}
452
453fn query_param(query: &str, key: &str) -> Option<String> {
454 for pair in query.split('&') {
455 let mut it = pair.splitn(2, '=');
456 let k = it.next()?;
457 let v = it.next().unwrap_or("");
458 if k == key {
459 return Some(percent_decode(v));
460 }
461 }
462 None
463}
464
465fn percent_decode(s: &str) -> String {
466 let bytes = s.as_bytes();
470 let mut out = Vec::with_capacity(bytes.len());
471 let mut i = 0;
472 while i < bytes.len() {
473 if bytes[i] == b'%' && i + 2 < bytes.len() {
474 if let (Some(h), Some(l)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
475 out.push((h << 4) | l);
476 i += 3;
477 continue;
478 }
479 }
480 if bytes[i] == b'+' {
481 out.push(b' ');
482 } else {
483 out.push(bytes[i]);
484 }
485 i += 1;
486 }
487 String::from_utf8_lossy(&out).into_owned()
488}
489
490fn hex_val(b: u8) -> Option<u8> {
491 match b {
492 b'0'..=b'9' => Some(b - b'0'),
493 b'a'..=b'f' => Some(b - b'a' + 10),
494 b'A'..=b'F' => Some(b - b'A' + 10),
495 _ => None,
496 }
497}
498
499fn html_escape(s: &str) -> String {
500 s.replace('&', "&").replace('<', "<").replace('>', ">")
501}
502
503fn success_page(who: &str, rule_id: &str, scope: &str) -> String {
504 format!(
505 "<h1>Identity verified</h1>\
506 <p class=\"who\"><strong>{who}</strong> -- verified for scope <code>{scope}</code>.</p>\
507 <p class=\"rule\">Rule: <code>{rule_id}</code></p>\
508 <p class=\"close\">You may close this tab and return to your editor.\
509 The pending tool call will be released automatically.</p>",
510 who = html_escape(who),
511 scope = html_escape(scope),
512 rule_id = html_escape(rule_id),
513 )
514}
515
516const INDEX_HTML: &str = "<h1>Aperion Shield</h1><p>This server handles identity verification callbacks. \
517Open a verification URL emitted by Shield to continue.</p>";
518
519const EXPIRED_HTML: &str = "<h1>Verification expired</h1><p>This challenge is no longer in flight. \
520Re-run the gated tool call to start a new verification.</p>";
521
522const PAGE_CSS: &str = "body { background: #0f172a; color: #e2e8f0; font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", sans-serif; margin: 0; padding: 40px 20px; }
523.card { max-width: 560px; margin: 0 auto; background: #1e293b; border: 1px solid #334155; border-radius: 12px; padding: 32px; box-shadow: 0 20px 60px rgba(0,0,0,0.5); }
524h1 { color: #6ee7b7; margin: 0 0 16px; font-size: 22px; }
525p { line-height: 1.55; margin: 8px 0; }
526.who strong { color: #fff; }
527code { background: #0f172a; color: #6ee7b7; padding: 2px 6px; border-radius: 4px; font-family: \"JetBrains Mono\", \"SF Mono\", monospace; font-size: 0.92em; }
528.close { color: #94a3b8; font-size: 0.92em; margin-top: 18px; }
529pre { background: #0f172a; padding: 12px; border-radius: 8px; overflow-x: auto; color: #fda4af; }";
530
531#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn percent_decode_basics() {
541 assert_eq!(percent_decode("hello%20world"), "hello world");
542 assert_eq!(percent_decode("a+b"), "a b");
543 assert_eq!(percent_decode("%21%40%23"), "!@#");
544 assert_eq!(percent_decode("no-encoded-chars"), "no-encoded-chars");
545 }
546
547 #[test]
548 fn query_param_finds_keys() {
549 assert_eq!(query_param("a=1&b=hi", "a"), Some("1".into()));
550 assert_eq!(query_param("a=1&b=hi", "b"), Some("hi".into()));
551 assert_eq!(query_param("a=1&b=hi", "c"), None);
552 assert_eq!(query_param("mock=1", "mock"), Some("1".into()));
553 assert_eq!(query_param("code=abc&state=xyz", "state"), Some("xyz".into()));
554 }
555
556 #[test]
557 fn html_escape_is_safe() {
558 assert_eq!(html_escape("<script>"), "<script>");
559 assert_eq!(html_escape("a&b"), "a&b");
560 }
561}