Skip to main content

harn_vm/
mcp_bulk_auth.rs

1//! Bulk MCP OAuth driver (harn#3355) — the keystone of the bulk-login program
2//! (harn#3354).
3//!
4//! Authenticating many OAuth-backed MCP servers today is one-at-a-time by
5//! construction. This driver orchestrates **all** pending flows at once on top
6//! of the unchanged per-server engine in [`crate::mcp_oauth`], and publishes a
7//! **per-server status event stream** so every surface (CLI `mcp login --all`,
8//! ACP `mcp/authorize_batch`, the burin "Connect all" GUI) renders incremental
9//! progress from one source. No protocol changes — bulk auth is N independent,
10//! spec-compliant flows that share one driver and one loopback listener
11//! (demuxed by the OAuth `state`).
12//!
13//! Division of labour: the driver owns flow orchestration + status; the
14//! **surface** owns IO (opening browsers, the loopback listener, terminal/GUI
15//! rendering). [`McpBulkAuth::prepare`] returns the authorize URLs to open
16//! (keyed by `state`); the surface opens them (serialized, to avoid a popup
17//! storm) and feeds each captured `{ state, code }` back via
18//! [`McpBulkAuth::complete`].
19//!
20//! The OAuth operations are abstracted behind [`OAuthFlowEngine`] so the driver
21//! is unit-testable end-to-end against a mock — no network, no browser. The
22//! real engine ([`RealOAuthFlowEngine`]) simply delegates to `mcp_oauth`.
23
24use std::collections::HashMap;
25use std::sync::{Arc, Mutex};
26use std::time::Duration;
27
28use async_trait::async_trait;
29use futures::stream::StreamExt;
30use serde::{Deserialize, Serialize};
31use tokio::sync::broadcast;
32
33use crate::mcp_auth::{canonical_resource_indicator, OAuthClientAuthMode};
34use crate::mcp_oauth::{self, BeginAuthorization, PendingAuthorization, StoredMcpToken};
35
36/// Broadcast backlog for the status stream. Comfortably exceeds the number of
37/// servers a workspace realistically declares; a slow subscriber that lags
38/// past this loses old events (it can re-query `mcp status`), never the driver.
39const STATUS_CHANNEL_CAPACITY: usize = 256;
40
41/// Which servers a bulk pass should authenticate.
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum BulkAuthMode {
44    /// First-auth: begin a flow only for servers without a currently-valid
45    /// bearer; skip already-connected ones. (`mcp login --all`.)
46    Missing,
47    /// Re-auth: begin a flow only for servers that *have* a stored token which
48    /// is no longer valid (expired/revoked); skip valid tokens and servers with
49    /// no token at all. (`mcp.reauth_expired()` / harn#3358.)
50    Expired,
51    /// Force a fresh flow for every server regardless of current state.
52    All,
53}
54
55/// One server to (maybe) authenticate. Mirrors the per-server inputs the
56/// engine's [`BeginAuthorization`] needs, plus a display `name` for status.
57#[derive(Clone, Debug, Default)]
58pub struct BulkAuthServer {
59    pub name: String,
60    pub server_url: String,
61    pub mode: Option<OAuthClientAuthMode>,
62    pub client_id: Option<String>,
63    pub client_secret: Option<String>,
64    pub static_secret_id: Option<String>,
65    pub scopes: Option<String>,
66}
67
68/// A flow that needs the surface to open a browser. The surface opens
69/// `authorize_url`, captures the redirect, and calls [`McpBulkAuth::complete`]
70/// with the `state` echoed back.
71#[derive(Clone, Debug, Serialize)]
72#[serde(rename_all = "camelCase")]
73pub struct PreparedFlow {
74    pub name: String,
75    pub server_url: String,
76    pub authorize_url: String,
77    pub state: String,
78    pub redirect_uri: String,
79}
80
81/// The result of preparing one server.
82#[derive(Clone, Debug)]
83pub enum PrepareOutcome {
84    /// Needs a browser consent; open `authorize_url`.
85    Pending(PreparedFlow),
86    /// Nothing to do (already connected, or nothing to re-auth).
87    Skipped {
88        name: String,
89        server_url: String,
90        reason: String,
91    },
92    /// Could not begin a flow (discovery/registration/timeout failure).
93    Failed {
94        name: String,
95        server_url: String,
96        error: String,
97    },
98}
99
100/// Phase of one server's bulk-auth lifecycle, streamed to subscribers.
101#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum McpAuthPhase {
104    /// Resolving the authorization server + client (discovery/registration).
105    Discovering,
106    /// Authorize URL minted; waiting for the user to consent in the browser.
107    AwaitingConsent,
108    /// Exchanging the authorization code for a token.
109    Exchanging,
110    /// Token stored — the server is connected.
111    Connected,
112    /// This server failed; `detail` carries a redacted reason.
113    Failed,
114    /// This server was skipped (e.g. already connected); `detail` says why.
115    Skipped,
116}
117
118/// One per-server status event on the stream.
119#[derive(Clone, Debug, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub struct McpAuthStatus {
122    /// Display name of the server.
123    pub server: String,
124    /// Server URL (empty when unknown, e.g. a late callback).
125    pub server_url: String,
126    /// Current phase.
127    pub phase: McpAuthPhase,
128    /// Human-readable note or redacted error, when relevant.
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub detail: Option<String>,
131}
132
133/// Driver tunables. **Data, not code:** loaded via [`BulkAuthConfig::load`]
134/// from `HARN_MCP_BULK_AUTH_CONFIG` or `~/.config/harn/mcp_bulk_auth.toml`
135/// (a `[bulk_auth]` table), so concurrency/timeout change without a recompile.
136#[derive(Clone, Copy, Debug, Deserialize)]
137#[serde(default)]
138pub struct BulkAuthConfig {
139    /// Max concurrent `begin_authorization` flows during prepare.
140    pub concurrency: usize,
141    /// Per-server budget for discovery + authorize-URL minting.
142    pub prepare_timeout_secs: u64,
143}
144
145impl Default for BulkAuthConfig {
146    fn default() -> Self {
147        Self {
148            concurrency: 8,
149            prepare_timeout_secs: 30,
150        }
151    }
152}
153
154#[derive(Debug, Default, Deserialize)]
155struct BulkAuthConfigFile {
156    #[serde(default)]
157    bulk_auth: BulkAuthConfig,
158}
159
160impl BulkAuthConfig {
161    /// Resolve the effective config: the `HARN_MCP_BULK_AUTH_CONFIG` path wins,
162    /// else `~/.config/harn/mcp_bulk_auth.toml`, else defaults. Skipped under
163    /// `cfg(test)` so unit tests are deterministic.
164    pub fn load() -> Self {
165        if let Ok(path) = std::env::var("HARN_MCP_BULK_AUTH_CONFIG") {
166            if let Some(config) = Self::read(&path) {
167                return config;
168            }
169        }
170        if !cfg!(test) {
171            if let Some(home) = crate::user_dirs::home_dir() {
172                let path = home.join(".config").join("harn").join("mcp_bulk_auth.toml");
173                if let Some(config) = Self::read(&path.to_string_lossy()) {
174                    return config;
175                }
176            }
177        }
178        Self::default()
179    }
180
181    fn read(path: &str) -> Option<Self> {
182        let content = std::fs::read_to_string(path).ok()?;
183        match toml::from_str::<BulkAuthConfigFile>(&content) {
184            Ok(file) => Some(file.bulk_auth),
185            Err(error) => {
186                eprintln!("[mcp_bulk_auth] TOML parse error in {path}: {error}");
187                None
188            }
189        }
190    }
191}
192
193/// The OAuth operations the driver needs, abstracted so it can be exercised
194/// against a mock. The real impl delegates to [`crate::mcp_oauth`].
195#[async_trait]
196pub trait OAuthFlowEngine: Send + Sync {
197    /// A currently-valid bearer for the server (refreshing if needed), or
198    /// `None` when none is stored / the stored one cannot be made valid.
199    async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String>;
200    /// Whether *any* token is stored for the server (valid or not).
201    async fn has_token(&self, server_url: &str) -> Result<bool, String>;
202    /// Begin an authorization, returning the authorize URL + `state`.
203    async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String>;
204    /// Complete a flow by `state`, exchanging the code and persisting the token.
205    async fn complete(
206        &self,
207        state: &str,
208        code: &str,
209        issuer: Option<&str>,
210    ) -> Result<StoredMcpToken, String>;
211}
212
213/// Production [`OAuthFlowEngine`] — a thin delegation to `mcp_oauth`.
214#[derive(Clone, Copy, Debug, Default)]
215pub struct RealOAuthFlowEngine;
216
217#[async_trait]
218impl OAuthFlowEngine for RealOAuthFlowEngine {
219    async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String> {
220        mcp_oauth::resolve_bearer(server_url).await
221    }
222
223    async fn has_token(&self, server_url: &str) -> Result<bool, String> {
224        let discovery = mcp_oauth::discover(server_url).await?;
225        let resource =
226            canonical_resource_indicator(server_url).map_err(|error| error.to_string())?;
227        Ok(
228            mcp_oauth::load_token(&resource, &discovery.authorization_server_issuer, None)
229                .await?
230                .is_some(),
231        )
232    }
233
234    async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String> {
235        mcp_oauth::begin_authorization(request).await
236    }
237
238    async fn complete(
239        &self,
240        state: &str,
241        code: &str,
242        issuer: Option<&str>,
243    ) -> Result<StoredMcpToken, String> {
244        mcp_oauth::complete_authorization(state, code, issuer).await
245    }
246}
247
248/// Name + URL remembered for a pending `state` so [`McpBulkAuth::complete`] can
249/// label its status events without the surface re-supplying them.
250#[derive(Clone, Debug)]
251struct FlowMeta {
252    name: String,
253    server_url: String,
254}
255
256/// Bulk OAuth orchestrator. Construct once, [`subscribe`](Self::subscribe) for
257/// the status stream, [`prepare`](Self::prepare) to begin all pending flows,
258/// then [`complete`](Self::complete) each captured callback.
259pub struct McpBulkAuth<E: OAuthFlowEngine = RealOAuthFlowEngine> {
260    engine: Arc<E>,
261    config: BulkAuthConfig,
262    status_tx: broadcast::Sender<McpAuthStatus>,
263    pending: Arc<Mutex<HashMap<String, FlowMeta>>>,
264}
265
266impl McpBulkAuth<RealOAuthFlowEngine> {
267    /// Construct the driver backed by the real `mcp_oauth` engine, with config
268    /// resolved from the overlay.
269    pub fn new() -> Self {
270        Self::with_engine(RealOAuthFlowEngine, BulkAuthConfig::load())
271    }
272}
273
274impl Default for McpBulkAuth<RealOAuthFlowEngine> {
275    fn default() -> Self {
276        Self::new()
277    }
278}
279
280impl<E: OAuthFlowEngine> McpBulkAuth<E> {
281    /// Construct with an explicit engine + config (the test/injection seam).
282    pub fn with_engine(engine: E, config: BulkAuthConfig) -> Self {
283        let (status_tx, _rx) = broadcast::channel(STATUS_CHANNEL_CAPACITY);
284        Self {
285            engine: Arc::new(engine),
286            config,
287            status_tx,
288            pending: Arc::new(Mutex::new(HashMap::new())),
289        }
290    }
291
292    /// Subscribe to the per-server status stream. Subscribe *before* calling
293    /// [`prepare`](Self::prepare) to observe every event.
294    pub fn subscribe(&self) -> broadcast::Receiver<McpAuthStatus> {
295        self.status_tx.subscribe()
296    }
297
298    /// Begin flows for all servers selected by `mode`, concurrently (bounded by
299    /// the configured concurrency). All flows share `redirect_uri` — the surface
300    /// binds one loopback listener and demuxes callbacks by `state`. Returns one
301    /// [`PrepareOutcome`] per input server; emits status events throughout.
302    pub async fn prepare(
303        &self,
304        servers: Vec<BulkAuthServer>,
305        mode: BulkAuthMode,
306        redirect_uri: &str,
307    ) -> Vec<PrepareOutcome> {
308        let concurrency = self.config.concurrency.max(1);
309        let timeout = Duration::from_secs(self.config.prepare_timeout_secs.max(1));
310        futures::stream::iter(servers.into_iter().map(|server| {
311            let engine = self.engine.clone();
312            let status_tx = self.status_tx.clone();
313            let pending = self.pending.clone();
314            let redirect_uri = redirect_uri.to_string();
315            async move {
316                prepare_one(
317                    engine,
318                    status_tx,
319                    pending,
320                    server,
321                    mode,
322                    redirect_uri,
323                    timeout,
324                )
325                .await
326            }
327        }))
328        .buffer_unordered(concurrency)
329        .collect::<Vec<_>>()
330        .await
331    }
332
333    /// Complete a flow whose callback the surface captured. Looks up the
334    /// server's display name by `state`, emits `Exchanging` then
335    /// `Connected`/`Failed`, and returns the stored token on success.
336    pub async fn complete(
337        &self,
338        state: &str,
339        code: &str,
340        issuer: Option<&str>,
341    ) -> Result<StoredMcpToken, String> {
342        let meta = self
343            .pending
344            .lock()
345            .unwrap_or_else(|poison| poison.into_inner())
346            .get(state)
347            .cloned();
348        let (name, server_url) = match meta {
349            Some(meta) => (meta.name, meta.server_url),
350            None => ("<unknown>".to_string(), String::new()),
351        };
352        emit(
353            &self.status_tx,
354            &name,
355            &server_url,
356            McpAuthPhase::Exchanging,
357            None,
358        );
359        match self.engine.complete(state, code, issuer).await {
360            Ok(token) => {
361                self.pending
362                    .lock()
363                    .unwrap_or_else(|poison| poison.into_inner())
364                    .remove(state);
365                emit(
366                    &self.status_tx,
367                    &name,
368                    &server_url,
369                    McpAuthPhase::Connected,
370                    None,
371                );
372                Ok(token)
373            }
374            Err(error) => {
375                emit(
376                    &self.status_tx,
377                    &name,
378                    &server_url,
379                    McpAuthPhase::Failed,
380                    Some(error.clone()),
381                );
382                Err(error)
383            }
384        }
385    }
386
387    /// Number of flows still awaiting a callback (begun but not completed).
388    pub fn pending_count(&self) -> usize {
389        self.pending
390            .lock()
391            .unwrap_or_else(|poison| poison.into_inner())
392            .len()
393    }
394
395    /// Whether `state` belongs to a flow this driver began (and has not yet
396    /// completed). A surface that multiplexes single-URL and batch callbacks on
397    /// one channel uses this to route a captured `{ state, code }` to the driver
398    /// (so completion streams `Exchanging`/`Connected`) only when the state is
399    /// one of ours — otherwise it falls back to its per-flow path unchanged.
400    pub fn knows_state(&self, state: &str) -> bool {
401        self.pending
402            .lock()
403            .unwrap_or_else(|poison| poison.into_inner())
404            .contains_key(state)
405    }
406}
407
408/// Begin (or skip) one server, emitting its phase events. Free function so the
409/// per-server future owns only cloned handles (no `&self` borrow across the
410/// concurrent stream).
411async fn prepare_one<E: OAuthFlowEngine>(
412    engine: Arc<E>,
413    status_tx: broadcast::Sender<McpAuthStatus>,
414    pending: Arc<Mutex<HashMap<String, FlowMeta>>>,
415    server: BulkAuthServer,
416    mode: BulkAuthMode,
417    redirect_uri: String,
418    timeout: Duration,
419) -> PrepareOutcome {
420    emit(
421        &status_tx,
422        &server.name,
423        &server.server_url,
424        McpAuthPhase::Discovering,
425        None,
426    );
427
428    match tokio::time::timeout(timeout, decide(&*engine, &server, mode)).await {
429        Ok(AuthDecision::Begin) => {}
430        Ok(AuthDecision::Skip(reason)) => {
431            emit(
432                &status_tx,
433                &server.name,
434                &server.server_url,
435                McpAuthPhase::Skipped,
436                Some(reason.to_string()),
437            );
438            return PrepareOutcome::Skipped {
439                name: server.name,
440                server_url: server.server_url,
441                reason: reason.to_string(),
442            };
443        }
444        Err(_) => {
445            return fail(
446                &status_tx,
447                server,
448                "timed out resolving authorization server",
449            );
450        }
451    }
452
453    let request = BeginAuthorization {
454        server_url: server.server_url.clone(),
455        redirect_uri: redirect_uri.clone(),
456        mode: server.mode,
457        client_id: server.client_id.clone(),
458        client_secret: server.client_secret.clone(),
459        static_secret_id: server.static_secret_id.clone(),
460        scopes: server.scopes.clone(),
461    };
462    match tokio::time::timeout(timeout, engine.begin(request)).await {
463        Ok(Ok(pending_auth)) => {
464            pending
465                .lock()
466                .unwrap_or_else(|poison| poison.into_inner())
467                .insert(
468                    pending_auth.state.clone(),
469                    FlowMeta {
470                        name: server.name.clone(),
471                        server_url: server.server_url.clone(),
472                    },
473                );
474            emit(
475                &status_tx,
476                &server.name,
477                &server.server_url,
478                McpAuthPhase::AwaitingConsent,
479                None,
480            );
481            PrepareOutcome::Pending(PreparedFlow {
482                name: server.name,
483                server_url: server.server_url,
484                authorize_url: pending_auth.authorize_url,
485                state: pending_auth.state,
486                redirect_uri,
487            })
488        }
489        Ok(Err(error)) => fail(&status_tx, server, &error),
490        Err(_) => fail(&status_tx, server, "timed out minting authorization URL"),
491    }
492}
493
494/// Outcome of the per-mode "does this server need a flow?" decision.
495enum AuthDecision {
496    Begin,
497    Skip(&'static str),
498}
499
500async fn decide<E: OAuthFlowEngine>(
501    engine: &E,
502    server: &BulkAuthServer,
503    mode: BulkAuthMode,
504) -> AuthDecision {
505    match mode {
506        BulkAuthMode::All => AuthDecision::Begin,
507        BulkAuthMode::Missing => match engine.current_bearer(&server.server_url).await {
508            Ok(Some(_)) => AuthDecision::Skip("already connected"),
509            // None or error → not currently usable; begin (begin surfaces any
510            // real discovery error as a Failed outcome).
511            _ => AuthDecision::Begin,
512        },
513        BulkAuthMode::Expired => {
514            // Only re-auth servers that have a token which is no longer valid.
515            match engine.has_token(&server.server_url).await {
516                Ok(false) => return AuthDecision::Skip("no stored token"),
517                Ok(true) => {}
518                Err(_) => return AuthDecision::Skip("no stored token"),
519            }
520            match engine.current_bearer(&server.server_url).await {
521                Ok(Some(_)) => AuthDecision::Skip("token still valid"),
522                _ => AuthDecision::Begin,
523            }
524        }
525    }
526}
527
528fn fail(
529    status_tx: &broadcast::Sender<McpAuthStatus>,
530    server: BulkAuthServer,
531    error: &str,
532) -> PrepareOutcome {
533    emit(
534        status_tx,
535        &server.name,
536        &server.server_url,
537        McpAuthPhase::Failed,
538        Some(error.to_string()),
539    );
540    PrepareOutcome::Failed {
541        name: server.name,
542        server_url: server.server_url,
543        error: error.to_string(),
544    }
545}
546
547fn emit(
548    status_tx: &broadcast::Sender<McpAuthStatus>,
549    server: &str,
550    server_url: &str,
551    phase: McpAuthPhase,
552    detail: Option<String>,
553) {
554    // A send with no live receivers is fine — status is best-effort telemetry.
555    let _ = status_tx.send(McpAuthStatus {
556        server: server.to_string(),
557        server_url: server_url.to_string(),
558        phase,
559        detail,
560    });
561}
562
563/// Canonical snake_case JSON for one prepare outcome — the "outcome as data"
564/// shape consumed by script/CLI surfaces (e.g. the `mcp.reauth_expired()`
565/// builtin, harn#3358). A `Pending` flow needs interactive consent, so it is
566/// reported as `reauth_required` with its `authorize_url`/`state`; `Skipped`
567/// and `Failed` carry their reason/error. (The ACP wire uses its own camelCase
568/// grouping — different consumer, different convention.)
569pub fn prepare_outcome_to_json(outcome: &PrepareOutcome) -> serde_json::Value {
570    match outcome {
571        PrepareOutcome::Pending(flow) => serde_json::json!({
572            "server": flow.name,
573            "server_url": flow.server_url,
574            "status": "reauth_required",
575            "authorize_url": flow.authorize_url,
576            "state": flow.state,
577        }),
578        PrepareOutcome::Skipped {
579            name,
580            server_url,
581            reason,
582        } => serde_json::json!({
583            "server": name,
584            "server_url": server_url,
585            "status": "skipped",
586            "reason": reason,
587        }),
588        PrepareOutcome::Failed {
589            name,
590            server_url,
591            error,
592        } => serde_json::json!({
593            "server": name,
594            "server_url": server_url,
595            "status": "failed",
596            "error": error,
597        }),
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use std::sync::atomic::{AtomicUsize, Ordering};
605
606    /// Scripted mock engine: per-URL valid/stored-token state + a begin/complete
607    /// counter, so tests assert phase sequences and mode filtering with no
608    /// network or browser.
609    #[derive(Default)]
610    struct MockEngine {
611        /// URLs that currently have a valid bearer.
612        valid: Vec<String>,
613        /// URLs that have a stored token (valid or not).
614        stored: Vec<String>,
615        /// URLs whose `begin` should fail.
616        begin_fails: Vec<String>,
617        begin_calls: AtomicUsize,
618        state_counter: AtomicUsize,
619    }
620
621    #[async_trait]
622    impl OAuthFlowEngine for MockEngine {
623        async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String> {
624            Ok(self
625                .valid
626                .iter()
627                .any(|u| u == server_url)
628                .then(|| "bearer".to_string()))
629        }
630        async fn has_token(&self, server_url: &str) -> Result<bool, String> {
631            Ok(self.stored.iter().any(|u| u == server_url))
632        }
633        async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String> {
634            self.begin_calls.fetch_add(1, Ordering::SeqCst);
635            if self.begin_fails.contains(&request.server_url) {
636                return Err("discovery exploded".to_string());
637            }
638            let n = self.state_counter.fetch_add(1, Ordering::SeqCst);
639            let state = format!("state-{n}");
640            Ok(PendingAuthorization {
641                authorize_url: format!("https://auth.example/authorize?state={state}"),
642                state,
643                redirect_uri: request.redirect_uri,
644                resource: request.server_url,
645                issuer: "https://auth.example".to_string(),
646            })
647        }
648        async fn complete(
649            &self,
650            state: &str,
651            _code: &str,
652            _issuer: Option<&str>,
653        ) -> Result<StoredMcpToken, String> {
654            if state == "bad-state" {
655                return Err("token exchange failed".to_string());
656            }
657            Ok(StoredMcpToken {
658                access_token: "access".to_string(),
659                refresh_token: None,
660                expires_at_unix: None,
661                token_endpoint: "https://auth.example/token".to_string(),
662                client_id: "client".to_string(),
663                client_secret: None,
664                token_endpoint_auth_method: "none".to_string(),
665                issuer: "https://auth.example".to_string(),
666                resource: "https://mcp.example/mcp".to_string(),
667                scopes: None,
668                token_response_extra: None,
669            })
670        }
671    }
672
673    fn server(name: &str, url: &str) -> BulkAuthServer {
674        BulkAuthServer {
675            name: name.to_string(),
676            server_url: url.to_string(),
677            ..Default::default()
678        }
679    }
680
681    fn driver(engine: MockEngine) -> McpBulkAuth<MockEngine> {
682        // concurrency 1 keeps the mock's `state-N` assignment deterministic.
683        McpBulkAuth::with_engine(
684            engine,
685            BulkAuthConfig {
686                concurrency: 1,
687                prepare_timeout_secs: 5,
688            },
689        )
690    }
691
692    async fn drain(rx: &mut broadcast::Receiver<McpAuthStatus>) -> Vec<McpAuthStatus> {
693        let mut out = Vec::new();
694        while let Ok(status) = rx.try_recv() {
695            out.push(status);
696        }
697        out
698    }
699
700    fn phases(events: &[McpAuthStatus], server: &str) -> Vec<McpAuthPhase> {
701        events
702            .iter()
703            .filter(|e| e.server == server)
704            .map(|e| e.phase)
705            .collect()
706    }
707
708    #[tokio::test]
709    async fn prepares_all_servers_and_emits_phase_sequence() {
710        let driver = driver(MockEngine::default());
711        let mut rx = driver.subscribe();
712        let outcomes = driver
713            .prepare(
714                vec![
715                    server("a", "https://a.example/mcp"),
716                    server("b", "https://b.example/mcp"),
717                    server("c", "https://c.example/mcp"),
718                ],
719                BulkAuthMode::All,
720                "http://127.0.0.1:9783/callback",
721            )
722            .await;
723
724        assert_eq!(outcomes.len(), 3);
725        assert!(outcomes
726            .iter()
727            .all(|o| matches!(o, PrepareOutcome::Pending(_))));
728        assert_eq!(driver.pending_count(), 3);
729
730        let events = drain(&mut rx).await;
731        for name in ["a", "b", "c"] {
732            assert_eq!(
733                phases(&events, name),
734                vec![McpAuthPhase::Discovering, McpAuthPhase::AwaitingConsent],
735                "server {name}"
736            );
737        }
738    }
739
740    #[tokio::test]
741    async fn missing_mode_skips_connected_servers() {
742        let engine = MockEngine {
743            valid: vec!["https://b.example/mcp".to_string()],
744            ..Default::default()
745        };
746        let driver = driver(engine);
747        let outcomes = driver
748            .prepare(
749                vec![
750                    server("a", "https://a.example/mcp"),
751                    server("b", "https://b.example/mcp"),
752                ],
753                BulkAuthMode::Missing,
754                "http://127.0.0.1:9783/callback",
755            )
756            .await;
757
758        let a = outcomes.iter().find(|o| outcome_name(o) == "a").unwrap();
759        let b = outcomes.iter().find(|o| outcome_name(o) == "b").unwrap();
760        assert!(matches!(a, PrepareOutcome::Pending(_)));
761        assert!(
762            matches!(b, PrepareOutcome::Skipped { reason, .. } if reason == "already connected")
763        );
764    }
765
766    #[tokio::test]
767    async fn expired_mode_only_reauths_stale_stored_tokens() {
768        let engine = MockEngine {
769            // "stale" has a stored-but-invalid token, "fresh" is valid, "none"
770            // has nothing stored.
771            valid: vec!["https://fresh.example/mcp".to_string()],
772            stored: vec![
773                "https://stale.example/mcp".to_string(),
774                "https://fresh.example/mcp".to_string(),
775            ],
776            ..Default::default()
777        };
778        let driver = driver(engine);
779        let outcomes = driver
780            .prepare(
781                vec![
782                    server("stale", "https://stale.example/mcp"),
783                    server("fresh", "https://fresh.example/mcp"),
784                    server("none", "https://none.example/mcp"),
785                ],
786                BulkAuthMode::Expired,
787                "http://127.0.0.1:9783/callback",
788            )
789            .await;
790
791        let stale = outcomes
792            .iter()
793            .find(|o| outcome_name(o) == "stale")
794            .unwrap();
795        let fresh = outcomes
796            .iter()
797            .find(|o| outcome_name(o) == "fresh")
798            .unwrap();
799        let none = outcomes.iter().find(|o| outcome_name(o) == "none").unwrap();
800        assert!(
801            matches!(stale, PrepareOutcome::Pending(_)),
802            "stale → re-auth"
803        );
804        assert!(
805            matches!(fresh, PrepareOutcome::Skipped { reason, .. } if reason == "token still valid")
806        );
807        assert!(
808            matches!(none, PrepareOutcome::Skipped { reason, .. } if reason == "no stored token")
809        );
810    }
811
812    #[tokio::test]
813    async fn reauth_expired_outcomes_as_json_drive_only_stale() {
814        // The `mcp.reauth_expired()` (harn#3358) acceptance: two servers with
815        // expired (stored-but-invalid) tokens + one valid → exactly the two are
816        // driven to re-auth and the valid one is left untouched (Skipped).
817        let engine = MockEngine {
818            valid: vec!["https://fresh.example/mcp".to_string()],
819            stored: vec![
820                "https://stale1.example/mcp".to_string(),
821                "https://stale2.example/mcp".to_string(),
822                "https://fresh.example/mcp".to_string(),
823            ],
824            ..Default::default()
825        };
826        let driver = driver(engine);
827        let outcomes = driver
828            .prepare(
829                vec![
830                    server("stale1", "https://stale1.example/mcp"),
831                    server("stale2", "https://stale2.example/mcp"),
832                    server("fresh", "https://fresh.example/mcp"),
833                ],
834                BulkAuthMode::Expired,
835                "http://127.0.0.1:9783/callback",
836            )
837            .await;
838
839        let json: Vec<serde_json::Value> = outcomes.iter().map(prepare_outcome_to_json).collect();
840        let by_server = |name: &str| {
841            json.iter()
842                .find(|value| value["server"] == name)
843                .cloned()
844                .unwrap()
845        };
846
847        let reauthed: Vec<_> = json
848            .iter()
849            .filter(|value| value["status"] == "reauth_required")
850            .collect();
851        assert_eq!(
852            reauthed.len(),
853            2,
854            "exactly the two stale servers are driven"
855        );
856
857        let stale1 = by_server("stale1");
858        assert_eq!(stale1["status"], "reauth_required");
859        assert!(
860            stale1["authorize_url"].as_str().is_some(),
861            "a re-auth outcome carries an authorize_url for the caller to open"
862        );
863        assert_eq!(by_server("stale2")["status"], "reauth_required");
864
865        let fresh = by_server("fresh");
866        assert_eq!(fresh["status"], "skipped");
867        assert_eq!(fresh["reason"], "token still valid");
868    }
869
870    #[tokio::test]
871    async fn one_servers_failure_is_isolated() {
872        let engine = MockEngine {
873            begin_fails: vec!["https://b.example/mcp".to_string()],
874            ..Default::default()
875        };
876        let driver = driver(engine);
877        let mut rx = driver.subscribe();
878        let outcomes = driver
879            .prepare(
880                vec![
881                    server("a", "https://a.example/mcp"),
882                    server("b", "https://b.example/mcp"),
883                    server("c", "https://c.example/mcp"),
884                ],
885                BulkAuthMode::All,
886                "http://127.0.0.1:9783/callback",
887            )
888            .await;
889
890        let b = outcomes.iter().find(|o| outcome_name(o) == "b").unwrap();
891        assert!(matches!(b, PrepareOutcome::Failed { error, .. } if error.contains("discovery")));
892        // a and c still succeeded.
893        assert_eq!(
894            outcomes
895                .iter()
896                .filter(|o| matches!(o, PrepareOutcome::Pending(_)))
897                .count(),
898            2
899        );
900        let events = drain(&mut rx).await;
901        assert_eq!(
902            phases(&events, "b"),
903            vec![McpAuthPhase::Discovering, McpAuthPhase::Failed]
904        );
905    }
906
907    #[tokio::test]
908    async fn complete_routes_by_state_and_streams_terminal_phase() {
909        let driver = driver(MockEngine::default());
910        let mut rx = driver.subscribe();
911        let outcomes = driver
912            .prepare(
913                vec![server("a", "https://a.example/mcp")],
914                BulkAuthMode::All,
915                "http://127.0.0.1:9783/callback",
916            )
917            .await;
918        let state = match &outcomes[0] {
919            PrepareOutcome::Pending(flow) => flow.state.clone(),
920            other => panic!("expected pending, got {other:?}"),
921        };
922        let _ = drain(&mut rx).await;
923
924        let token = driver.complete(&state, "auth-code", None).await.unwrap();
925        assert_eq!(token.access_token, "access");
926        assert_eq!(driver.pending_count(), 0, "completed flow is cleared");
927
928        let events = drain(&mut rx).await;
929        assert_eq!(
930            phases(&events, "a"),
931            vec![McpAuthPhase::Exchanging, McpAuthPhase::Connected]
932        );
933    }
934
935    #[tokio::test]
936    async fn complete_failure_emits_failed_and_keeps_pending() {
937        let driver = driver(MockEngine::default());
938        // Seed a pending flow whose state the mock will reject.
939        driver.pending.lock().unwrap().insert(
940            "bad-state".to_string(),
941            FlowMeta {
942                name: "a".to_string(),
943                server_url: "https://a.example/mcp".to_string(),
944            },
945        );
946        let mut rx = driver.subscribe();
947        let error = driver
948            .complete("bad-state", "code", None)
949            .await
950            .unwrap_err();
951        assert!(error.contains("token exchange failed"));
952        let events = drain(&mut rx).await;
953        assert_eq!(
954            phases(&events, "a"),
955            vec![McpAuthPhase::Exchanging, McpAuthPhase::Failed]
956        );
957    }
958
959    #[test]
960    fn status_serializes_snake_case() {
961        let json = serde_json::to_value(McpAuthStatus {
962            server: "Notion".to_string(),
963            server_url: "https://mcp.notion.com/mcp".to_string(),
964            phase: McpAuthPhase::AwaitingConsent,
965            detail: None,
966        })
967        .unwrap();
968        assert_eq!(json["server"], serde_json::json!("Notion"));
969        assert_eq!(json["phase"], serde_json::json!("awaiting_consent"));
970        assert!(json.get("detail").is_none(), "None detail is omitted");
971    }
972
973    #[test]
974    fn config_defaults_when_no_overlay() {
975        let config = BulkAuthConfig::load();
976        assert_eq!(config.concurrency, 8);
977        assert_eq!(config.prepare_timeout_secs, 30);
978    }
979
980    fn outcome_name(outcome: &PrepareOutcome) -> &str {
981        match outcome {
982            PrepareOutcome::Pending(flow) => &flow.name,
983            PrepareOutcome::Skipped { name, .. } => name,
984            PrepareOutcome::Failed { name, .. } => name,
985        }
986    }
987}