Skip to main content

greentic_setup/
setup_actions.rs

1//! Provider-agnostic setup actions and OAuth setup helpers.
2
3use std::collections::BTreeMap;
4use std::path::{Path, PathBuf};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use anyhow::{Context, Result, anyhow, bail};
8use base64::Engine;
9use base64::engine::general_purpose::URL_SAFE_NO_PAD;
10use hmac::{Hmac, KeyInit, Mac};
11use serde::{Deserialize, Serialize};
12use serde_json::{Map as JsonMap, Value};
13use sha2::Sha256;
14
15type HmacSha256 = Hmac<Sha256>;
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum SetupActionKind {
20    OauthInstallButton,
21    OauthDeviceCode,
22    OpenUrl,
23    CopySecret,
24    ManualStep,
25    DownloadFile,
26    AdminConsentButton,
27    #[serde(untagged)]
28    Other(String),
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SetupActionStatus {
34    Pending,
35    Complete,
36    Failed,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
40pub struct SetupAction {
41    pub id: String,
42    pub kind: SetupActionKind,
43    pub label: String,
44    pub provider_id: String,
45    pub tenant: String,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub team: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub authorize_url: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub callback_path: Option<String>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub state: Option<String>,
54    pub status: SetupActionStatus,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub created_at: Option<String>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub completed_at: Option<String>,
59    #[serde(flatten)]
60    pub extra: JsonMap<String, Value>,
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
64pub struct SetupActionStateFile {
65    pub provider_id: String,
66    pub tenant: String,
67    pub team: String,
68    pub actions: Vec<SetupAction>,
69}
70
71#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
72pub struct OAuthStatePayload {
73    pub provider_id: String,
74    pub tenant: String,
75    pub team: String,
76    pub action_id: String,
77    pub nonce: String,
78    pub expires_at: u64,
79}
80
81#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
82pub struct OAuthMetadata {
83    #[serde(default)]
84    pub auth_type: Option<String>,
85    #[serde(default)]
86    pub authorize_url: Option<String>,
87    pub token_url: String,
88    #[serde(default)]
89    pub redirect_path: Option<String>,
90    #[serde(default)]
91    pub scopes: Vec<String>,
92    #[serde(default)]
93    pub secret_keys: Vec<String>,
94    #[serde(default)]
95    pub response_secret_map: BTreeMap<String, String>,
96}
97
98pub fn extract_setup_actions(
99    provider_id: &str,
100    tenant: &str,
101    team: Option<&str>,
102    value: &Value,
103) -> Result<Vec<SetupAction>> {
104    let Some(actions) = value.get("setup_actions").and_then(Value::as_array) else {
105        return Ok(Vec::new());
106    };
107
108    actions
109        .iter()
110        .map(|raw| parse_setup_action(provider_id, tenant, team, raw))
111        .collect()
112}
113
114pub fn strip_setup_actions(value: &Value) -> Value {
115    let mut cloned = value.clone();
116    if let Some(obj) = cloned.as_object_mut() {
117        obj.remove("setup_actions");
118        obj.remove("pending_setup_actions");
119    }
120    cloned
121}
122
123pub fn persist_setup_actions(bundle_root: &Path, actions: &[SetupAction]) -> Result<Vec<PathBuf>> {
124    let mut grouped: BTreeMap<(String, String, String), Vec<SetupAction>> = BTreeMap::new();
125    for action in actions {
126        grouped
127            .entry((
128                action.provider_id.clone(),
129                action.tenant.clone(),
130                team_segment(action.team.as_deref()).to_string(),
131            ))
132            .or_default()
133            .push(action.clone());
134    }
135
136    let mut paths = Vec::new();
137    for ((provider_id, tenant, team), new_actions) in grouped {
138        let path = setup_actions_state_path(bundle_root, &tenant, &team, &provider_id);
139        let mut file = if path.exists() {
140            let raw = std::fs::read_to_string(&path)
141                .with_context(|| format!("failed to read {}", path.display()))?;
142            serde_json::from_str::<SetupActionStateFile>(&raw)
143                .with_context(|| format!("failed to parse {}", path.display()))?
144        } else {
145            SetupActionStateFile {
146                provider_id: provider_id.clone(),
147                tenant: tenant.clone(),
148                team: team.clone(),
149                actions: Vec::new(),
150            }
151        };
152
153        for mut action in new_actions {
154            if action.created_at.is_none() {
155                action.created_at = Some(now_stamp());
156            }
157            if let Some(existing) = file.actions.iter_mut().find(|a| a.id == action.id) {
158                let created_at = existing.created_at.clone().or(action.created_at.clone());
159                *existing = action;
160                existing.created_at = created_at;
161            } else {
162                file.actions.push(action);
163            }
164        }
165
166        if let Some(parent) = path.parent() {
167            std::fs::create_dir_all(parent)?;
168        }
169        let payload = serde_json::to_string_pretty(&file)?;
170        std::fs::write(&path, payload)
171            .with_context(|| format!("failed to write {}", path.display()))?;
172        paths.push(path);
173    }
174    Ok(paths)
175}
176
177pub fn sign_pending_oauth_actions(bundle_root: &Path, actions: &mut [SetupAction]) -> Result<()> {
178    let key = load_or_create_signing_key(bundle_root)?;
179    for action in actions {
180        if action.status != SetupActionStatus::Pending
181            || action.kind != SetupActionKind::OauthInstallButton
182            || action.state.is_some()
183        {
184            continue;
185        }
186        let team = team_segment(action.team.as_deref()).to_string();
187        let payload = OAuthStatePayload {
188            provider_id: action.provider_id.clone(),
189            tenant: action.tenant.clone(),
190            team,
191            action_id: action.id.clone(),
192            nonce: URL_SAFE_NO_PAD.encode(rand::random::<[u8; 16]>()),
193            expires_at: current_epoch_secs() + 15 * 60,
194        };
195        let state = sign_oauth_state(&payload, &key)?;
196        if let Some(authorize_url) = action.authorize_url.as_mut()
197            && !authorize_url_contains_state(authorize_url)
198            && let Ok(mut parsed) = url::Url::parse(authorize_url)
199        {
200            parsed.query_pairs_mut().append_pair("state", &state);
201            *authorize_url = parsed.to_string();
202        }
203        action.state = Some(state);
204    }
205    Ok(())
206}
207
208pub fn load_setup_action(
209    bundle_root: &Path,
210    tenant: &str,
211    team: &str,
212    provider_id: &str,
213    action_id: &str,
214) -> Result<Option<SetupAction>> {
215    let path = setup_actions_state_path(bundle_root, tenant, team, provider_id);
216    if !path.exists() {
217        return Ok(None);
218    }
219    let raw = std::fs::read_to_string(&path)
220        .with_context(|| format!("failed to read {}", path.display()))?;
221    let file: SetupActionStateFile = serde_json::from_str(&raw)
222        .with_context(|| format!("failed to parse {}", path.display()))?;
223    Ok(file.actions.into_iter().find(|a| a.id == action_id))
224}
225
226pub fn mark_setup_action_complete(
227    bundle_root: &Path,
228    tenant: &str,
229    team: &str,
230    provider_id: &str,
231    action_id: &str,
232) -> Result<()> {
233    let path = setup_actions_state_path(bundle_root, tenant, team, provider_id);
234    let raw = std::fs::read_to_string(&path)
235        .with_context(|| format!("failed to read {}", path.display()))?;
236    let mut file: SetupActionStateFile = serde_json::from_str(&raw)
237        .with_context(|| format!("failed to parse {}", path.display()))?;
238    let Some(action) = file.actions.iter_mut().find(|a| a.id == action_id) else {
239        bail!("setup action not found: {action_id}");
240    };
241    action.status = SetupActionStatus::Complete;
242    action.completed_at = Some(now_stamp());
243    let payload = serde_json::to_string_pretty(&file)?;
244    std::fs::write(&path, payload)
245        .with_context(|| format!("failed to write {}", path.display()))?;
246    Ok(())
247}
248
249pub fn setup_actions_state_path(
250    bundle_root: &Path,
251    tenant: &str,
252    team: &str,
253    provider_id: &str,
254) -> PathBuf {
255    bundle_root
256        .join("state")
257        .join("config")
258        .join("setup-actions")
259        .join(tenant)
260        .join(team_segment(Some(team)))
261        .join(format!("{provider_id}.json"))
262}
263
264pub fn signing_key_path(bundle_root: &Path) -> PathBuf {
265    bundle_root.join(".greentic").join("setup-oauth-state-key")
266}
267
268pub fn load_or_create_signing_key(bundle_root: &Path) -> Result<Vec<u8>> {
269    let path = signing_key_path(bundle_root);
270    if path.exists() {
271        let raw = std::fs::read_to_string(&path)
272            .with_context(|| format!("failed to read {}", path.display()))?;
273        return URL_SAFE_NO_PAD
274            .decode(raw.trim())
275            .context("failed to decode setup OAuth state signing key");
276    }
277    let bytes: [u8; 32] = rand::random();
278    if let Some(parent) = path.parent() {
279        std::fs::create_dir_all(parent)?;
280    }
281    std::fs::write(&path, URL_SAFE_NO_PAD.encode(bytes))
282        .with_context(|| format!("failed to write {}", path.display()))?;
283    Ok(bytes.to_vec())
284}
285
286pub fn sign_oauth_state(payload: &OAuthStatePayload, key: &[u8]) -> Result<String> {
287    let payload_json = serde_json::to_vec(payload)?;
288    let payload_b64 = URL_SAFE_NO_PAD.encode(payload_json);
289    let mut mac = HmacSha256::new_from_slice(key).context("invalid HMAC key")?;
290    mac.update(payload_b64.as_bytes());
291    let sig = mac.finalize().into_bytes();
292    Ok(format!("{payload_b64}.{}", URL_SAFE_NO_PAD.encode(sig)))
293}
294
295pub fn validate_oauth_state(
296    token: &str,
297    key: &[u8],
298    expected_provider_id: Option<&str>,
299    expected_tenant: Option<&str>,
300    expected_team: Option<&str>,
301    now_epoch: u64,
302) -> Result<OAuthStatePayload> {
303    let (payload_b64, sig_b64) = token
304        .split_once('.')
305        .ok_or_else(|| anyhow!("invalid OAuth state format"))?;
306    let sig = URL_SAFE_NO_PAD
307        .decode(sig_b64)
308        .context("invalid OAuth state signature encoding")?;
309    let mut mac = HmacSha256::new_from_slice(key).context("invalid HMAC key")?;
310    mac.update(payload_b64.as_bytes());
311    mac.verify_slice(&sig)
312        .map_err(|_| anyhow!("invalid OAuth state signature"))?;
313    let payload_bytes = URL_SAFE_NO_PAD
314        .decode(payload_b64)
315        .context("invalid OAuth state payload encoding")?;
316    let payload: OAuthStatePayload =
317        serde_json::from_slice(&payload_bytes).context("invalid OAuth state payload")?;
318    if payload.expires_at <= now_epoch {
319        bail!("OAuth state has expired");
320    }
321    if let Some(expected) = expected_provider_id
322        && payload.provider_id != expected
323    {
324        bail!("OAuth state provider mismatch");
325    }
326    if let Some(expected) = expected_tenant
327        && payload.tenant != expected
328    {
329        bail!("OAuth state tenant mismatch");
330    }
331    if let Some(expected) = expected_team
332        && payload.team != expected
333    {
334        bail!("OAuth state team mismatch");
335    }
336    Ok(payload)
337}
338
339pub fn current_epoch_secs() -> u64 {
340    SystemTime::now()
341        .duration_since(UNIX_EPOCH)
342        .unwrap_or_default()
343        .as_secs()
344}
345
346pub fn map_oauth_token_response(
347    metadata: &OAuthMetadata,
348    response: &Value,
349) -> Result<BTreeMap<String, String>> {
350    let mut mapped = BTreeMap::new();
351    for (secret_key, response_key) in &metadata.response_secret_map {
352        if let Some(value) = response.get(response_key).and_then(value_to_string) {
353            mapped.insert(secret_key.clone(), value);
354        }
355    }
356    if mapped.is_empty()
357        && let Some(token) = response.get("access_token").and_then(value_to_string)
358    {
359        for key in &metadata.secret_keys {
360            mapped.insert(key.clone(), token.clone());
361        }
362    }
363    if mapped.is_empty() {
364        bail!("OAuth token response did not contain mappable secrets");
365    }
366    Ok(mapped)
367}
368
369fn parse_setup_action(
370    provider_id: &str,
371    tenant: &str,
372    team: Option<&str>,
373    raw: &Value,
374) -> Result<SetupAction> {
375    let mut obj = raw
376        .as_object()
377        .cloned()
378        .ok_or_else(|| anyhow!("setup action must be an object"))?;
379    let id = take_string(&mut obj, "id").ok_or_else(|| anyhow!("setup action missing id"))?;
380    let kind = match take_string(&mut obj, "kind")
381        .ok_or_else(|| anyhow!("setup action missing kind"))?
382        .as_str()
383    {
384        "oauth_install_button" => SetupActionKind::OauthInstallButton,
385        "oauth_device_code" => SetupActionKind::OauthDeviceCode,
386        "open_url" => SetupActionKind::OpenUrl,
387        "copy_secret" => SetupActionKind::CopySecret,
388        "manual_step" => SetupActionKind::ManualStep,
389        "download_file" => SetupActionKind::DownloadFile,
390        "admin_consent_button" => SetupActionKind::AdminConsentButton,
391        other => SetupActionKind::Other(other.to_string()),
392    };
393    let label = take_string(&mut obj, "label").unwrap_or_else(|| id.clone());
394    let provider_id =
395        take_string(&mut obj, "provider_id").unwrap_or_else(|| provider_id.to_string());
396    let tenant = take_string(&mut obj, "tenant").unwrap_or_else(|| tenant.to_string());
397    let team = take_string(&mut obj, "team").or_else(|| team.map(ToString::to_string));
398    let status = match take_string(&mut obj, "status").as_deref() {
399        Some("complete") => SetupActionStatus::Complete,
400        Some("failed") => SetupActionStatus::Failed,
401        _ => SetupActionStatus::Pending,
402    };
403    Ok(SetupAction {
404        id,
405        kind,
406        label,
407        provider_id,
408        tenant,
409        team,
410        authorize_url: take_string(&mut obj, "authorize_url"),
411        callback_path: take_string(&mut obj, "callback_path"),
412        state: take_string(&mut obj, "state"),
413        status,
414        created_at: take_string(&mut obj, "created_at"),
415        completed_at: take_string(&mut obj, "completed_at"),
416        extra: obj,
417    })
418}
419
420fn take_string(obj: &mut JsonMap<String, Value>, key: &str) -> Option<String> {
421    obj.remove(key).and_then(|value| match value {
422        Value::String(text) if !text.trim().is_empty() => Some(text),
423        Value::Number(number) => Some(number.to_string()),
424        Value::Bool(value) => Some(value.to_string()),
425        _ => None,
426    })
427}
428
429fn team_segment(team: Option<&str>) -> &str {
430    team.map(str::trim)
431        .filter(|value| !value.is_empty())
432        .unwrap_or("default")
433}
434
435fn now_stamp() -> String {
436    current_epoch_secs().to_string()
437}
438
439fn value_to_string(value: &Value) -> Option<String> {
440    match value {
441        Value::String(text) if !text.is_empty() => Some(text.clone()),
442        Value::Number(number) => Some(number.to_string()),
443        Value::Bool(value) => Some(value.to_string()),
444        _ => None,
445    }
446}
447
448fn authorize_url_contains_state(value: &str) -> bool {
449    url::Url::parse(value)
450        .ok()
451        .and_then(|url| {
452            url.query_pairs()
453                .any(|(key, _)| key == "state")
454                .then_some(())
455        })
456        .is_some()
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use serde_json::json;
463
464    #[test]
465    fn extract_setup_actions_fills_scope_defaults() {
466        let value = json!({
467            "setup_actions": [{
468                "id": "install",
469                "kind": "oauth_install_button",
470                "label": "Add to Example",
471                "authorize_url": "https://example.com/auth"
472            }]
473        });
474        let actions =
475            extract_setup_actions("messaging-example", "demo", Some("default"), &value).unwrap();
476        assert_eq!(actions.len(), 1);
477        assert_eq!(actions[0].provider_id, "messaging-example");
478        assert_eq!(actions[0].tenant, "demo");
479        assert_eq!(actions[0].team.as_deref(), Some("default"));
480    }
481
482    #[test]
483    fn extract_setup_actions_supports_oauth_device_code() {
484        let value = json!({
485            "setup_actions": [{
486                "id": "connect",
487                "kind": "oauth_device_code",
488                "label": "Connect"
489            }]
490        });
491        let actions =
492            extract_setup_actions("messaging-teams", "demo", Some("default"), &value).unwrap();
493        assert_eq!(actions.len(), 1);
494        assert_eq!(actions[0].kind, SetupActionKind::OauthDeviceCode);
495        assert_eq!(actions[0].provider_id, "messaging-teams");
496    }
497
498    #[test]
499    fn persist_setup_actions_upserts_by_id() {
500        let temp = tempfile::tempdir().unwrap();
501        let mut action = SetupAction {
502            id: "install".into(),
503            kind: SetupActionKind::OauthInstallButton,
504            label: "Add".into(),
505            provider_id: "messaging-example".into(),
506            tenant: "demo".into(),
507            team: Some("default".into()),
508            authorize_url: Some("https://example.com/one".into()),
509            callback_path: None,
510            state: None,
511            status: SetupActionStatus::Pending,
512            created_at: None,
513            completed_at: None,
514            extra: JsonMap::new(),
515        };
516        persist_setup_actions(temp.path(), &[action.clone()]).unwrap();
517        action.authorize_url = Some("https://example.com/two".into());
518        persist_setup_actions(temp.path(), &[action]).unwrap();
519        let path = setup_actions_state_path(temp.path(), "demo", "default", "messaging-example");
520        let file: SetupActionStateFile =
521            serde_json::from_str(&std::fs::read_to_string(path).unwrap()).unwrap();
522        assert_eq!(file.actions.len(), 1);
523        assert_eq!(
524            file.actions[0].authorize_url.as_deref(),
525            Some("https://example.com/two")
526        );
527    }
528
529    #[test]
530    fn oauth_state_rejects_bad_signature_and_expiry() {
531        let key = b"test-key";
532        let payload = OAuthStatePayload {
533            provider_id: "messaging-example".into(),
534            tenant: "demo".into(),
535            team: "default".into(),
536            action_id: "install".into(),
537            nonce: "n".into(),
538            expires_at: 100,
539        };
540        let token = sign_oauth_state(&payload, key).unwrap();
541        assert!(validate_oauth_state(&token, key, None, None, None, 99).is_ok());
542        assert!(validate_oauth_state(&token, b"other", None, None, None, 99).is_err());
543        assert!(validate_oauth_state(&token, key, None, None, None, 100).is_err());
544    }
545
546    #[test]
547    fn sign_pending_oauth_actions_adds_state_to_action_and_url() {
548        let temp = tempfile::tempdir().unwrap();
549        let mut actions = vec![SetupAction {
550            id: "install".into(),
551            kind: SetupActionKind::OauthInstallButton,
552            label: "Add".into(),
553            provider_id: "messaging-example".into(),
554            tenant: "demo".into(),
555            team: Some("default".into()),
556            authorize_url: Some("https://example.com/oauth?client_id=abc".into()),
557            callback_path: Some("/oauth/callback/example".into()),
558            state: None,
559            status: SetupActionStatus::Pending,
560            created_at: None,
561            completed_at: None,
562            extra: JsonMap::new(),
563        }];
564        sign_pending_oauth_actions(temp.path(), &mut actions).unwrap();
565        let state = actions[0].state.as_deref().unwrap();
566        assert!(
567            actions[0]
568                .authorize_url
569                .as_deref()
570                .unwrap()
571                .contains("state=")
572        );
573        let key = load_or_create_signing_key(temp.path()).unwrap();
574        let payload =
575            validate_oauth_state(state, &key, Some("messaging-example"), None, None, 0).unwrap();
576        assert_eq!(payload.action_id, "install");
577    }
578
579    #[test]
580    fn token_response_maps_access_token_to_secret_keys() {
581        let metadata = OAuthMetadata {
582            token_url: "https://example.com/token".into(),
583            secret_keys: vec!["EXAMPLE_TOKEN".into()],
584            ..Default::default()
585        };
586        let mapped = map_oauth_token_response(&metadata, &json!({"access_token": "xoxb"})).unwrap();
587        assert_eq!(
588            mapped.get("EXAMPLE_TOKEN").map(String::as_str),
589            Some("xoxb")
590        );
591    }
592}