1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44 #[error(transparent)]
45 Core(#[from] kontext_dev_core::KontextDevCoreError),
46 #[error("failed to parse URL `{url}`")]
47 InvalidUrl {
48 url: String,
49 source: url::ParseError,
50 },
51 #[error("failed to open browser for OAuth authorization")]
52 BrowserOpenFailed,
53 #[error("OAuth callback timed out after {timeout_seconds}s")]
54 OAuthCallbackTimeout { timeout_seconds: i64 },
55 #[error("OAuth callback channel was unexpectedly cancelled")]
56 OAuthCallbackCancelled,
57 #[error("OAuth callback is missing the authorization code")]
58 MissingAuthorizationCode,
59 #[error("OAuth callback returned an error: {error}")]
60 OAuthCallbackError { error: String },
61 #[error("OAuth state mismatch")]
62 InvalidOAuthState,
63 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64 TokenRequest { token_url: String, message: String },
65 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66 TokenExchange { resource: String, message: String },
67 #[error("Kontext-Dev connect session request failed: {message}")]
68 ConnectSession { message: String },
69 #[error("Kontext-Dev integration OAuth init failed: {message}")]
70 IntegrationOAuthInit { message: String },
71 #[error("failed to read token cache at `{path}`: {source}")]
72 TokenCacheRead {
73 path: String,
74 source: std::io::Error,
75 },
76 #[error("failed to write token cache at `{path}`: {source}")]
77 TokenCacheWrite {
78 path: String,
79 source: std::io::Error,
80 },
81 #[error("failed to deserialize token cache at `{path}`: {source}")]
82 TokenCacheDeserialize {
83 path: String,
84 source: serde_json::Error,
85 },
86 #[error("failed to serialize token cache: {source}")]
87 TokenCacheSerialize { source: serde_json::Error },
88 #[error("Kontext-Dev access token is empty")]
89 EmptyAccessToken,
90 #[error("missing integration UI URL; set `integration_ui_url` in config")]
91 MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96 #[serde(rename = "sessionId")]
97 pub session_id: String,
98 #[serde(rename = "expiresAt")]
99 pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104 #[serde(rename = "authorizationUrl")]
105 pub authorization_url: String,
106 #[serde(default)]
107 pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112 pub connected: bool,
113 #[serde(default)]
114 pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119 pub identity_token: AccessToken,
120 pub gateway_token: TokenExchangeToken,
121 pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126 access_token: String,
127 token_type: String,
128 refresh_token: Option<String>,
129 scope: Option<String>,
130 expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135 client_id: String,
136 resource: String,
137 identity: CachedAccessToken,
138 gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143 if token.access_token.is_empty() {
144 return Err(KontextDevError::EmptyAccessToken);
145 }
146
147 Ok(Self {
148 access_token: token.access_token.clone(),
149 token_type: token.token_type.clone(),
150 refresh_token: token.refresh_token.clone(),
151 scope: token.scope.clone(),
152 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153 })
154 }
155
156 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157 if token.access_token.is_empty() {
158 return Err(KontextDevError::EmptyAccessToken);
159 }
160
161 Ok(Self {
162 access_token: token.access_token.clone(),
163 token_type: token.token_type.clone(),
164 refresh_token: token.refresh_token.clone(),
165 scope: token.scope.clone(),
166 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167 })
168 }
169
170 fn is_valid(&self) -> bool {
171 match self.expires_at_unix_ms {
172 Some(expires_at) => {
173 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174 now_unix_ms().saturating_add(buffer_ms) < expires_at
175 }
176 None => true,
177 }
178 }
179
180 fn to_access_token(&self) -> AccessToken {
181 AccessToken {
182 access_token: self.access_token.clone(),
183 token_type: self.token_type.clone(),
184 expires_in: self
185 .expires_at_unix_ms
186 .and_then(unix_ms_to_relative_seconds),
187 refresh_token: self.refresh_token.clone(),
188 scope: self.scope.clone(),
189 }
190 }
191
192 fn to_token_exchange_token(&self) -> TokenExchangeToken {
193 TokenExchangeToken {
194 access_token: self.access_token.clone(),
195 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196 token_type: self.token_type.clone(),
197 expires_in: self
198 .expires_at_unix_ms
199 .and_then(unix_ms_to_relative_seconds),
200 scope: self.scope.clone(),
201 refresh_token: self.refresh_token.clone(),
202 }
203 }
204}
205
206fn now_unix_ms() -> u64 {
207 SystemTime::now()
208 .duration_since(UNIX_EPOCH)
209 .unwrap_or_else(|_| Duration::from_secs(0))
210 .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214 if unix_ms <= now_unix_ms() {
215 return Some(0);
216 }
217
218 let delta_ms = unix_ms - now_unix_ms();
219 let secs = delta_ms / 1000;
220 i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224 if let Some(expires_in) = expires_in
225 && expires_in > 0
226 {
227 let expires_in_u64 = u64::try_from(expires_in).ok()?;
228 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229 }
230
231 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235 let mut parts = token.split('.');
236 let _header = parts.next()?;
237 let payload = parts.next()?;
238
239 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240 .decode(payload)
241 .ok()?;
242 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243 value.get("exp").and_then(|exp| {
244 exp.as_u64()
245 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246 })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251 error: Option<String>,
252 error_description: Option<String>,
253 message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthAuthorizationServerMetadata {
258 authorization_endpoint: Option<String>,
259}
260
261#[derive(Clone, Debug, Deserialize)]
262struct OAuthCallbackPayload {
263 code: Option<String>,
264 state: Option<String>,
265 error: Option<String>,
266 error_description: Option<String>,
267}
268
269#[derive(Clone, Debug)]
270struct PkcePair {
271 verifier: String,
272 challenge: String,
273}
274
275fn generate_pkce_pair() -> PkcePair {
276 let mut raw = [0u8; 64];
277 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
278
279 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
280 let digest = sha2::Sha256::digest(verifier.as_bytes());
281 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
282
283 PkcePair {
284 verifier,
285 challenge,
286 }
287}
288
289fn generate_state() -> String {
290 let mut bytes = [0u8; 16];
291 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
292 bytes.iter().map(|b| format!("{b:02x}")).collect()
293}
294
295struct CallbackServerGuard {
296 server: Arc<Server>,
297}
298
299impl Drop for CallbackServerGuard {
300 fn drop(&mut self) {
301 self.server.unblock();
302 }
303}
304
305fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
306 let full = format!("http://localhost{url_path}");
307 let parsed = Url::parse(&full).ok()?;
308
309 if parsed.path() != callback_path {
310 return None;
311 }
312
313 let mut payload = OAuthCallbackPayload {
314 code: None,
315 state: None,
316 error: None,
317 error_description: None,
318 };
319
320 for (key, value) in parsed.query_pairs() {
321 match key.as_ref() {
322 "code" => payload.code = Some(value.to_string()),
323 "state" => payload.state = Some(value.to_string()),
324 "error" => payload.error = Some(value.to_string()),
325 "error_description" => payload.error_description = Some(value.to_string()),
326 _ => {}
327 }
328 }
329
330 Some(payload)
331}
332
333fn spawn_callback_server(
334 server: Arc<Server>,
335 callback_path: String,
336 tx: oneshot::Sender<OAuthCallbackPayload>,
337) -> tokio::task::JoinHandle<()> {
338 tokio::task::spawn_blocking(move || {
339 while let Ok(request) = server.recv() {
340 let path = request.url().to_string();
341 if let Some(payload) = parse_callback(&path, &callback_path) {
342 let response = Response::from_string(
343 "Authentication complete. You can return to your terminal.",
344 );
345 let _ = request.respond(response);
346 let _ = tx.send(payload);
347 break;
348 }
349
350 let response = Response::from_string("Invalid callback").with_status_code(400);
351 let _ = request.respond(response);
352 }
353 })
354}
355
356pub struct KontextDevClient {
357 config: KontextDevConfig,
358 http: Client,
359}
360
361impl KontextDevClient {
362 pub fn new(config: KontextDevConfig) -> Self {
363 Self {
364 config,
365 http: Client::new(),
366 }
367 }
368
369 pub fn config(&self) -> &KontextDevConfig {
370 &self.config
371 }
372
373 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
374 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
375 }
376
377 pub fn token_url(&self) -> Result<String, KontextDevError> {
378 resolve_token_url(&self.config).map_err(KontextDevError::from)
379 }
380
381 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
382 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
383 }
384
385 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
386 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
387 }
388
389 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
392 let resource = self.config.resource.clone();
393
394 if let Some(cache) = self.read_cache()?
395 && cache.client_id == self.config.client_id
396 && cache.resource == resource
397 && cache.gateway.is_valid()
398 && cache.identity.is_valid()
399 {
400 return Ok(KontextAuthSession {
401 identity_token: cache.identity.to_access_token(),
402 gateway_token: cache.gateway.to_token_exchange_token(),
403 browser_auth_performed: false,
404 });
405 }
406
407 if let Some(cache) = self.read_cache()?
408 && cache.client_id == self.config.client_id
409 && cache.resource == resource
410 {
411 let maybe_refreshed = if cache.identity.is_valid() {
412 Some(cache.identity.to_access_token())
413 } else if let Some(refresh_token) = &cache.identity.refresh_token {
414 Some(self.refresh_identity_token(refresh_token).await?)
415 } else {
416 None
417 };
418
419 if let Some(identity_token) = maybe_refreshed {
420 let gateway_token = self
421 .exchange_for_resource(&identity_token.access_token, &resource, None)
422 .await?;
423 self.write_cache(&identity_token, &gateway_token)?;
424 return Ok(KontextAuthSession {
425 identity_token,
426 gateway_token,
427 browser_auth_performed: false,
428 });
429 }
430 }
431
432 let identity_token = self.authorize_with_browser_pkce().await?;
433 let gateway_token = self
434 .exchange_for_resource(&identity_token.access_token, &resource, None)
435 .await?;
436
437 self.write_cache(&identity_token, &gateway_token)?;
438
439 Ok(KontextAuthSession {
440 identity_token,
441 gateway_token,
442 browser_auth_performed: true,
443 })
444 }
445
446 pub async fn create_integration_connect_url(
449 &self,
450 gateway_access_token: &str,
451 ) -> Result<String, KontextDevError> {
452 let session = self.create_connect_session(gateway_access_token).await?;
453 self.integration_connect_url(&session.session_id)
454 }
455
456 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
457 if session_id.trim().is_empty() {
458 return Err(KontextDevError::ConnectSession {
459 message: "connect session id is empty".to_string(),
460 });
461 }
462
463 let base = if let Some(explicit) = &self.config.integration_ui_url {
464 explicit.trim_end_matches('/').to_string()
465 } else {
466 let server = resolve_server_base_url(&self.config)?;
467 if server.contains("api.kontext.dev") {
468 "https://app.kontext.dev".to_string()
469 } else {
470 server
471 }
472 };
473
474 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
475 url: base.clone(),
476 source,
477 })?;
478 url.set_path("/oauth/connect");
479 url.query_pairs_mut().append_pair("session", session_id);
480 Ok(url.to_string())
481 }
482
483 pub async fn open_integration_connect_page(
485 &self,
486 gateway_access_token: &str,
487 ) -> Result<String, KontextDevError> {
488 let url = self
489 .create_integration_connect_url(gateway_access_token)
490 .await?;
491
492 if webbrowser::open(&url).is_err() {
493 return Err(KontextDevError::BrowserOpenFailed);
494 }
495
496 Ok(url)
497 }
498
499 pub async fn create_connect_session(
500 &self,
501 gateway_access_token: &str,
502 ) -> Result<ConnectSession, KontextDevError> {
503 let url = self.connect_session_url()?;
504 let response = self
505 .http
506 .post(&url)
507 .header("Authorization", format!("Bearer {gateway_access_token}"))
508 .send()
509 .await
510 .map_err(|err| KontextDevError::ConnectSession {
511 message: err.to_string(),
512 })?;
513
514 if !response.status().is_success() {
515 let message = build_error_message(response).await;
516 return Err(KontextDevError::ConnectSession { message });
517 }
518
519 response
520 .json::<ConnectSession>()
521 .await
522 .map_err(|err| KontextDevError::ConnectSession {
523 message: err.to_string(),
524 })
525 }
526
527 pub async fn initiate_integration_oauth(
528 &self,
529 gateway_access_token: &str,
530 integration_id: &str,
531 return_to: Option<&str>,
532 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
533 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
534
535 let mut request = self
536 .http
537 .post(&url)
538 .header("Authorization", format!("Bearer {gateway_access_token}"));
539
540 if let Some(return_to) = return_to {
541 request = request.json(&serde_json::json!({ "returnTo": return_to }));
542 }
543
544 let response =
545 request
546 .send()
547 .await
548 .map_err(|err| KontextDevError::IntegrationOAuthInit {
549 message: err.to_string(),
550 })?;
551
552 if !response.status().is_success() {
553 let message = build_error_message(response).await;
554 return Err(KontextDevError::IntegrationOAuthInit { message });
555 }
556
557 response
558 .json::<IntegrationOAuthInitResponse>()
559 .await
560 .map_err(|err| KontextDevError::IntegrationOAuthInit {
561 message: err.to_string(),
562 })
563 }
564
565 pub async fn integration_connection_status(
566 &self,
567 gateway_access_token: &str,
568 integration_id: &str,
569 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
570 let url = resolve_integration_connection_url(&self.config, integration_id)?;
571 let response = self
572 .http
573 .get(url)
574 .header("Authorization", format!("Bearer {gateway_access_token}"))
575 .send()
576 .await
577 .map_err(|err| KontextDevError::IntegrationOAuthInit {
578 message: err.to_string(),
579 })?;
580
581 if !response.status().is_success() {
582 let message = build_error_message(response).await;
583 return Err(KontextDevError::IntegrationOAuthInit { message });
584 }
585
586 response
587 .json::<IntegrationConnectionStatus>()
588 .await
589 .map_err(|err| KontextDevError::IntegrationOAuthInit {
590 message: err.to_string(),
591 })
592 }
593
594 pub async fn wait_for_integration_connection(
595 &self,
596 gateway_access_token: &str,
597 integration_id: &str,
598 timeout_ms: u64,
599 interval_ms: u64,
600 ) -> Result<bool, KontextDevError> {
601 let started = now_unix_ms();
602
603 loop {
604 let status = self
605 .integration_connection_status(gateway_access_token, integration_id)
606 .await?;
607 if status.connected {
608 return Ok(true);
609 }
610
611 if now_unix_ms().saturating_sub(started) >= timeout_ms {
612 return Ok(false);
613 }
614
615 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
616 }
617 }
618
619 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
620 let auth_url = self.resolve_authorization_url().await?;
621 let token_url = self.token_url()?;
622 let pkce = generate_pkce_pair();
623 let state = generate_state();
624
625 let (callback_url, callback_payload) = self.listen_for_callback().await?;
626
627 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
628 url: auth_url.clone(),
629 source,
630 })?;
631
632 url.query_pairs_mut()
633 .append_pair("client_id", &self.config.client_id)
634 .append_pair("response_type", "code")
635 .append_pair("redirect_uri", &callback_url)
636 .append_pair("state", &state)
637 .append_pair("scope", &self.config.scope)
638 .append_pair("code_challenge_method", "S256")
639 .append_pair("code_challenge", &pkce.challenge);
640
641 if webbrowser::open(url.as_str()).is_err() {
642 return Err(KontextDevError::BrowserOpenFailed);
643 }
644
645 let payload = callback_payload.await?;
646
647 if let Some(error) = payload.error {
648 let with_details = payload
649 .error_description
650 .map(|description| format!("{error}: {description}"))
651 .unwrap_or(error);
652 return Err(KontextDevError::OAuthCallbackError {
653 error: with_details,
654 });
655 }
656
657 if payload.state.as_deref() != Some(state.as_str()) {
658 return Err(KontextDevError::InvalidOAuthState);
659 }
660
661 let code = payload
662 .code
663 .ok_or(KontextDevError::MissingAuthorizationCode)?;
664
665 let mut body = vec![
666 ("grant_type", "authorization_code".to_string()),
667 ("code", code),
668 ("redirect_uri", callback_url),
669 ("client_id", self.config.client_id.clone()),
670 ("code_verifier", pkce.verifier),
671 ];
672
673 if let Some(client_secret) = &self.config.client_secret {
674 body.push(("client_secret", client_secret.clone()));
675 }
676
677 post_token(&self.http, &token_url, None, &body).await
678 }
679
680 async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
681 if let Some(discovered) = self.discover_authorization_endpoint().await {
682 return Ok(discovered);
683 }
684
685 let authorize_url = self.authorize_url()?;
686 if !self.endpoint_is_missing(&authorize_url).await {
687 return Ok(authorize_url);
688 }
689
690 let server_base = resolve_server_base_url(&self.config)?;
691 Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
692 }
693
694 async fn discover_authorization_endpoint(&self) -> Option<String> {
695 let base = resolve_server_base_url(&self.config).ok()?;
696 let base = base.trim_end_matches('/');
697
698 let candidates = [
699 format!("{base}/.well-known/oauth-authorization-server/mcp"),
700 format!("{base}/.well-known/oauth-authorization-server"),
701 ];
702
703 for url in candidates {
704 let response = match self.http.get(&url).send().await {
705 Ok(response) => response,
706 Err(_) => continue,
707 };
708
709 if !response.status().is_success() {
710 continue;
711 }
712
713 let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
714 Ok(metadata) => metadata,
715 Err(_) => continue,
716 };
717
718 let Some(endpoint) = metadata.authorization_endpoint else {
719 continue;
720 };
721
722 if Url::parse(&endpoint).is_ok() {
723 return Some(endpoint);
724 }
725 }
726
727 None
728 }
729
730 async fn endpoint_is_missing(&self, url: &str) -> bool {
731 let probe_client = match reqwest::Client::builder()
732 .redirect(reqwest::redirect::Policy::none())
733 .build()
734 {
735 Ok(client) => client,
736 Err(_) => return false,
737 };
738
739 match probe_client.get(url).send().await {
740 Ok(response) => response.status() == StatusCode::NOT_FOUND,
741 Err(_) => false,
742 }
743 }
744
745 async fn refresh_identity_token(
746 &self,
747 refresh_token: &str,
748 ) -> Result<AccessToken, KontextDevError> {
749 let token_url = self.token_url()?;
750 let mut body = vec![
751 ("grant_type", "refresh_token".to_string()),
752 ("refresh_token", refresh_token.to_string()),
753 ("client_id", self.config.client_id.clone()),
754 ];
755
756 if let Some(client_secret) = &self.config.client_secret {
757 body.push(("client_secret", client_secret.clone()));
758 }
759
760 post_token(&self.http, &token_url, None, &body).await
761 }
762
763 pub async fn exchange_for_resource(
764 &self,
765 subject_token: &str,
766 resource: &str,
767 scope: Option<&str>,
768 ) -> Result<TokenExchangeToken, KontextDevError> {
769 let token_url = self.token_url()?;
770
771 let mut body = vec![
772 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
773 ("subject_token", subject_token.to_string()),
774 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
775 ("resource", resource.to_string()),
776 ];
777
778 if let Some(scope) = scope {
779 body.push(("scope", scope.to_string()));
780 }
781
782 let auth_header = self.config.client_secret.as_ref().map(|secret| {
783 let raw = format!("{}:{}", self.config.client_id, secret);
784 format!(
785 "Basic {}",
786 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
787 )
788 });
789
790 if self.config.client_secret.is_none() {
791 body.push(("client_id", self.config.client_id.clone()));
792 }
793
794 let response = post_form_with_optional_auth::<TokenExchangeToken>(
795 &self.http,
796 &token_url,
797 auth_header,
798 &body,
799 )
800 .await
801 .map_err(|message| KontextDevError::TokenExchange {
802 resource: resource.to_string(),
803 message,
804 })?;
805
806 if response.access_token.is_empty() {
807 return Err(KontextDevError::EmptyAccessToken);
808 }
809
810 Ok(response)
811 }
812
813 async fn listen_for_callback(
814 &self,
815 ) -> Result<
816 (
817 String,
818 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
819 ),
820 KontextDevError,
821 > {
822 let redirect_uri = self.config.redirect_uri.trim().to_string();
823 let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
824 url: redirect_uri.clone(),
825 source,
826 })?;
827
828 if parsed.scheme() != "http" {
829 return Err(KontextDevError::OAuthCallbackError {
830 error: "redirect_uri must use http".to_string(),
831 });
832 }
833
834 if parsed.query().is_some() || parsed.fragment().is_some() {
835 return Err(KontextDevError::OAuthCallbackError {
836 error: "redirect_uri must not include query parameters or fragments".to_string(),
837 });
838 }
839
840 let host = parsed
841 .host_str()
842 .ok_or_else(|| KontextDevError::OAuthCallbackError {
843 error: "redirect_uri host is missing".to_string(),
844 })?;
845
846 let port = parsed
847 .port()
848 .ok_or_else(|| KontextDevError::OAuthCallbackError {
849 error: "redirect_uri must include an explicit port".to_string(),
850 })?;
851
852 let callback_path = parsed.path().to_string();
853 let bind_addr = if host.contains(':') {
854 format!("[{host}]:{port}")
855 } else {
856 format!("{host}:{port}")
857 };
858
859 let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
860 KontextDevError::OAuthCallbackError {
861 error: format!("failed to start callback server at {bind_addr}: {err}"),
862 }
863 })?);
864
865 let (tx, rx) = oneshot::channel();
866 let _join = spawn_callback_server(server.clone(), callback_path, tx);
867 let _guard = CallbackServerGuard { server };
868 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
869
870 let fut = async move {
871 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
872 .await
873 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
874 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
875 drop(_guard);
876 Ok(payload)
877 };
878
879 Ok((redirect_uri, fut))
880 }
881
882 fn token_cache_path(&self) -> Option<PathBuf> {
883 if let Some(explicit) = &self.config.token_cache_path {
884 return Some(PathBuf::from(explicit));
885 }
886
887 let home = dirs::home_dir()?;
888 let mut path = home;
889 path.push(".kontext-dev");
890 path.push("tokens");
891
892 let sanitized_client_id: String = self
893 .config
894 .client_id
895 .chars()
896 .map(|ch| {
897 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
898 ch
899 } else {
900 '_'
901 }
902 })
903 .collect();
904
905 path.push(format!("{sanitized_client_id}.json"));
906 Some(path)
907 }
908
909 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
910 let Some(path) = self.token_cache_path() else {
911 return Ok(None);
912 };
913
914 let raw = match fs::read_to_string(&path) {
915 Ok(raw) => raw,
916 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
917 Err(source) => {
918 return Err(KontextDevError::TokenCacheRead {
919 path: path.display().to_string(),
920 source,
921 });
922 }
923 };
924
925 serde_json::from_str(&raw).map(Some).map_err(|source| {
926 KontextDevError::TokenCacheDeserialize {
927 path: path.display().to_string(),
928 source,
929 }
930 })
931 }
932
933 fn write_cache(
934 &self,
935 identity: &AccessToken,
936 gateway: &TokenExchangeToken,
937 ) -> Result<(), KontextDevError> {
938 let Some(path) = self.token_cache_path() else {
939 return Ok(());
940 };
941
942 if let Some(parent) = path.parent() {
943 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
944 path: parent.display().to_string(),
945 source,
946 })?;
947 }
948
949 let payload = TokenCacheFile {
950 client_id: self.config.client_id.clone(),
951 resource: self.config.resource.clone(),
952 identity: CachedAccessToken::from_access_token(identity)?,
953 gateway: CachedAccessToken::from_token_exchange(gateway)?,
954 };
955
956 let serialized = serde_json::to_string_pretty(&payload)
957 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
958
959 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
960 path: path.display().to_string(),
961 source,
962 })
963 }
964}
965
966pub async fn request_access_token(
970 config: &KontextDevConfig,
971) -> Result<AccessToken, KontextDevError> {
972 let token_url = resolve_token_url(config)?;
973
974 let mut body = vec![
975 ("grant_type", "client_credentials".to_string()),
976 ("scope", config.scope.clone()),
977 ];
978
979 let auth_header = if let Some(secret) = &config.client_secret {
980 let raw = format!("{}:{}", config.client_id, secret);
981 Some(format!(
982 "Basic {}",
983 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
984 ))
985 } else {
986 body.push(("client_id", config.client_id.clone()));
987 None
988 };
989
990 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
991}
992
993async fn post_token(
994 http: &Client,
995 token_url: &str,
996 authorization: Option<&str>,
997 body: &[(impl AsRef<str>, String)],
998) -> Result<AccessToken, KontextDevError> {
999 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1000
1001 let response = post_form_with_optional_auth::<AccessToken>(
1002 http,
1003 token_url,
1004 authorization.map(ToString::to_string),
1005 &body_vec,
1006 )
1007 .await
1008 .map_err(|message| KontextDevError::TokenRequest {
1009 token_url: token_url.to_string(),
1010 message,
1011 })?;
1012
1013 Ok(response)
1014}
1015
1016async fn post_form_with_optional_auth<T>(
1017 http: &Client,
1018 url: &str,
1019 authorization: Option<String>,
1020 body: &[(&str, String)],
1021) -> Result<T, String>
1022where
1023 T: DeserializeOwned,
1024{
1025 let mut request = http
1026 .post(url)
1027 .header("Content-Type", "application/x-www-form-urlencoded");
1028
1029 if let Some(header) = authorization {
1030 request = request.header("Authorization", header);
1031 }
1032
1033 let form = body
1034 .iter()
1035 .map(|(k, v)| (k.to_string(), v.to_string()))
1036 .collect::<Vec<(String, String)>>();
1037
1038 let response = request
1039 .form(&form)
1040 .send()
1041 .await
1042 .map_err(|err| err.to_string())?;
1043
1044 if !response.status().is_success() {
1045 return Err(build_error_message(response).await);
1046 }
1047
1048 response.json::<T>().await.map_err(|err| err.to_string())
1049}
1050
1051async fn build_error_message(response: reqwest::Response) -> String {
1052 let status = response.status();
1053 let fallback = format!(
1054 "{} {}",
1055 status.as_u16(),
1056 status.canonical_reason().unwrap_or("")
1057 );
1058
1059 let body = response.text().await.unwrap_or_default();
1060 if body.is_empty() {
1061 return fallback.trim().to_string();
1062 }
1063
1064 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1065 if let Some(description) = parsed.error_description {
1066 return description;
1067 }
1068 if let Some(message) = parsed.message {
1069 return message;
1070 }
1071 if let Some(error) = parsed.error {
1072 return error;
1073 }
1074 }
1075
1076 format!("{fallback}: {body}")
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use super::*;
1082
1083 fn config() -> KontextDevConfig {
1084 KontextDevConfig {
1085 server: Some("http://localhost:4000".to_string()),
1086 mcp_url: None,
1087 token_url: None,
1088 client_id: "client_123".to_string(),
1089 client_secret: None,
1090 scope: DEFAULT_SCOPE.to_string(),
1091 server_name: DEFAULT_SERVER_NAME.to_string(),
1092 resource: DEFAULT_RESOURCE.to_string(),
1093 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1094 integration_return_to: None,
1095 open_connect_page_on_login: true,
1096 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1097 token_cache_path: None,
1098 redirect_uri: "http://localhost:3333/callback".to_string(),
1099 }
1100 }
1101
1102 #[test]
1103 fn create_connect_url_uses_hosted_ui() {
1104 let client = KontextDevClient::new(config());
1105 let url = client
1106 .integration_connect_url("session-123")
1107 .expect("url should be built");
1108 assert_eq!(
1109 url,
1110 "https://app.kontext.dev/oauth/connect?session=session-123"
1111 );
1112 }
1113
1114 #[test]
1115 fn jwt_exp_decode_reads_exp() {
1116 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1117 let payload =
1118 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1119 let token = format!("{header}.{payload}.sig");
1120 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1121 }
1122
1123 #[test]
1124 fn oauth_metadata_parses_authorization_endpoint() {
1125 let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1126 r#"{
1127 "issuer": "https://issuer.example.com",
1128 "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1129 }"#,
1130 )
1131 .expect("metadata should parse");
1132
1133 assert_eq!(
1134 metadata.authorization_endpoint.as_deref(),
1135 Some("https://issuer.example.com/oauth2/auth")
1136 );
1137 }
1138}