Skip to main content

kontext_dev/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44    #[error(transparent)]
45    Core(#[from] kontext_dev_core::KontextDevCoreError),
46    #[error("failed to parse URL `{url}`")]
47    InvalidUrl {
48        url: String,
49        source: url::ParseError,
50    },
51    #[error("failed to open browser for OAuth authorization")]
52    BrowserOpenFailed,
53    #[error("OAuth callback timed out after {timeout_seconds}s")]
54    OAuthCallbackTimeout { timeout_seconds: i64 },
55    #[error("OAuth callback channel was unexpectedly cancelled")]
56    OAuthCallbackCancelled,
57    #[error("OAuth callback is missing the authorization code")]
58    MissingAuthorizationCode,
59    #[error("OAuth callback returned an error: {error}")]
60    OAuthCallbackError { error: String },
61    #[error("OAuth state mismatch")]
62    InvalidOAuthState,
63    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64    TokenRequest { token_url: String, message: String },
65    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66    TokenExchange { resource: String, message: String },
67    #[error("Kontext-Dev connect session request failed: {message}")]
68    ConnectSession { message: String },
69    #[error("Kontext-Dev integration OAuth init failed: {message}")]
70    IntegrationOAuthInit { message: String },
71    #[error("failed to read token cache at `{path}`: {source}")]
72    TokenCacheRead {
73        path: String,
74        source: std::io::Error,
75    },
76    #[error("failed to write token cache at `{path}`: {source}")]
77    TokenCacheWrite {
78        path: String,
79        source: std::io::Error,
80    },
81    #[error("failed to deserialize token cache at `{path}`: {source}")]
82    TokenCacheDeserialize {
83        path: String,
84        source: serde_json::Error,
85    },
86    #[error("failed to serialize token cache: {source}")]
87    TokenCacheSerialize { source: serde_json::Error },
88    #[error("Kontext-Dev access token is empty")]
89    EmptyAccessToken,
90    #[error("missing integration UI URL; set `integration_ui_url` in config")]
91    MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96    #[serde(rename = "sessionId")]
97    pub session_id: String,
98    #[serde(rename = "expiresAt")]
99    pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104    #[serde(rename = "authorizationUrl")]
105    pub authorization_url: String,
106    #[serde(default)]
107    pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112    pub connected: bool,
113    #[serde(default)]
114    pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119    pub identity_token: AccessToken,
120    pub gateway_token: TokenExchangeToken,
121    pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126    access_token: String,
127    token_type: String,
128    refresh_token: Option<String>,
129    scope: Option<String>,
130    expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135    client_id: String,
136    resource: String,
137    identity: CachedAccessToken,
138    gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143        if token.access_token.is_empty() {
144            return Err(KontextDevError::EmptyAccessToken);
145        }
146
147        Ok(Self {
148            access_token: token.access_token.clone(),
149            token_type: token.token_type.clone(),
150            refresh_token: token.refresh_token.clone(),
151            scope: token.scope.clone(),
152            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153        })
154    }
155
156    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157        if token.access_token.is_empty() {
158            return Err(KontextDevError::EmptyAccessToken);
159        }
160
161        Ok(Self {
162            access_token: token.access_token.clone(),
163            token_type: token.token_type.clone(),
164            refresh_token: token.refresh_token.clone(),
165            scope: token.scope.clone(),
166            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167        })
168    }
169
170    fn is_valid(&self) -> bool {
171        match self.expires_at_unix_ms {
172            Some(expires_at) => {
173                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174                now_unix_ms().saturating_add(buffer_ms) < expires_at
175            }
176            None => true,
177        }
178    }
179
180    fn to_access_token(&self) -> AccessToken {
181        AccessToken {
182            access_token: self.access_token.clone(),
183            token_type: self.token_type.clone(),
184            expires_in: self
185                .expires_at_unix_ms
186                .and_then(unix_ms_to_relative_seconds),
187            refresh_token: self.refresh_token.clone(),
188            scope: self.scope.clone(),
189        }
190    }
191
192    fn to_token_exchange_token(&self) -> TokenExchangeToken {
193        TokenExchangeToken {
194            access_token: self.access_token.clone(),
195            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196            token_type: self.token_type.clone(),
197            expires_in: self
198                .expires_at_unix_ms
199                .and_then(unix_ms_to_relative_seconds),
200            scope: self.scope.clone(),
201            refresh_token: self.refresh_token.clone(),
202        }
203    }
204}
205
206fn now_unix_ms() -> u64 {
207    SystemTime::now()
208        .duration_since(UNIX_EPOCH)
209        .unwrap_or_else(|_| Duration::from_secs(0))
210        .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214    if unix_ms <= now_unix_ms() {
215        return Some(0);
216    }
217
218    let delta_ms = unix_ms - now_unix_ms();
219    let secs = delta_ms / 1000;
220    i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224    if let Some(expires_in) = expires_in
225        && expires_in > 0
226    {
227        let expires_in_u64 = u64::try_from(expires_in).ok()?;
228        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229    }
230
231    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235    let mut parts = token.split('.');
236    let _header = parts.next()?;
237    let payload = parts.next()?;
238
239    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240        .decode(payload)
241        .ok()?;
242    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243    value.get("exp").and_then(|exp| {
244        exp.as_u64()
245            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246    })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251    error: Option<String>,
252    error_description: Option<String>,
253    message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthAuthorizationServerMetadata {
258    authorization_endpoint: Option<String>,
259}
260
261#[derive(Clone, Debug, Deserialize)]
262struct OAuthCallbackPayload {
263    code: Option<String>,
264    state: Option<String>,
265    error: Option<String>,
266    error_description: Option<String>,
267}
268
269#[derive(Clone, Debug)]
270struct PkcePair {
271    verifier: String,
272    challenge: String,
273}
274
275fn generate_pkce_pair() -> PkcePair {
276    let mut raw = [0u8; 64];
277    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
278
279    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
280    let digest = sha2::Sha256::digest(verifier.as_bytes());
281    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
282
283    PkcePair {
284        verifier,
285        challenge,
286    }
287}
288
289fn generate_state() -> String {
290    let mut bytes = [0u8; 16];
291    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
292    bytes.iter().map(|b| format!("{b:02x}")).collect()
293}
294
295struct CallbackServerGuard {
296    server: Arc<Server>,
297}
298
299impl Drop for CallbackServerGuard {
300    fn drop(&mut self) {
301        self.server.unblock();
302    }
303}
304
305fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
306    let full = format!("http://localhost{url_path}");
307    let parsed = Url::parse(&full).ok()?;
308
309    if parsed.path() != callback_path {
310        return None;
311    }
312
313    let mut payload = OAuthCallbackPayload {
314        code: None,
315        state: None,
316        error: None,
317        error_description: None,
318    };
319
320    for (key, value) in parsed.query_pairs() {
321        match key.as_ref() {
322            "code" => payload.code = Some(value.to_string()),
323            "state" => payload.state = Some(value.to_string()),
324            "error" => payload.error = Some(value.to_string()),
325            "error_description" => payload.error_description = Some(value.to_string()),
326            _ => {}
327        }
328    }
329
330    Some(payload)
331}
332
333fn spawn_callback_server(
334    server: Arc<Server>,
335    callback_path: String,
336    tx: oneshot::Sender<OAuthCallbackPayload>,
337) -> tokio::task::JoinHandle<()> {
338    tokio::task::spawn_blocking(move || {
339        while let Ok(request) = server.recv() {
340            let path = request.url().to_string();
341            if let Some(payload) = parse_callback(&path, &callback_path) {
342                let response = Response::from_string(
343                    "Authentication complete. You can return to your terminal.",
344                );
345                let _ = request.respond(response);
346                let _ = tx.send(payload);
347                break;
348            }
349
350            let response = Response::from_string("Invalid callback").with_status_code(400);
351            let _ = request.respond(response);
352        }
353    })
354}
355
356pub struct KontextDevClient {
357    config: KontextDevConfig,
358    http: Client,
359}
360
361impl KontextDevClient {
362    pub fn new(config: KontextDevConfig) -> Self {
363        Self {
364            config,
365            http: Client::new(),
366        }
367    }
368
369    pub fn config(&self) -> &KontextDevConfig {
370        &self.config
371    }
372
373    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
374        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
375    }
376
377    pub fn token_url(&self) -> Result<String, KontextDevError> {
378        resolve_token_url(&self.config).map_err(KontextDevError::from)
379    }
380
381    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
382        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
383    }
384
385    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
386        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
387    }
388
389    /// Authenticate a user via browser PKCE, exchange the identity token to a
390    /// resource-scoped MCP gateway token, and persist cache for future runs.
391    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
392        let resource = self.config.resource.clone();
393
394        if let Some(cache) = self.read_cache()?
395            && cache.client_id == self.config.client_id
396            && cache.resource == resource
397            && cache.gateway.is_valid()
398            && cache.identity.is_valid()
399        {
400            return Ok(KontextAuthSession {
401                identity_token: cache.identity.to_access_token(),
402                gateway_token: cache.gateway.to_token_exchange_token(),
403                browser_auth_performed: false,
404            });
405        }
406
407        if let Some(cache) = self.read_cache()?
408            && cache.client_id == self.config.client_id
409            && cache.resource == resource
410        {
411            let maybe_refreshed = if cache.identity.is_valid() {
412                Some(cache.identity.to_access_token())
413            } else if let Some(refresh_token) = &cache.identity.refresh_token {
414                Some(self.refresh_identity_token(refresh_token).await?)
415            } else {
416                None
417            };
418
419            if let Some(identity_token) = maybe_refreshed {
420                let gateway_token = self
421                    .exchange_for_resource(&identity_token.access_token, &resource, None)
422                    .await?;
423                self.write_cache(&identity_token, &gateway_token)?;
424                return Ok(KontextAuthSession {
425                    identity_token,
426                    gateway_token,
427                    browser_auth_performed: false,
428                });
429            }
430        }
431
432        let identity_token = self.authorize_with_browser_pkce().await?;
433        let gateway_token = self
434            .exchange_for_resource(&identity_token.access_token, &resource, None)
435            .await?;
436
437        self.write_cache(&identity_token, &gateway_token)?;
438
439        Ok(KontextAuthSession {
440            identity_token,
441            gateway_token,
442            browser_auth_performed: true,
443        })
444    }
445
446    /// Request a short-lived connect session and build the hosted integration
447    /// connect URL (`.../oauth/connect?session=...`).
448    pub async fn create_integration_connect_url(
449        &self,
450        gateway_access_token: &str,
451    ) -> Result<String, KontextDevError> {
452        let session = self.create_connect_session(gateway_access_token).await?;
453        self.integration_connect_url(&session.session_id)
454    }
455
456    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
457        if session_id.trim().is_empty() {
458            return Err(KontextDevError::ConnectSession {
459                message: "connect session id is empty".to_string(),
460            });
461        }
462
463        let base = if let Some(explicit) = &self.config.integration_ui_url {
464            explicit.trim_end_matches('/').to_string()
465        } else {
466            let server = resolve_server_base_url(&self.config)?;
467            if server.contains("api.kontext.dev") {
468                "https://app.kontext.dev".to_string()
469            } else {
470                server
471            }
472        };
473
474        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
475            url: base.clone(),
476            source,
477        })?;
478        url.set_path("/oauth/connect");
479        url.query_pairs_mut().append_pair("session", session_id);
480        Ok(url.to_string())
481    }
482
483    /// Opens the integration connect page in the browser.
484    pub async fn open_integration_connect_page(
485        &self,
486        gateway_access_token: &str,
487    ) -> Result<String, KontextDevError> {
488        let url = self
489            .create_integration_connect_url(gateway_access_token)
490            .await?;
491
492        if webbrowser::open(&url).is_err() {
493            return Err(KontextDevError::BrowserOpenFailed);
494        }
495
496        Ok(url)
497    }
498
499    pub async fn create_connect_session(
500        &self,
501        gateway_access_token: &str,
502    ) -> Result<ConnectSession, KontextDevError> {
503        let url = self.connect_session_url()?;
504        let response = self
505            .http
506            .post(&url)
507            .header("Authorization", format!("Bearer {gateway_access_token}"))
508            .send()
509            .await
510            .map_err(|err| KontextDevError::ConnectSession {
511                message: err.to_string(),
512            })?;
513
514        if !response.status().is_success() {
515            let message = build_error_message(response).await;
516            return Err(KontextDevError::ConnectSession { message });
517        }
518
519        response
520            .json::<ConnectSession>()
521            .await
522            .map_err(|err| KontextDevError::ConnectSession {
523                message: err.to_string(),
524            })
525    }
526
527    pub async fn initiate_integration_oauth(
528        &self,
529        gateway_access_token: &str,
530        integration_id: &str,
531        return_to: Option<&str>,
532    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
533        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
534
535        let mut request = self
536            .http
537            .post(&url)
538            .header("Authorization", format!("Bearer {gateway_access_token}"));
539
540        if let Some(return_to) = return_to {
541            request = request.json(&serde_json::json!({ "returnTo": return_to }));
542        }
543
544        let response =
545            request
546                .send()
547                .await
548                .map_err(|err| KontextDevError::IntegrationOAuthInit {
549                    message: err.to_string(),
550                })?;
551
552        if !response.status().is_success() {
553            let message = build_error_message(response).await;
554            return Err(KontextDevError::IntegrationOAuthInit { message });
555        }
556
557        response
558            .json::<IntegrationOAuthInitResponse>()
559            .await
560            .map_err(|err| KontextDevError::IntegrationOAuthInit {
561                message: err.to_string(),
562            })
563    }
564
565    pub async fn integration_connection_status(
566        &self,
567        gateway_access_token: &str,
568        integration_id: &str,
569    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
570        let url = resolve_integration_connection_url(&self.config, integration_id)?;
571        let response = self
572            .http
573            .get(url)
574            .header("Authorization", format!("Bearer {gateway_access_token}"))
575            .send()
576            .await
577            .map_err(|err| KontextDevError::IntegrationOAuthInit {
578                message: err.to_string(),
579            })?;
580
581        if !response.status().is_success() {
582            let message = build_error_message(response).await;
583            return Err(KontextDevError::IntegrationOAuthInit { message });
584        }
585
586        response
587            .json::<IntegrationConnectionStatus>()
588            .await
589            .map_err(|err| KontextDevError::IntegrationOAuthInit {
590                message: err.to_string(),
591            })
592    }
593
594    pub async fn wait_for_integration_connection(
595        &self,
596        gateway_access_token: &str,
597        integration_id: &str,
598        timeout_ms: u64,
599        interval_ms: u64,
600    ) -> Result<bool, KontextDevError> {
601        let started = now_unix_ms();
602
603        loop {
604            let status = self
605                .integration_connection_status(gateway_access_token, integration_id)
606                .await?;
607            if status.connected {
608                return Ok(true);
609            }
610
611            if now_unix_ms().saturating_sub(started) >= timeout_ms {
612                return Ok(false);
613            }
614
615            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
616        }
617    }
618
619    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
620        let auth_url = self.resolve_authorization_url().await?;
621        let token_url = self.token_url()?;
622        let pkce = generate_pkce_pair();
623        let state = generate_state();
624
625        let (callback_url, callback_payload) = self.listen_for_callback().await?;
626
627        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
628            url: auth_url.clone(),
629            source,
630        })?;
631
632        url.query_pairs_mut()
633            .append_pair("client_id", &self.config.client_id)
634            .append_pair("response_type", "code")
635            .append_pair("redirect_uri", &callback_url)
636            .append_pair("state", &state)
637            .append_pair("scope", &self.config.scope)
638            .append_pair("code_challenge_method", "S256")
639            .append_pair("code_challenge", &pkce.challenge);
640
641        if webbrowser::open(url.as_str()).is_err() {
642            return Err(KontextDevError::BrowserOpenFailed);
643        }
644
645        let payload = callback_payload.await?;
646
647        if let Some(error) = payload.error {
648            let with_details = payload
649                .error_description
650                .map(|description| format!("{error}: {description}"))
651                .unwrap_or(error);
652            return Err(KontextDevError::OAuthCallbackError {
653                error: with_details,
654            });
655        }
656
657        if payload.state.as_deref() != Some(state.as_str()) {
658            return Err(KontextDevError::InvalidOAuthState);
659        }
660
661        let code = payload
662            .code
663            .ok_or(KontextDevError::MissingAuthorizationCode)?;
664
665        let mut body = vec![
666            ("grant_type", "authorization_code".to_string()),
667            ("code", code),
668            ("redirect_uri", callback_url),
669            ("client_id", self.config.client_id.clone()),
670            ("code_verifier", pkce.verifier),
671        ];
672
673        if let Some(client_secret) = &self.config.client_secret {
674            body.push(("client_secret", client_secret.clone()));
675        }
676
677        post_token(&self.http, &token_url, None, &body).await
678    }
679
680    async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
681        if let Some(discovered) = self.discover_authorization_endpoint().await {
682            return Ok(discovered);
683        }
684
685        let authorize_url = self.authorize_url()?;
686        if !self.endpoint_is_missing(&authorize_url).await {
687            return Ok(authorize_url);
688        }
689
690        let server_base = resolve_server_base_url(&self.config)?;
691        Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
692    }
693
694    async fn discover_authorization_endpoint(&self) -> Option<String> {
695        let base = resolve_server_base_url(&self.config).ok()?;
696        let base = base.trim_end_matches('/');
697
698        let candidates = [
699            format!("{base}/.well-known/oauth-authorization-server/mcp"),
700            format!("{base}/.well-known/oauth-authorization-server"),
701        ];
702
703        for url in candidates {
704            let response = match self.http.get(&url).send().await {
705                Ok(response) => response,
706                Err(_) => continue,
707            };
708
709            if !response.status().is_success() {
710                continue;
711            }
712
713            let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
714                Ok(metadata) => metadata,
715                Err(_) => continue,
716            };
717
718            let Some(endpoint) = metadata.authorization_endpoint else {
719                continue;
720            };
721
722            if Url::parse(&endpoint).is_ok() {
723                return Some(endpoint);
724            }
725        }
726
727        None
728    }
729
730    async fn endpoint_is_missing(&self, url: &str) -> bool {
731        let probe_client = match reqwest::Client::builder()
732            .redirect(reqwest::redirect::Policy::none())
733            .build()
734        {
735            Ok(client) => client,
736            Err(_) => return false,
737        };
738
739        match probe_client.get(url).send().await {
740            Ok(response) => response.status() == StatusCode::NOT_FOUND,
741            Err(_) => false,
742        }
743    }
744
745    async fn refresh_identity_token(
746        &self,
747        refresh_token: &str,
748    ) -> Result<AccessToken, KontextDevError> {
749        let token_url = self.token_url()?;
750        let mut body = vec![
751            ("grant_type", "refresh_token".to_string()),
752            ("refresh_token", refresh_token.to_string()),
753            ("client_id", self.config.client_id.clone()),
754        ];
755
756        if let Some(client_secret) = &self.config.client_secret {
757            body.push(("client_secret", client_secret.clone()));
758        }
759
760        post_token(&self.http, &token_url, None, &body).await
761    }
762
763    pub async fn exchange_for_resource(
764        &self,
765        subject_token: &str,
766        resource: &str,
767        scope: Option<&str>,
768    ) -> Result<TokenExchangeToken, KontextDevError> {
769        let token_url = self.token_url()?;
770
771        let mut body = vec![
772            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
773            ("subject_token", subject_token.to_string()),
774            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
775            ("resource", resource.to_string()),
776        ];
777
778        if let Some(scope) = scope {
779            body.push(("scope", scope.to_string()));
780        }
781
782        let auth_header = self.config.client_secret.as_ref().map(|secret| {
783            let raw = format!("{}:{}", self.config.client_id, secret);
784            format!(
785                "Basic {}",
786                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
787            )
788        });
789
790        if self.config.client_secret.is_none() {
791            body.push(("client_id", self.config.client_id.clone()));
792        }
793
794        let response = post_form_with_optional_auth::<TokenExchangeToken>(
795            &self.http,
796            &token_url,
797            auth_header,
798            &body,
799        )
800        .await
801        .map_err(|message| KontextDevError::TokenExchange {
802            resource: resource.to_string(),
803            message,
804        })?;
805
806        if response.access_token.is_empty() {
807            return Err(KontextDevError::EmptyAccessToken);
808        }
809
810        Ok(response)
811    }
812
813    async fn listen_for_callback(
814        &self,
815    ) -> Result<
816        (
817            String,
818            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
819        ),
820        KontextDevError,
821    > {
822        let redirect_uri = self.config.redirect_uri.trim().to_string();
823        let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
824            url: redirect_uri.clone(),
825            source,
826        })?;
827
828        if parsed.scheme() != "http" {
829            return Err(KontextDevError::OAuthCallbackError {
830                error: "redirect_uri must use http".to_string(),
831            });
832        }
833
834        if parsed.query().is_some() || parsed.fragment().is_some() {
835            return Err(KontextDevError::OAuthCallbackError {
836                error: "redirect_uri must not include query parameters or fragments".to_string(),
837            });
838        }
839
840        let host = parsed
841            .host_str()
842            .ok_or_else(|| KontextDevError::OAuthCallbackError {
843                error: "redirect_uri host is missing".to_string(),
844            })?;
845
846        let port = parsed
847            .port()
848            .ok_or_else(|| KontextDevError::OAuthCallbackError {
849                error: "redirect_uri must include an explicit port".to_string(),
850            })?;
851
852        let callback_path = parsed.path().to_string();
853        let bind_addr = if host.contains(':') {
854            format!("[{host}]:{port}")
855        } else {
856            format!("{host}:{port}")
857        };
858
859        let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
860            KontextDevError::OAuthCallbackError {
861                error: format!("failed to start callback server at {bind_addr}: {err}"),
862            }
863        })?);
864
865        let (tx, rx) = oneshot::channel();
866        let _join = spawn_callback_server(server.clone(), callback_path, tx);
867        let _guard = CallbackServerGuard { server };
868        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
869
870        let fut = async move {
871            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
872                .await
873                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
874                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
875            drop(_guard);
876            Ok(payload)
877        };
878
879        Ok((redirect_uri, fut))
880    }
881
882    fn token_cache_path(&self) -> Option<PathBuf> {
883        if let Some(explicit) = &self.config.token_cache_path {
884            return Some(PathBuf::from(explicit));
885        }
886
887        let home = dirs::home_dir()?;
888        let mut path = home;
889        path.push(".kontext-dev");
890        path.push("tokens");
891
892        let sanitized_client_id: String = self
893            .config
894            .client_id
895            .chars()
896            .map(|ch| {
897                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
898                    ch
899                } else {
900                    '_'
901                }
902            })
903            .collect();
904
905        path.push(format!("{sanitized_client_id}.json"));
906        Some(path)
907    }
908
909    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
910        let Some(path) = self.token_cache_path() else {
911            return Ok(None);
912        };
913
914        let raw = match fs::read_to_string(&path) {
915            Ok(raw) => raw,
916            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
917            Err(source) => {
918                return Err(KontextDevError::TokenCacheRead {
919                    path: path.display().to_string(),
920                    source,
921                });
922            }
923        };
924
925        serde_json::from_str(&raw).map(Some).map_err(|source| {
926            KontextDevError::TokenCacheDeserialize {
927                path: path.display().to_string(),
928                source,
929            }
930        })
931    }
932
933    fn write_cache(
934        &self,
935        identity: &AccessToken,
936        gateway: &TokenExchangeToken,
937    ) -> Result<(), KontextDevError> {
938        let Some(path) = self.token_cache_path() else {
939            return Ok(());
940        };
941
942        if let Some(parent) = path.parent() {
943            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
944                path: parent.display().to_string(),
945                source,
946            })?;
947        }
948
949        let payload = TokenCacheFile {
950            client_id: self.config.client_id.clone(),
951            resource: self.config.resource.clone(),
952            identity: CachedAccessToken::from_access_token(identity)?,
953            gateway: CachedAccessToken::from_token_exchange(gateway)?,
954        };
955
956        let serialized = serde_json::to_string_pretty(&payload)
957            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
958
959        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
960            path: path.display().to_string(),
961            source,
962        })
963    }
964}
965
966/// Legacy helper that uses the client-credentials grant.
967///
968/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
969pub async fn request_access_token(
970    config: &KontextDevConfig,
971) -> Result<AccessToken, KontextDevError> {
972    let token_url = resolve_token_url(config)?;
973
974    let mut body = vec![
975        ("grant_type", "client_credentials".to_string()),
976        ("scope", config.scope.clone()),
977    ];
978
979    let auth_header = if let Some(secret) = &config.client_secret {
980        let raw = format!("{}:{}", config.client_id, secret);
981        Some(format!(
982            "Basic {}",
983            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
984        ))
985    } else {
986        body.push(("client_id", config.client_id.clone()));
987        None
988    };
989
990    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
991}
992
993async fn post_token(
994    http: &Client,
995    token_url: &str,
996    authorization: Option<&str>,
997    body: &[(impl AsRef<str>, String)],
998) -> Result<AccessToken, KontextDevError> {
999    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1000
1001    let response = post_form_with_optional_auth::<AccessToken>(
1002        http,
1003        token_url,
1004        authorization.map(ToString::to_string),
1005        &body_vec,
1006    )
1007    .await
1008    .map_err(|message| KontextDevError::TokenRequest {
1009        token_url: token_url.to_string(),
1010        message,
1011    })?;
1012
1013    Ok(response)
1014}
1015
1016async fn post_form_with_optional_auth<T>(
1017    http: &Client,
1018    url: &str,
1019    authorization: Option<String>,
1020    body: &[(&str, String)],
1021) -> Result<T, String>
1022where
1023    T: DeserializeOwned,
1024{
1025    let mut request = http
1026        .post(url)
1027        .header("Content-Type", "application/x-www-form-urlencoded");
1028
1029    if let Some(header) = authorization {
1030        request = request.header("Authorization", header);
1031    }
1032
1033    let form = body
1034        .iter()
1035        .map(|(k, v)| (k.to_string(), v.to_string()))
1036        .collect::<Vec<(String, String)>>();
1037
1038    let response = request
1039        .form(&form)
1040        .send()
1041        .await
1042        .map_err(|err| err.to_string())?;
1043
1044    if !response.status().is_success() {
1045        return Err(build_error_message(response).await);
1046    }
1047
1048    response.json::<T>().await.map_err(|err| err.to_string())
1049}
1050
1051async fn build_error_message(response: reqwest::Response) -> String {
1052    let status = response.status();
1053    let fallback = format!(
1054        "{} {}",
1055        status.as_u16(),
1056        status.canonical_reason().unwrap_or("")
1057    );
1058
1059    let body = response.text().await.unwrap_or_default();
1060    if body.is_empty() {
1061        return fallback.trim().to_string();
1062    }
1063
1064    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1065        if let Some(description) = parsed.error_description {
1066            return description;
1067        }
1068        if let Some(message) = parsed.message {
1069            return message;
1070        }
1071        if let Some(error) = parsed.error {
1072            return error;
1073        }
1074    }
1075
1076    format!("{fallback}: {body}")
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082
1083    fn config() -> KontextDevConfig {
1084        KontextDevConfig {
1085            server: Some("http://localhost:4000".to_string()),
1086            mcp_url: None,
1087            token_url: None,
1088            client_id: "client_123".to_string(),
1089            client_secret: None,
1090            scope: DEFAULT_SCOPE.to_string(),
1091            server_name: DEFAULT_SERVER_NAME.to_string(),
1092            resource: DEFAULT_RESOURCE.to_string(),
1093            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1094            integration_return_to: None,
1095            open_connect_page_on_login: true,
1096            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1097            token_cache_path: None,
1098            redirect_uri: "http://localhost:3333/callback".to_string(),
1099        }
1100    }
1101
1102    #[test]
1103    fn create_connect_url_uses_hosted_ui() {
1104        let client = KontextDevClient::new(config());
1105        let url = client
1106            .integration_connect_url("session-123")
1107            .expect("url should be built");
1108        assert_eq!(
1109            url,
1110            "https://app.kontext.dev/oauth/connect?session=session-123"
1111        );
1112    }
1113
1114    #[test]
1115    fn jwt_exp_decode_reads_exp() {
1116        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1117        let payload =
1118            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1119        let token = format!("{header}.{payload}.sig");
1120        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1121    }
1122
1123    #[test]
1124    fn oauth_metadata_parses_authorization_endpoint() {
1125        let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1126            r#"{
1127              "issuer": "https://issuer.example.com",
1128              "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1129            }"#,
1130        )
1131        .expect("metadata should parse");
1132
1133        assert_eq!(
1134            metadata.authorization_endpoint.as_deref(),
1135            Some("https://issuer.example.com/oauth2/auth")
1136        );
1137    }
1138}