Skip to main content

kontext_dev_sdk/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub mod client;
21pub mod errors;
22mod http_client;
23pub mod management;
24pub mod mcp;
25pub mod oauth;
26pub mod orchestrator;
27pub mod server;
28pub mod verify;
29
30pub use client::ClientState;
31pub use client::ConnectSessionResult;
32pub use client::IntegrationInfo;
33pub use client::KontextClient as SdkKontextClient;
34pub use client::KontextClientConfig;
35pub use client::ToolResult;
36pub use client::create_kontext_client;
37pub use errors::*;
38pub use management::KontextManagementClient;
39pub use management::KontextManagementClientConfig;
40pub use management::*;
41pub use mcp::KontextMcp;
42pub use mcp::KontextMcpConfig;
43pub use mcp::RuntimeIntegrationCategory;
44pub use mcp::RuntimeIntegrationConnectType;
45pub use mcp::RuntimeIntegrationRecord;
46pub use oauth::KontextOAuthProvider;
47pub use oauth::KontextOAuthProviderConfig;
48pub use oauth::ParsedOAuthCallback;
49pub use oauth::TokenExchangeConfig;
50pub use oauth::exchange_token;
51pub use oauth::parse_oauth_callback;
52pub use orchestrator::KontextOrchestrator;
53pub use orchestrator::KontextOrchestratorConfig;
54pub use orchestrator::KontextOrchestratorState;
55pub use orchestrator::create_kontext_orchestrator;
56pub use server::IntegrationCredential;
57pub use server::IntegrationName;
58pub use server::IntegrationResolvedCredentials;
59pub use server::KnownIntegration;
60pub use server::Kontext;
61pub use server::KontextOptions;
62pub use server::MiddlewareOptions;
63pub use verify::JwksClient;
64pub use verify::KontextTokenVerifier;
65pub use verify::KontextTokenVerifierConfig;
66pub use verify::TokenVerificationError;
67pub use verify::TokenVerificationErrorCode;
68pub use verify::VerifiedTokenClaims;
69pub use verify::VerifyResult;
70
71pub use kontext_dev_sdk_core::AccessToken;
72pub use kontext_dev_sdk_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
73pub use kontext_dev_sdk_core::DEFAULT_RESOURCE;
74pub use kontext_dev_sdk_core::DEFAULT_SCOPE;
75pub use kontext_dev_sdk_core::DEFAULT_SERVER;
76pub use kontext_dev_sdk_core::DEFAULT_SERVER_NAME;
77pub use kontext_dev_sdk_core::KontextDevConfig;
78pub use kontext_dev_sdk_core::KontextDevCoreError;
79pub use kontext_dev_sdk_core::TokenExchangeToken;
80pub use kontext_dev_sdk_core::normalize_server_url;
81pub use kontext_dev_sdk_core::resolve_authorize_url;
82pub use kontext_dev_sdk_core::resolve_connect_session_url;
83pub use kontext_dev_sdk_core::resolve_integration_connection_url;
84pub use kontext_dev_sdk_core::resolve_integration_oauth_init_url;
85pub use kontext_dev_sdk_core::resolve_mcp_url;
86pub use kontext_dev_sdk_core::resolve_server_base_url;
87pub use kontext_dev_sdk_core::resolve_token_url;
88
89const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
90const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
91const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
92
93#[derive(Debug, thiserror::Error)]
94pub enum KontextDevError {
95    #[error(transparent)]
96    Core(#[from] kontext_dev_sdk_core::KontextDevCoreError),
97    #[error("failed to parse URL `{url}`")]
98    InvalidUrl {
99        url: String,
100        source: url::ParseError,
101    },
102    #[error("failed to open browser for OAuth authorization")]
103    BrowserOpenFailed,
104    #[error("OAuth callback timed out after {timeout_seconds}s")]
105    OAuthCallbackTimeout { timeout_seconds: i64 },
106    #[error("OAuth callback channel was unexpectedly cancelled")]
107    OAuthCallbackCancelled,
108    #[error("OAuth callback is missing the authorization code")]
109    MissingAuthorizationCode,
110    #[error("OAuth callback returned an error: {error}")]
111    OAuthCallbackError { error: String },
112    #[error("OAuth state mismatch")]
113    InvalidOAuthState,
114    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
115    TokenRequest { token_url: String, message: String },
116    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
117    TokenExchange { resource: String, message: String },
118    #[error("Kontext-Dev connect session request failed: {message}")]
119    ConnectSession { message: String },
120    #[error("Kontext-Dev integration OAuth init failed: {message}")]
121    IntegrationOAuthInit { message: String },
122    #[error("failed to read token cache at `{path}`: {source}")]
123    TokenCacheRead {
124        path: String,
125        source: std::io::Error,
126    },
127    #[error("failed to write token cache at `{path}`: {source}")]
128    TokenCacheWrite {
129        path: String,
130        source: std::io::Error,
131    },
132    #[error("failed to deserialize token cache at `{path}`: {source}")]
133    TokenCacheDeserialize {
134        path: String,
135        source: serde_json::Error,
136    },
137    #[error("failed to serialize token cache: {source}")]
138    TokenCacheSerialize { source: serde_json::Error },
139    #[error("Kontext-Dev access token is empty")]
140    EmptyAccessToken,
141    #[error("missing integration UI URL; set `integration_ui_url` in config")]
142    MissingIntegrationUiUrl,
143}
144
145#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
146pub struct ConnectSession {
147    #[serde(rename = "sessionId")]
148    pub session_id: String,
149    #[serde(rename = "expiresAt")]
150    pub expires_at: String,
151}
152
153#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
154pub struct IntegrationOAuthInitResponse {
155    #[serde(rename = "authorizationUrl")]
156    pub authorization_url: String,
157    #[serde(default)]
158    pub state: Option<String>,
159}
160
161#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
162pub struct IntegrationConnectionStatus {
163    pub connected: bool,
164    #[serde(default)]
165    pub expires_at: Option<String>,
166}
167
168#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
169pub struct KontextAuthSession {
170    pub identity_token: AccessToken,
171    pub gateway_token: TokenExchangeToken,
172    pub browser_auth_performed: bool,
173}
174
175#[derive(Clone, Debug, Deserialize, Serialize)]
176struct CachedAccessToken {
177    access_token: String,
178    token_type: String,
179    refresh_token: Option<String>,
180    scope: Option<String>,
181    expires_at_unix_ms: Option<u64>,
182}
183
184#[derive(Clone, Debug, Deserialize, Serialize)]
185struct TokenCacheFile {
186    client_id: String,
187    resource: String,
188    identity: CachedAccessToken,
189    gateway: CachedAccessToken,
190}
191
192impl CachedAccessToken {
193    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
194        if token.access_token.is_empty() {
195            return Err(KontextDevError::EmptyAccessToken);
196        }
197
198        Ok(Self {
199            access_token: token.access_token.clone(),
200            token_type: token.token_type.clone(),
201            refresh_token: token.refresh_token.clone(),
202            scope: token.scope.clone(),
203            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
204        })
205    }
206
207    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
208        if token.access_token.is_empty() {
209            return Err(KontextDevError::EmptyAccessToken);
210        }
211
212        Ok(Self {
213            access_token: token.access_token.clone(),
214            token_type: token.token_type.clone(),
215            refresh_token: token.refresh_token.clone(),
216            scope: token.scope.clone(),
217            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
218        })
219    }
220
221    fn is_valid(&self) -> bool {
222        match self.expires_at_unix_ms {
223            Some(expires_at) => {
224                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
225                now_unix_ms().saturating_add(buffer_ms) < expires_at
226            }
227            None => true,
228        }
229    }
230
231    fn to_access_token(&self) -> AccessToken {
232        AccessToken {
233            access_token: self.access_token.clone(),
234            token_type: self.token_type.clone(),
235            expires_in: self
236                .expires_at_unix_ms
237                .and_then(unix_ms_to_relative_seconds),
238            refresh_token: self.refresh_token.clone(),
239            scope: self.scope.clone(),
240        }
241    }
242
243    fn to_token_exchange_token(&self) -> TokenExchangeToken {
244        TokenExchangeToken {
245            access_token: self.access_token.clone(),
246            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
247            token_type: self.token_type.clone(),
248            expires_in: self
249                .expires_at_unix_ms
250                .and_then(unix_ms_to_relative_seconds),
251            scope: self.scope.clone(),
252            refresh_token: self.refresh_token.clone(),
253        }
254    }
255}
256
257fn now_unix_ms() -> u64 {
258    SystemTime::now()
259        .duration_since(UNIX_EPOCH)
260        .unwrap_or_else(|_| Duration::from_secs(0))
261        .as_millis() as u64
262}
263
264fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
265    if unix_ms <= now_unix_ms() {
266        return Some(0);
267    }
268
269    let delta_ms = unix_ms - now_unix_ms();
270    let secs = delta_ms / 1000;
271    i64::try_from(secs).ok()
272}
273
274fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
275    if let Some(expires_in) = expires_in
276        && expires_in > 0
277    {
278        let expires_in_u64 = u64::try_from(expires_in).ok()?;
279        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
280    }
281
282    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
283}
284
285fn decode_jwt_exp(token: &str) -> Option<u64> {
286    let mut parts = token.split('.');
287    let _header = parts.next()?;
288    let payload = parts.next()?;
289
290    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
291        .decode(payload)
292        .ok()?;
293    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
294    value.get("exp").and_then(|exp| {
295        exp.as_u64()
296            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
297    })
298}
299
300#[derive(Clone, Debug, Deserialize)]
301struct OAuthErrorBody {
302    error: Option<String>,
303    error_description: Option<String>,
304    message: Option<String>,
305}
306
307#[derive(Clone, Debug, Deserialize)]
308struct OAuthAuthorizationServerMetadata {
309    authorization_endpoint: Option<String>,
310}
311
312#[derive(Clone, Debug, Deserialize)]
313struct OAuthCallbackPayload {
314    code: Option<String>,
315    state: Option<String>,
316    error: Option<String>,
317    error_description: Option<String>,
318}
319
320#[derive(Clone, Debug)]
321struct PkcePair {
322    verifier: String,
323    challenge: String,
324}
325
326fn generate_pkce_pair() -> PkcePair {
327    let mut raw = [0u8; 64];
328    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
329
330    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
331    let digest = sha2::Sha256::digest(verifier.as_bytes());
332    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
333
334    PkcePair {
335        verifier,
336        challenge,
337    }
338}
339
340fn generate_state() -> String {
341    let mut bytes = [0u8; 16];
342    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
343    bytes.iter().map(|b| format!("{b:02x}")).collect()
344}
345
346fn normalized_scope(scope: &str) -> Option<&str> {
347    let trimmed = scope.trim();
348    if trimmed.is_empty() {
349        None
350    } else {
351        Some(trimmed)
352    }
353}
354
355struct CallbackServerGuard {
356    server: Arc<Server>,
357}
358
359impl Drop for CallbackServerGuard {
360    fn drop(&mut self) {
361        self.server.unblock();
362    }
363}
364
365fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
366    let full = format!("http://localhost{url_path}");
367    let parsed = Url::parse(&full).ok()?;
368
369    if parsed.path() != callback_path {
370        return None;
371    }
372
373    let mut payload = OAuthCallbackPayload {
374        code: None,
375        state: None,
376        error: None,
377        error_description: None,
378    };
379
380    for (key, value) in parsed.query_pairs() {
381        match key.as_ref() {
382            "code" => payload.code = Some(value.to_string()),
383            "state" => payload.state = Some(value.to_string()),
384            "error" => payload.error = Some(value.to_string()),
385            "error_description" => payload.error_description = Some(value.to_string()),
386            _ => {}
387        }
388    }
389
390    Some(payload)
391}
392
393fn spawn_callback_server(
394    server: Arc<Server>,
395    callback_path: String,
396    tx: oneshot::Sender<OAuthCallbackPayload>,
397) -> tokio::task::JoinHandle<()> {
398    tokio::task::spawn_blocking(move || {
399        while let Ok(request) = server.recv() {
400            let path = request.url().to_string();
401            if let Some(payload) = parse_callback(&path, &callback_path) {
402                let response = Response::from_string(
403                    "Authentication complete. You can return to your terminal.",
404                );
405                let _ = request.respond(response);
406                let _ = tx.send(payload);
407                break;
408            }
409
410            let response = Response::from_string("Invalid callback").with_status_code(400);
411            let _ = request.respond(response);
412        }
413    })
414}
415
416#[derive(Clone, Debug)]
417pub struct KontextDevClient {
418    config: KontextDevConfig,
419    http: Client,
420}
421
422impl KontextDevClient {
423    pub fn new(config: KontextDevConfig) -> Self {
424        Self {
425            config,
426            http: Client::new(),
427        }
428    }
429
430    pub fn config(&self) -> &KontextDevConfig {
431        &self.config
432    }
433
434    pub fn server_base_url(&self) -> Result<String, KontextDevError> {
435        resolve_server_base_url(&self.config).map_err(KontextDevError::from)
436    }
437
438    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
439        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
440    }
441
442    pub fn token_url(&self) -> Result<String, KontextDevError> {
443        resolve_token_url(&self.config).map_err(KontextDevError::from)
444    }
445
446    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
447        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
448    }
449
450    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
451        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
452    }
453
454    pub fn clear_token_cache(&self) -> Result<(), KontextDevError> {
455        let Some(path) = self.token_cache_path() else {
456            return Ok(());
457        };
458
459        if !path.exists() {
460            return Ok(());
461        }
462
463        fs::remove_file(&path).map_err(|source| KontextDevError::TokenCacheWrite {
464            path: path.display().to_string(),
465            source,
466        })
467    }
468
469    /// Authenticate a user via browser PKCE, exchange the identity token to a
470    /// resource-scoped MCP gateway token, and persist cache for future runs.
471    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
472        let resource = self.config.resource.clone();
473
474        if let Some(cache) = self.read_cache()?
475            && cache.client_id == self.config.client_id
476            && cache.resource == resource
477            && cache.gateway.is_valid()
478            && cache.identity.is_valid()
479        {
480            return Ok(KontextAuthSession {
481                identity_token: cache.identity.to_access_token(),
482                gateway_token: cache.gateway.to_token_exchange_token(),
483                browser_auth_performed: false,
484            });
485        }
486
487        if let Some(cache) = self.read_cache()?
488            && cache.client_id == self.config.client_id
489            && cache.resource == resource
490        {
491            let maybe_refreshed = if cache.identity.is_valid() {
492                Some(cache.identity.to_access_token())
493            } else if let Some(refresh_token) = &cache.identity.refresh_token {
494                Some(self.refresh_identity_token(refresh_token).await?)
495            } else {
496                None
497            };
498
499            if let Some(identity_token) = maybe_refreshed {
500                let gateway_token = self
501                    .exchange_for_resource(&identity_token.access_token, &resource, None)
502                    .await?;
503                self.write_cache(&identity_token, &gateway_token)?;
504                return Ok(KontextAuthSession {
505                    identity_token,
506                    gateway_token,
507                    browser_auth_performed: false,
508                });
509            }
510        }
511
512        let identity_token = self.authorize_with_browser_pkce().await?;
513        let gateway_token = self
514            .exchange_for_resource(&identity_token.access_token, &resource, None)
515            .await?;
516
517        self.write_cache(&identity_token, &gateway_token)?;
518
519        Ok(KontextAuthSession {
520            identity_token,
521            gateway_token,
522            browser_auth_performed: true,
523        })
524    }
525
526    /// Request a short-lived connect session and build the hosted integration
527    /// connect URL (`.../oauth/connect?session=...`).
528    pub async fn create_integration_connect_url(
529        &self,
530        gateway_access_token: &str,
531    ) -> Result<String, KontextDevError> {
532        let session = self.create_connect_session(gateway_access_token).await?;
533        self.integration_connect_url(&session.session_id)
534    }
535
536    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
537        if session_id.trim().is_empty() {
538            return Err(KontextDevError::ConnectSession {
539                message: "connect session id is empty".to_string(),
540            });
541        }
542
543        let base = if let Some(explicit) = &self.config.integration_ui_url {
544            explicit.trim_end_matches('/').to_string()
545        } else {
546            let server = resolve_server_base_url(&self.config)?;
547            if server.contains("api.kontext.dev") {
548                "https://app.kontext.dev".to_string()
549            } else {
550                server
551            }
552        };
553
554        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
555            url: base.clone(),
556            source,
557        })?;
558        url.set_path("/oauth/connect");
559        url.query_pairs_mut().append_pair("session", session_id);
560        Ok(url.to_string())
561    }
562
563    /// Opens the integration connect page in the browser.
564    pub async fn open_integration_connect_page(
565        &self,
566        gateway_access_token: &str,
567    ) -> Result<String, KontextDevError> {
568        let url = self
569            .create_integration_connect_url(gateway_access_token)
570            .await?;
571
572        if webbrowser::open(&url).is_err() {
573            return Err(KontextDevError::BrowserOpenFailed);
574        }
575
576        Ok(url)
577    }
578
579    pub async fn create_connect_session(
580        &self,
581        gateway_access_token: &str,
582    ) -> Result<ConnectSession, KontextDevError> {
583        let url = self.connect_session_url()?;
584        let response = self
585            .http
586            .post(&url)
587            .header("Authorization", format!("Bearer {gateway_access_token}"))
588            .json(&serde_json::json!({}))
589            .send()
590            .await
591            .map_err(|err| KontextDevError::ConnectSession {
592                message: err.to_string(),
593            })?;
594
595        if !response.status().is_success() {
596            let message = build_error_message(response).await;
597            return Err(KontextDevError::ConnectSession { message });
598        }
599
600        response
601            .json::<ConnectSession>()
602            .await
603            .map_err(|err| KontextDevError::ConnectSession {
604                message: err.to_string(),
605            })
606    }
607
608    pub async fn initiate_integration_oauth(
609        &self,
610        gateway_access_token: &str,
611        integration_id: &str,
612        return_to: Option<&str>,
613    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
614        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
615
616        let payload = return_to
617            .map(|value| serde_json::json!({ "returnTo": value }))
618            .unwrap_or_else(|| serde_json::json!({}));
619
620        let request = self
621            .http
622            .post(&url)
623            .header("Authorization", format!("Bearer {gateway_access_token}"))
624            .json(&payload);
625
626        let response =
627            request
628                .send()
629                .await
630                .map_err(|err| KontextDevError::IntegrationOAuthInit {
631                    message: err.to_string(),
632                })?;
633
634        if !response.status().is_success() {
635            let message = build_error_message(response).await;
636            return Err(KontextDevError::IntegrationOAuthInit { message });
637        }
638
639        response
640            .json::<IntegrationOAuthInitResponse>()
641            .await
642            .map_err(|err| KontextDevError::IntegrationOAuthInit {
643                message: err.to_string(),
644            })
645    }
646
647    pub async fn integration_connection_status(
648        &self,
649        gateway_access_token: &str,
650        integration_id: &str,
651    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
652        let url = resolve_integration_connection_url(&self.config, integration_id)?;
653        let response = self
654            .http
655            .get(url)
656            .header("Authorization", format!("Bearer {gateway_access_token}"))
657            .send()
658            .await
659            .map_err(|err| KontextDevError::IntegrationOAuthInit {
660                message: err.to_string(),
661            })?;
662
663        if !response.status().is_success() {
664            let message = build_error_message(response).await;
665            return Err(KontextDevError::IntegrationOAuthInit { message });
666        }
667
668        response
669            .json::<IntegrationConnectionStatus>()
670            .await
671            .map_err(|err| KontextDevError::IntegrationOAuthInit {
672                message: err.to_string(),
673            })
674    }
675
676    pub async fn wait_for_integration_connection(
677        &self,
678        gateway_access_token: &str,
679        integration_id: &str,
680        timeout_ms: u64,
681        interval_ms: u64,
682    ) -> Result<bool, KontextDevError> {
683        let started = now_unix_ms();
684
685        loop {
686            let status = self
687                .integration_connection_status(gateway_access_token, integration_id)
688                .await?;
689            if status.connected {
690                return Ok(true);
691            }
692
693            if now_unix_ms().saturating_sub(started) >= timeout_ms {
694                return Ok(false);
695            }
696
697            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
698        }
699    }
700
701    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
702        let auth_url = self.resolve_authorization_url().await?;
703        let token_url = self.token_url()?;
704        let pkce = generate_pkce_pair();
705        let state = generate_state();
706
707        let (callback_url, callback_payload) = self.listen_for_callback().await?;
708
709        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
710            url: auth_url.clone(),
711            source,
712        })?;
713
714        {
715            let mut query = url.query_pairs_mut();
716            query
717                .append_pair("client_id", &self.config.client_id)
718                .append_pair("response_type", "code")
719                .append_pair("redirect_uri", &callback_url)
720                .append_pair("state", &state);
721
722            if let Some(scope) = normalized_scope(&self.config.scope) {
723                query.append_pair("scope", scope);
724            }
725
726            query
727                .append_pair("code_challenge_method", "S256")
728                .append_pair("code_challenge", &pkce.challenge);
729        }
730
731        if webbrowser::open(url.as_str()).is_err() {
732            return Err(KontextDevError::BrowserOpenFailed);
733        }
734
735        let payload = callback_payload.await?;
736
737        if let Some(error) = payload.error {
738            let with_details = payload
739                .error_description
740                .map(|description| format!("{error}: {description}"))
741                .unwrap_or(error);
742            return Err(KontextDevError::OAuthCallbackError {
743                error: with_details,
744            });
745        }
746
747        if payload.state.as_deref() != Some(state.as_str()) {
748            return Err(KontextDevError::InvalidOAuthState);
749        }
750
751        let code = payload
752            .code
753            .ok_or(KontextDevError::MissingAuthorizationCode)?;
754
755        let mut body = vec![
756            ("grant_type", "authorization_code".to_string()),
757            ("code", code),
758            ("redirect_uri", callback_url),
759            ("client_id", self.config.client_id.clone()),
760            ("code_verifier", pkce.verifier),
761        ];
762
763        if let Some(client_secret) = &self.config.client_secret {
764            body.push(("client_secret", client_secret.clone()));
765        }
766
767        post_token(&self.http, &token_url, None, &body).await
768    }
769
770    async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
771        if let Some(discovered) = self.discover_authorization_endpoint().await {
772            return Ok(discovered);
773        }
774
775        let authorize_url = self.authorize_url()?;
776        if !self.endpoint_is_missing(&authorize_url).await {
777            return Ok(authorize_url);
778        }
779
780        let server_base = resolve_server_base_url(&self.config)?;
781        Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
782    }
783
784    async fn discover_authorization_endpoint(&self) -> Option<String> {
785        let base = resolve_server_base_url(&self.config).ok()?;
786        let base = base.trim_end_matches('/');
787
788        let candidates = [
789            format!("{base}/.well-known/oauth-authorization-server/mcp"),
790            format!("{base}/.well-known/oauth-authorization-server"),
791        ];
792
793        for url in candidates {
794            let response = match self.http.get(&url).send().await {
795                Ok(response) => response,
796                Err(_) => continue,
797            };
798
799            if !response.status().is_success() {
800                continue;
801            }
802
803            let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
804                Ok(metadata) => metadata,
805                Err(_) => continue,
806            };
807
808            let Some(endpoint) = metadata.authorization_endpoint else {
809                continue;
810            };
811
812            if Url::parse(&endpoint).is_ok() {
813                return Some(endpoint);
814            }
815        }
816
817        None
818    }
819
820    async fn endpoint_is_missing(&self, url: &str) -> bool {
821        let probe_client = match reqwest::Client::builder()
822            .redirect(reqwest::redirect::Policy::none())
823            .build()
824        {
825            Ok(client) => client,
826            Err(_) => return false,
827        };
828
829        match probe_client.get(url).send().await {
830            Ok(response) => response.status() == StatusCode::NOT_FOUND,
831            Err(_) => false,
832        }
833    }
834
835    async fn refresh_identity_token(
836        &self,
837        refresh_token: &str,
838    ) -> Result<AccessToken, KontextDevError> {
839        let token_url = self.token_url()?;
840        let mut body = vec![
841            ("grant_type", "refresh_token".to_string()),
842            ("refresh_token", refresh_token.to_string()),
843            ("client_id", self.config.client_id.clone()),
844        ];
845
846        if let Some(client_secret) = &self.config.client_secret {
847            body.push(("client_secret", client_secret.clone()));
848        }
849
850        post_token(&self.http, &token_url, None, &body).await
851    }
852
853    pub async fn exchange_for_resource(
854        &self,
855        subject_token: &str,
856        resource: &str,
857        scope: Option<&str>,
858    ) -> Result<TokenExchangeToken, KontextDevError> {
859        let token_url = self.token_url()?;
860
861        let mut body = vec![
862            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
863            ("subject_token", subject_token.to_string()),
864            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
865            ("resource", resource.to_string()),
866        ];
867
868        if let Some(scope) = scope {
869            body.push(("scope", scope.to_string()));
870        }
871
872        let auth_header = self.config.client_secret.as_ref().map(|secret| {
873            let raw = format!("{}:{}", self.config.client_id, secret);
874            format!(
875                "Basic {}",
876                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
877            )
878        });
879
880        if self.config.client_secret.is_none() {
881            body.push(("client_id", self.config.client_id.clone()));
882        }
883
884        let response = post_form_with_optional_auth::<TokenExchangeToken>(
885            &self.http,
886            &token_url,
887            auth_header,
888            &body,
889        )
890        .await
891        .map_err(|message| KontextDevError::TokenExchange {
892            resource: resource.to_string(),
893            message,
894        })?;
895
896        if response.access_token.is_empty() {
897            return Err(KontextDevError::EmptyAccessToken);
898        }
899
900        Ok(response)
901    }
902
903    async fn listen_for_callback(
904        &self,
905    ) -> Result<
906        (
907            String,
908            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
909        ),
910        KontextDevError,
911    > {
912        let redirect_uri = self.config.redirect_uri.trim().to_string();
913        let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
914            url: redirect_uri.clone(),
915            source,
916        })?;
917
918        if parsed.scheme() != "http" {
919            return Err(KontextDevError::OAuthCallbackError {
920                error: "redirect_uri must use http".to_string(),
921            });
922        }
923
924        if parsed.query().is_some() || parsed.fragment().is_some() {
925            return Err(KontextDevError::OAuthCallbackError {
926                error: "redirect_uri must not include query parameters or fragments".to_string(),
927            });
928        }
929
930        let host = parsed
931            .host_str()
932            .ok_or_else(|| KontextDevError::OAuthCallbackError {
933                error: "redirect_uri host is missing".to_string(),
934            })?;
935
936        let port = parsed
937            .port()
938            .ok_or_else(|| KontextDevError::OAuthCallbackError {
939                error: "redirect_uri must include an explicit port".to_string(),
940            })?;
941
942        let callback_path = parsed.path().to_string();
943        let bind_addr = if host.contains(':') {
944            format!("[{host}]:{port}")
945        } else {
946            format!("{host}:{port}")
947        };
948
949        let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
950            KontextDevError::OAuthCallbackError {
951                error: format!("failed to start callback server at {bind_addr}: {err}"),
952            }
953        })?);
954
955        let (tx, rx) = oneshot::channel();
956        let _join = spawn_callback_server(server.clone(), callback_path, tx);
957        let _guard = CallbackServerGuard { server };
958        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
959
960        let fut = async move {
961            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
962                .await
963                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
964                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
965            drop(_guard);
966            Ok(payload)
967        };
968
969        Ok((redirect_uri, fut))
970    }
971
972    fn token_cache_path(&self) -> Option<PathBuf> {
973        if let Some(explicit) = &self.config.token_cache_path {
974            return Some(PathBuf::from(explicit));
975        }
976
977        let home = dirs::home_dir()?;
978        let mut path = home;
979        path.push(".kontext-dev");
980        path.push("tokens");
981
982        let sanitized_client_id: String = self
983            .config
984            .client_id
985            .chars()
986            .map(|ch| {
987                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
988                    ch
989                } else {
990                    '_'
991                }
992            })
993            .collect();
994
995        path.push(format!("{sanitized_client_id}.json"));
996        Some(path)
997    }
998
999    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
1000        let Some(path) = self.token_cache_path() else {
1001            return Ok(None);
1002        };
1003
1004        let raw = match fs::read_to_string(&path) {
1005            Ok(raw) => raw,
1006            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
1007            Err(source) => {
1008                return Err(KontextDevError::TokenCacheRead {
1009                    path: path.display().to_string(),
1010                    source,
1011                });
1012            }
1013        };
1014
1015        serde_json::from_str(&raw).map(Some).map_err(|source| {
1016            KontextDevError::TokenCacheDeserialize {
1017                path: path.display().to_string(),
1018                source,
1019            }
1020        })
1021    }
1022
1023    fn write_cache(
1024        &self,
1025        identity: &AccessToken,
1026        gateway: &TokenExchangeToken,
1027    ) -> Result<(), KontextDevError> {
1028        let Some(path) = self.token_cache_path() else {
1029            return Ok(());
1030        };
1031
1032        if let Some(parent) = path.parent() {
1033            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
1034                path: parent.display().to_string(),
1035                source,
1036            })?;
1037        }
1038
1039        let payload = TokenCacheFile {
1040            client_id: self.config.client_id.clone(),
1041            resource: self.config.resource.clone(),
1042            identity: CachedAccessToken::from_access_token(identity)?,
1043            gateway: CachedAccessToken::from_token_exchange(gateway)?,
1044        };
1045
1046        let serialized = serde_json::to_string_pretty(&payload)
1047            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
1048
1049        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
1050            path: path.display().to_string(),
1051            source,
1052        })
1053    }
1054}
1055
1056/// Legacy helper that uses the client-credentials grant.
1057///
1058/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
1059pub async fn request_access_token(
1060    config: &KontextDevConfig,
1061) -> Result<AccessToken, KontextDevError> {
1062    let token_url = resolve_token_url(config)?;
1063
1064    let mut body = vec![("grant_type", "client_credentials".to_string())];
1065
1066    if let Some(scope) = normalized_scope(&config.scope) {
1067        body.push(("scope", scope.to_string()));
1068    }
1069
1070    let auth_header = if let Some(secret) = &config.client_secret {
1071        let raw = format!("{}:{}", config.client_id, secret);
1072        Some(format!(
1073            "Basic {}",
1074            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1075        ))
1076    } else {
1077        body.push(("client_id", config.client_id.clone()));
1078        None
1079    };
1080
1081    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1082}
1083
1084async fn post_token(
1085    http: &Client,
1086    token_url: &str,
1087    authorization: Option<&str>,
1088    body: &[(impl AsRef<str>, String)],
1089) -> Result<AccessToken, KontextDevError> {
1090    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1091
1092    let response = post_form_with_optional_auth::<AccessToken>(
1093        http,
1094        token_url,
1095        authorization.map(ToString::to_string),
1096        &body_vec,
1097    )
1098    .await
1099    .map_err(|message| KontextDevError::TokenRequest {
1100        token_url: token_url.to_string(),
1101        message,
1102    })?;
1103
1104    Ok(response)
1105}
1106
1107async fn post_form_with_optional_auth<T>(
1108    http: &Client,
1109    url: &str,
1110    authorization: Option<String>,
1111    body: &[(&str, String)],
1112) -> Result<T, String>
1113where
1114    T: DeserializeOwned,
1115{
1116    let mut request = http
1117        .post(url)
1118        .header("Content-Type", "application/x-www-form-urlencoded");
1119
1120    if let Some(header) = authorization {
1121        request = request.header("Authorization", header);
1122    }
1123
1124    let form = body
1125        .iter()
1126        .map(|(k, v)| (k.to_string(), v.to_string()))
1127        .collect::<Vec<(String, String)>>();
1128
1129    let response = request
1130        .form(&form)
1131        .send()
1132        .await
1133        .map_err(|err| err.to_string())?;
1134
1135    if !response.status().is_success() {
1136        return Err(build_error_message(response).await);
1137    }
1138
1139    response.json::<T>().await.map_err(|err| err.to_string())
1140}
1141
1142async fn build_error_message(response: reqwest::Response) -> String {
1143    let status = response.status();
1144    let fallback = format!(
1145        "{} {}",
1146        status.as_u16(),
1147        status.canonical_reason().unwrap_or("")
1148    );
1149
1150    let body = response.text().await.unwrap_or_default();
1151    if body.is_empty() {
1152        return fallback.trim().to_string();
1153    }
1154
1155    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1156        if let Some(description) = parsed.error_description {
1157            return description;
1158        }
1159        if let Some(message) = parsed.message {
1160            return message;
1161        }
1162        if let Some(error) = parsed.error {
1163            return error;
1164        }
1165    }
1166
1167    format!("{fallback}: {body}")
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172    use super::*;
1173
1174    fn config() -> KontextDevConfig {
1175        KontextDevConfig {
1176            server: "http://localhost:4000".to_string(),
1177            client_id: "client_123".to_string(),
1178            client_secret: None,
1179            scope: DEFAULT_SCOPE.to_string(),
1180            server_name: DEFAULT_SERVER_NAME.to_string(),
1181            resource: DEFAULT_RESOURCE.to_string(),
1182            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1183            integration_return_to: None,
1184            open_connect_page_on_login: true,
1185            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1186            token_cache_path: None,
1187            redirect_uri: "http://localhost:3333/callback".to_string(),
1188        }
1189    }
1190
1191    #[test]
1192    fn create_connect_url_uses_hosted_ui() {
1193        let client = KontextDevClient::new(config());
1194        let url = client
1195            .integration_connect_url("session-123")
1196            .expect("url should be built");
1197        assert_eq!(
1198            url,
1199            "https://app.kontext.dev/oauth/connect?session=session-123"
1200        );
1201    }
1202
1203    #[test]
1204    fn jwt_exp_decode_reads_exp() {
1205        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1206        let payload =
1207            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1208        let token = format!("{header}.{payload}.sig");
1209        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1210    }
1211
1212    #[test]
1213    fn oauth_metadata_parses_authorization_endpoint() {
1214        let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1215            r#"{
1216              "issuer": "https://issuer.example.com",
1217              "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1218            }"#,
1219        )
1220        .expect("metadata should parse");
1221
1222        assert_eq!(
1223            metadata.authorization_endpoint.as_deref(),
1224            Some("https://issuer.example.com/oauth2/auth")
1225        );
1226    }
1227
1228    #[test]
1229    fn normalized_scope_omits_blank_values() {
1230        assert_eq!(normalized_scope(""), None);
1231        assert_eq!(normalized_scope("   "), None);
1232    }
1233
1234    #[test]
1235    fn normalized_scope_preserves_non_empty_values() {
1236        assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1237        assert_eq!(
1238            normalized_scope("  mcp:invoke openid  "),
1239            Some("mcp:invoke openid")
1240        );
1241    }
1242}