1use base64::Engine;
2use reqwest::Client;
3use serde::Deserialize;
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use sha2::Digest;
7use std::fs;
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11use std::time::SystemTime;
12use std::time::UNIX_EPOCH;
13use tiny_http::Response;
14use tiny_http::Server;
15use tokio::sync::oneshot;
16use tokio::time::timeout;
17use url::Url;
18
19pub use kontext_dev_core::AccessToken;
20pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
21pub use kontext_dev_core::DEFAULT_RESOURCE;
22pub use kontext_dev_core::DEFAULT_SCOPE;
23pub use kontext_dev_core::DEFAULT_SERVER_NAME;
24pub use kontext_dev_core::KontextDevConfig;
25pub use kontext_dev_core::KontextDevCoreError;
26pub use kontext_dev_core::TokenExchangeToken;
27pub use kontext_dev_core::build_mcp_url;
28pub use kontext_dev_core::normalize_server_url;
29pub use kontext_dev_core::resolve_authorize_url;
30pub use kontext_dev_core::resolve_connect_session_url;
31pub use kontext_dev_core::resolve_integration_connection_url;
32pub use kontext_dev_core::resolve_integration_oauth_init_url;
33pub use kontext_dev_core::resolve_mcp_url;
34pub use kontext_dev_core::resolve_server_base_url;
35pub use kontext_dev_core::resolve_token_url;
36
37const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
38const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
39const CONNECT_CALLBACK_PATH: &str = "/callback";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44 #[error(transparent)]
45 Core(#[from] kontext_dev_core::KontextDevCoreError),
46 #[error("failed to parse URL `{url}`")]
47 InvalidUrl {
48 url: String,
49 source: url::ParseError,
50 },
51 #[error("failed to open browser for OAuth authorization")]
52 BrowserOpenFailed,
53 #[error("OAuth callback timed out after {timeout_seconds}s")]
54 OAuthCallbackTimeout { timeout_seconds: i64 },
55 #[error("OAuth callback channel was unexpectedly cancelled")]
56 OAuthCallbackCancelled,
57 #[error("OAuth callback is missing the authorization code")]
58 MissingAuthorizationCode,
59 #[error("OAuth callback returned an error: {error}")]
60 OAuthCallbackError { error: String },
61 #[error("OAuth state mismatch")]
62 InvalidOAuthState,
63 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64 TokenRequest { token_url: String, message: String },
65 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66 TokenExchange { resource: String, message: String },
67 #[error("Kontext-Dev connect session request failed: {message}")]
68 ConnectSession { message: String },
69 #[error("Kontext-Dev integration OAuth init failed: {message}")]
70 IntegrationOAuthInit { message: String },
71 #[error("failed to read token cache at `{path}`: {source}")]
72 TokenCacheRead {
73 path: String,
74 source: std::io::Error,
75 },
76 #[error("failed to write token cache at `{path}`: {source}")]
77 TokenCacheWrite {
78 path: String,
79 source: std::io::Error,
80 },
81 #[error("failed to deserialize token cache at `{path}`: {source}")]
82 TokenCacheDeserialize {
83 path: String,
84 source: serde_json::Error,
85 },
86 #[error("failed to serialize token cache: {source}")]
87 TokenCacheSerialize { source: serde_json::Error },
88 #[error("Kontext-Dev access token is empty")]
89 EmptyAccessToken,
90 #[error("missing integration UI URL; set `integration_ui_url` in config")]
91 MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96 #[serde(rename = "sessionId")]
97 pub session_id: String,
98 #[serde(rename = "expiresAt")]
99 pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104 #[serde(rename = "authorizationUrl")]
105 pub authorization_url: String,
106 #[serde(default)]
107 pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112 pub connected: bool,
113 #[serde(default)]
114 pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119 pub identity_token: AccessToken,
120 pub gateway_token: TokenExchangeToken,
121 pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126 access_token: String,
127 token_type: String,
128 refresh_token: Option<String>,
129 scope: Option<String>,
130 expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135 client_id: String,
136 resource: String,
137 identity: CachedAccessToken,
138 gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143 if token.access_token.is_empty() {
144 return Err(KontextDevError::EmptyAccessToken);
145 }
146
147 Ok(Self {
148 access_token: token.access_token.clone(),
149 token_type: token.token_type.clone(),
150 refresh_token: token.refresh_token.clone(),
151 scope: token.scope.clone(),
152 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153 })
154 }
155
156 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157 if token.access_token.is_empty() {
158 return Err(KontextDevError::EmptyAccessToken);
159 }
160
161 Ok(Self {
162 access_token: token.access_token.clone(),
163 token_type: token.token_type.clone(),
164 refresh_token: token.refresh_token.clone(),
165 scope: token.scope.clone(),
166 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167 })
168 }
169
170 fn is_valid(&self) -> bool {
171 match self.expires_at_unix_ms {
172 Some(expires_at) => {
173 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174 now_unix_ms().saturating_add(buffer_ms) < expires_at
175 }
176 None => true,
177 }
178 }
179
180 fn to_access_token(&self) -> AccessToken {
181 AccessToken {
182 access_token: self.access_token.clone(),
183 token_type: self.token_type.clone(),
184 expires_in: self
185 .expires_at_unix_ms
186 .and_then(unix_ms_to_relative_seconds),
187 refresh_token: self.refresh_token.clone(),
188 scope: self.scope.clone(),
189 }
190 }
191
192 fn to_token_exchange_token(&self) -> TokenExchangeToken {
193 TokenExchangeToken {
194 access_token: self.access_token.clone(),
195 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196 token_type: self.token_type.clone(),
197 expires_in: self
198 .expires_at_unix_ms
199 .and_then(unix_ms_to_relative_seconds),
200 scope: self.scope.clone(),
201 refresh_token: self.refresh_token.clone(),
202 }
203 }
204}
205
206fn now_unix_ms() -> u64 {
207 SystemTime::now()
208 .duration_since(UNIX_EPOCH)
209 .unwrap_or_else(|_| Duration::from_secs(0))
210 .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214 if unix_ms <= now_unix_ms() {
215 return Some(0);
216 }
217
218 let delta_ms = unix_ms - now_unix_ms();
219 let secs = delta_ms / 1000;
220 i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224 if let Some(expires_in) = expires_in
225 && expires_in > 0
226 {
227 let expires_in_u64 = u64::try_from(expires_in).ok()?;
228 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229 }
230
231 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235 let mut parts = token.split('.');
236 let _header = parts.next()?;
237 let payload = parts.next()?;
238
239 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240 .decode(payload)
241 .ok()?;
242 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243 value.get("exp").and_then(|exp| {
244 exp.as_u64()
245 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246 })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251 error: Option<String>,
252 error_description: Option<String>,
253 message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthCallbackPayload {
258 code: Option<String>,
259 state: Option<String>,
260 error: Option<String>,
261 error_description: Option<String>,
262}
263
264#[derive(Clone, Debug)]
265struct PkcePair {
266 verifier: String,
267 challenge: String,
268}
269
270fn generate_pkce_pair() -> PkcePair {
271 let mut raw = [0u8; 64];
272 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
273
274 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
275 let digest = sha2::Sha256::digest(verifier.as_bytes());
276 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
277
278 PkcePair {
279 verifier,
280 challenge,
281 }
282}
283
284fn generate_state() -> String {
285 let mut bytes = [0u8; 16];
286 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
287 bytes.iter().map(|b| format!("{b:02x}")).collect()
288}
289
290struct CallbackServerGuard {
291 server: Arc<Server>,
292}
293
294impl Drop for CallbackServerGuard {
295 fn drop(&mut self) {
296 self.server.unblock();
297 }
298}
299
300fn parse_callback(url_path: &str) -> Option<OAuthCallbackPayload> {
301 let full = format!("http://localhost{url_path}");
302 let parsed = Url::parse(&full).ok()?;
303
304 if parsed.path() != CONNECT_CALLBACK_PATH {
305 return None;
306 }
307
308 let mut payload = OAuthCallbackPayload {
309 code: None,
310 state: None,
311 error: None,
312 error_description: None,
313 };
314
315 for (key, value) in parsed.query_pairs() {
316 match key.as_ref() {
317 "code" => payload.code = Some(value.to_string()),
318 "state" => payload.state = Some(value.to_string()),
319 "error" => payload.error = Some(value.to_string()),
320 "error_description" => payload.error_description = Some(value.to_string()),
321 _ => {}
322 }
323 }
324
325 Some(payload)
326}
327
328fn spawn_callback_server(
329 server: Arc<Server>,
330 tx: oneshot::Sender<OAuthCallbackPayload>,
331) -> tokio::task::JoinHandle<()> {
332 tokio::task::spawn_blocking(move || {
333 while let Ok(request) = server.recv() {
334 let path = request.url().to_string();
335 if let Some(payload) = parse_callback(&path) {
336 let response = Response::from_string(
337 "Authentication complete. You can return to your terminal.",
338 );
339 let _ = request.respond(response);
340 let _ = tx.send(payload);
341 break;
342 }
343
344 let response = Response::from_string("Invalid callback").with_status_code(400);
345 let _ = request.respond(response);
346 }
347 })
348}
349
350pub struct KontextDevClient {
351 config: KontextDevConfig,
352 http: Client,
353}
354
355impl KontextDevClient {
356 pub fn new(config: KontextDevConfig) -> Self {
357 Self {
358 config,
359 http: Client::new(),
360 }
361 }
362
363 pub fn config(&self) -> &KontextDevConfig {
364 &self.config
365 }
366
367 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
368 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
369 }
370
371 pub fn token_url(&self) -> Result<String, KontextDevError> {
372 resolve_token_url(&self.config).map_err(KontextDevError::from)
373 }
374
375 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
376 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
377 }
378
379 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
380 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
381 }
382
383 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
386 let resource = self.config.resource.clone();
387
388 if let Some(cache) = self.read_cache()?
389 && cache.client_id == self.config.client_id
390 && cache.resource == resource
391 && cache.gateway.is_valid()
392 && cache.identity.is_valid()
393 {
394 return Ok(KontextAuthSession {
395 identity_token: cache.identity.to_access_token(),
396 gateway_token: cache.gateway.to_token_exchange_token(),
397 browser_auth_performed: false,
398 });
399 }
400
401 if let Some(cache) = self.read_cache()?
402 && cache.client_id == self.config.client_id
403 && cache.resource == resource
404 {
405 let maybe_refreshed = if cache.identity.is_valid() {
406 Some(cache.identity.to_access_token())
407 } else if let Some(refresh_token) = &cache.identity.refresh_token {
408 Some(self.refresh_identity_token(refresh_token).await?)
409 } else {
410 None
411 };
412
413 if let Some(identity_token) = maybe_refreshed {
414 let gateway_token = self
415 .exchange_for_resource(&identity_token.access_token, &resource, None)
416 .await?;
417 self.write_cache(&identity_token, &gateway_token)?;
418 return Ok(KontextAuthSession {
419 identity_token,
420 gateway_token,
421 browser_auth_performed: false,
422 });
423 }
424 }
425
426 let identity_token = self.authorize_with_browser_pkce().await?;
427 let gateway_token = self
428 .exchange_for_resource(&identity_token.access_token, &resource, None)
429 .await?;
430
431 self.write_cache(&identity_token, &gateway_token)?;
432
433 Ok(KontextAuthSession {
434 identity_token,
435 gateway_token,
436 browser_auth_performed: true,
437 })
438 }
439
440 pub async fn create_integration_connect_url(
443 &self,
444 gateway_access_token: &str,
445 ) -> Result<String, KontextDevError> {
446 let session = self.create_connect_session(gateway_access_token).await?;
447 self.integration_connect_url(&session.session_id)
448 }
449
450 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
451 if session_id.trim().is_empty() {
452 return Err(KontextDevError::ConnectSession {
453 message: "connect session id is empty".to_string(),
454 });
455 }
456
457 let base = if let Some(explicit) = &self.config.integration_ui_url {
458 explicit.trim_end_matches('/').to_string()
459 } else {
460 let server = resolve_server_base_url(&self.config)?;
461 if server.contains("api.kontext.dev") {
462 "https://app.kontext.dev".to_string()
463 } else {
464 server
465 }
466 };
467
468 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
469 url: base.clone(),
470 source,
471 })?;
472 url.set_path("/oauth/connect");
473 url.query_pairs_mut().append_pair("session", session_id);
474 Ok(url.to_string())
475 }
476
477 pub async fn open_integration_connect_page(
479 &self,
480 gateway_access_token: &str,
481 ) -> Result<String, KontextDevError> {
482 let url = self
483 .create_integration_connect_url(gateway_access_token)
484 .await?;
485
486 if webbrowser::open(&url).is_err() {
487 return Err(KontextDevError::BrowserOpenFailed);
488 }
489
490 Ok(url)
491 }
492
493 pub async fn create_connect_session(
494 &self,
495 gateway_access_token: &str,
496 ) -> Result<ConnectSession, KontextDevError> {
497 let url = self.connect_session_url()?;
498 let response = self
499 .http
500 .post(&url)
501 .header("Authorization", format!("Bearer {gateway_access_token}"))
502 .send()
503 .await
504 .map_err(|err| KontextDevError::ConnectSession {
505 message: err.to_string(),
506 })?;
507
508 if !response.status().is_success() {
509 let message = build_error_message(response).await;
510 return Err(KontextDevError::ConnectSession { message });
511 }
512
513 response
514 .json::<ConnectSession>()
515 .await
516 .map_err(|err| KontextDevError::ConnectSession {
517 message: err.to_string(),
518 })
519 }
520
521 pub async fn initiate_integration_oauth(
522 &self,
523 gateway_access_token: &str,
524 integration_id: &str,
525 return_to: Option<&str>,
526 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
527 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
528
529 let mut request = self
530 .http
531 .post(&url)
532 .header("Authorization", format!("Bearer {gateway_access_token}"));
533
534 if let Some(return_to) = return_to {
535 request = request.json(&serde_json::json!({ "returnTo": return_to }));
536 }
537
538 let response =
539 request
540 .send()
541 .await
542 .map_err(|err| KontextDevError::IntegrationOAuthInit {
543 message: err.to_string(),
544 })?;
545
546 if !response.status().is_success() {
547 let message = build_error_message(response).await;
548 return Err(KontextDevError::IntegrationOAuthInit { message });
549 }
550
551 response
552 .json::<IntegrationOAuthInitResponse>()
553 .await
554 .map_err(|err| KontextDevError::IntegrationOAuthInit {
555 message: err.to_string(),
556 })
557 }
558
559 pub async fn integration_connection_status(
560 &self,
561 gateway_access_token: &str,
562 integration_id: &str,
563 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
564 let url = resolve_integration_connection_url(&self.config, integration_id)?;
565 let response = self
566 .http
567 .get(url)
568 .header("Authorization", format!("Bearer {gateway_access_token}"))
569 .send()
570 .await
571 .map_err(|err| KontextDevError::IntegrationOAuthInit {
572 message: err.to_string(),
573 })?;
574
575 if !response.status().is_success() {
576 let message = build_error_message(response).await;
577 return Err(KontextDevError::IntegrationOAuthInit { message });
578 }
579
580 response
581 .json::<IntegrationConnectionStatus>()
582 .await
583 .map_err(|err| KontextDevError::IntegrationOAuthInit {
584 message: err.to_string(),
585 })
586 }
587
588 pub async fn wait_for_integration_connection(
589 &self,
590 gateway_access_token: &str,
591 integration_id: &str,
592 timeout_ms: u64,
593 interval_ms: u64,
594 ) -> Result<bool, KontextDevError> {
595 let started = now_unix_ms();
596
597 loop {
598 let status = self
599 .integration_connection_status(gateway_access_token, integration_id)
600 .await?;
601 if status.connected {
602 return Ok(true);
603 }
604
605 if now_unix_ms().saturating_sub(started) >= timeout_ms {
606 return Ok(false);
607 }
608
609 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
610 }
611 }
612
613 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
614 let auth_url = self.authorize_url()?;
615 let token_url = self.token_url()?;
616 let pkce = generate_pkce_pair();
617 let state = generate_state();
618
619 let (callback_url, callback_payload) = self.listen_for_callback().await?;
620
621 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
622 url: auth_url.clone(),
623 source,
624 })?;
625
626 url.query_pairs_mut()
627 .append_pair("client_id", &self.config.client_id)
628 .append_pair("response_type", "code")
629 .append_pair("redirect_uri", &callback_url)
630 .append_pair("state", &state)
631 .append_pair("scope", &self.config.scope)
632 .append_pair("code_challenge_method", "S256")
633 .append_pair("code_challenge", &pkce.challenge);
634
635 if webbrowser::open(url.as_str()).is_err() {
636 return Err(KontextDevError::BrowserOpenFailed);
637 }
638
639 let payload = callback_payload.await?;
640
641 if let Some(error) = payload.error {
642 let with_details = payload
643 .error_description
644 .map(|description| format!("{error}: {description}"))
645 .unwrap_or(error);
646 return Err(KontextDevError::OAuthCallbackError {
647 error: with_details,
648 });
649 }
650
651 if payload.state.as_deref() != Some(state.as_str()) {
652 return Err(KontextDevError::InvalidOAuthState);
653 }
654
655 let code = payload
656 .code
657 .ok_or(KontextDevError::MissingAuthorizationCode)?;
658
659 let mut body = vec![
660 ("grant_type", "authorization_code".to_string()),
661 ("code", code),
662 ("redirect_uri", callback_url),
663 ("client_id", self.config.client_id.clone()),
664 ("code_verifier", pkce.verifier),
665 ];
666
667 if let Some(client_secret) = &self.config.client_secret {
668 body.push(("client_secret", client_secret.clone()));
669 }
670
671 post_token(&self.http, &token_url, None, &body).await
672 }
673
674 async fn refresh_identity_token(
675 &self,
676 refresh_token: &str,
677 ) -> Result<AccessToken, KontextDevError> {
678 let token_url = self.token_url()?;
679 let mut body = vec![
680 ("grant_type", "refresh_token".to_string()),
681 ("refresh_token", refresh_token.to_string()),
682 ("client_id", self.config.client_id.clone()),
683 ];
684
685 if let Some(client_secret) = &self.config.client_secret {
686 body.push(("client_secret", client_secret.clone()));
687 }
688
689 post_token(&self.http, &token_url, None, &body).await
690 }
691
692 pub async fn exchange_for_resource(
693 &self,
694 subject_token: &str,
695 resource: &str,
696 scope: Option<&str>,
697 ) -> Result<TokenExchangeToken, KontextDevError> {
698 let token_url = self.token_url()?;
699
700 let mut body = vec![
701 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
702 ("subject_token", subject_token.to_string()),
703 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
704 ("resource", resource.to_string()),
705 ];
706
707 if let Some(scope) = scope {
708 body.push(("scope", scope.to_string()));
709 }
710
711 let auth_header = self.config.client_secret.as_ref().map(|secret| {
712 let raw = format!("{}:{}", self.config.client_id, secret);
713 format!(
714 "Basic {}",
715 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
716 )
717 });
718
719 if self.config.client_secret.is_none() {
720 body.push(("client_id", self.config.client_id.clone()));
721 }
722
723 let response = post_form_with_optional_auth::<TokenExchangeToken>(
724 &self.http,
725 &token_url,
726 auth_header,
727 &body,
728 )
729 .await
730 .map_err(|message| KontextDevError::TokenExchange {
731 resource: resource.to_string(),
732 message,
733 })?;
734
735 if response.access_token.is_empty() {
736 return Err(KontextDevError::EmptyAccessToken);
737 }
738
739 Ok(response)
740 }
741
742 async fn listen_for_callback(
743 &self,
744 ) -> Result<
745 (
746 String,
747 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
748 ),
749 KontextDevError,
750 > {
751 let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| {
752 KontextDevError::OAuthCallbackError {
753 error: format!("failed to start callback server: {err}"),
754 }
755 })?);
756
757 let callback_url = match server.server_addr() {
758 tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
759 format!(
760 "http://{}:{}{}",
761 addr.ip(),
762 addr.port(),
763 CONNECT_CALLBACK_PATH
764 )
765 }
766 tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
767 format!(
768 "http://[{}]:{}{}",
769 addr.ip(),
770 addr.port(),
771 CONNECT_CALLBACK_PATH
772 )
773 }
774 #[cfg(not(target_os = "windows"))]
775 _ => {
776 return Err(KontextDevError::OAuthCallbackError {
777 error: "unable to determine callback address".to_string(),
778 });
779 }
780 };
781
782 let (tx, rx) = oneshot::channel();
783 let _join = spawn_callback_server(server.clone(), tx);
784 let _guard = CallbackServerGuard { server };
785 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
786
787 let fut = async move {
788 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
789 .await
790 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
791 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
792 drop(_guard);
793 Ok(payload)
794 };
795
796 Ok((callback_url, fut))
797 }
798
799 fn token_cache_path(&self) -> Option<PathBuf> {
800 if let Some(explicit) = &self.config.token_cache_path {
801 return Some(PathBuf::from(explicit));
802 }
803
804 let home = dirs::home_dir()?;
805 let mut path = home;
806 path.push(".kontext-dev");
807 path.push("tokens");
808
809 let sanitized_client_id: String = self
810 .config
811 .client_id
812 .chars()
813 .map(|ch| {
814 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
815 ch
816 } else {
817 '_'
818 }
819 })
820 .collect();
821
822 path.push(format!("{sanitized_client_id}.json"));
823 Some(path)
824 }
825
826 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
827 let Some(path) = self.token_cache_path() else {
828 return Ok(None);
829 };
830
831 let raw = match fs::read_to_string(&path) {
832 Ok(raw) => raw,
833 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
834 Err(source) => {
835 return Err(KontextDevError::TokenCacheRead {
836 path: path.display().to_string(),
837 source,
838 });
839 }
840 };
841
842 serde_json::from_str(&raw).map(Some).map_err(|source| {
843 KontextDevError::TokenCacheDeserialize {
844 path: path.display().to_string(),
845 source,
846 }
847 })
848 }
849
850 fn write_cache(
851 &self,
852 identity: &AccessToken,
853 gateway: &TokenExchangeToken,
854 ) -> Result<(), KontextDevError> {
855 let Some(path) = self.token_cache_path() else {
856 return Ok(());
857 };
858
859 if let Some(parent) = path.parent() {
860 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
861 path: parent.display().to_string(),
862 source,
863 })?;
864 }
865
866 let payload = TokenCacheFile {
867 client_id: self.config.client_id.clone(),
868 resource: self.config.resource.clone(),
869 identity: CachedAccessToken::from_access_token(identity)?,
870 gateway: CachedAccessToken::from_token_exchange(gateway)?,
871 };
872
873 let serialized = serde_json::to_string_pretty(&payload)
874 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
875
876 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
877 path: path.display().to_string(),
878 source,
879 })
880 }
881}
882
883pub async fn request_access_token(
887 config: &KontextDevConfig,
888) -> Result<AccessToken, KontextDevError> {
889 let token_url = resolve_token_url(config)?;
890
891 let mut body = vec![
892 ("grant_type", "client_credentials".to_string()),
893 ("scope", config.scope.clone()),
894 ];
895
896 let auth_header = if let Some(secret) = &config.client_secret {
897 let raw = format!("{}:{}", config.client_id, secret);
898 Some(format!(
899 "Basic {}",
900 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
901 ))
902 } else {
903 body.push(("client_id", config.client_id.clone()));
904 None
905 };
906
907 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
908}
909
910async fn post_token(
911 http: &Client,
912 token_url: &str,
913 authorization: Option<&str>,
914 body: &[(impl AsRef<str>, String)],
915) -> Result<AccessToken, KontextDevError> {
916 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
917
918 let response = post_form_with_optional_auth::<AccessToken>(
919 http,
920 token_url,
921 authorization.map(ToString::to_string),
922 &body_vec,
923 )
924 .await
925 .map_err(|message| KontextDevError::TokenRequest {
926 token_url: token_url.to_string(),
927 message,
928 })?;
929
930 Ok(response)
931}
932
933async fn post_form_with_optional_auth<T>(
934 http: &Client,
935 url: &str,
936 authorization: Option<String>,
937 body: &[(&str, String)],
938) -> Result<T, String>
939where
940 T: DeserializeOwned,
941{
942 let mut request = http
943 .post(url)
944 .header("Content-Type", "application/x-www-form-urlencoded");
945
946 if let Some(header) = authorization {
947 request = request.header("Authorization", header);
948 }
949
950 let form = body
951 .iter()
952 .map(|(k, v)| (k.to_string(), v.to_string()))
953 .collect::<Vec<(String, String)>>();
954
955 let response = request
956 .form(&form)
957 .send()
958 .await
959 .map_err(|err| err.to_string())?;
960
961 if !response.status().is_success() {
962 return Err(build_error_message(response).await);
963 }
964
965 response.json::<T>().await.map_err(|err| err.to_string())
966}
967
968async fn build_error_message(response: reqwest::Response) -> String {
969 let status = response.status();
970 let fallback = format!(
971 "{} {}",
972 status.as_u16(),
973 status.canonical_reason().unwrap_or("")
974 );
975
976 let body = response.text().await.unwrap_or_default();
977 if body.is_empty() {
978 return fallback.trim().to_string();
979 }
980
981 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
982 if let Some(description) = parsed.error_description {
983 return description;
984 }
985 if let Some(message) = parsed.message {
986 return message;
987 }
988 if let Some(error) = parsed.error {
989 return error;
990 }
991 }
992
993 format!("{fallback}: {body}")
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999
1000 fn config() -> KontextDevConfig {
1001 KontextDevConfig {
1002 server: Some("http://localhost:4000".to_string()),
1003 mcp_url: None,
1004 token_url: None,
1005 client_id: "client_123".to_string(),
1006 client_secret: None,
1007 scope: DEFAULT_SCOPE.to_string(),
1008 server_name: DEFAULT_SERVER_NAME.to_string(),
1009 resource: DEFAULT_RESOURCE.to_string(),
1010 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1011 integration_return_to: None,
1012 open_connect_page_on_login: true,
1013 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1014 token_cache_path: None,
1015 }
1016 }
1017
1018 #[test]
1019 fn create_connect_url_uses_hosted_ui() {
1020 let client = KontextDevClient::new(config());
1021 let url = client
1022 .integration_connect_url("session-123")
1023 .expect("url should be built");
1024 assert_eq!(
1025 url,
1026 "https://app.kontext.dev/oauth/connect?session=session-123"
1027 );
1028 }
1029
1030 #[test]
1031 fn jwt_exp_decode_reads_exp() {
1032 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1033 let payload =
1034 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1035 let token = format!("{header}.{payload}.sig");
1036 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1037 }
1038}