1use std::collections::HashMap;
25use std::sync::{Arc, Mutex};
26use std::time::Duration;
27
28use async_trait::async_trait;
29use futures::stream::StreamExt;
30use serde::{Deserialize, Serialize};
31use tokio::sync::broadcast;
32
33use crate::mcp_auth::{canonical_resource_indicator, OAuthClientAuthMode};
34use crate::mcp_oauth::{self, BeginAuthorization, PendingAuthorization, StoredMcpToken};
35
36const STATUS_CHANNEL_CAPACITY: usize = 256;
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum BulkAuthMode {
44 Missing,
47 Expired,
51 All,
53}
54
55#[derive(Clone, Debug, Default)]
58pub struct BulkAuthServer {
59 pub name: String,
60 pub server_url: String,
61 pub mode: Option<OAuthClientAuthMode>,
62 pub client_id: Option<String>,
63 pub client_secret: Option<String>,
64 pub static_secret_id: Option<String>,
65 pub scopes: Option<String>,
66}
67
68#[derive(Clone, Debug, Serialize)]
72#[serde(rename_all = "camelCase")]
73pub struct PreparedFlow {
74 pub name: String,
75 pub server_url: String,
76 pub authorize_url: String,
77 pub state: String,
78 pub redirect_uri: String,
79}
80
81#[derive(Clone, Debug)]
83pub enum PrepareOutcome {
84 Pending(PreparedFlow),
86 Skipped {
88 name: String,
89 server_url: String,
90 reason: String,
91 },
92 Failed {
94 name: String,
95 server_url: String,
96 error: String,
97 },
98}
99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum McpAuthPhase {
104 Discovering,
106 AwaitingConsent,
108 Exchanging,
110 Connected,
112 Failed,
114 Skipped,
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub struct McpAuthStatus {
122 pub server: String,
124 pub server_url: String,
126 pub phase: McpAuthPhase,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub detail: Option<String>,
131}
132
133#[derive(Clone, Copy, Debug, Deserialize)]
137#[serde(default)]
138pub struct BulkAuthConfig {
139 pub concurrency: usize,
141 pub prepare_timeout_secs: u64,
143}
144
145impl Default for BulkAuthConfig {
146 fn default() -> Self {
147 Self {
148 concurrency: 8,
149 prepare_timeout_secs: 30,
150 }
151 }
152}
153
154#[derive(Debug, Default, Deserialize)]
155struct BulkAuthConfigFile {
156 #[serde(default)]
157 bulk_auth: BulkAuthConfig,
158}
159
160impl BulkAuthConfig {
161 pub fn load() -> Self {
165 if let Ok(path) = std::env::var("HARN_MCP_BULK_AUTH_CONFIG") {
166 if let Some(config) = Self::read(&path) {
167 return config;
168 }
169 }
170 if !cfg!(test) {
171 if let Some(home) = crate::user_dirs::home_dir() {
172 let path = home.join(".config").join("harn").join("mcp_bulk_auth.toml");
173 if let Some(config) = Self::read(&path.to_string_lossy()) {
174 return config;
175 }
176 }
177 }
178 Self::default()
179 }
180
181 fn read(path: &str) -> Option<Self> {
182 let content = std::fs::read_to_string(path).ok()?;
183 match toml::from_str::<BulkAuthConfigFile>(&content) {
184 Ok(file) => Some(file.bulk_auth),
185 Err(error) => {
186 eprintln!("[mcp_bulk_auth] TOML parse error in {path}: {error}");
187 None
188 }
189 }
190 }
191}
192
193#[async_trait]
196pub trait OAuthFlowEngine: Send + Sync {
197 async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String>;
200 async fn has_token(&self, server_url: &str) -> Result<bool, String>;
202 async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String>;
204 async fn complete(
206 &self,
207 state: &str,
208 code: &str,
209 issuer: Option<&str>,
210 ) -> Result<StoredMcpToken, String>;
211}
212
213#[derive(Clone, Copy, Debug, Default)]
215pub struct RealOAuthFlowEngine;
216
217#[async_trait]
218impl OAuthFlowEngine for RealOAuthFlowEngine {
219 async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String> {
220 mcp_oauth::resolve_bearer(server_url).await
221 }
222
223 async fn has_token(&self, server_url: &str) -> Result<bool, String> {
224 let discovery = mcp_oauth::discover(server_url).await?;
225 let resource =
226 canonical_resource_indicator(server_url).map_err(|error| error.to_string())?;
227 Ok(
228 mcp_oauth::load_token(&resource, &discovery.authorization_server_issuer, None)
229 .await?
230 .is_some(),
231 )
232 }
233
234 async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String> {
235 mcp_oauth::begin_authorization(request).await
236 }
237
238 async fn complete(
239 &self,
240 state: &str,
241 code: &str,
242 issuer: Option<&str>,
243 ) -> Result<StoredMcpToken, String> {
244 mcp_oauth::complete_authorization(state, code, issuer).await
245 }
246}
247
248#[derive(Clone, Debug)]
251struct FlowMeta {
252 name: String,
253 server_url: String,
254}
255
256pub struct McpBulkAuth<E: OAuthFlowEngine = RealOAuthFlowEngine> {
260 engine: Arc<E>,
261 config: BulkAuthConfig,
262 status_tx: broadcast::Sender<McpAuthStatus>,
263 pending: Arc<Mutex<HashMap<String, FlowMeta>>>,
264}
265
266impl McpBulkAuth<RealOAuthFlowEngine> {
267 pub fn new() -> Self {
270 Self::with_engine(RealOAuthFlowEngine, BulkAuthConfig::load())
271 }
272}
273
274impl Default for McpBulkAuth<RealOAuthFlowEngine> {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280impl<E: OAuthFlowEngine> McpBulkAuth<E> {
281 pub fn with_engine(engine: E, config: BulkAuthConfig) -> Self {
283 let (status_tx, _rx) = broadcast::channel(STATUS_CHANNEL_CAPACITY);
284 Self {
285 engine: Arc::new(engine),
286 config,
287 status_tx,
288 pending: Arc::new(Mutex::new(HashMap::new())),
289 }
290 }
291
292 pub fn subscribe(&self) -> broadcast::Receiver<McpAuthStatus> {
295 self.status_tx.subscribe()
296 }
297
298 pub async fn prepare(
303 &self,
304 servers: Vec<BulkAuthServer>,
305 mode: BulkAuthMode,
306 redirect_uri: &str,
307 ) -> Vec<PrepareOutcome> {
308 let concurrency = self.config.concurrency.max(1);
309 let timeout = Duration::from_secs(self.config.prepare_timeout_secs.max(1));
310 futures::stream::iter(servers.into_iter().map(|server| {
311 let engine = self.engine.clone();
312 let status_tx = self.status_tx.clone();
313 let pending = self.pending.clone();
314 let redirect_uri = redirect_uri.to_string();
315 async move {
316 prepare_one(
317 engine,
318 status_tx,
319 pending,
320 server,
321 mode,
322 redirect_uri,
323 timeout,
324 )
325 .await
326 }
327 }))
328 .buffer_unordered(concurrency)
329 .collect::<Vec<_>>()
330 .await
331 }
332
333 pub async fn complete(
337 &self,
338 state: &str,
339 code: &str,
340 issuer: Option<&str>,
341 ) -> Result<StoredMcpToken, String> {
342 let meta = self
343 .pending
344 .lock()
345 .unwrap_or_else(|poison| poison.into_inner())
346 .get(state)
347 .cloned();
348 let (name, server_url) = match meta {
349 Some(meta) => (meta.name, meta.server_url),
350 None => ("<unknown>".to_string(), String::new()),
351 };
352 emit(
353 &self.status_tx,
354 &name,
355 &server_url,
356 McpAuthPhase::Exchanging,
357 None,
358 );
359 match self.engine.complete(state, code, issuer).await {
360 Ok(token) => {
361 self.pending
362 .lock()
363 .unwrap_or_else(|poison| poison.into_inner())
364 .remove(state);
365 emit(
366 &self.status_tx,
367 &name,
368 &server_url,
369 McpAuthPhase::Connected,
370 None,
371 );
372 Ok(token)
373 }
374 Err(error) => {
375 emit(
376 &self.status_tx,
377 &name,
378 &server_url,
379 McpAuthPhase::Failed,
380 Some(error.clone()),
381 );
382 Err(error)
383 }
384 }
385 }
386
387 pub fn pending_count(&self) -> usize {
389 self.pending
390 .lock()
391 .unwrap_or_else(|poison| poison.into_inner())
392 .len()
393 }
394
395 pub fn knows_state(&self, state: &str) -> bool {
401 self.pending
402 .lock()
403 .unwrap_or_else(|poison| poison.into_inner())
404 .contains_key(state)
405 }
406}
407
408async fn prepare_one<E: OAuthFlowEngine>(
412 engine: Arc<E>,
413 status_tx: broadcast::Sender<McpAuthStatus>,
414 pending: Arc<Mutex<HashMap<String, FlowMeta>>>,
415 server: BulkAuthServer,
416 mode: BulkAuthMode,
417 redirect_uri: String,
418 timeout: Duration,
419) -> PrepareOutcome {
420 emit(
421 &status_tx,
422 &server.name,
423 &server.server_url,
424 McpAuthPhase::Discovering,
425 None,
426 );
427
428 match tokio::time::timeout(timeout, decide(&*engine, &server, mode)).await {
429 Ok(AuthDecision::Begin) => {}
430 Ok(AuthDecision::Skip(reason)) => {
431 emit(
432 &status_tx,
433 &server.name,
434 &server.server_url,
435 McpAuthPhase::Skipped,
436 Some(reason.to_string()),
437 );
438 return PrepareOutcome::Skipped {
439 name: server.name,
440 server_url: server.server_url,
441 reason: reason.to_string(),
442 };
443 }
444 Err(_) => {
445 return fail(
446 &status_tx,
447 server,
448 "timed out resolving authorization server",
449 );
450 }
451 }
452
453 let request = BeginAuthorization {
454 server_url: server.server_url.clone(),
455 redirect_uri: redirect_uri.clone(),
456 mode: server.mode,
457 client_id: server.client_id.clone(),
458 client_secret: server.client_secret.clone(),
459 static_secret_id: server.static_secret_id.clone(),
460 scopes: server.scopes.clone(),
461 };
462 match tokio::time::timeout(timeout, engine.begin(request)).await {
463 Ok(Ok(pending_auth)) => {
464 pending
465 .lock()
466 .unwrap_or_else(|poison| poison.into_inner())
467 .insert(
468 pending_auth.state.clone(),
469 FlowMeta {
470 name: server.name.clone(),
471 server_url: server.server_url.clone(),
472 },
473 );
474 emit(
475 &status_tx,
476 &server.name,
477 &server.server_url,
478 McpAuthPhase::AwaitingConsent,
479 None,
480 );
481 PrepareOutcome::Pending(PreparedFlow {
482 name: server.name,
483 server_url: server.server_url,
484 authorize_url: pending_auth.authorize_url,
485 state: pending_auth.state,
486 redirect_uri,
487 })
488 }
489 Ok(Err(error)) => fail(&status_tx, server, &error),
490 Err(_) => fail(&status_tx, server, "timed out minting authorization URL"),
491 }
492}
493
494enum AuthDecision {
496 Begin,
497 Skip(&'static str),
498}
499
500async fn decide<E: OAuthFlowEngine>(
501 engine: &E,
502 server: &BulkAuthServer,
503 mode: BulkAuthMode,
504) -> AuthDecision {
505 match mode {
506 BulkAuthMode::All => AuthDecision::Begin,
507 BulkAuthMode::Missing => match engine.current_bearer(&server.server_url).await {
508 Ok(Some(_)) => AuthDecision::Skip("already connected"),
509 _ => AuthDecision::Begin,
512 },
513 BulkAuthMode::Expired => {
514 match engine.has_token(&server.server_url).await {
516 Ok(false) => return AuthDecision::Skip("no stored token"),
517 Ok(true) => {}
518 Err(_) => return AuthDecision::Skip("no stored token"),
519 }
520 match engine.current_bearer(&server.server_url).await {
521 Ok(Some(_)) => AuthDecision::Skip("token still valid"),
522 _ => AuthDecision::Begin,
523 }
524 }
525 }
526}
527
528fn fail(
529 status_tx: &broadcast::Sender<McpAuthStatus>,
530 server: BulkAuthServer,
531 error: &str,
532) -> PrepareOutcome {
533 emit(
534 status_tx,
535 &server.name,
536 &server.server_url,
537 McpAuthPhase::Failed,
538 Some(error.to_string()),
539 );
540 PrepareOutcome::Failed {
541 name: server.name,
542 server_url: server.server_url,
543 error: error.to_string(),
544 }
545}
546
547fn emit(
548 status_tx: &broadcast::Sender<McpAuthStatus>,
549 server: &str,
550 server_url: &str,
551 phase: McpAuthPhase,
552 detail: Option<String>,
553) {
554 let _ = status_tx.send(McpAuthStatus {
556 server: server.to_string(),
557 server_url: server_url.to_string(),
558 phase,
559 detail,
560 });
561}
562
563pub fn prepare_outcome_to_json(outcome: &PrepareOutcome) -> serde_json::Value {
570 match outcome {
571 PrepareOutcome::Pending(flow) => serde_json::json!({
572 "server": flow.name,
573 "server_url": flow.server_url,
574 "status": "reauth_required",
575 "authorize_url": flow.authorize_url,
576 "state": flow.state,
577 }),
578 PrepareOutcome::Skipped {
579 name,
580 server_url,
581 reason,
582 } => serde_json::json!({
583 "server": name,
584 "server_url": server_url,
585 "status": "skipped",
586 "reason": reason,
587 }),
588 PrepareOutcome::Failed {
589 name,
590 server_url,
591 error,
592 } => serde_json::json!({
593 "server": name,
594 "server_url": server_url,
595 "status": "failed",
596 "error": error,
597 }),
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use std::sync::atomic::{AtomicUsize, Ordering};
605
606 #[derive(Default)]
610 struct MockEngine {
611 valid: Vec<String>,
613 stored: Vec<String>,
615 begin_fails: Vec<String>,
617 begin_calls: AtomicUsize,
618 state_counter: AtomicUsize,
619 }
620
621 #[async_trait]
622 impl OAuthFlowEngine for MockEngine {
623 async fn current_bearer(&self, server_url: &str) -> Result<Option<String>, String> {
624 Ok(self
625 .valid
626 .iter()
627 .any(|u| u == server_url)
628 .then(|| "bearer".to_string()))
629 }
630 async fn has_token(&self, server_url: &str) -> Result<bool, String> {
631 Ok(self.stored.iter().any(|u| u == server_url))
632 }
633 async fn begin(&self, request: BeginAuthorization) -> Result<PendingAuthorization, String> {
634 self.begin_calls.fetch_add(1, Ordering::SeqCst);
635 if self.begin_fails.contains(&request.server_url) {
636 return Err("discovery exploded".to_string());
637 }
638 let n = self.state_counter.fetch_add(1, Ordering::SeqCst);
639 let state = format!("state-{n}");
640 Ok(PendingAuthorization {
641 authorize_url: format!("https://auth.example/authorize?state={state}"),
642 state,
643 redirect_uri: request.redirect_uri,
644 resource: request.server_url,
645 issuer: "https://auth.example".to_string(),
646 })
647 }
648 async fn complete(
649 &self,
650 state: &str,
651 _code: &str,
652 _issuer: Option<&str>,
653 ) -> Result<StoredMcpToken, String> {
654 if state == "bad-state" {
655 return Err("token exchange failed".to_string());
656 }
657 Ok(StoredMcpToken {
658 access_token: "access".to_string(),
659 refresh_token: None,
660 expires_at_unix: None,
661 token_endpoint: "https://auth.example/token".to_string(),
662 client_id: "client".to_string(),
663 client_secret: None,
664 token_endpoint_auth_method: "none".to_string(),
665 issuer: "https://auth.example".to_string(),
666 resource: "https://mcp.example/mcp".to_string(),
667 scopes: None,
668 token_response_extra: None,
669 })
670 }
671 }
672
673 fn server(name: &str, url: &str) -> BulkAuthServer {
674 BulkAuthServer {
675 name: name.to_string(),
676 server_url: url.to_string(),
677 ..Default::default()
678 }
679 }
680
681 fn driver(engine: MockEngine) -> McpBulkAuth<MockEngine> {
682 McpBulkAuth::with_engine(
684 engine,
685 BulkAuthConfig {
686 concurrency: 1,
687 prepare_timeout_secs: 5,
688 },
689 )
690 }
691
692 async fn drain(rx: &mut broadcast::Receiver<McpAuthStatus>) -> Vec<McpAuthStatus> {
693 let mut out = Vec::new();
694 while let Ok(status) = rx.try_recv() {
695 out.push(status);
696 }
697 out
698 }
699
700 fn phases(events: &[McpAuthStatus], server: &str) -> Vec<McpAuthPhase> {
701 events
702 .iter()
703 .filter(|e| e.server == server)
704 .map(|e| e.phase)
705 .collect()
706 }
707
708 #[tokio::test]
709 async fn prepares_all_servers_and_emits_phase_sequence() {
710 let driver = driver(MockEngine::default());
711 let mut rx = driver.subscribe();
712 let outcomes = driver
713 .prepare(
714 vec![
715 server("a", "https://a.example/mcp"),
716 server("b", "https://b.example/mcp"),
717 server("c", "https://c.example/mcp"),
718 ],
719 BulkAuthMode::All,
720 "http://127.0.0.1:9783/callback",
721 )
722 .await;
723
724 assert_eq!(outcomes.len(), 3);
725 assert!(outcomes
726 .iter()
727 .all(|o| matches!(o, PrepareOutcome::Pending(_))));
728 assert_eq!(driver.pending_count(), 3);
729
730 let events = drain(&mut rx).await;
731 for name in ["a", "b", "c"] {
732 assert_eq!(
733 phases(&events, name),
734 vec![McpAuthPhase::Discovering, McpAuthPhase::AwaitingConsent],
735 "server {name}"
736 );
737 }
738 }
739
740 #[tokio::test]
741 async fn missing_mode_skips_connected_servers() {
742 let engine = MockEngine {
743 valid: vec!["https://b.example/mcp".to_string()],
744 ..Default::default()
745 };
746 let driver = driver(engine);
747 let outcomes = driver
748 .prepare(
749 vec![
750 server("a", "https://a.example/mcp"),
751 server("b", "https://b.example/mcp"),
752 ],
753 BulkAuthMode::Missing,
754 "http://127.0.0.1:9783/callback",
755 )
756 .await;
757
758 let a = outcomes.iter().find(|o| outcome_name(o) == "a").unwrap();
759 let b = outcomes.iter().find(|o| outcome_name(o) == "b").unwrap();
760 assert!(matches!(a, PrepareOutcome::Pending(_)));
761 assert!(
762 matches!(b, PrepareOutcome::Skipped { reason, .. } if reason == "already connected")
763 );
764 }
765
766 #[tokio::test]
767 async fn expired_mode_only_reauths_stale_stored_tokens() {
768 let engine = MockEngine {
769 valid: vec!["https://fresh.example/mcp".to_string()],
772 stored: vec![
773 "https://stale.example/mcp".to_string(),
774 "https://fresh.example/mcp".to_string(),
775 ],
776 ..Default::default()
777 };
778 let driver = driver(engine);
779 let outcomes = driver
780 .prepare(
781 vec![
782 server("stale", "https://stale.example/mcp"),
783 server("fresh", "https://fresh.example/mcp"),
784 server("none", "https://none.example/mcp"),
785 ],
786 BulkAuthMode::Expired,
787 "http://127.0.0.1:9783/callback",
788 )
789 .await;
790
791 let stale = outcomes
792 .iter()
793 .find(|o| outcome_name(o) == "stale")
794 .unwrap();
795 let fresh = outcomes
796 .iter()
797 .find(|o| outcome_name(o) == "fresh")
798 .unwrap();
799 let none = outcomes.iter().find(|o| outcome_name(o) == "none").unwrap();
800 assert!(
801 matches!(stale, PrepareOutcome::Pending(_)),
802 "stale → re-auth"
803 );
804 assert!(
805 matches!(fresh, PrepareOutcome::Skipped { reason, .. } if reason == "token still valid")
806 );
807 assert!(
808 matches!(none, PrepareOutcome::Skipped { reason, .. } if reason == "no stored token")
809 );
810 }
811
812 #[tokio::test]
813 async fn reauth_expired_outcomes_as_json_drive_only_stale() {
814 let engine = MockEngine {
818 valid: vec!["https://fresh.example/mcp".to_string()],
819 stored: vec![
820 "https://stale1.example/mcp".to_string(),
821 "https://stale2.example/mcp".to_string(),
822 "https://fresh.example/mcp".to_string(),
823 ],
824 ..Default::default()
825 };
826 let driver = driver(engine);
827 let outcomes = driver
828 .prepare(
829 vec![
830 server("stale1", "https://stale1.example/mcp"),
831 server("stale2", "https://stale2.example/mcp"),
832 server("fresh", "https://fresh.example/mcp"),
833 ],
834 BulkAuthMode::Expired,
835 "http://127.0.0.1:9783/callback",
836 )
837 .await;
838
839 let json: Vec<serde_json::Value> = outcomes.iter().map(prepare_outcome_to_json).collect();
840 let by_server = |name: &str| {
841 json.iter()
842 .find(|value| value["server"] == name)
843 .cloned()
844 .unwrap()
845 };
846
847 let reauthed: Vec<_> = json
848 .iter()
849 .filter(|value| value["status"] == "reauth_required")
850 .collect();
851 assert_eq!(
852 reauthed.len(),
853 2,
854 "exactly the two stale servers are driven"
855 );
856
857 let stale1 = by_server("stale1");
858 assert_eq!(stale1["status"], "reauth_required");
859 assert!(
860 stale1["authorize_url"].as_str().is_some(),
861 "a re-auth outcome carries an authorize_url for the caller to open"
862 );
863 assert_eq!(by_server("stale2")["status"], "reauth_required");
864
865 let fresh = by_server("fresh");
866 assert_eq!(fresh["status"], "skipped");
867 assert_eq!(fresh["reason"], "token still valid");
868 }
869
870 #[tokio::test]
871 async fn one_servers_failure_is_isolated() {
872 let engine = MockEngine {
873 begin_fails: vec!["https://b.example/mcp".to_string()],
874 ..Default::default()
875 };
876 let driver = driver(engine);
877 let mut rx = driver.subscribe();
878 let outcomes = driver
879 .prepare(
880 vec![
881 server("a", "https://a.example/mcp"),
882 server("b", "https://b.example/mcp"),
883 server("c", "https://c.example/mcp"),
884 ],
885 BulkAuthMode::All,
886 "http://127.0.0.1:9783/callback",
887 )
888 .await;
889
890 let b = outcomes.iter().find(|o| outcome_name(o) == "b").unwrap();
891 assert!(matches!(b, PrepareOutcome::Failed { error, .. } if error.contains("discovery")));
892 assert_eq!(
894 outcomes
895 .iter()
896 .filter(|o| matches!(o, PrepareOutcome::Pending(_)))
897 .count(),
898 2
899 );
900 let events = drain(&mut rx).await;
901 assert_eq!(
902 phases(&events, "b"),
903 vec![McpAuthPhase::Discovering, McpAuthPhase::Failed]
904 );
905 }
906
907 #[tokio::test]
908 async fn complete_routes_by_state_and_streams_terminal_phase() {
909 let driver = driver(MockEngine::default());
910 let mut rx = driver.subscribe();
911 let outcomes = driver
912 .prepare(
913 vec![server("a", "https://a.example/mcp")],
914 BulkAuthMode::All,
915 "http://127.0.0.1:9783/callback",
916 )
917 .await;
918 let state = match &outcomes[0] {
919 PrepareOutcome::Pending(flow) => flow.state.clone(),
920 other => panic!("expected pending, got {other:?}"),
921 };
922 let _ = drain(&mut rx).await;
923
924 let token = driver.complete(&state, "auth-code", None).await.unwrap();
925 assert_eq!(token.access_token, "access");
926 assert_eq!(driver.pending_count(), 0, "completed flow is cleared");
927
928 let events = drain(&mut rx).await;
929 assert_eq!(
930 phases(&events, "a"),
931 vec![McpAuthPhase::Exchanging, McpAuthPhase::Connected]
932 );
933 }
934
935 #[tokio::test]
936 async fn complete_failure_emits_failed_and_keeps_pending() {
937 let driver = driver(MockEngine::default());
938 driver.pending.lock().unwrap().insert(
940 "bad-state".to_string(),
941 FlowMeta {
942 name: "a".to_string(),
943 server_url: "https://a.example/mcp".to_string(),
944 },
945 );
946 let mut rx = driver.subscribe();
947 let error = driver
948 .complete("bad-state", "code", None)
949 .await
950 .unwrap_err();
951 assert!(error.contains("token exchange failed"));
952 let events = drain(&mut rx).await;
953 assert_eq!(
954 phases(&events, "a"),
955 vec![McpAuthPhase::Exchanging, McpAuthPhase::Failed]
956 );
957 }
958
959 #[test]
960 fn status_serializes_snake_case() {
961 let json = serde_json::to_value(McpAuthStatus {
962 server: "Notion".to_string(),
963 server_url: "https://mcp.notion.com/mcp".to_string(),
964 phase: McpAuthPhase::AwaitingConsent,
965 detail: None,
966 })
967 .unwrap();
968 assert_eq!(json["server"], serde_json::json!("Notion"));
969 assert_eq!(json["phase"], serde_json::json!("awaiting_consent"));
970 assert!(json.get("detail").is_none(), "None detail is omitted");
971 }
972
973 #[test]
974 fn config_defaults_when_no_overlay() {
975 let config = BulkAuthConfig::load();
976 assert_eq!(config.concurrency, 8);
977 assert_eq!(config.prepare_timeout_secs, 30);
978 }
979
980 fn outcome_name(outcome: &PrepareOutcome) -> &str {
981 match outcome {
982 PrepareOutcome::Pending(flow) => &flow.name,
983 PrepareOutcome::Skipped { name, .. } => name,
984 PrepareOutcome::Failed { name, .. } => name,
985 }
986 }
987}