Skip to main content

greentic_setup/
oauth_callback.rs

1//! Provider-agnostic OAuth callback completion helpers.
2
3use std::path::Path;
4
5use anyhow::{Context, Result, bail};
6use serde::{Deserialize, Serialize};
7use serde_json::{Map as JsonMap, Value};
8
9use crate::setup_actions::{OAuthMetadata, SetupActionStatus};
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct OAuthCallbackInput {
13    pub code: String,
14    pub state: String,
15}
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
18pub struct OAuthCallbackReport {
19    pub provider_id: String,
20    pub tenant: String,
21    pub team: String,
22    pub action_id: String,
23    pub persisted_secret_keys: Vec<String>,
24}
25
26pub fn load_provider_oauth_metadata(
27    bundle_root: &Path,
28    provider_id: &str,
29    extension_key: &str,
30) -> Result<OAuthMetadata> {
31    let discovered = crate::discovery::discover(bundle_root)
32        .context("failed to discover providers for OAuth callback")?;
33    let provider = discovered
34        .find_setup_target(provider_id)
35        .ok_or_else(|| anyhow::anyhow!("provider not found for OAuth callback: {provider_id}"))?;
36    let raw = crate::discovery::read_pack_extension(&provider.pack_path, extension_key)?
37        .ok_or_else(|| anyhow::anyhow!("provider missing OAuth metadata: {extension_key}"))?;
38    let metadata = raw.get("inline").cloned().unwrap_or(raw);
39    serde_json::from_value(metadata).context("failed to parse provider OAuth metadata")
40}
41
42pub async fn complete_oauth_callback_with_token_response(
43    bundle_root: &Path,
44    env: &str,
45    input: &OAuthCallbackInput,
46    token_response: &Value,
47    extension_key: &str,
48) -> Result<OAuthCallbackReport> {
49    if input.code.trim().is_empty() {
50        bail!("OAuth callback missing code");
51    }
52    let key = crate::setup_actions::load_or_create_signing_key(bundle_root)?;
53    let state = crate::setup_actions::validate_oauth_state(
54        &input.state,
55        &key,
56        None,
57        None,
58        None,
59        crate::setup_actions::current_epoch_secs(),
60    )?;
61    let action = crate::setup_actions::load_setup_action(
62        bundle_root,
63        &state.tenant,
64        &state.team,
65        &state.provider_id,
66        &state.action_id,
67    )?
68    .ok_or_else(|| anyhow::anyhow!("setup action not found: {}", state.action_id))?;
69    if action.status != SetupActionStatus::Pending {
70        bail!("setup action is not pending");
71    }
72
73    let metadata = load_provider_oauth_metadata(bundle_root, &state.provider_id, extension_key)?;
74    let mapped = crate::setup_actions::map_oauth_token_response(&metadata, token_response)?;
75    let config = Value::Object(
76        mapped
77            .iter()
78            .map(|(key, value)| (key.clone(), Value::String(value.clone())))
79            .collect::<JsonMap<_, _>>(),
80    );
81
82    crate::qa::persist::persist_all_config_as_secrets(
83        bundle_root,
84        env,
85        &state.tenant,
86        Some(&state.team),
87        &state.provider_id,
88        &config,
89        None,
90    )
91    .await?;
92    crate::setup_actions::mark_setup_action_complete(
93        bundle_root,
94        &state.tenant,
95        &state.team,
96        &state.provider_id,
97        &state.action_id,
98    )?;
99
100    Ok(OAuthCallbackReport {
101        provider_id: state.provider_id,
102        tenant: state.tenant,
103        team: state.team,
104        action_id: state.action_id,
105        persisted_secret_keys: mapped.keys().cloned().collect(),
106    })
107}
108
109pub async fn complete_oauth_callback(
110    bundle_root: &Path,
111    env: &str,
112    input: &OAuthCallbackInput,
113    extension_key: &str,
114) -> Result<OAuthCallbackReport> {
115    if input.code.trim().is_empty() {
116        bail!("OAuth callback missing code");
117    }
118    let key = crate::setup_actions::load_or_create_signing_key(bundle_root)?;
119    let state = crate::setup_actions::validate_oauth_state(
120        &input.state,
121        &key,
122        None,
123        None,
124        None,
125        crate::setup_actions::current_epoch_secs(),
126    )?;
127    let action = crate::setup_actions::load_setup_action(
128        bundle_root,
129        &state.tenant,
130        &state.team,
131        &state.provider_id,
132        &state.action_id,
133    )?
134    .ok_or_else(|| anyhow::anyhow!("setup action not found: {}", state.action_id))?;
135    if action.status != SetupActionStatus::Pending {
136        bail!("setup action is not pending");
137    }
138
139    let metadata = load_provider_oauth_metadata(bundle_root, &state.provider_id, extension_key)?;
140    let callback_path = action
141        .callback_path
142        .as_deref()
143        .or(metadata.redirect_path.as_deref())
144        .ok_or_else(|| anyhow::anyhow!("OAuth callback path is missing"))?;
145    let public_base_url = resolve_public_base_url(
146        bundle_root,
147        &state.tenant,
148        Some(&state.team),
149        &state.provider_id,
150    )?;
151    let redirect_uri = format!(
152        "{}{}",
153        public_base_url.trim_end_matches('/'),
154        ensure_leading_slash(callback_path)
155    );
156    let setup_answers = load_provider_setup_answers(bundle_root, &state.provider_id)?;
157    let client_id = first_nonempty(&setup_answers, &["client_id", "oauth_client_id"])
158        .ok_or_else(|| anyhow::anyhow!("OAuth client_id is missing from provider setup answers"))?;
159    let client_secret = first_nonempty(&setup_answers, &["client_secret", "oauth_client_secret"])
160        .ok_or_else(|| {
161        anyhow::anyhow!("OAuth client_secret is missing from provider setup answers")
162    })?;
163
164    let token_response = exchange_oauth_code(
165        &metadata,
166        input.code.trim(),
167        &redirect_uri,
168        &client_id,
169        &client_secret,
170    )?;
171
172    complete_oauth_callback_with_token_response(
173        bundle_root,
174        env,
175        input,
176        &token_response,
177        extension_key,
178    )
179    .await
180}
181
182pub fn exchange_oauth_code(
183    metadata: &OAuthMetadata,
184    code: &str,
185    redirect_uri: &str,
186    client_id: &str,
187    client_secret: &str,
188) -> Result<Value> {
189    let mut response = ureq::post(&metadata.token_url)
190        .send_form([
191            ("grant_type", "authorization_code"),
192            ("code", code),
193            ("redirect_uri", redirect_uri),
194            ("client_id", client_id),
195            ("client_secret", client_secret),
196        ])
197        .context("OAuth token exchange failed")?;
198    response
199        .body_mut()
200        .read_json::<Value>()
201        .context("failed to parse OAuth token response")
202}
203
204pub fn resolve_public_base_url(
205    bundle_root: &Path,
206    tenant: &str,
207    team: Option<&str>,
208    provider_id: &str,
209) -> Result<String> {
210    if let Some(value) = load_provider_setup_answers(bundle_root, provider_id)?
211        .get("public_base_url")
212        .and_then(Value::as_str)
213        .map(str::trim)
214        .filter(|value| !value.is_empty())
215    {
216        return Ok(value.to_string());
217    }
218
219    if let Some(policy) =
220        crate::platform_setup::load_effective_static_routes_defaults(bundle_root, tenant, team)?
221        && let Some(value) = policy.public_base_url
222    {
223        return Ok(value);
224    }
225
226    bail!("This provider requires a public_base_url to generate OAuth callback and webhook URLs.")
227}
228
229fn load_provider_setup_answers(bundle_root: &Path, provider_id: &str) -> Result<Value> {
230    let path = bundle_root
231        .join("state")
232        .join("config")
233        .join(provider_id)
234        .join("setup-answers.json");
235    if !path.exists() {
236        return Ok(Value::Object(JsonMap::new()));
237    }
238    let raw = std::fs::read_to_string(&path)
239        .with_context(|| format!("failed to read {}", path.display()))?;
240    serde_json::from_str(&raw).with_context(|| format!("failed to parse {}", path.display()))
241}
242
243fn first_nonempty(value: &Value, keys: &[&str]) -> Option<String> {
244    let obj = value.as_object()?;
245    keys.iter().find_map(|key| {
246        obj.get(*key)
247            .and_then(Value::as_str)
248            .map(str::trim)
249            .filter(|value| !value.is_empty())
250            .map(ToString::to_string)
251    })
252}
253
254fn ensure_leading_slash(value: &str) -> String {
255    if value.starts_with('/') {
256        value.to_string()
257    } else {
258        format!("/{value}")
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use greentic_secrets_lib::SecretsStore;
266    use serde_json::json;
267    use std::io::Write;
268    use zip::write::{FileOptions, ZipWriter};
269
270    fn write_provider_pack(path: &Path) -> anyhow::Result<()> {
271        write_provider_pack_with_manifest(
272            path,
273            json!({
274                "pack_id": "messaging-example",
275                "extensions": {
276                    "messaging.oauth.v1": {
277                        "token_url": "https://example.com/token",
278                        "secret_keys": ["EXAMPLE_TOKEN"]
279                    }
280                }
281            }),
282        )
283    }
284
285    fn write_provider_pack_with_manifest(
286        path: &Path,
287        manifest: serde_json::Value,
288    ) -> anyhow::Result<()> {
289        let file = std::fs::File::create(path)?;
290        let mut writer = ZipWriter::new(file);
291        let options: FileOptions<'_, ()> =
292            FileOptions::default().compression_method(zip::CompressionMethod::Stored);
293        writer.start_file("pack.manifest.json", options)?;
294        writer.write_all(manifest.to_string().as_bytes())?;
295        writer.finish()?;
296        Ok(())
297    }
298
299    fn persist_provider_answers(
300        bundle: &Path,
301        provider_id: &str,
302        answers: serde_json::Value,
303    ) -> anyhow::Result<()> {
304        let dir = bundle.join("state/config").join(provider_id);
305        std::fs::create_dir_all(&dir)?;
306        std::fs::write(dir.join("setup-answers.json"), answers.to_string())?;
307        Ok(())
308    }
309
310    fn signed_state(bundle: &Path, action_id: &str) -> anyhow::Result<String> {
311        let key = crate::setup_actions::load_or_create_signing_key(bundle)?;
312        let state_payload = crate::setup_actions::OAuthStatePayload {
313            provider_id: "messaging-example".into(),
314            tenant: "demo".into(),
315            team: "default".into(),
316            action_id: action_id.into(),
317            nonce: "nonce".into(),
318            expires_at: crate::setup_actions::current_epoch_secs() + 60,
319        };
320        crate::setup_actions::sign_oauth_state(&state_payload, &key)
321    }
322
323    fn persist_action(
324        bundle: &Path,
325        action_id: &str,
326        status: crate::setup_actions::SetupActionStatus,
327        callback_path: Option<&str>,
328    ) -> anyhow::Result<String> {
329        let state = signed_state(bundle, action_id)?;
330        let mut actions = crate::setup_actions::extract_setup_actions(
331            "messaging-example",
332            "demo",
333            Some("default"),
334            &json!({
335                "setup_actions": [{
336                    "id": action_id,
337                    "kind": "oauth_install_button",
338                    "label": "Add",
339                    "authorize_url": "https://example.com/auth",
340                    "callback_path": callback_path,
341                    "state": state,
342                    "status": status
343                }]
344            }),
345        )?;
346        actions[0].status = status;
347        crate::setup_actions::persist_setup_actions(bundle, &actions)?;
348        Ok(state)
349    }
350
351    #[tokio::test]
352    async fn callback_maps_token_to_secret_and_marks_action_complete() -> anyhow::Result<()> {
353        let temp = tempfile::tempdir()?;
354        let bundle = temp.path();
355        std::fs::create_dir_all(bundle.join("providers/messaging"))?;
356        write_provider_pack(&bundle.join("providers/messaging/messaging-example.gtpack"))?;
357
358        let state = persist_action(
359            bundle,
360            "install",
361            crate::setup_actions::SetupActionStatus::Pending,
362            None,
363        )?;
364
365        let report = complete_oauth_callback_with_token_response(
366            bundle,
367            "dev",
368            &OAuthCallbackInput {
369                code: "code".into(),
370                state,
371            },
372            &json!({"access_token": "token-value"}),
373            "messaging.oauth.v1",
374        )
375        .await?;
376
377        assert_eq!(report.persisted_secret_keys, vec!["EXAMPLE_TOKEN"]);
378        let action = crate::setup_actions::load_setup_action(
379            bundle,
380            "demo",
381            "default",
382            "messaging-example",
383            "install",
384        )?
385        .unwrap();
386        assert_eq!(
387            action.status,
388            crate::setup_actions::SetupActionStatus::Complete
389        );
390
391        let store = crate::secrets::open_dev_store(bundle)?;
392        let uri = crate::canonical_secret_uri(
393            "dev",
394            "demo",
395            Some("default"),
396            "messaging-example",
397            "EXAMPLE_TOKEN",
398        );
399        let bytes = store.get(&uri).await?;
400        assert_eq!(String::from_utf8(bytes)?, "token-value");
401        Ok(())
402    }
403
404    #[test]
405    fn load_provider_oauth_metadata_reports_missing_provider_and_extension() -> anyhow::Result<()> {
406        let temp = tempfile::tempdir()?;
407        let bundle = temp.path();
408        std::fs::create_dir_all(bundle.join("providers/messaging"))?;
409        write_provider_pack_with_manifest(
410            &bundle.join("providers/messaging/messaging-example.gtpack"),
411            json!({"pack_id": "messaging-example"}),
412        )?;
413
414        let missing_provider =
415            load_provider_oauth_metadata(bundle, "messaging-missing", "messaging.oauth.v1")
416                .unwrap_err()
417                .to_string();
418        assert!(missing_provider.contains("provider not found"));
419
420        let missing_extension =
421            load_provider_oauth_metadata(bundle, "messaging-example", "messaging.oauth.v1")
422                .unwrap_err()
423                .to_string();
424        assert!(missing_extension.contains("missing OAuth metadata"));
425        Ok(())
426    }
427
428    #[test]
429    fn load_provider_oauth_metadata_accepts_inline_extension_wrapper() -> anyhow::Result<()> {
430        let temp = tempfile::tempdir()?;
431        let bundle = temp.path();
432        std::fs::create_dir_all(bundle.join("providers/messaging"))?;
433        write_provider_pack_with_manifest(
434            &bundle.join("providers/messaging/messaging-example.gtpack"),
435            json!({
436                "pack_id": "messaging-example",
437                "extensions": {
438                    "messaging.oauth.v1": {
439                        "kind": "messaging.oauth.v1",
440                        "inline": {
441                            "token_url": "https://example.com/token",
442                            "secret_keys": ["EXAMPLE_TOKEN"]
443                        }
444                    }
445                }
446            }),
447        )?;
448
449        let metadata =
450            load_provider_oauth_metadata(bundle, "messaging-example", "messaging.oauth.v1")?;
451
452        assert_eq!(metadata.token_url, "https://example.com/token");
453        assert_eq!(metadata.secret_keys, vec!["EXAMPLE_TOKEN"]);
454        Ok(())
455    }
456
457    #[test]
458    fn resolve_public_base_url_prefers_provider_answer() -> anyhow::Result<()> {
459        let temp = tempfile::tempdir()?;
460        let bundle = temp.path();
461        persist_provider_answers(
462            bundle,
463            "messaging-example",
464            json!({"public_base_url": "https://provider.example.com"}),
465        )?;
466
467        let resolved =
468            resolve_public_base_url(bundle, "demo", Some("default"), "messaging-example")?;
469        assert_eq!(resolved, "https://provider.example.com");
470        Ok(())
471    }
472
473    #[test]
474    fn resolve_public_base_url_uses_static_routes_and_runtime_fallback() -> anyhow::Result<()> {
475        let temp = tempfile::tempdir()?;
476        let bundle = temp.path();
477        crate::platform_setup::persist_static_routes_artifact(
478            bundle,
479            &crate::platform_setup::StaticRoutesPolicy {
480                public_base_url: Some("https://static.example.com".into()),
481                ..crate::platform_setup::StaticRoutesPolicy::default()
482            },
483        )?;
484        let resolved =
485            resolve_public_base_url(bundle, "demo", Some("default"), "messaging-example")?;
486        assert_eq!(resolved, "https://static.example.com");
487
488        let temp = tempfile::tempdir()?;
489        let bundle = temp.path();
490        let runtime_dir = bundle.join("state/runtime/demo.default");
491        std::fs::create_dir_all(&runtime_dir)?;
492        std::fs::write(
493            runtime_dir.join("endpoints.json"),
494            json!({"public_base_url": "https://runtime.example.com"}).to_string(),
495        )?;
496        let resolved =
497            resolve_public_base_url(bundle, "demo", Some("default"), "messaging-example")?;
498        assert_eq!(resolved, "https://runtime.example.com");
499        Ok(())
500    }
501
502    #[test]
503    fn resolve_public_base_url_errors_when_missing() {
504        let temp = tempfile::tempdir().unwrap();
505        let err =
506            resolve_public_base_url(temp.path(), "demo", Some("default"), "messaging-example")
507                .unwrap_err()
508                .to_string();
509        assert!(err.contains("requires a public_base_url"));
510    }
511
512    #[test]
513    fn setup_answer_helpers_handle_missing_file_and_nonempty_aliases() -> anyhow::Result<()> {
514        let temp = tempfile::tempdir()?;
515        let empty = load_provider_setup_answers(temp.path(), "messaging-example")?;
516        assert_eq!(empty, Value::Object(JsonMap::new()));
517
518        let answers = json!({"client_id": "  ", "oauth_client_id": "client"});
519        assert_eq!(
520            first_nonempty(&answers, &["client_id", "oauth_client_id"]).as_deref(),
521            Some("client")
522        );
523        assert_eq!(ensure_leading_slash("oauth/callback"), "/oauth/callback");
524        assert_eq!(ensure_leading_slash("/oauth/callback"), "/oauth/callback");
525        Ok(())
526    }
527
528    #[tokio::test]
529    async fn callback_rejects_empty_code_missing_action_and_completed_action() -> anyhow::Result<()>
530    {
531        let temp = tempfile::tempdir()?;
532        let bundle = temp.path();
533        let state = signed_state(bundle, "missing")?;
534        let err = complete_oauth_callback_with_token_response(
535            bundle,
536            "dev",
537            &OAuthCallbackInput {
538                code: " ".into(),
539                state: state.clone(),
540            },
541            &json!({"access_token": "token"}),
542            "messaging.oauth.v1",
543        )
544        .await
545        .unwrap_err()
546        .to_string();
547        assert!(err.contains("missing code"));
548
549        let err = complete_oauth_callback_with_token_response(
550            bundle,
551            "dev",
552            &OAuthCallbackInput {
553                code: "code".into(),
554                state,
555            },
556            &json!({"access_token": "token"}),
557            "messaging.oauth.v1",
558        )
559        .await
560        .unwrap_err()
561        .to_string();
562        assert!(err.contains("setup action not found"));
563
564        std::fs::create_dir_all(bundle.join("providers/messaging"))?;
565        write_provider_pack(&bundle.join("providers/messaging/messaging-example.gtpack"))?;
566        let state = persist_action(
567            bundle,
568            "install",
569            crate::setup_actions::SetupActionStatus::Complete,
570            None,
571        )?;
572        let err = complete_oauth_callback_with_token_response(
573            bundle,
574            "dev",
575            &OAuthCallbackInput {
576                code: "code".into(),
577                state,
578            },
579            &json!({"access_token": "token"}),
580            "messaging.oauth.v1",
581        )
582        .await
583        .unwrap_err()
584        .to_string();
585        assert!(err.contains("not pending"));
586        Ok(())
587    }
588
589    #[tokio::test]
590    async fn live_callback_validates_before_network_exchange() -> anyhow::Result<()> {
591        let temp = tempfile::tempdir()?;
592        let bundle = temp.path();
593        std::fs::create_dir_all(bundle.join("providers/messaging"))?;
594        write_provider_pack(&bundle.join("providers/messaging/messaging-example.gtpack"))?;
595        let state = persist_action(
596            bundle,
597            "install",
598            crate::setup_actions::SetupActionStatus::Pending,
599            None,
600        )?;
601
602        let err = complete_oauth_callback(
603            bundle,
604            "dev",
605            &OAuthCallbackInput {
606                code: " ".into(),
607                state: state.clone(),
608            },
609            "messaging.oauth.v1",
610        )
611        .await
612        .unwrap_err()
613        .to_string();
614        assert!(err.contains("missing code"));
615
616        let err = complete_oauth_callback(
617            bundle,
618            "dev",
619            &OAuthCallbackInput {
620                code: "code".into(),
621                state,
622            },
623            "messaging.oauth.v1",
624        )
625        .await
626        .unwrap_err()
627        .to_string();
628        assert!(err.contains("callback path is missing"));
629        Ok(())
630    }
631}