1use std::collections::BTreeMap;
4use std::path::{Path, PathBuf};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use anyhow::{Context, Result, anyhow, bail};
8use base64::Engine;
9use base64::engine::general_purpose::URL_SAFE_NO_PAD;
10use hmac::{Hmac, KeyInit, Mac};
11use serde::{Deserialize, Serialize};
12use serde_json::{Map as JsonMap, Value};
13use sha2::Sha256;
14
15type HmacSha256 = Hmac<Sha256>;
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum SetupActionKind {
20 OauthInstallButton,
21 OauthDeviceCode,
22 OpenUrl,
23 CopySecret,
24 ManualStep,
25 DownloadFile,
26 AdminConsentButton,
27 #[serde(untagged)]
28 Other(String),
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SetupActionStatus {
34 Pending,
35 Complete,
36 Failed,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
40pub struct SetupAction {
41 pub id: String,
42 pub kind: SetupActionKind,
43 pub label: String,
44 pub provider_id: String,
45 pub tenant: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub team: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub authorize_url: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub callback_path: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub state: Option<String>,
54 pub status: SetupActionStatus,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub created_at: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub completed_at: Option<String>,
59 #[serde(flatten)]
60 pub extra: JsonMap<String, Value>,
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
64pub struct SetupActionStateFile {
65 pub provider_id: String,
66 pub tenant: String,
67 pub team: String,
68 pub actions: Vec<SetupAction>,
69}
70
71#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
72pub struct OAuthStatePayload {
73 pub provider_id: String,
74 pub tenant: String,
75 pub team: String,
76 pub action_id: String,
77 pub nonce: String,
78 pub expires_at: u64,
79}
80
81#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
82pub struct OAuthMetadata {
83 #[serde(default)]
84 pub auth_type: Option<String>,
85 #[serde(default)]
86 pub authorize_url: Option<String>,
87 pub token_url: String,
88 #[serde(default)]
89 pub redirect_path: Option<String>,
90 #[serde(default)]
91 pub scopes: Vec<String>,
92 #[serde(default)]
93 pub secret_keys: Vec<String>,
94 #[serde(default)]
95 pub response_secret_map: BTreeMap<String, String>,
96}
97
98pub fn extract_setup_actions(
99 provider_id: &str,
100 tenant: &str,
101 team: Option<&str>,
102 value: &Value,
103) -> Result<Vec<SetupAction>> {
104 let Some(actions) = value.get("setup_actions").and_then(Value::as_array) else {
105 return Ok(Vec::new());
106 };
107
108 actions
109 .iter()
110 .map(|raw| parse_setup_action(provider_id, tenant, team, raw))
111 .collect()
112}
113
114pub fn strip_setup_actions(value: &Value) -> Value {
115 let mut cloned = value.clone();
116 if let Some(obj) = cloned.as_object_mut() {
117 obj.remove("setup_actions");
118 obj.remove("pending_setup_actions");
119 }
120 cloned
121}
122
123pub fn persist_setup_actions(bundle_root: &Path, actions: &[SetupAction]) -> Result<Vec<PathBuf>> {
124 let mut grouped: BTreeMap<(String, String, String), Vec<SetupAction>> = BTreeMap::new();
125 for action in actions {
126 grouped
127 .entry((
128 action.provider_id.clone(),
129 action.tenant.clone(),
130 team_segment(action.team.as_deref()).to_string(),
131 ))
132 .or_default()
133 .push(action.clone());
134 }
135
136 let mut paths = Vec::new();
137 for ((provider_id, tenant, team), new_actions) in grouped {
138 let path = setup_actions_state_path(bundle_root, &tenant, &team, &provider_id);
139 let mut file = if path.exists() {
140 let raw = std::fs::read_to_string(&path)
141 .with_context(|| format!("failed to read {}", path.display()))?;
142 serde_json::from_str::<SetupActionStateFile>(&raw)
143 .with_context(|| format!("failed to parse {}", path.display()))?
144 } else {
145 SetupActionStateFile {
146 provider_id: provider_id.clone(),
147 tenant: tenant.clone(),
148 team: team.clone(),
149 actions: Vec::new(),
150 }
151 };
152
153 for mut action in new_actions {
154 if action.created_at.is_none() {
155 action.created_at = Some(now_stamp());
156 }
157 if let Some(existing) = file.actions.iter_mut().find(|a| a.id == action.id) {
158 let created_at = existing.created_at.clone().or(action.created_at.clone());
159 *existing = action;
160 existing.created_at = created_at;
161 } else {
162 file.actions.push(action);
163 }
164 }
165
166 if let Some(parent) = path.parent() {
167 std::fs::create_dir_all(parent)?;
168 }
169 let payload = serde_json::to_string_pretty(&file)?;
170 std::fs::write(&path, payload)
171 .with_context(|| format!("failed to write {}", path.display()))?;
172 paths.push(path);
173 }
174 Ok(paths)
175}
176
177pub fn sign_pending_oauth_actions(bundle_root: &Path, actions: &mut [SetupAction]) -> Result<()> {
178 let key = load_or_create_signing_key(bundle_root)?;
179 for action in actions {
180 if action.status != SetupActionStatus::Pending
181 || action.kind != SetupActionKind::OauthInstallButton
182 || action.state.is_some()
183 {
184 continue;
185 }
186 let team = team_segment(action.team.as_deref()).to_string();
187 let payload = OAuthStatePayload {
188 provider_id: action.provider_id.clone(),
189 tenant: action.tenant.clone(),
190 team,
191 action_id: action.id.clone(),
192 nonce: URL_SAFE_NO_PAD.encode(rand::random::<[u8; 16]>()),
193 expires_at: current_epoch_secs() + 15 * 60,
194 };
195 let state = sign_oauth_state(&payload, &key)?;
196 if let Some(authorize_url) = action.authorize_url.as_mut()
197 && !authorize_url_contains_state(authorize_url)
198 && let Ok(mut parsed) = url::Url::parse(authorize_url)
199 {
200 parsed.query_pairs_mut().append_pair("state", &state);
201 *authorize_url = parsed.to_string();
202 }
203 action.state = Some(state);
204 }
205 Ok(())
206}
207
208pub fn load_setup_action(
209 bundle_root: &Path,
210 tenant: &str,
211 team: &str,
212 provider_id: &str,
213 action_id: &str,
214) -> Result<Option<SetupAction>> {
215 let path = setup_actions_state_path(bundle_root, tenant, team, provider_id);
216 if !path.exists() {
217 return Ok(None);
218 }
219 let raw = std::fs::read_to_string(&path)
220 .with_context(|| format!("failed to read {}", path.display()))?;
221 let file: SetupActionStateFile = serde_json::from_str(&raw)
222 .with_context(|| format!("failed to parse {}", path.display()))?;
223 Ok(file.actions.into_iter().find(|a| a.id == action_id))
224}
225
226pub fn mark_setup_action_complete(
227 bundle_root: &Path,
228 tenant: &str,
229 team: &str,
230 provider_id: &str,
231 action_id: &str,
232) -> Result<()> {
233 let path = setup_actions_state_path(bundle_root, tenant, team, provider_id);
234 let raw = std::fs::read_to_string(&path)
235 .with_context(|| format!("failed to read {}", path.display()))?;
236 let mut file: SetupActionStateFile = serde_json::from_str(&raw)
237 .with_context(|| format!("failed to parse {}", path.display()))?;
238 let Some(action) = file.actions.iter_mut().find(|a| a.id == action_id) else {
239 bail!("setup action not found: {action_id}");
240 };
241 action.status = SetupActionStatus::Complete;
242 action.completed_at = Some(now_stamp());
243 let payload = serde_json::to_string_pretty(&file)?;
244 std::fs::write(&path, payload)
245 .with_context(|| format!("failed to write {}", path.display()))?;
246 Ok(())
247}
248
249pub fn setup_actions_state_path(
250 bundle_root: &Path,
251 tenant: &str,
252 team: &str,
253 provider_id: &str,
254) -> PathBuf {
255 bundle_root
256 .join("state")
257 .join("config")
258 .join("setup-actions")
259 .join(tenant)
260 .join(team_segment(Some(team)))
261 .join(format!("{provider_id}.json"))
262}
263
264pub fn signing_key_path(bundle_root: &Path) -> PathBuf {
265 bundle_root.join(".greentic").join("setup-oauth-state-key")
266}
267
268pub fn load_or_create_signing_key(bundle_root: &Path) -> Result<Vec<u8>> {
269 let path = signing_key_path(bundle_root);
270 if path.exists() {
271 let raw = std::fs::read_to_string(&path)
272 .with_context(|| format!("failed to read {}", path.display()))?;
273 return URL_SAFE_NO_PAD
274 .decode(raw.trim())
275 .context("failed to decode setup OAuth state signing key");
276 }
277 let bytes: [u8; 32] = rand::random();
278 if let Some(parent) = path.parent() {
279 std::fs::create_dir_all(parent)?;
280 }
281 std::fs::write(&path, URL_SAFE_NO_PAD.encode(bytes))
282 .with_context(|| format!("failed to write {}", path.display()))?;
283 Ok(bytes.to_vec())
284}
285
286pub fn sign_oauth_state(payload: &OAuthStatePayload, key: &[u8]) -> Result<String> {
287 let payload_json = serde_json::to_vec(payload)?;
288 let payload_b64 = URL_SAFE_NO_PAD.encode(payload_json);
289 let mut mac = HmacSha256::new_from_slice(key).context("invalid HMAC key")?;
290 mac.update(payload_b64.as_bytes());
291 let sig = mac.finalize().into_bytes();
292 Ok(format!("{payload_b64}.{}", URL_SAFE_NO_PAD.encode(sig)))
293}
294
295pub fn validate_oauth_state(
296 token: &str,
297 key: &[u8],
298 expected_provider_id: Option<&str>,
299 expected_tenant: Option<&str>,
300 expected_team: Option<&str>,
301 now_epoch: u64,
302) -> Result<OAuthStatePayload> {
303 let (payload_b64, sig_b64) = token
304 .split_once('.')
305 .ok_or_else(|| anyhow!("invalid OAuth state format"))?;
306 let sig = URL_SAFE_NO_PAD
307 .decode(sig_b64)
308 .context("invalid OAuth state signature encoding")?;
309 let mut mac = HmacSha256::new_from_slice(key).context("invalid HMAC key")?;
310 mac.update(payload_b64.as_bytes());
311 mac.verify_slice(&sig)
312 .map_err(|_| anyhow!("invalid OAuth state signature"))?;
313 let payload_bytes = URL_SAFE_NO_PAD
314 .decode(payload_b64)
315 .context("invalid OAuth state payload encoding")?;
316 let payload: OAuthStatePayload =
317 serde_json::from_slice(&payload_bytes).context("invalid OAuth state payload")?;
318 if payload.expires_at <= now_epoch {
319 bail!("OAuth state has expired");
320 }
321 if let Some(expected) = expected_provider_id
322 && payload.provider_id != expected
323 {
324 bail!("OAuth state provider mismatch");
325 }
326 if let Some(expected) = expected_tenant
327 && payload.tenant != expected
328 {
329 bail!("OAuth state tenant mismatch");
330 }
331 if let Some(expected) = expected_team
332 && payload.team != expected
333 {
334 bail!("OAuth state team mismatch");
335 }
336 Ok(payload)
337}
338
339pub fn current_epoch_secs() -> u64 {
340 SystemTime::now()
341 .duration_since(UNIX_EPOCH)
342 .unwrap_or_default()
343 .as_secs()
344}
345
346pub fn map_oauth_token_response(
347 metadata: &OAuthMetadata,
348 response: &Value,
349) -> Result<BTreeMap<String, String>> {
350 let mut mapped = BTreeMap::new();
351 for (secret_key, response_key) in &metadata.response_secret_map {
352 if let Some(value) = response.get(response_key).and_then(value_to_string) {
353 mapped.insert(secret_key.clone(), value);
354 }
355 }
356 if mapped.is_empty()
357 && let Some(token) = response.get("access_token").and_then(value_to_string)
358 {
359 for key in &metadata.secret_keys {
360 mapped.insert(key.clone(), token.clone());
361 }
362 }
363 if mapped.is_empty() {
364 bail!("OAuth token response did not contain mappable secrets");
365 }
366 Ok(mapped)
367}
368
369fn parse_setup_action(
370 provider_id: &str,
371 tenant: &str,
372 team: Option<&str>,
373 raw: &Value,
374) -> Result<SetupAction> {
375 let mut obj = raw
376 .as_object()
377 .cloned()
378 .ok_or_else(|| anyhow!("setup action must be an object"))?;
379 let id = take_string(&mut obj, "id").ok_or_else(|| anyhow!("setup action missing id"))?;
380 let kind = match take_string(&mut obj, "kind")
381 .ok_or_else(|| anyhow!("setup action missing kind"))?
382 .as_str()
383 {
384 "oauth_install_button" => SetupActionKind::OauthInstallButton,
385 "oauth_device_code" => SetupActionKind::OauthDeviceCode,
386 "open_url" => SetupActionKind::OpenUrl,
387 "copy_secret" => SetupActionKind::CopySecret,
388 "manual_step" => SetupActionKind::ManualStep,
389 "download_file" => SetupActionKind::DownloadFile,
390 "admin_consent_button" => SetupActionKind::AdminConsentButton,
391 other => SetupActionKind::Other(other.to_string()),
392 };
393 let label = take_string(&mut obj, "label").unwrap_or_else(|| id.clone());
394 let provider_id =
395 take_string(&mut obj, "provider_id").unwrap_or_else(|| provider_id.to_string());
396 let tenant = take_string(&mut obj, "tenant").unwrap_or_else(|| tenant.to_string());
397 let team = take_string(&mut obj, "team").or_else(|| team.map(ToString::to_string));
398 let status = match take_string(&mut obj, "status").as_deref() {
399 Some("complete") => SetupActionStatus::Complete,
400 Some("failed") => SetupActionStatus::Failed,
401 _ => SetupActionStatus::Pending,
402 };
403 Ok(SetupAction {
404 id,
405 kind,
406 label,
407 provider_id,
408 tenant,
409 team,
410 authorize_url: take_string(&mut obj, "authorize_url"),
411 callback_path: take_string(&mut obj, "callback_path"),
412 state: take_string(&mut obj, "state"),
413 status,
414 created_at: take_string(&mut obj, "created_at"),
415 completed_at: take_string(&mut obj, "completed_at"),
416 extra: obj,
417 })
418}
419
420fn take_string(obj: &mut JsonMap<String, Value>, key: &str) -> Option<String> {
421 obj.remove(key).and_then(|value| match value {
422 Value::String(text) if !text.trim().is_empty() => Some(text),
423 Value::Number(number) => Some(number.to_string()),
424 Value::Bool(value) => Some(value.to_string()),
425 _ => None,
426 })
427}
428
429fn team_segment(team: Option<&str>) -> &str {
430 team.map(str::trim)
431 .filter(|value| !value.is_empty())
432 .unwrap_or("default")
433}
434
435fn now_stamp() -> String {
436 current_epoch_secs().to_string()
437}
438
439fn value_to_string(value: &Value) -> Option<String> {
440 match value {
441 Value::String(text) if !text.is_empty() => Some(text.clone()),
442 Value::Number(number) => Some(number.to_string()),
443 Value::Bool(value) => Some(value.to_string()),
444 _ => None,
445 }
446}
447
448fn authorize_url_contains_state(value: &str) -> bool {
449 url::Url::parse(value)
450 .ok()
451 .and_then(|url| {
452 url.query_pairs()
453 .any(|(key, _)| key == "state")
454 .then_some(())
455 })
456 .is_some()
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use serde_json::json;
463
464 #[test]
465 fn extract_setup_actions_fills_scope_defaults() {
466 let value = json!({
467 "setup_actions": [{
468 "id": "install",
469 "kind": "oauth_install_button",
470 "label": "Add to Example",
471 "authorize_url": "https://example.com/auth"
472 }]
473 });
474 let actions =
475 extract_setup_actions("messaging-example", "demo", Some("default"), &value).unwrap();
476 assert_eq!(actions.len(), 1);
477 assert_eq!(actions[0].provider_id, "messaging-example");
478 assert_eq!(actions[0].tenant, "demo");
479 assert_eq!(actions[0].team.as_deref(), Some("default"));
480 }
481
482 #[test]
483 fn extract_setup_actions_supports_oauth_device_code() {
484 let value = json!({
485 "setup_actions": [{
486 "id": "connect",
487 "kind": "oauth_device_code",
488 "label": "Connect"
489 }]
490 });
491 let actions =
492 extract_setup_actions("messaging-teams", "demo", Some("default"), &value).unwrap();
493 assert_eq!(actions.len(), 1);
494 assert_eq!(actions[0].kind, SetupActionKind::OauthDeviceCode);
495 assert_eq!(actions[0].provider_id, "messaging-teams");
496 }
497
498 #[test]
499 fn persist_setup_actions_upserts_by_id() {
500 let temp = tempfile::tempdir().unwrap();
501 let mut action = SetupAction {
502 id: "install".into(),
503 kind: SetupActionKind::OauthInstallButton,
504 label: "Add".into(),
505 provider_id: "messaging-example".into(),
506 tenant: "demo".into(),
507 team: Some("default".into()),
508 authorize_url: Some("https://example.com/one".into()),
509 callback_path: None,
510 state: None,
511 status: SetupActionStatus::Pending,
512 created_at: None,
513 completed_at: None,
514 extra: JsonMap::new(),
515 };
516 persist_setup_actions(temp.path(), &[action.clone()]).unwrap();
517 action.authorize_url = Some("https://example.com/two".into());
518 persist_setup_actions(temp.path(), &[action]).unwrap();
519 let path = setup_actions_state_path(temp.path(), "demo", "default", "messaging-example");
520 let file: SetupActionStateFile =
521 serde_json::from_str(&std::fs::read_to_string(path).unwrap()).unwrap();
522 assert_eq!(file.actions.len(), 1);
523 assert_eq!(
524 file.actions[0].authorize_url.as_deref(),
525 Some("https://example.com/two")
526 );
527 }
528
529 #[test]
530 fn oauth_state_rejects_bad_signature_and_expiry() {
531 let key = b"test-key";
532 let payload = OAuthStatePayload {
533 provider_id: "messaging-example".into(),
534 tenant: "demo".into(),
535 team: "default".into(),
536 action_id: "install".into(),
537 nonce: "n".into(),
538 expires_at: 100,
539 };
540 let token = sign_oauth_state(&payload, key).unwrap();
541 assert!(validate_oauth_state(&token, key, None, None, None, 99).is_ok());
542 assert!(validate_oauth_state(&token, b"other", None, None, None, 99).is_err());
543 assert!(validate_oauth_state(&token, key, None, None, None, 100).is_err());
544 }
545
546 #[test]
547 fn sign_pending_oauth_actions_adds_state_to_action_and_url() {
548 let temp = tempfile::tempdir().unwrap();
549 let mut actions = vec![SetupAction {
550 id: "install".into(),
551 kind: SetupActionKind::OauthInstallButton,
552 label: "Add".into(),
553 provider_id: "messaging-example".into(),
554 tenant: "demo".into(),
555 team: Some("default".into()),
556 authorize_url: Some("https://example.com/oauth?client_id=abc".into()),
557 callback_path: Some("/oauth/callback/example".into()),
558 state: None,
559 status: SetupActionStatus::Pending,
560 created_at: None,
561 completed_at: None,
562 extra: JsonMap::new(),
563 }];
564 sign_pending_oauth_actions(temp.path(), &mut actions).unwrap();
565 let state = actions[0].state.as_deref().unwrap();
566 assert!(
567 actions[0]
568 .authorize_url
569 .as_deref()
570 .unwrap()
571 .contains("state=")
572 );
573 let key = load_or_create_signing_key(temp.path()).unwrap();
574 let payload =
575 validate_oauth_state(state, &key, Some("messaging-example"), None, None, 0).unwrap();
576 assert_eq!(payload.action_id, "install");
577 }
578
579 #[test]
580 fn token_response_maps_access_token_to_secret_keys() {
581 let metadata = OAuthMetadata {
582 token_url: "https://example.com/token".into(),
583 secret_keys: vec!["EXAMPLE_TOKEN".into()],
584 ..Default::default()
585 };
586 let mapped = map_oauth_token_response(&metadata, &json!({"access_token": "xoxb"})).unwrap();
587 assert_eq!(
588 mapped.get("EXAMPLE_TOKEN").map(String::as_str),
589 Some("xoxb")
590 );
591 }
592}