Skip to main content

kontext_dev/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const CONNECT_CALLBACK_PATH: &str = "/callback";
41const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
42
43#[derive(Debug, thiserror::Error)]
44pub enum KontextDevError {
45    #[error(transparent)]
46    Core(#[from] kontext_dev_core::KontextDevCoreError),
47    #[error("failed to parse URL `{url}`")]
48    InvalidUrl {
49        url: String,
50        source: url::ParseError,
51    },
52    #[error("failed to open browser for OAuth authorization")]
53    BrowserOpenFailed,
54    #[error("OAuth callback timed out after {timeout_seconds}s")]
55    OAuthCallbackTimeout { timeout_seconds: i64 },
56    #[error("OAuth callback channel was unexpectedly cancelled")]
57    OAuthCallbackCancelled,
58    #[error("OAuth callback is missing the authorization code")]
59    MissingAuthorizationCode,
60    #[error("OAuth callback returned an error: {error}")]
61    OAuthCallbackError { error: String },
62    #[error("OAuth state mismatch")]
63    InvalidOAuthState,
64    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
65    TokenRequest { token_url: String, message: String },
66    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
67    TokenExchange { resource: String, message: String },
68    #[error("Kontext-Dev connect session request failed: {message}")]
69    ConnectSession { message: String },
70    #[error("Kontext-Dev integration OAuth init failed: {message}")]
71    IntegrationOAuthInit { message: String },
72    #[error("failed to read token cache at `{path}`: {source}")]
73    TokenCacheRead {
74        path: String,
75        source: std::io::Error,
76    },
77    #[error("failed to write token cache at `{path}`: {source}")]
78    TokenCacheWrite {
79        path: String,
80        source: std::io::Error,
81    },
82    #[error("failed to deserialize token cache at `{path}`: {source}")]
83    TokenCacheDeserialize {
84        path: String,
85        source: serde_json::Error,
86    },
87    #[error("failed to serialize token cache: {source}")]
88    TokenCacheSerialize { source: serde_json::Error },
89    #[error("Kontext-Dev access token is empty")]
90    EmptyAccessToken,
91    #[error("missing integration UI URL; set `integration_ui_url` in config")]
92    MissingIntegrationUiUrl,
93}
94
95#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
96pub struct ConnectSession {
97    #[serde(rename = "sessionId")]
98    pub session_id: String,
99    #[serde(rename = "expiresAt")]
100    pub expires_at: String,
101}
102
103#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
104pub struct IntegrationOAuthInitResponse {
105    #[serde(rename = "authorizationUrl")]
106    pub authorization_url: String,
107    #[serde(default)]
108    pub state: Option<String>,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
112pub struct IntegrationConnectionStatus {
113    pub connected: bool,
114    #[serde(default)]
115    pub expires_at: Option<String>,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
119pub struct KontextAuthSession {
120    pub identity_token: AccessToken,
121    pub gateway_token: TokenExchangeToken,
122    pub browser_auth_performed: bool,
123}
124
125#[derive(Clone, Debug, Deserialize, Serialize)]
126struct CachedAccessToken {
127    access_token: String,
128    token_type: String,
129    refresh_token: Option<String>,
130    scope: Option<String>,
131    expires_at_unix_ms: Option<u64>,
132}
133
134#[derive(Clone, Debug, Deserialize, Serialize)]
135struct TokenCacheFile {
136    client_id: String,
137    resource: String,
138    identity: CachedAccessToken,
139    gateway: CachedAccessToken,
140}
141
142impl CachedAccessToken {
143    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
144        if token.access_token.is_empty() {
145            return Err(KontextDevError::EmptyAccessToken);
146        }
147
148        Ok(Self {
149            access_token: token.access_token.clone(),
150            token_type: token.token_type.clone(),
151            refresh_token: token.refresh_token.clone(),
152            scope: token.scope.clone(),
153            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
154        })
155    }
156
157    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
158        if token.access_token.is_empty() {
159            return Err(KontextDevError::EmptyAccessToken);
160        }
161
162        Ok(Self {
163            access_token: token.access_token.clone(),
164            token_type: token.token_type.clone(),
165            refresh_token: token.refresh_token.clone(),
166            scope: token.scope.clone(),
167            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
168        })
169    }
170
171    fn is_valid(&self) -> bool {
172        match self.expires_at_unix_ms {
173            Some(expires_at) => {
174                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
175                now_unix_ms().saturating_add(buffer_ms) < expires_at
176            }
177            None => true,
178        }
179    }
180
181    fn to_access_token(&self) -> AccessToken {
182        AccessToken {
183            access_token: self.access_token.clone(),
184            token_type: self.token_type.clone(),
185            expires_in: self
186                .expires_at_unix_ms
187                .and_then(unix_ms_to_relative_seconds),
188            refresh_token: self.refresh_token.clone(),
189            scope: self.scope.clone(),
190        }
191    }
192
193    fn to_token_exchange_token(&self) -> TokenExchangeToken {
194        TokenExchangeToken {
195            access_token: self.access_token.clone(),
196            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
197            token_type: self.token_type.clone(),
198            expires_in: self
199                .expires_at_unix_ms
200                .and_then(unix_ms_to_relative_seconds),
201            scope: self.scope.clone(),
202            refresh_token: self.refresh_token.clone(),
203        }
204    }
205}
206
207fn now_unix_ms() -> u64 {
208    SystemTime::now()
209        .duration_since(UNIX_EPOCH)
210        .unwrap_or_else(|_| Duration::from_secs(0))
211        .as_millis() as u64
212}
213
214fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
215    if unix_ms <= now_unix_ms() {
216        return Some(0);
217    }
218
219    let delta_ms = unix_ms - now_unix_ms();
220    let secs = delta_ms / 1000;
221    i64::try_from(secs).ok()
222}
223
224fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
225    if let Some(expires_in) = expires_in
226        && expires_in > 0
227    {
228        let expires_in_u64 = u64::try_from(expires_in).ok()?;
229        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
230    }
231
232    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
233}
234
235fn decode_jwt_exp(token: &str) -> Option<u64> {
236    let mut parts = token.split('.');
237    let _header = parts.next()?;
238    let payload = parts.next()?;
239
240    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
241        .decode(payload)
242        .ok()?;
243    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
244    value.get("exp").and_then(|exp| {
245        exp.as_u64()
246            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
247    })
248}
249
250#[derive(Clone, Debug, Deserialize)]
251struct OAuthErrorBody {
252    error: Option<String>,
253    error_description: Option<String>,
254    message: Option<String>,
255}
256
257#[derive(Clone, Debug, Deserialize)]
258struct OAuthAuthorizationServerMetadata {
259    authorization_endpoint: Option<String>,
260}
261
262#[derive(Clone, Debug, Deserialize)]
263struct OAuthCallbackPayload {
264    code: Option<String>,
265    state: Option<String>,
266    error: Option<String>,
267    error_description: Option<String>,
268}
269
270#[derive(Clone, Debug)]
271struct PkcePair {
272    verifier: String,
273    challenge: String,
274}
275
276fn generate_pkce_pair() -> PkcePair {
277    let mut raw = [0u8; 64];
278    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
279
280    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
281    let digest = sha2::Sha256::digest(verifier.as_bytes());
282    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
283
284    PkcePair {
285        verifier,
286        challenge,
287    }
288}
289
290fn generate_state() -> String {
291    let mut bytes = [0u8; 16];
292    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
293    bytes.iter().map(|b| format!("{b:02x}")).collect()
294}
295
296struct CallbackServerGuard {
297    server: Arc<Server>,
298}
299
300impl Drop for CallbackServerGuard {
301    fn drop(&mut self) {
302        self.server.unblock();
303    }
304}
305
306fn parse_callback(url_path: &str) -> Option<OAuthCallbackPayload> {
307    let full = format!("http://localhost{url_path}");
308    let parsed = Url::parse(&full).ok()?;
309
310    if parsed.path() != CONNECT_CALLBACK_PATH {
311        return None;
312    }
313
314    let mut payload = OAuthCallbackPayload {
315        code: None,
316        state: None,
317        error: None,
318        error_description: None,
319    };
320
321    for (key, value) in parsed.query_pairs() {
322        match key.as_ref() {
323            "code" => payload.code = Some(value.to_string()),
324            "state" => payload.state = Some(value.to_string()),
325            "error" => payload.error = Some(value.to_string()),
326            "error_description" => payload.error_description = Some(value.to_string()),
327            _ => {}
328        }
329    }
330
331    Some(payload)
332}
333
334fn spawn_callback_server(
335    server: Arc<Server>,
336    tx: oneshot::Sender<OAuthCallbackPayload>,
337) -> tokio::task::JoinHandle<()> {
338    tokio::task::spawn_blocking(move || {
339        while let Ok(request) = server.recv() {
340            let path = request.url().to_string();
341            if let Some(payload) = parse_callback(&path) {
342                let response = Response::from_string(
343                    "Authentication complete. You can return to your terminal.",
344                );
345                let _ = request.respond(response);
346                let _ = tx.send(payload);
347                break;
348            }
349
350            let response = Response::from_string("Invalid callback").with_status_code(400);
351            let _ = request.respond(response);
352        }
353    })
354}
355
356pub struct KontextDevClient {
357    config: KontextDevConfig,
358    http: Client,
359}
360
361impl KontextDevClient {
362    pub fn new(config: KontextDevConfig) -> Self {
363        Self {
364            config,
365            http: Client::new(),
366        }
367    }
368
369    pub fn config(&self) -> &KontextDevConfig {
370        &self.config
371    }
372
373    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
374        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
375    }
376
377    pub fn token_url(&self) -> Result<String, KontextDevError> {
378        resolve_token_url(&self.config).map_err(KontextDevError::from)
379    }
380
381    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
382        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
383    }
384
385    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
386        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
387    }
388
389    /// Authenticate a user via browser PKCE, exchange the identity token to a
390    /// resource-scoped MCP gateway token, and persist cache for future runs.
391    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
392        let resource = self.config.resource.clone();
393
394        if let Some(cache) = self.read_cache()?
395            && cache.client_id == self.config.client_id
396            && cache.resource == resource
397            && cache.gateway.is_valid()
398            && cache.identity.is_valid()
399        {
400            return Ok(KontextAuthSession {
401                identity_token: cache.identity.to_access_token(),
402                gateway_token: cache.gateway.to_token_exchange_token(),
403                browser_auth_performed: false,
404            });
405        }
406
407        if let Some(cache) = self.read_cache()?
408            && cache.client_id == self.config.client_id
409            && cache.resource == resource
410        {
411            let maybe_refreshed = if cache.identity.is_valid() {
412                Some(cache.identity.to_access_token())
413            } else if let Some(refresh_token) = &cache.identity.refresh_token {
414                Some(self.refresh_identity_token(refresh_token).await?)
415            } else {
416                None
417            };
418
419            if let Some(identity_token) = maybe_refreshed {
420                let gateway_token = self
421                    .exchange_for_resource(&identity_token.access_token, &resource, None)
422                    .await?;
423                self.write_cache(&identity_token, &gateway_token)?;
424                return Ok(KontextAuthSession {
425                    identity_token,
426                    gateway_token,
427                    browser_auth_performed: false,
428                });
429            }
430        }
431
432        let identity_token = self.authorize_with_browser_pkce().await?;
433        let gateway_token = self
434            .exchange_for_resource(&identity_token.access_token, &resource, None)
435            .await?;
436
437        self.write_cache(&identity_token, &gateway_token)?;
438
439        Ok(KontextAuthSession {
440            identity_token,
441            gateway_token,
442            browser_auth_performed: true,
443        })
444    }
445
446    /// Request a short-lived connect session and build the hosted integration
447    /// connect URL (`.../oauth/connect?session=...`).
448    pub async fn create_integration_connect_url(
449        &self,
450        gateway_access_token: &str,
451    ) -> Result<String, KontextDevError> {
452        let session = self.create_connect_session(gateway_access_token).await?;
453        self.integration_connect_url(&session.session_id)
454    }
455
456    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
457        if session_id.trim().is_empty() {
458            return Err(KontextDevError::ConnectSession {
459                message: "connect session id is empty".to_string(),
460            });
461        }
462
463        let base = if let Some(explicit) = &self.config.integration_ui_url {
464            explicit.trim_end_matches('/').to_string()
465        } else {
466            let server = resolve_server_base_url(&self.config)?;
467            if server.contains("api.kontext.dev") {
468                "https://app.kontext.dev".to_string()
469            } else {
470                server
471            }
472        };
473
474        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
475            url: base.clone(),
476            source,
477        })?;
478        url.set_path("/oauth/connect");
479        url.query_pairs_mut().append_pair("session", session_id);
480        Ok(url.to_string())
481    }
482
483    /// Opens the integration connect page in the browser.
484    pub async fn open_integration_connect_page(
485        &self,
486        gateway_access_token: &str,
487    ) -> Result<String, KontextDevError> {
488        let url = self
489            .create_integration_connect_url(gateway_access_token)
490            .await?;
491
492        if webbrowser::open(&url).is_err() {
493            return Err(KontextDevError::BrowserOpenFailed);
494        }
495
496        Ok(url)
497    }
498
499    pub async fn create_connect_session(
500        &self,
501        gateway_access_token: &str,
502    ) -> Result<ConnectSession, KontextDevError> {
503        let url = self.connect_session_url()?;
504        let response = self
505            .http
506            .post(&url)
507            .header("Authorization", format!("Bearer {gateway_access_token}"))
508            .send()
509            .await
510            .map_err(|err| KontextDevError::ConnectSession {
511                message: err.to_string(),
512            })?;
513
514        if !response.status().is_success() {
515            let message = build_error_message(response).await;
516            return Err(KontextDevError::ConnectSession { message });
517        }
518
519        response
520            .json::<ConnectSession>()
521            .await
522            .map_err(|err| KontextDevError::ConnectSession {
523                message: err.to_string(),
524            })
525    }
526
527    pub async fn initiate_integration_oauth(
528        &self,
529        gateway_access_token: &str,
530        integration_id: &str,
531        return_to: Option<&str>,
532    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
533        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
534
535        let mut request = self
536            .http
537            .post(&url)
538            .header("Authorization", format!("Bearer {gateway_access_token}"));
539
540        if let Some(return_to) = return_to {
541            request = request.json(&serde_json::json!({ "returnTo": return_to }));
542        }
543
544        let response =
545            request
546                .send()
547                .await
548                .map_err(|err| KontextDevError::IntegrationOAuthInit {
549                    message: err.to_string(),
550                })?;
551
552        if !response.status().is_success() {
553            let message = build_error_message(response).await;
554            return Err(KontextDevError::IntegrationOAuthInit { message });
555        }
556
557        response
558            .json::<IntegrationOAuthInitResponse>()
559            .await
560            .map_err(|err| KontextDevError::IntegrationOAuthInit {
561                message: err.to_string(),
562            })
563    }
564
565    pub async fn integration_connection_status(
566        &self,
567        gateway_access_token: &str,
568        integration_id: &str,
569    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
570        let url = resolve_integration_connection_url(&self.config, integration_id)?;
571        let response = self
572            .http
573            .get(url)
574            .header("Authorization", format!("Bearer {gateway_access_token}"))
575            .send()
576            .await
577            .map_err(|err| KontextDevError::IntegrationOAuthInit {
578                message: err.to_string(),
579            })?;
580
581        if !response.status().is_success() {
582            let message = build_error_message(response).await;
583            return Err(KontextDevError::IntegrationOAuthInit { message });
584        }
585
586        response
587            .json::<IntegrationConnectionStatus>()
588            .await
589            .map_err(|err| KontextDevError::IntegrationOAuthInit {
590                message: err.to_string(),
591            })
592    }
593
594    pub async fn wait_for_integration_connection(
595        &self,
596        gateway_access_token: &str,
597        integration_id: &str,
598        timeout_ms: u64,
599        interval_ms: u64,
600    ) -> Result<bool, KontextDevError> {
601        let started = now_unix_ms();
602
603        loop {
604            let status = self
605                .integration_connection_status(gateway_access_token, integration_id)
606                .await?;
607            if status.connected {
608                return Ok(true);
609            }
610
611            if now_unix_ms().saturating_sub(started) >= timeout_ms {
612                return Ok(false);
613            }
614
615            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
616        }
617    }
618
619    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
620        let auth_url = self.resolve_authorization_url().await?;
621        let token_url = self.token_url()?;
622        let pkce = generate_pkce_pair();
623        let state = generate_state();
624
625        let (callback_url, callback_payload) = self.listen_for_callback().await?;
626
627        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
628            url: auth_url.clone(),
629            source,
630        })?;
631
632        url.query_pairs_mut()
633            .append_pair("client_id", &self.config.client_id)
634            .append_pair("response_type", "code")
635            .append_pair("redirect_uri", &callback_url)
636            .append_pair("state", &state)
637            .append_pair("scope", &self.config.scope)
638            .append_pair("code_challenge_method", "S256")
639            .append_pair("code_challenge", &pkce.challenge);
640
641        if webbrowser::open(url.as_str()).is_err() {
642            return Err(KontextDevError::BrowserOpenFailed);
643        }
644
645        let payload = callback_payload.await?;
646
647        if let Some(error) = payload.error {
648            let with_details = payload
649                .error_description
650                .map(|description| format!("{error}: {description}"))
651                .unwrap_or(error);
652            return Err(KontextDevError::OAuthCallbackError {
653                error: with_details,
654            });
655        }
656
657        if payload.state.as_deref() != Some(state.as_str()) {
658            return Err(KontextDevError::InvalidOAuthState);
659        }
660
661        let code = payload
662            .code
663            .ok_or(KontextDevError::MissingAuthorizationCode)?;
664
665        let mut body = vec![
666            ("grant_type", "authorization_code".to_string()),
667            ("code", code),
668            ("redirect_uri", callback_url),
669            ("client_id", self.config.client_id.clone()),
670            ("code_verifier", pkce.verifier),
671        ];
672
673        if let Some(client_secret) = &self.config.client_secret {
674            body.push(("client_secret", client_secret.clone()));
675        }
676
677        post_token(&self.http, &token_url, None, &body).await
678    }
679
680    async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
681        if let Some(discovered) = self.discover_authorization_endpoint().await {
682            return Ok(discovered);
683        }
684
685        let authorize_url = self.authorize_url()?;
686        if !self.endpoint_is_missing(&authorize_url).await {
687            return Ok(authorize_url);
688        }
689
690        let server_base = resolve_server_base_url(&self.config)?;
691        Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
692    }
693
694    async fn discover_authorization_endpoint(&self) -> Option<String> {
695        let base = resolve_server_base_url(&self.config).ok()?;
696        let base = base.trim_end_matches('/');
697
698        let candidates = [
699            format!("{base}/.well-known/oauth-authorization-server/mcp"),
700            format!("{base}/.well-known/oauth-authorization-server"),
701        ];
702
703        for url in candidates {
704            let response = match self.http.get(&url).send().await {
705                Ok(response) => response,
706                Err(_) => continue,
707            };
708
709            if !response.status().is_success() {
710                continue;
711            }
712
713            let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
714                Ok(metadata) => metadata,
715                Err(_) => continue,
716            };
717
718            let Some(endpoint) = metadata.authorization_endpoint else {
719                continue;
720            };
721
722            if Url::parse(&endpoint).is_ok() {
723                return Some(endpoint);
724            }
725        }
726
727        None
728    }
729
730    async fn endpoint_is_missing(&self, url: &str) -> bool {
731        let probe_client = match reqwest::Client::builder()
732            .redirect(reqwest::redirect::Policy::none())
733            .build()
734        {
735            Ok(client) => client,
736            Err(_) => return false,
737        };
738
739        match probe_client.get(url).send().await {
740            Ok(response) => response.status() == StatusCode::NOT_FOUND,
741            Err(_) => false,
742        }
743    }
744
745    async fn refresh_identity_token(
746        &self,
747        refresh_token: &str,
748    ) -> Result<AccessToken, KontextDevError> {
749        let token_url = self.token_url()?;
750        let mut body = vec![
751            ("grant_type", "refresh_token".to_string()),
752            ("refresh_token", refresh_token.to_string()),
753            ("client_id", self.config.client_id.clone()),
754        ];
755
756        if let Some(client_secret) = &self.config.client_secret {
757            body.push(("client_secret", client_secret.clone()));
758        }
759
760        post_token(&self.http, &token_url, None, &body).await
761    }
762
763    pub async fn exchange_for_resource(
764        &self,
765        subject_token: &str,
766        resource: &str,
767        scope: Option<&str>,
768    ) -> Result<TokenExchangeToken, KontextDevError> {
769        let token_url = self.token_url()?;
770
771        let mut body = vec![
772            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
773            ("subject_token", subject_token.to_string()),
774            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
775            ("resource", resource.to_string()),
776        ];
777
778        if let Some(scope) = scope {
779            body.push(("scope", scope.to_string()));
780        }
781
782        let auth_header = self.config.client_secret.as_ref().map(|secret| {
783            let raw = format!("{}:{}", self.config.client_id, secret);
784            format!(
785                "Basic {}",
786                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
787            )
788        });
789
790        if self.config.client_secret.is_none() {
791            body.push(("client_id", self.config.client_id.clone()));
792        }
793
794        let response = post_form_with_optional_auth::<TokenExchangeToken>(
795            &self.http,
796            &token_url,
797            auth_header,
798            &body,
799        )
800        .await
801        .map_err(|message| KontextDevError::TokenExchange {
802            resource: resource.to_string(),
803            message,
804        })?;
805
806        if response.access_token.is_empty() {
807            return Err(KontextDevError::EmptyAccessToken);
808        }
809
810        Ok(response)
811    }
812
813    async fn listen_for_callback(
814        &self,
815    ) -> Result<
816        (
817            String,
818            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
819        ),
820        KontextDevError,
821    > {
822        let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| {
823            KontextDevError::OAuthCallbackError {
824                error: format!("failed to start callback server: {err}"),
825            }
826        })?);
827
828        let callback_url = match server.server_addr() {
829            tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
830                format!(
831                    "http://{}:{}{}",
832                    addr.ip(),
833                    addr.port(),
834                    CONNECT_CALLBACK_PATH
835                )
836            }
837            tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
838                format!(
839                    "http://[{}]:{}{}",
840                    addr.ip(),
841                    addr.port(),
842                    CONNECT_CALLBACK_PATH
843                )
844            }
845            #[cfg(not(target_os = "windows"))]
846            _ => {
847                return Err(KontextDevError::OAuthCallbackError {
848                    error: "unable to determine callback address".to_string(),
849                });
850            }
851        };
852
853        let (tx, rx) = oneshot::channel();
854        let _join = spawn_callback_server(server.clone(), tx);
855        let _guard = CallbackServerGuard { server };
856        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
857
858        let fut = async move {
859            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
860                .await
861                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
862                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
863            drop(_guard);
864            Ok(payload)
865        };
866
867        Ok((callback_url, fut))
868    }
869
870    fn token_cache_path(&self) -> Option<PathBuf> {
871        if let Some(explicit) = &self.config.token_cache_path {
872            return Some(PathBuf::from(explicit));
873        }
874
875        let home = dirs::home_dir()?;
876        let mut path = home;
877        path.push(".kontext-dev");
878        path.push("tokens");
879
880        let sanitized_client_id: String = self
881            .config
882            .client_id
883            .chars()
884            .map(|ch| {
885                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
886                    ch
887                } else {
888                    '_'
889                }
890            })
891            .collect();
892
893        path.push(format!("{sanitized_client_id}.json"));
894        Some(path)
895    }
896
897    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
898        let Some(path) = self.token_cache_path() else {
899            return Ok(None);
900        };
901
902        let raw = match fs::read_to_string(&path) {
903            Ok(raw) => raw,
904            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
905            Err(source) => {
906                return Err(KontextDevError::TokenCacheRead {
907                    path: path.display().to_string(),
908                    source,
909                });
910            }
911        };
912
913        serde_json::from_str(&raw).map(Some).map_err(|source| {
914            KontextDevError::TokenCacheDeserialize {
915                path: path.display().to_string(),
916                source,
917            }
918        })
919    }
920
921    fn write_cache(
922        &self,
923        identity: &AccessToken,
924        gateway: &TokenExchangeToken,
925    ) -> Result<(), KontextDevError> {
926        let Some(path) = self.token_cache_path() else {
927            return Ok(());
928        };
929
930        if let Some(parent) = path.parent() {
931            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
932                path: parent.display().to_string(),
933                source,
934            })?;
935        }
936
937        let payload = TokenCacheFile {
938            client_id: self.config.client_id.clone(),
939            resource: self.config.resource.clone(),
940            identity: CachedAccessToken::from_access_token(identity)?,
941            gateway: CachedAccessToken::from_token_exchange(gateway)?,
942        };
943
944        let serialized = serde_json::to_string_pretty(&payload)
945            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
946
947        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
948            path: path.display().to_string(),
949            source,
950        })
951    }
952}
953
954/// Legacy helper that uses the client-credentials grant.
955///
956/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
957pub async fn request_access_token(
958    config: &KontextDevConfig,
959) -> Result<AccessToken, KontextDevError> {
960    let token_url = resolve_token_url(config)?;
961
962    let mut body = vec![
963        ("grant_type", "client_credentials".to_string()),
964        ("scope", config.scope.clone()),
965    ];
966
967    let auth_header = if let Some(secret) = &config.client_secret {
968        let raw = format!("{}:{}", config.client_id, secret);
969        Some(format!(
970            "Basic {}",
971            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
972        ))
973    } else {
974        body.push(("client_id", config.client_id.clone()));
975        None
976    };
977
978    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
979}
980
981async fn post_token(
982    http: &Client,
983    token_url: &str,
984    authorization: Option<&str>,
985    body: &[(impl AsRef<str>, String)],
986) -> Result<AccessToken, KontextDevError> {
987    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
988
989    let response = post_form_with_optional_auth::<AccessToken>(
990        http,
991        token_url,
992        authorization.map(ToString::to_string),
993        &body_vec,
994    )
995    .await
996    .map_err(|message| KontextDevError::TokenRequest {
997        token_url: token_url.to_string(),
998        message,
999    })?;
1000
1001    Ok(response)
1002}
1003
1004async fn post_form_with_optional_auth<T>(
1005    http: &Client,
1006    url: &str,
1007    authorization: Option<String>,
1008    body: &[(&str, String)],
1009) -> Result<T, String>
1010where
1011    T: DeserializeOwned,
1012{
1013    let mut request = http
1014        .post(url)
1015        .header("Content-Type", "application/x-www-form-urlencoded");
1016
1017    if let Some(header) = authorization {
1018        request = request.header("Authorization", header);
1019    }
1020
1021    let form = body
1022        .iter()
1023        .map(|(k, v)| (k.to_string(), v.to_string()))
1024        .collect::<Vec<(String, String)>>();
1025
1026    let response = request
1027        .form(&form)
1028        .send()
1029        .await
1030        .map_err(|err| err.to_string())?;
1031
1032    if !response.status().is_success() {
1033        return Err(build_error_message(response).await);
1034    }
1035
1036    response.json::<T>().await.map_err(|err| err.to_string())
1037}
1038
1039async fn build_error_message(response: reqwest::Response) -> String {
1040    let status = response.status();
1041    let fallback = format!(
1042        "{} {}",
1043        status.as_u16(),
1044        status.canonical_reason().unwrap_or("")
1045    );
1046
1047    let body = response.text().await.unwrap_or_default();
1048    if body.is_empty() {
1049        return fallback.trim().to_string();
1050    }
1051
1052    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1053        if let Some(description) = parsed.error_description {
1054            return description;
1055        }
1056        if let Some(message) = parsed.message {
1057            return message;
1058        }
1059        if let Some(error) = parsed.error {
1060            return error;
1061        }
1062    }
1063
1064    format!("{fallback}: {body}")
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070
1071    fn config() -> KontextDevConfig {
1072        KontextDevConfig {
1073            server: Some("http://localhost:4000".to_string()),
1074            mcp_url: None,
1075            token_url: None,
1076            client_id: "client_123".to_string(),
1077            client_secret: None,
1078            scope: DEFAULT_SCOPE.to_string(),
1079            server_name: DEFAULT_SERVER_NAME.to_string(),
1080            resource: DEFAULT_RESOURCE.to_string(),
1081            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1082            integration_return_to: None,
1083            open_connect_page_on_login: true,
1084            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1085            token_cache_path: None,
1086        }
1087    }
1088
1089    #[test]
1090    fn create_connect_url_uses_hosted_ui() {
1091        let client = KontextDevClient::new(config());
1092        let url = client
1093            .integration_connect_url("session-123")
1094            .expect("url should be built");
1095        assert_eq!(
1096            url,
1097            "https://app.kontext.dev/oauth/connect?session=session-123"
1098        );
1099    }
1100
1101    #[test]
1102    fn jwt_exp_decode_reads_exp() {
1103        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1104        let payload =
1105            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1106        let token = format!("{header}.{payload}.sig");
1107        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1108    }
1109
1110    #[test]
1111    fn oauth_metadata_parses_authorization_endpoint() {
1112        let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1113            r#"{
1114              "issuer": "https://issuer.example.com",
1115              "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1116            }"#,
1117        )
1118        .expect("metadata should parse");
1119
1120        assert_eq!(
1121            metadata.authorization_endpoint.as_deref(),
1122            Some("https://issuer.example.com/oauth2/auth")
1123        );
1124    }
1125}