Skip to main content

aperion_shield/identity/
server.rs

1//! Local OAuth callback server.
2//!
3//! Bound to `127.0.0.1` on an OS-assigned port (unless the user pinned
4//! one in `identity.yaml`). Two routes:
5//!
6//!   * `GET /verify/<challenge_id>` -- the URL Shield hands the user.
7//!     For OAuth providers it 302-redirects to the provider's
8//!     authorize URL (which is what the provider's [`begin`] already
9//!     returned as `verify_url`); for the mock provider it short-
10//!     circuits straight to `/callback?code=mock&state=<id>&mock=1`.
11//!   * `GET /callback` -- the OAuth redirect target. Pulls the
12//!     in-flight challenge state, calls [`exchange`] on the matching
13//!     provider, mints + caches a [`Proof`], and renders a friendly
14//!     "you may return to your editor" page.
15//!
16//! Concurrency model: each Shield process has at most one callback
17//! server. It's started lazily on the first identity-required decision
18//! (via `IdentityGate::callback_base`). State -- in-flight challenges
19//! and the proof cache handle -- lives behind an `Arc<RwLock<...>>`
20//! shared with the request handlers.
21//!
22//! Security:
23//!   * Only binds to `127.0.0.1` (or whatever the config says, but we
24//!     hard-warn if anyone tries to bind a non-loopback address).
25//!   * Each callback validates `state` against an in-flight entry.
26//!   * Stale challenges (`expires_at <= now`) are reaped at lookup time.
27//!   * The HTML response includes a `cache-control: no-store` and a
28//!     CSP that disables script execution -- the page is static text.
29
30use std::collections::HashMap;
31use std::convert::Infallible;
32use std::net::SocketAddr;
33use std::sync::Arc;
34
35use bytes::Bytes;
36use http_body_util::Full;
37use hyper::body::Incoming;
38use hyper::server::conn::http1;
39use hyper::service::service_fn;
40use hyper::{Method, Request, Response, StatusCode};
41use hyper_util::rt::TokioIo;
42use log::{debug, error, info, warn};
43use tokio::sync::{oneshot, RwLock};
44
45/// Type alias for the response body we produce. hyper 1.x decoupled
46/// the body type from the framework, so we pick `Full<Bytes>` for our
47/// small static HTML/text responses -- it is the simplest body that
48/// implements `http_body::Body` and works with the http1 server.
49type ResponseBody = Full<Bytes>;
50
51use super::{Challenge, IdentityProvider, Requirement};
52use crate::identity::cache::ProofCache;
53use crate::identity::proof::ProofSigner;
54
55/// State the server holds for one in-flight verification.
56#[derive(Debug, Clone)]
57pub struct Inflight {
58    pub rule_id: String,
59    pub provider: String,
60    pub requirement: Requirement,
61}
62
63#[derive(Default)]
64struct State {
65    /// challenge_id -> (challenge, in-flight meta)
66    challenges: HashMap<String, (Challenge, Inflight)>,
67}
68
69/// Handle returned by [`CallbackServer::spawn`]. Owns the abort guard
70/// for the background task; dropping it shuts the server down.
71pub struct ServerHandle {
72    pub bound: SocketAddr,
73    pub base_url: String,
74    state: Arc<RwLock<State>>,
75    #[allow(dead_code)] // kept alive so the server shuts down on drop
76    shutdown_tx: Option<oneshot::Sender<()>>,
77}
78
79impl ServerHandle {
80    pub fn base_url(&self) -> String {
81        self.base_url.clone()
82    }
83
84    /// Register a freshly-minted challenge so the server knows what to
85    /// do when its `/callback` route is hit with this `state`.
86    pub async fn register(&self, challenge: Challenge, inflight: Inflight) {
87        let cid = challenge.challenge_id.clone();
88        let entry = (challenge, inflight);
89        let mut g = self.state.write().await;
90        g.challenges.insert(cid, entry);
91    }
92
93    /// For tests / diagnostics: how many challenges are currently
94    /// in-flight.
95    #[cfg(test)]
96    pub async fn inflight_count(&self) -> usize {
97        self.state.read().await.challenges.len()
98    }
99}
100
101pub struct CallbackServer;
102
103impl CallbackServer {
104    /// Spin up the callback server. Returns a handle whose `base_url`
105    /// points at the bound address.
106    pub async fn spawn(
107        host: &str,
108        port: u16,
109        providers: Vec<Arc<dyn IdentityProvider>>,
110        cache: Arc<ProofCache>,
111        signer: Arc<ProofSigner>,
112    ) -> anyhow::Result<ServerHandle> {
113        if host != "127.0.0.1" && host != "localhost" && host != "::1" {
114            warn!(
115                "[shield-identity] callback_host='{}' is NOT loopback -- this exposes the \
116                 OAuth callback to the network. Only override this if you know what you're doing.",
117                host,
118            );
119        }
120        let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
121        // Bind a std listener so we can read the OS-assigned port
122        // synchronously, then hand it off to hyper.
123        let listener = std::net::TcpListener::bind(addr)?;
124        listener.set_nonblocking(true)?;
125        let bound = listener.local_addr()?;
126        let base_url = format!("http://{}:{}", bound.ip(), bound.port());
127
128        let state: Arc<RwLock<State>> = Arc::new(RwLock::new(State::default()));
129        let providers: Arc<Vec<Arc<dyn IdentityProvider>>> = Arc::new(providers);
130
131        let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
132        // hyper 1.x is connection-oriented: we accept directly from a
133        // tokio TcpListener and serve each connection with
134        // `http1::Builder::new().serve_connection`. The per-connection
135        // service is built fresh each loop so it owns its own clones
136        // of the shared state.
137        let tokio_listener = tokio::net::TcpListener::from_std(listener)?;
138        let accept_state = state.clone();
139        let accept_providers = providers.clone();
140        let accept_cache = cache.clone();
141        let accept_signer = signer.clone();
142        tokio::spawn(async move {
143            loop {
144                tokio::select! {
145                    _ = &mut shutdown_rx => {
146                        debug!("[shield-identity] callback server shutting down");
147                        return;
148                    }
149                    accept = tokio_listener.accept() => {
150                        let (stream, peer) = match accept {
151                            Ok(p) => p,
152                            Err(e) => {
153                                error!("[shield-identity] accept failed: {}", e);
154                                continue;
155                            }
156                        };
157                        let io = TokioIo::new(stream);
158                        let svc_state = accept_state.clone();
159                        let svc_providers = accept_providers.clone();
160                        let svc_cache = accept_cache.clone();
161                        let svc_signer = accept_signer.clone();
162                        let svc = service_fn(move |req: Request<Incoming>| {
163                            let state = svc_state.clone();
164                            let providers = svc_providers.clone();
165                            let cache = svc_cache.clone();
166                            let signer = svc_signer.clone();
167                            async move {
168                                Ok::<_, Infallible>(handle(req, state, providers, cache, signer).await)
169                            }
170                        });
171                        tokio::spawn(async move {
172                            if let Err(e) = http1::Builder::new()
173                                .serve_connection(io, svc)
174                                .await
175                            {
176                                debug!(
177                                    "[shield-identity] connection from {} ended: {}",
178                                    peer, e
179                                );
180                            }
181                        });
182                    }
183                }
184            }
185        });
186
187        info!(
188            "[shield-identity] callback server listening on {} (base_url={})",
189            bound, base_url
190        );
191
192        Ok(ServerHandle {
193            bound,
194            base_url,
195            state,
196            shutdown_tx: Some(shutdown_tx),
197        })
198    }
199}
200
201// ────────────────────────────────────────────────────────────────────
202// Request handler
203// ────────────────────────────────────────────────────────────────────
204
205async fn handle(
206    req: Request<Incoming>,
207    state: Arc<RwLock<State>>,
208    providers: Arc<Vec<Arc<dyn IdentityProvider>>>,
209    cache: Arc<ProofCache>,
210    signer: Arc<ProofSigner>,
211) -> Response<ResponseBody> {
212    let method = req.method().clone();
213    let path = req.uri().path().to_string();
214    debug!("[shield-identity] {} {}", method, path);
215
216    // Routes:
217    //   GET /healthz
218    //   GET /verify/<id>?mock=1  -- mock-only short-circuit
219    //   GET /callback?code=...&state=...
220    //   GET /                    -- friendly index
221    if method != Method::GET {
222        return text(StatusCode::METHOD_NOT_ALLOWED, "GET only");
223    }
224
225    if path == "/healthz" {
226        return text(StatusCode::OK, "ok\n");
227    }
228    if path == "/" {
229        return html(StatusCode::OK, INDEX_HTML);
230    }
231
232    if let Some(rest) = path.strip_prefix("/verify/") {
233        let challenge_id = rest.trim_end_matches('/').to_string();
234        let query = req.uri().query().unwrap_or("").to_string();
235        return handle_verify(&challenge_id, &query, &state, &providers).await;
236    }
237
238    if path == "/callback" {
239        let query = req.uri().query().unwrap_or("").to_string();
240        return handle_callback(&query, &state, &providers, &cache, &signer).await;
241    }
242
243    text(StatusCode::NOT_FOUND, "not found")
244}
245
246async fn handle_verify(
247    challenge_id: &str,
248    query: &str,
249    state: &Arc<RwLock<State>>,
250    providers: &Arc<Vec<Arc<dyn IdentityProvider>>>,
251) -> Response<ResponseBody> {
252    let mock_flow = query_param(query, "mock").is_some();
253    let (challenge, inflight) = {
254        let g = state.read().await;
255        match g.challenges.get(challenge_id) {
256            Some((c, i)) => (c.clone(), i.clone()),
257            None => return html(StatusCode::NOT_FOUND, EXPIRED_HTML),
258        }
259    };
260
261    if mock_flow {
262        // Find the matching provider and run an in-process exchange,
263        // then redirect to /callback?code=...&state=... so the rest of
264        // the pipeline runs identically to the real OAuth flow.
265        let provider = providers.iter().find(|p| p.id() == inflight.provider).cloned();
266        let Some(p) = provider else {
267            return html(StatusCode::INTERNAL_SERVER_ERROR, "no matching provider");
268        };
269        match p.exchange(challenge_id, "synthetic-code", challenge_id, challenge.pkce_verifier.as_deref()).await {
270            Ok(_vi) => {
271                // We could mint+cache here, but to keep the
272                // happy-path identical between mock and real, redirect
273                // to /callback which will go through the same code.
274                let location = format!("/callback?code=synthetic-code&state={}&mock=1", challenge_id);
275                let mut resp = Response::new(empty_body());
276                *resp.status_mut() = StatusCode::FOUND;
277                resp.headers_mut().insert(
278                    hyper::header::LOCATION,
279                    hyper::header::HeaderValue::from_str(&location).unwrap(),
280                );
281                resp
282            }
283            Err(e) => html(
284                StatusCode::INTERNAL_SERVER_ERROR,
285                &format!("<h1>Mock exchange failed</h1><pre>{}</pre>", html_escape(&e.to_string())),
286            ),
287        }
288    } else {
289        // Real flow: just bounce the user to the provider's verify_url.
290        let mut resp = Response::new(empty_body());
291        *resp.status_mut() = StatusCode::FOUND;
292        if let Ok(hv) = hyper::header::HeaderValue::from_str(&challenge.verify_url) {
293            resp.headers_mut().insert(hyper::header::LOCATION, hv);
294        }
295        resp
296    }
297}
298
299async fn handle_callback(
300    query: &str,
301    state: &Arc<RwLock<State>>,
302    providers: &Arc<Vec<Arc<dyn IdentityProvider>>>,
303    cache: &Arc<ProofCache>,
304    signer: &Arc<ProofSigner>,
305) -> Response<ResponseBody> {
306    let code = match query_param(query, "code") {
307        Some(c) => c,
308        None => return html(StatusCode::BAD_REQUEST, "<h1>Missing OAuth code</h1>"),
309    };
310    let cb_state = match query_param(query, "state") {
311        Some(s) => s,
312        None => return html(StatusCode::BAD_REQUEST, "<h1>Missing OAuth state</h1>"),
313    };
314
315    let (challenge, inflight) = {
316        let g = state.read().await;
317        match g.challenges.get(&cb_state) {
318            Some((c, i)) => (c.clone(), i.clone()),
319            None => return html(StatusCode::NOT_FOUND, EXPIRED_HTML),
320        }
321    };
322
323    let provider = providers.iter().find(|p| p.id() == inflight.provider).cloned();
324    let Some(p) = provider else {
325        return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Provider gone</h1>");
326    };
327
328    let vi = match p
329        .exchange(&cb_state, &code, &cb_state, challenge.pkce_verifier.as_deref())
330        .await
331    {
332        Ok(v) => v,
333        Err(e) => {
334            warn!("[shield-identity] exchange failed: {}", e);
335            return html(
336                StatusCode::INTERNAL_SERVER_ERROR,
337                &format!("<h1>Verification failed</h1><pre>{}</pre>", html_escape(&e.to_string())),
338            );
339        }
340    };
341
342    if !inflight.requirement.allows(&vi) {
343        return html(
344            StatusCode::FORBIDDEN,
345            &format!(
346                "<h1>Verified, but not authorised</h1>\
347                 <p>Identity \"{}\" verified successfully, but is not on the allow-list for this rule.</p>",
348                html_escape(&vi.email.clone().unwrap_or_else(|| vi.subject.clone())),
349            ),
350        );
351    }
352    if vi.loa < inflight.requirement.loa {
353        return html(
354            StatusCode::FORBIDDEN,
355            &format!(
356                "<h1>Verified, but LOA too low</h1>\
357                 <p>This rule requires LOA &ge; {} but the provider reported LOA {}.</p>",
358                inflight.requirement.loa, vi.loa,
359            ),
360        );
361    }
362
363    // Mint + cache the signed proof. Outside the request handler this
364    // is exactly what the held tool call is polling on.
365    let now = super::unix_now();
366    let proof = super::Proof {
367        v: 1,
368        provider: vi.provider.clone(),
369        subject: vi.subject.clone(),
370        email: vi.email.clone(),
371        loa: vi.loa,
372        scope: inflight.requirement.scope.clone(),
373        verified_at: now,
374        expires_at: now.saturating_add(inflight.requirement.max_proof_age_seconds),
375        nonce: hex::encode(rand::random::<[u8; 16]>()),
376        sig: String::new(),
377    };
378    let signed = match signer.sign(proof) {
379        Ok(p) => p,
380        Err(e) => {
381            error!("[shield-identity] sign failed: {}", e);
382            return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Internal signing error</h1>");
383        }
384    };
385    if let Err(e) = cache.insert(signed) {
386        error!("[shield-identity] cache insert failed: {}", e);
387        return html(StatusCode::INTERNAL_SERVER_ERROR, "<h1>Internal cache error</h1>");
388    }
389
390    // Clean up the in-flight slot so an expired state can't be reused.
391    {
392        let mut g = state.write().await;
393        g.challenges.remove(&cb_state);
394    }
395
396    info!(
397        "[shield-identity] verified provider={} subject={} loa={} scope={} rule={}",
398        vi.provider, vi.subject, vi.loa, inflight.requirement.scope, inflight.rule_id
399    );
400
401    let display = vi.email.unwrap_or_else(|| vi.subject.clone());
402    html(StatusCode::OK, &success_page(&display, &inflight.rule_id, &inflight.requirement.scope))
403}
404
405// ────────────────────────────────────────────────────────────────────
406// Helpers
407// ────────────────────────────────────────────────────────────────────
408
409fn empty_body() -> ResponseBody {
410    Full::new(Bytes::new())
411}
412
413fn body_from_string(s: String) -> ResponseBody {
414    Full::new(Bytes::from(s))
415}
416
417fn text(status: StatusCode, body: &str) -> Response<ResponseBody> {
418    let mut resp = Response::new(body_from_string(body.to_string()));
419    *resp.status_mut() = status;
420    resp.headers_mut().insert(
421        hyper::header::CONTENT_TYPE,
422        hyper::header::HeaderValue::from_static("text/plain; charset=utf-8"),
423    );
424    resp
425}
426
427fn html(status: StatusCode, body: &str) -> Response<ResponseBody> {
428    let full = format!(
429        "<!doctype html><html><head><meta charset=\"utf-8\">\
430         <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">\
431         <title>Aperion Shield</title>\
432         <style>{css}</style></head><body><main class=\"card\">{body}</main></body></html>",
433        css = PAGE_CSS,
434        body = body,
435    );
436    let mut resp = Response::new(body_from_string(full));
437    *resp.status_mut() = status;
438    resp.headers_mut().insert(
439        hyper::header::CONTENT_TYPE,
440        hyper::header::HeaderValue::from_static("text/html; charset=utf-8"),
441    );
442    resp.headers_mut().insert(
443        hyper::header::CACHE_CONTROL,
444        hyper::header::HeaderValue::from_static("no-store"),
445    );
446    resp.headers_mut().insert(
447        hyper::header::HeaderName::from_static("content-security-policy"),
448        hyper::header::HeaderValue::from_static("default-src 'self'; script-src 'none'; style-src 'unsafe-inline'"),
449    );
450    resp
451}
452
453fn query_param(query: &str, key: &str) -> Option<String> {
454    for pair in query.split('&') {
455        let mut it = pair.splitn(2, '=');
456        let k = it.next()?;
457        let v = it.next().unwrap_or("");
458        if k == key {
459            return Some(percent_decode(v));
460        }
461    }
462    None
463}
464
465fn percent_decode(s: &str) -> String {
466    // Tiny percent-decoder -- avoids pulling in the percent-encoding
467    // crate just for this. Hex parsing is on the hot path of exactly
468    // zero requests per second so simplicity beats speed.
469    let bytes = s.as_bytes();
470    let mut out = Vec::with_capacity(bytes.len());
471    let mut i = 0;
472    while i < bytes.len() {
473        if bytes[i] == b'%' && i + 2 < bytes.len() {
474            if let (Some(h), Some(l)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
475                out.push((h << 4) | l);
476                i += 3;
477                continue;
478            }
479        }
480        if bytes[i] == b'+' {
481            out.push(b' ');
482        } else {
483            out.push(bytes[i]);
484        }
485        i += 1;
486    }
487    String::from_utf8_lossy(&out).into_owned()
488}
489
490fn hex_val(b: u8) -> Option<u8> {
491    match b {
492        b'0'..=b'9' => Some(b - b'0'),
493        b'a'..=b'f' => Some(b - b'a' + 10),
494        b'A'..=b'F' => Some(b - b'A' + 10),
495        _ => None,
496    }
497}
498
499fn html_escape(s: &str) -> String {
500    s.replace('&', "&amp;").replace('<', "&lt;").replace('>', "&gt;")
501}
502
503fn success_page(who: &str, rule_id: &str, scope: &str) -> String {
504    format!(
505        "<h1>Identity verified</h1>\
506         <p class=\"who\"><strong>{who}</strong> -- verified for scope <code>{scope}</code>.</p>\
507         <p class=\"rule\">Rule: <code>{rule_id}</code></p>\
508         <p class=\"close\">You may close this tab and return to your editor.\
509         The pending tool call will be released automatically.</p>",
510        who = html_escape(who),
511        scope = html_escape(scope),
512        rule_id = html_escape(rule_id),
513    )
514}
515
516const INDEX_HTML: &str = "<h1>Aperion Shield</h1><p>This server handles identity verification callbacks. \
517Open a verification URL emitted by Shield to continue.</p>";
518
519const EXPIRED_HTML: &str = "<h1>Verification expired</h1><p>This challenge is no longer in flight. \
520Re-run the gated tool call to start a new verification.</p>";
521
522const PAGE_CSS: &str = "body { background: #0f172a; color: #e2e8f0; font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", sans-serif; margin: 0; padding: 40px 20px; }
523.card { max-width: 560px; margin: 0 auto; background: #1e293b; border: 1px solid #334155; border-radius: 12px; padding: 32px; box-shadow: 0 20px 60px rgba(0,0,0,0.5); }
524h1 { color: #6ee7b7; margin: 0 0 16px; font-size: 22px; }
525p { line-height: 1.55; margin: 8px 0; }
526.who strong { color: #fff; }
527code { background: #0f172a; color: #6ee7b7; padding: 2px 6px; border-radius: 4px; font-family: \"JetBrains Mono\", \"SF Mono\", monospace; font-size: 0.92em; }
528.close { color: #94a3b8; font-size: 0.92em; margin-top: 18px; }
529pre { background: #0f172a; padding: 12px; border-radius: 8px; overflow-x: auto; color: #fda4af; }";
530
531// ────────────────────────────────────────────────────────────────────
532// Tests
533// ────────────────────────────────────────────────────────────────────
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn percent_decode_basics() {
541        assert_eq!(percent_decode("hello%20world"), "hello world");
542        assert_eq!(percent_decode("a+b"), "a b");
543        assert_eq!(percent_decode("%21%40%23"), "!@#");
544        assert_eq!(percent_decode("no-encoded-chars"), "no-encoded-chars");
545    }
546
547    #[test]
548    fn query_param_finds_keys() {
549        assert_eq!(query_param("a=1&b=hi", "a"), Some("1".into()));
550        assert_eq!(query_param("a=1&b=hi", "b"), Some("hi".into()));
551        assert_eq!(query_param("a=1&b=hi", "c"), None);
552        assert_eq!(query_param("mock=1", "mock"), Some("1".into()));
553        assert_eq!(query_param("code=abc&state=xyz", "state"), Some("xyz".into()));
554    }
555
556    #[test]
557    fn html_escape_is_safe() {
558        assert_eq!(html_escape("<script>"), "&lt;script&gt;");
559        assert_eq!(html_escape("a&b"), "a&amp;b");
560    }
561}