Skip to main content

kontext_dev_sdk/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub mod client;
21pub mod errors;
22mod http_client;
23pub mod management;
24pub mod mcp;
25pub mod oauth;
26pub mod orchestrator;
27pub mod prompt_guidance;
28pub mod server;
29pub mod verify;
30
31pub use client::ClientState;
32pub use client::ConnectSessionResult;
33pub use client::IntegrationInfo;
34pub use client::KontextClient as SdkKontextClient;
35pub use client::KontextClientConfig;
36pub use client::ToolResult;
37pub use client::create_kontext_client;
38pub use errors::*;
39pub use management::KontextManagementClient;
40pub use management::KontextManagementClientConfig;
41pub use management::*;
42pub use mcp::KontextMcp;
43pub use mcp::KontextMcpConfig;
44pub use mcp::RuntimeIntegrationCategory;
45pub use mcp::RuntimeIntegrationConnectType;
46pub use mcp::RuntimeIntegrationRecord;
47pub use oauth::KontextOAuthProvider;
48pub use oauth::KontextOAuthProviderConfig;
49pub use oauth::ParsedOAuthCallback;
50pub use oauth::TokenExchangeConfig;
51pub use oauth::exchange_token;
52pub use oauth::parse_oauth_callback;
53pub use orchestrator::KontextOrchestrator;
54pub use orchestrator::KontextOrchestratorConfig;
55pub use orchestrator::KontextOrchestratorState;
56pub use orchestrator::create_kontext_orchestrator;
57pub use prompt_guidance::KontextPromptGuidance;
58pub use prompt_guidance::build_kontext_prompt_guidance;
59pub use server::IntegrationCredential;
60pub use server::IntegrationName;
61pub use server::IntegrationResolvedCredentials;
62pub use server::KnownIntegration;
63pub use server::Kontext;
64pub use server::KontextOptions;
65pub use server::MiddlewareOptions;
66pub use verify::JwksClient;
67pub use verify::KontextTokenVerifier;
68pub use verify::KontextTokenVerifierConfig;
69pub use verify::TokenVerificationError;
70pub use verify::TokenVerificationErrorCode;
71pub use verify::VerifiedTokenClaims;
72pub use verify::VerifyResult;
73
74pub use kontext_dev_sdk_core::AccessToken;
75pub use kontext_dev_sdk_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
76pub use kontext_dev_sdk_core::DEFAULT_RESOURCE;
77pub use kontext_dev_sdk_core::DEFAULT_SCOPE;
78pub use kontext_dev_sdk_core::DEFAULT_SERVER;
79pub use kontext_dev_sdk_core::DEFAULT_SERVER_NAME;
80pub use kontext_dev_sdk_core::KontextDevConfig;
81pub use kontext_dev_sdk_core::KontextDevCoreError;
82pub use kontext_dev_sdk_core::TokenExchangeToken;
83pub use kontext_dev_sdk_core::normalize_server_url;
84pub use kontext_dev_sdk_core::resolve_authorize_url;
85pub use kontext_dev_sdk_core::resolve_connect_session_url;
86pub use kontext_dev_sdk_core::resolve_integration_connection_url;
87pub use kontext_dev_sdk_core::resolve_integration_oauth_init_url;
88pub use kontext_dev_sdk_core::resolve_mcp_url;
89pub use kontext_dev_sdk_core::resolve_server_base_url;
90pub use kontext_dev_sdk_core::resolve_token_url;
91
92const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
93const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
94const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
95
96#[derive(Debug, thiserror::Error)]
97pub enum KontextDevError {
98    #[error(transparent)]
99    Core(#[from] kontext_dev_sdk_core::KontextDevCoreError),
100    #[error("failed to parse URL `{url}`")]
101    InvalidUrl {
102        url: String,
103        source: url::ParseError,
104    },
105    #[error("failed to open browser for OAuth authorization")]
106    BrowserOpenFailed,
107    #[error("OAuth callback timed out after {timeout_seconds}s")]
108    OAuthCallbackTimeout { timeout_seconds: i64 },
109    #[error("OAuth callback channel was unexpectedly cancelled")]
110    OAuthCallbackCancelled,
111    #[error("OAuth callback is missing the authorization code")]
112    MissingAuthorizationCode,
113    #[error("OAuth callback returned an error: {error}")]
114    OAuthCallbackError { error: String },
115    #[error("OAuth state mismatch")]
116    InvalidOAuthState,
117    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
118    TokenRequest { token_url: String, message: String },
119    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
120    TokenExchange { resource: String, message: String },
121    #[error("Kontext-Dev connect session request failed: {message}")]
122    ConnectSession { message: String },
123    #[error("Kontext-Dev integration OAuth init failed: {message}")]
124    IntegrationOAuthInit { message: String },
125    #[error("failed to read token cache at `{path}`: {source}")]
126    TokenCacheRead {
127        path: String,
128        source: std::io::Error,
129    },
130    #[error("failed to write token cache at `{path}`: {source}")]
131    TokenCacheWrite {
132        path: String,
133        source: std::io::Error,
134    },
135    #[error("failed to deserialize token cache at `{path}`: {source}")]
136    TokenCacheDeserialize {
137        path: String,
138        source: serde_json::Error,
139    },
140    #[error("failed to serialize token cache: {source}")]
141    TokenCacheSerialize { source: serde_json::Error },
142    #[error("Kontext-Dev access token is empty")]
143    EmptyAccessToken,
144    #[error("missing integration UI URL; set `integration_ui_url` in config")]
145    MissingIntegrationUiUrl,
146}
147
148#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
149pub struct ConnectSession {
150    #[serde(rename = "sessionId")]
151    pub session_id: String,
152    #[serde(rename = "expiresAt")]
153    pub expires_at: String,
154}
155
156#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
157pub struct IntegrationOAuthInitResponse {
158    #[serde(rename = "authorizationUrl")]
159    pub authorization_url: String,
160    #[serde(default)]
161    pub state: Option<String>,
162}
163
164#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
165pub struct IntegrationConnectionStatus {
166    pub connected: bool,
167    #[serde(default)]
168    pub expires_at: Option<String>,
169}
170
171#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
172pub struct KontextAuthSession {
173    pub identity_token: AccessToken,
174    pub gateway_token: TokenExchangeToken,
175    pub browser_auth_performed: bool,
176}
177
178#[derive(Clone, Debug, Deserialize, Serialize)]
179struct CachedAccessToken {
180    access_token: String,
181    token_type: String,
182    refresh_token: Option<String>,
183    scope: Option<String>,
184    expires_at_unix_ms: Option<u64>,
185}
186
187#[derive(Clone, Debug, Deserialize, Serialize)]
188struct TokenCacheFile {
189    client_id: String,
190    resource: String,
191    identity: CachedAccessToken,
192    gateway: CachedAccessToken,
193}
194
195impl CachedAccessToken {
196    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
197        if token.access_token.is_empty() {
198            return Err(KontextDevError::EmptyAccessToken);
199        }
200
201        Ok(Self {
202            access_token: token.access_token.clone(),
203            token_type: token.token_type.clone(),
204            refresh_token: token.refresh_token.clone(),
205            scope: token.scope.clone(),
206            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
207        })
208    }
209
210    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
211        if token.access_token.is_empty() {
212            return Err(KontextDevError::EmptyAccessToken);
213        }
214
215        Ok(Self {
216            access_token: token.access_token.clone(),
217            token_type: token.token_type.clone(),
218            refresh_token: token.refresh_token.clone(),
219            scope: token.scope.clone(),
220            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
221        })
222    }
223
224    fn is_valid(&self) -> bool {
225        match self.expires_at_unix_ms {
226            Some(expires_at) => {
227                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
228                now_unix_ms().saturating_add(buffer_ms) < expires_at
229            }
230            None => true,
231        }
232    }
233
234    fn to_access_token(&self) -> AccessToken {
235        AccessToken {
236            access_token: self.access_token.clone(),
237            token_type: self.token_type.clone(),
238            expires_in: self
239                .expires_at_unix_ms
240                .and_then(unix_ms_to_relative_seconds),
241            refresh_token: self.refresh_token.clone(),
242            scope: self.scope.clone(),
243        }
244    }
245
246    fn to_token_exchange_token(&self) -> TokenExchangeToken {
247        TokenExchangeToken {
248            access_token: self.access_token.clone(),
249            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
250            token_type: self.token_type.clone(),
251            expires_in: self
252                .expires_at_unix_ms
253                .and_then(unix_ms_to_relative_seconds),
254            scope: self.scope.clone(),
255            refresh_token: self.refresh_token.clone(),
256        }
257    }
258}
259
260fn now_unix_ms() -> u64 {
261    SystemTime::now()
262        .duration_since(UNIX_EPOCH)
263        .unwrap_or_else(|_| Duration::from_secs(0))
264        .as_millis() as u64
265}
266
267fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
268    if unix_ms <= now_unix_ms() {
269        return Some(0);
270    }
271
272    let delta_ms = unix_ms - now_unix_ms();
273    let secs = delta_ms / 1000;
274    i64::try_from(secs).ok()
275}
276
277fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
278    if let Some(expires_in) = expires_in
279        && expires_in > 0
280    {
281        let expires_in_u64 = u64::try_from(expires_in).ok()?;
282        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
283    }
284
285    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
286}
287
288fn decode_jwt_exp(token: &str) -> Option<u64> {
289    let mut parts = token.split('.');
290    let _header = parts.next()?;
291    let payload = parts.next()?;
292
293    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
294        .decode(payload)
295        .ok()?;
296    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
297    value.get("exp").and_then(|exp| {
298        exp.as_u64()
299            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
300    })
301}
302
303#[derive(Clone, Debug, Deserialize)]
304struct OAuthErrorBody {
305    error: Option<String>,
306    error_description: Option<String>,
307    message: Option<String>,
308}
309
310#[derive(Clone, Debug, Deserialize)]
311struct OAuthAuthorizationServerMetadata {
312    authorization_endpoint: Option<String>,
313}
314
315#[derive(Clone, Debug, Deserialize)]
316struct OAuthCallbackPayload {
317    code: Option<String>,
318    state: Option<String>,
319    error: Option<String>,
320    error_description: Option<String>,
321}
322
323#[derive(Clone, Debug)]
324struct PkcePair {
325    verifier: String,
326    challenge: String,
327}
328
329fn generate_pkce_pair() -> PkcePair {
330    let mut raw = [0u8; 64];
331    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
332
333    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
334    let digest = sha2::Sha256::digest(verifier.as_bytes());
335    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
336
337    PkcePair {
338        verifier,
339        challenge,
340    }
341}
342
343fn generate_state() -> String {
344    let mut bytes = [0u8; 16];
345    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
346    bytes.iter().map(|b| format!("{b:02x}")).collect()
347}
348
349fn normalized_scope(scope: &str) -> Option<&str> {
350    let trimmed = scope.trim();
351    if trimmed.is_empty() {
352        None
353    } else {
354        Some(trimmed)
355    }
356}
357
358struct CallbackServerGuard {
359    server: Arc<Server>,
360}
361
362impl Drop for CallbackServerGuard {
363    fn drop(&mut self) {
364        self.server.unblock();
365    }
366}
367
368fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
369    let full = format!("http://localhost{url_path}");
370    let parsed = Url::parse(&full).ok()?;
371
372    if parsed.path() != callback_path {
373        return None;
374    }
375
376    let mut payload = OAuthCallbackPayload {
377        code: None,
378        state: None,
379        error: None,
380        error_description: None,
381    };
382
383    for (key, value) in parsed.query_pairs() {
384        match key.as_ref() {
385            "code" => payload.code = Some(value.to_string()),
386            "state" => payload.state = Some(value.to_string()),
387            "error" => payload.error = Some(value.to_string()),
388            "error_description" => payload.error_description = Some(value.to_string()),
389            _ => {}
390        }
391    }
392
393    Some(payload)
394}
395
396fn spawn_callback_server(
397    server: Arc<Server>,
398    callback_path: String,
399    tx: oneshot::Sender<OAuthCallbackPayload>,
400) -> tokio::task::JoinHandle<()> {
401    tokio::task::spawn_blocking(move || {
402        while let Ok(request) = server.recv() {
403            let path = request.url().to_string();
404            if let Some(payload) = parse_callback(&path, &callback_path) {
405                let response = Response::from_string(
406                    "Authentication complete. You can return to your terminal.",
407                );
408                let _ = request.respond(response);
409                let _ = tx.send(payload);
410                break;
411            }
412
413            let response = Response::from_string("Invalid callback").with_status_code(400);
414            let _ = request.respond(response);
415        }
416    })
417}
418
419#[derive(Clone, Debug)]
420pub struct KontextDevClient {
421    config: KontextDevConfig,
422    http: Client,
423}
424
425impl KontextDevClient {
426    pub fn new(config: KontextDevConfig) -> Self {
427        Self {
428            config,
429            http: Client::new(),
430        }
431    }
432
433    pub fn config(&self) -> &KontextDevConfig {
434        &self.config
435    }
436
437    pub fn server_base_url(&self) -> Result<String, KontextDevError> {
438        resolve_server_base_url(&self.config).map_err(KontextDevError::from)
439    }
440
441    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
442        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
443    }
444
445    pub fn token_url(&self) -> Result<String, KontextDevError> {
446        resolve_token_url(&self.config).map_err(KontextDevError::from)
447    }
448
449    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
450        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
451    }
452
453    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
454        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
455    }
456
457    pub fn clear_token_cache(&self) -> Result<(), KontextDevError> {
458        let Some(path) = self.token_cache_path() else {
459            return Ok(());
460        };
461
462        if !path.exists() {
463            return Ok(());
464        }
465
466        fs::remove_file(&path).map_err(|source| KontextDevError::TokenCacheWrite {
467            path: path.display().to_string(),
468            source,
469        })
470    }
471
472    /// Authenticate a user via browser PKCE, exchange the identity token to a
473    /// resource-scoped MCP gateway token, and persist cache for future runs.
474    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
475        let resource = self.config.resource.clone();
476
477        if let Some(cache) = self.read_cache()?
478            && cache.client_id == self.config.client_id
479            && cache.resource == resource
480            && cache.gateway.is_valid()
481            && cache.identity.is_valid()
482        {
483            return Ok(KontextAuthSession {
484                identity_token: cache.identity.to_access_token(),
485                gateway_token: cache.gateway.to_token_exchange_token(),
486                browser_auth_performed: false,
487            });
488        }
489
490        if let Some(cache) = self.read_cache()?
491            && cache.client_id == self.config.client_id
492            && cache.resource == resource
493        {
494            let maybe_refreshed = if cache.identity.is_valid() {
495                Some(cache.identity.to_access_token())
496            } else if let Some(refresh_token) = &cache.identity.refresh_token {
497                Some(self.refresh_identity_token(refresh_token).await?)
498            } else {
499                None
500            };
501
502            if let Some(identity_token) = maybe_refreshed {
503                let gateway_token = self
504                    .exchange_for_resource(&identity_token.access_token, &resource, None)
505                    .await?;
506                self.write_cache(&identity_token, &gateway_token)?;
507                return Ok(KontextAuthSession {
508                    identity_token,
509                    gateway_token,
510                    browser_auth_performed: false,
511                });
512            }
513        }
514
515        let identity_token = self.authorize_with_browser_pkce().await?;
516        let gateway_token = self
517            .exchange_for_resource(&identity_token.access_token, &resource, None)
518            .await?;
519
520        self.write_cache(&identity_token, &gateway_token)?;
521
522        Ok(KontextAuthSession {
523            identity_token,
524            gateway_token,
525            browser_auth_performed: true,
526        })
527    }
528
529    /// Request a short-lived connect session and build the hosted integration
530    /// connect URL (`.../oauth/connect?session=...`).
531    pub async fn create_integration_connect_url(
532        &self,
533        gateway_access_token: &str,
534    ) -> Result<String, KontextDevError> {
535        let session = self.create_connect_session(gateway_access_token).await?;
536        self.integration_connect_url(&session.session_id)
537    }
538
539    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
540        if session_id.trim().is_empty() {
541            return Err(KontextDevError::ConnectSession {
542                message: "connect session id is empty".to_string(),
543            });
544        }
545
546        let base = if let Some(explicit) = &self.config.integration_ui_url {
547            explicit.trim_end_matches('/').to_string()
548        } else {
549            let server = resolve_server_base_url(&self.config)?;
550            if server.contains("api.kontext.dev") {
551                "https://app.kontext.dev".to_string()
552            } else {
553                server
554            }
555        };
556
557        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
558            url: base.clone(),
559            source,
560        })?;
561        url.set_path("/oauth/connect");
562        url.query_pairs_mut().append_pair("session", session_id);
563        Ok(url.to_string())
564    }
565
566    /// Opens the integration connect page in the browser.
567    pub async fn open_integration_connect_page(
568        &self,
569        gateway_access_token: &str,
570    ) -> Result<String, KontextDevError> {
571        let url = self
572            .create_integration_connect_url(gateway_access_token)
573            .await?;
574
575        if webbrowser::open(&url).is_err() {
576            return Err(KontextDevError::BrowserOpenFailed);
577        }
578
579        Ok(url)
580    }
581
582    pub async fn create_connect_session(
583        &self,
584        gateway_access_token: &str,
585    ) -> Result<ConnectSession, KontextDevError> {
586        let url = self.connect_session_url()?;
587        let response = self
588            .http
589            .post(&url)
590            .header("Authorization", format!("Bearer {gateway_access_token}"))
591            .json(&serde_json::json!({}))
592            .send()
593            .await
594            .map_err(|err| KontextDevError::ConnectSession {
595                message: err.to_string(),
596            })?;
597
598        if !response.status().is_success() {
599            let message = build_error_message(response).await;
600            return Err(KontextDevError::ConnectSession { message });
601        }
602
603        response
604            .json::<ConnectSession>()
605            .await
606            .map_err(|err| KontextDevError::ConnectSession {
607                message: err.to_string(),
608            })
609    }
610
611    pub async fn initiate_integration_oauth(
612        &self,
613        gateway_access_token: &str,
614        integration_id: &str,
615        return_to: Option<&str>,
616    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
617        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
618
619        let payload = return_to
620            .map(|value| serde_json::json!({ "returnTo": value }))
621            .unwrap_or_else(|| serde_json::json!({}));
622
623        let request = self
624            .http
625            .post(&url)
626            .header("Authorization", format!("Bearer {gateway_access_token}"))
627            .json(&payload);
628
629        let response =
630            request
631                .send()
632                .await
633                .map_err(|err| KontextDevError::IntegrationOAuthInit {
634                    message: err.to_string(),
635                })?;
636
637        if !response.status().is_success() {
638            let message = build_error_message(response).await;
639            return Err(KontextDevError::IntegrationOAuthInit { message });
640        }
641
642        response
643            .json::<IntegrationOAuthInitResponse>()
644            .await
645            .map_err(|err| KontextDevError::IntegrationOAuthInit {
646                message: err.to_string(),
647            })
648    }
649
650    pub async fn integration_connection_status(
651        &self,
652        gateway_access_token: &str,
653        integration_id: &str,
654    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
655        let url = resolve_integration_connection_url(&self.config, integration_id)?;
656        let response = self
657            .http
658            .get(url)
659            .header("Authorization", format!("Bearer {gateway_access_token}"))
660            .send()
661            .await
662            .map_err(|err| KontextDevError::IntegrationOAuthInit {
663                message: err.to_string(),
664            })?;
665
666        if !response.status().is_success() {
667            let message = build_error_message(response).await;
668            return Err(KontextDevError::IntegrationOAuthInit { message });
669        }
670
671        response
672            .json::<IntegrationConnectionStatus>()
673            .await
674            .map_err(|err| KontextDevError::IntegrationOAuthInit {
675                message: err.to_string(),
676            })
677    }
678
679    pub async fn wait_for_integration_connection(
680        &self,
681        gateway_access_token: &str,
682        integration_id: &str,
683        timeout_ms: u64,
684        interval_ms: u64,
685    ) -> Result<bool, KontextDevError> {
686        let started = now_unix_ms();
687
688        loop {
689            let status = self
690                .integration_connection_status(gateway_access_token, integration_id)
691                .await?;
692            if status.connected {
693                return Ok(true);
694            }
695
696            if now_unix_ms().saturating_sub(started) >= timeout_ms {
697                return Ok(false);
698            }
699
700            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
701        }
702    }
703
704    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
705        let auth_url = self.resolve_authorization_url().await?;
706        let token_url = self.token_url()?;
707        let pkce = generate_pkce_pair();
708        let state = generate_state();
709
710        let (callback_url, callback_payload) = self.listen_for_callback().await?;
711
712        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
713            url: auth_url.clone(),
714            source,
715        })?;
716
717        {
718            let mut query = url.query_pairs_mut();
719            query
720                .append_pair("client_id", &self.config.client_id)
721                .append_pair("response_type", "code")
722                .append_pair("redirect_uri", &callback_url)
723                .append_pair("state", &state);
724
725            if let Some(scope) = normalized_scope(&self.config.scope) {
726                query.append_pair("scope", scope);
727            }
728
729            query
730                .append_pair("code_challenge_method", "S256")
731                .append_pair("code_challenge", &pkce.challenge);
732        }
733
734        if webbrowser::open(url.as_str()).is_err() {
735            return Err(KontextDevError::BrowserOpenFailed);
736        }
737
738        let payload = callback_payload.await?;
739
740        if let Some(error) = payload.error {
741            let with_details = payload
742                .error_description
743                .map(|description| format!("{error}: {description}"))
744                .unwrap_or(error);
745            return Err(KontextDevError::OAuthCallbackError {
746                error: with_details,
747            });
748        }
749
750        if payload.state.as_deref() != Some(state.as_str()) {
751            return Err(KontextDevError::InvalidOAuthState);
752        }
753
754        let code = payload
755            .code
756            .ok_or(KontextDevError::MissingAuthorizationCode)?;
757
758        let mut body = vec![
759            ("grant_type", "authorization_code".to_string()),
760            ("code", code),
761            ("redirect_uri", callback_url),
762            ("client_id", self.config.client_id.clone()),
763            ("code_verifier", pkce.verifier),
764        ];
765
766        if let Some(client_secret) = &self.config.client_secret {
767            body.push(("client_secret", client_secret.clone()));
768        }
769
770        post_token(&self.http, &token_url, None, &body).await
771    }
772
773    async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
774        if let Some(discovered) = self.discover_authorization_endpoint().await {
775            return Ok(discovered);
776        }
777
778        let authorize_url = self.authorize_url()?;
779        if !self.endpoint_is_missing(&authorize_url).await {
780            return Ok(authorize_url);
781        }
782
783        let server_base = resolve_server_base_url(&self.config)?;
784        Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
785    }
786
787    async fn discover_authorization_endpoint(&self) -> Option<String> {
788        let base = resolve_server_base_url(&self.config).ok()?;
789        let base = base.trim_end_matches('/');
790
791        let candidates = [
792            format!("{base}/.well-known/oauth-authorization-server/mcp"),
793            format!("{base}/.well-known/oauth-authorization-server"),
794        ];
795
796        for url in candidates {
797            let response = match self.http.get(&url).send().await {
798                Ok(response) => response,
799                Err(_) => continue,
800            };
801
802            if !response.status().is_success() {
803                continue;
804            }
805
806            let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
807                Ok(metadata) => metadata,
808                Err(_) => continue,
809            };
810
811            let Some(endpoint) = metadata.authorization_endpoint else {
812                continue;
813            };
814
815            if Url::parse(&endpoint).is_ok() {
816                return Some(endpoint);
817            }
818        }
819
820        None
821    }
822
823    async fn endpoint_is_missing(&self, url: &str) -> bool {
824        let probe_client = match reqwest::Client::builder()
825            .redirect(reqwest::redirect::Policy::none())
826            .build()
827        {
828            Ok(client) => client,
829            Err(_) => return false,
830        };
831
832        match probe_client.get(url).send().await {
833            Ok(response) => response.status() == StatusCode::NOT_FOUND,
834            Err(_) => false,
835        }
836    }
837
838    async fn refresh_identity_token(
839        &self,
840        refresh_token: &str,
841    ) -> Result<AccessToken, KontextDevError> {
842        let token_url = self.token_url()?;
843        let mut body = vec![
844            ("grant_type", "refresh_token".to_string()),
845            ("refresh_token", refresh_token.to_string()),
846            ("client_id", self.config.client_id.clone()),
847        ];
848
849        if let Some(client_secret) = &self.config.client_secret {
850            body.push(("client_secret", client_secret.clone()));
851        }
852
853        post_token(&self.http, &token_url, None, &body).await
854    }
855
856    pub async fn exchange_for_resource(
857        &self,
858        subject_token: &str,
859        resource: &str,
860        scope: Option<&str>,
861    ) -> Result<TokenExchangeToken, KontextDevError> {
862        let token_url = self.token_url()?;
863
864        let mut body = vec![
865            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
866            ("subject_token", subject_token.to_string()),
867            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
868            ("resource", resource.to_string()),
869        ];
870
871        if let Some(scope) = scope {
872            body.push(("scope", scope.to_string()));
873        }
874
875        let auth_header = self.config.client_secret.as_ref().map(|secret| {
876            let raw = format!("{}:{}", self.config.client_id, secret);
877            format!(
878                "Basic {}",
879                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
880            )
881        });
882
883        if self.config.client_secret.is_none() {
884            body.push(("client_id", self.config.client_id.clone()));
885        }
886
887        let response = post_form_with_optional_auth::<TokenExchangeToken>(
888            &self.http,
889            &token_url,
890            auth_header,
891            &body,
892        )
893        .await
894        .map_err(|message| KontextDevError::TokenExchange {
895            resource: resource.to_string(),
896            message,
897        })?;
898
899        if response.access_token.is_empty() {
900            return Err(KontextDevError::EmptyAccessToken);
901        }
902
903        Ok(response)
904    }
905
906    async fn listen_for_callback(
907        &self,
908    ) -> Result<
909        (
910            String,
911            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
912        ),
913        KontextDevError,
914    > {
915        let redirect_uri = self.config.redirect_uri.trim().to_string();
916        let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
917            url: redirect_uri.clone(),
918            source,
919        })?;
920
921        if parsed.scheme() != "http" {
922            return Err(KontextDevError::OAuthCallbackError {
923                error: "redirect_uri must use http".to_string(),
924            });
925        }
926
927        if parsed.query().is_some() || parsed.fragment().is_some() {
928            return Err(KontextDevError::OAuthCallbackError {
929                error: "redirect_uri must not include query parameters or fragments".to_string(),
930            });
931        }
932
933        let host = parsed
934            .host_str()
935            .ok_or_else(|| KontextDevError::OAuthCallbackError {
936                error: "redirect_uri host is missing".to_string(),
937            })?;
938
939        let port = parsed
940            .port()
941            .ok_or_else(|| KontextDevError::OAuthCallbackError {
942                error: "redirect_uri must include an explicit port".to_string(),
943            })?;
944
945        let callback_path = parsed.path().to_string();
946        let bind_addr = if host.contains(':') {
947            format!("[{host}]:{port}")
948        } else {
949            format!("{host}:{port}")
950        };
951
952        let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
953            KontextDevError::OAuthCallbackError {
954                error: format!("failed to start callback server at {bind_addr}: {err}"),
955            }
956        })?);
957
958        let (tx, rx) = oneshot::channel();
959        let _join = spawn_callback_server(server.clone(), callback_path, tx);
960        let _guard = CallbackServerGuard { server };
961        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
962
963        let fut = async move {
964            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
965                .await
966                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
967                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
968            drop(_guard);
969            Ok(payload)
970        };
971
972        Ok((redirect_uri, fut))
973    }
974
975    fn token_cache_path(&self) -> Option<PathBuf> {
976        if let Some(explicit) = &self.config.token_cache_path {
977            return Some(PathBuf::from(explicit));
978        }
979
980        let home = dirs::home_dir()?;
981        let mut path = home;
982        path.push(".kontext-dev");
983        path.push("tokens");
984
985        let sanitized_client_id: String = self
986            .config
987            .client_id
988            .chars()
989            .map(|ch| {
990                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
991                    ch
992                } else {
993                    '_'
994                }
995            })
996            .collect();
997
998        path.push(format!("{sanitized_client_id}.json"));
999        Some(path)
1000    }
1001
1002    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
1003        let Some(path) = self.token_cache_path() else {
1004            return Ok(None);
1005        };
1006
1007        let raw = match fs::read_to_string(&path) {
1008            Ok(raw) => raw,
1009            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
1010            Err(source) => {
1011                return Err(KontextDevError::TokenCacheRead {
1012                    path: path.display().to_string(),
1013                    source,
1014                });
1015            }
1016        };
1017
1018        serde_json::from_str(&raw).map(Some).map_err(|source| {
1019            KontextDevError::TokenCacheDeserialize {
1020                path: path.display().to_string(),
1021                source,
1022            }
1023        })
1024    }
1025
1026    fn write_cache(
1027        &self,
1028        identity: &AccessToken,
1029        gateway: &TokenExchangeToken,
1030    ) -> Result<(), KontextDevError> {
1031        let Some(path) = self.token_cache_path() else {
1032            return Ok(());
1033        };
1034
1035        if let Some(parent) = path.parent() {
1036            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
1037                path: parent.display().to_string(),
1038                source,
1039            })?;
1040        }
1041
1042        let payload = TokenCacheFile {
1043            client_id: self.config.client_id.clone(),
1044            resource: self.config.resource.clone(),
1045            identity: CachedAccessToken::from_access_token(identity)?,
1046            gateway: CachedAccessToken::from_token_exchange(gateway)?,
1047        };
1048
1049        let serialized = serde_json::to_string_pretty(&payload)
1050            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
1051
1052        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
1053            path: path.display().to_string(),
1054            source,
1055        })
1056    }
1057}
1058
1059/// Legacy helper that uses the client-credentials grant.
1060///
1061/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
1062pub async fn request_access_token(
1063    config: &KontextDevConfig,
1064) -> Result<AccessToken, KontextDevError> {
1065    let token_url = resolve_token_url(config)?;
1066
1067    let mut body = vec![("grant_type", "client_credentials".to_string())];
1068
1069    if let Some(scope) = normalized_scope(&config.scope) {
1070        body.push(("scope", scope.to_string()));
1071    }
1072
1073    let auth_header = if let Some(secret) = &config.client_secret {
1074        let raw = format!("{}:{}", config.client_id, secret);
1075        Some(format!(
1076            "Basic {}",
1077            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1078        ))
1079    } else {
1080        body.push(("client_id", config.client_id.clone()));
1081        None
1082    };
1083
1084    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1085}
1086
1087async fn post_token(
1088    http: &Client,
1089    token_url: &str,
1090    authorization: Option<&str>,
1091    body: &[(impl AsRef<str>, String)],
1092) -> Result<AccessToken, KontextDevError> {
1093    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1094
1095    let response = post_form_with_optional_auth::<AccessToken>(
1096        http,
1097        token_url,
1098        authorization.map(ToString::to_string),
1099        &body_vec,
1100    )
1101    .await
1102    .map_err(|message| KontextDevError::TokenRequest {
1103        token_url: token_url.to_string(),
1104        message,
1105    })?;
1106
1107    Ok(response)
1108}
1109
1110async fn post_form_with_optional_auth<T>(
1111    http: &Client,
1112    url: &str,
1113    authorization: Option<String>,
1114    body: &[(&str, String)],
1115) -> Result<T, String>
1116where
1117    T: DeserializeOwned,
1118{
1119    let mut request = http
1120        .post(url)
1121        .header("Content-Type", "application/x-www-form-urlencoded");
1122
1123    if let Some(header) = authorization {
1124        request = request.header("Authorization", header);
1125    }
1126
1127    let form = body
1128        .iter()
1129        .map(|(k, v)| (k.to_string(), v.to_string()))
1130        .collect::<Vec<(String, String)>>();
1131
1132    let response = request
1133        .form(&form)
1134        .send()
1135        .await
1136        .map_err(|err| err.to_string())?;
1137
1138    if !response.status().is_success() {
1139        return Err(build_error_message(response).await);
1140    }
1141
1142    response.json::<T>().await.map_err(|err| err.to_string())
1143}
1144
1145async fn build_error_message(response: reqwest::Response) -> String {
1146    let status = response.status();
1147    let fallback = format!(
1148        "{} {}",
1149        status.as_u16(),
1150        status.canonical_reason().unwrap_or("")
1151    );
1152
1153    let body = response.text().await.unwrap_or_default();
1154    if body.is_empty() {
1155        return fallback.trim().to_string();
1156    }
1157
1158    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1159        if let Some(description) = parsed.error_description {
1160            return description;
1161        }
1162        if let Some(message) = parsed.message {
1163            return message;
1164        }
1165        if let Some(error) = parsed.error {
1166            return error;
1167        }
1168    }
1169
1170    format!("{fallback}: {body}")
1171}
1172
1173#[cfg(test)]
1174mod tests {
1175    use super::*;
1176
1177    fn config() -> KontextDevConfig {
1178        KontextDevConfig {
1179            server: "http://localhost:4000".to_string(),
1180            client_id: "client_123".to_string(),
1181            client_secret: None,
1182            scope: DEFAULT_SCOPE.to_string(),
1183            server_name: DEFAULT_SERVER_NAME.to_string(),
1184            resource: DEFAULT_RESOURCE.to_string(),
1185            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1186            integration_return_to: None,
1187            open_connect_page_on_login: true,
1188            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1189            token_cache_path: None,
1190            redirect_uri: "http://localhost:3333/callback".to_string(),
1191        }
1192    }
1193
1194    #[test]
1195    fn create_connect_url_uses_hosted_ui() {
1196        let client = KontextDevClient::new(config());
1197        let url = client
1198            .integration_connect_url("session-123")
1199            .expect("url should be built");
1200        assert_eq!(
1201            url,
1202            "https://app.kontext.dev/oauth/connect?session=session-123"
1203        );
1204    }
1205
1206    #[test]
1207    fn jwt_exp_decode_reads_exp() {
1208        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1209        let payload =
1210            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1211        let token = format!("{header}.{payload}.sig");
1212        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1213    }
1214
1215    #[test]
1216    fn oauth_metadata_parses_authorization_endpoint() {
1217        let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1218            r#"{
1219              "issuer": "https://issuer.example.com",
1220              "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1221            }"#,
1222        )
1223        .expect("metadata should parse");
1224
1225        assert_eq!(
1226            metadata.authorization_endpoint.as_deref(),
1227            Some("https://issuer.example.com/oauth2/auth")
1228        );
1229    }
1230
1231    #[test]
1232    fn normalized_scope_omits_blank_values() {
1233        assert_eq!(normalized_scope(""), None);
1234        assert_eq!(normalized_scope("   "), None);
1235    }
1236
1237    #[test]
1238    fn normalized_scope_preserves_non_empty_values() {
1239        assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1240        assert_eq!(
1241            normalized_scope("  mcp:invoke openid  "),
1242            Some("mcp:invoke openid")
1243        );
1244    }
1245}