Skip to main content

ai_agent/bridge/
trusted_device.rs

1//! Trusted device token source for bridge (remote-control) sessions.
2//!
3//! Translated from openclaudecode/src/bridge/trustedDevice.ts
4//!
5//! Bridge sessions have SecurityTier=ELEVATED on the server (CCR v2).
6//! The server gates ConnectBridgeWorker on its own flag, this CLI-side
7//! flag controls whether the CLI sends X-Trusted-Device-Token at all.
8
9use crate::constants::env::ai;
10use std::sync::{Arc, RwLock};
11
12use reqwest;
13
14// =============================================================================
15// CONSTANTS
16// =============================================================================
17
18const TRUSTED_DEVICE_GATE: &str = "tengu_sessions_elevated_auth_enforcement";
19const ENROLLMENT_TIMEOUT_MS: u64 = 10_000;
20
21// =============================================================================
22// STORAGE TRAIT (for dependency injection)
23// =============================================================================
24
25/// Trait for secure storage operations.
26pub trait SecureStorage: Send + Sync {
27    fn read(&self) -> Option<StorageData>;
28    fn update(&self, data: &StorageData) -> Result<(), String>;
29}
30
31#[derive(Clone, Default)]
32pub struct StorageData {
33    pub trusted_device_token: Option<String>,
34    pub device_id: Option<String>,
35    // Add other fields as needed
36}
37
38// =============================================================================
39// STATE
40// =============================================================================
41
42/// Gate check function type.
43pub type GateFn = Box<dyn Fn(&str) -> bool + Send + Sync>;
44
45/// Auth token getter function type.
46pub type AuthTokenGetterFn = Box<dyn Fn() -> Option<String> + Send + Sync>;
47
48/// Base URL getter function type.
49pub type BaseUrlGetterFn = Box<dyn Fn() -> String + Send + Sync>;
50
51static GATE_GETTER: std::sync::OnceLock<GateFn> = std::sync::OnceLock::new();
52static AUTH_TOKEN_GETTER: std::sync::OnceLock<AuthTokenGetterFn> = std::sync::OnceLock::new();
53static BASE_URL_GETTER: std::sync::OnceLock<BaseUrlGetterFn> = std::sync::OnceLock::new();
54static STORAGE: std::sync::OnceLock<Arc<dyn SecureStorage>> = std::sync::OnceLock::new();
55
56// Cached token storage
57static CACHED_TOKEN: std::sync::OnceLock<RwLock<Option<String>>> = std::sync::OnceLock::new();
58
59// =============================================================================
60// INITIALIZATION
61// =============================================================================
62
63/// Register the gate check function.
64pub fn register_gate_check(gate: impl Fn(&str) -> bool + Send + Sync + 'static) {
65    let _ = GATE_GETTER.set(Box::new(gate));
66}
67
68/// Register the auth token getter function.
69pub fn register_auth_token_getter(getter: impl Fn() -> Option<String> + Send + Sync + 'static) {
70    let _ = AUTH_TOKEN_GETTER.set(Box::new(getter));
71}
72
73/// Register the base URL getter function.
74pub fn register_base_url_getter(getter: impl Fn() -> String + Send + Sync + 'static) {
75    let _ = BASE_URL_GETTER.set(Box::new(getter));
76}
77
78/// Register the secure storage implementation.
79pub fn register_secure_storage(storage: Arc<dyn SecureStorage>) {
80    let _ = STORAGE.set(storage);
81}
82
83// =============================================================================
84// GATE CHECK
85// =============================================================================
86
87fn is_gate_enabled() -> bool {
88    GATE_GETTER
89        .get()
90        .map(|gate| gate(TRUSTED_DEVICE_GATE))
91        // Default to false if not set
92        .unwrap_or(false)
93}
94
95// =============================================================================
96// TOKEN READ/WRITE
97// =============================================================================
98
99/// Get the stored trusted device token.
100/// Uses env var override for testing/canary, falls back to secure storage.
101/// Memoized for performance.
102pub fn get_trusted_device_token() -> Option<String> {
103    // Check env var first
104    if let Ok(env_token) = std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN) {
105        if !env_token.is_empty() {
106            return Some(env_token);
107        }
108    }
109
110    if !is_gate_enabled() {
111        return None;
112    }
113
114    // Use cached token if available
115    if let Some(cached) = CACHED_TOKEN.get() {
116        if let Ok(token) = cached.read() {
117            return token.clone();
118        }
119    }
120
121    // Read from storage
122    let token = STORAGE
123        .get()
124        .and_then(|s| s.read())
125        .and_then(|data| data.trusted_device_token);
126
127    // Cache it
128    if let Some(ref t) = token {
129        if let Some(cache) = CACHED_TOKEN.get() {
130            if let Ok(mut guard) = cache.write() {
131                *guard = Some(t.clone());
132            }
133        }
134    }
135
136    token
137}
138
139/// Clear the cached trusted device token.
140pub fn clear_trusted_device_token_cache() {
141    if let Some(cache) = CACHED_TOKEN.get() {
142        if let Ok(mut guard) = cache.write() {
143            *guard = None;
144        }
145    }
146}
147
148/// Clear the stored trusted device token from secure storage and the cache.
149/// Called before enrollTrustedDevice during /login so a stale token from the
150/// previous account isn't sent while enrollment is in-flight.
151pub fn clear_trusted_device_token() {
152    if !is_gate_enabled() {
153        return;
154    }
155
156    if let Some(storage) = STORAGE.get() {
157        if let Some(mut data) = storage.read() {
158            data.trusted_device_token = None;
159            let _ = storage.update(&data);
160        }
161    }
162
163    clear_trusted_device_token_cache();
164}
165
166/// Enroll this device via POST /auth/trusted_devices and persist the token
167/// to storage. Best-effort — returns on failure so callers don't block.
168pub async fn enroll_trusted_device() {
169    // Check gate
170    if !is_gate_enabled() {
171        log_debug("[trusted-device] Gate is off, skipping enrollment");
172        return;
173    }
174
175    // Check env var override
176    if std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN).is_ok() {
177        log_debug(
178            "[trusted-device] CLAUDE_TRUSTED_DEVICE_TOKEN env var is set, skipping enrollment",
179        );
180        return;
181    }
182
183    // Get access token
184    let access_token = match AUTH_TOKEN_GETTER.get() {
185        Some(getter) => getter(),
186        None => {
187            log_debug("[trusted-device] No auth token getter registered, skipping enrollment");
188            return;
189        }
190    };
191
192    let access_token = match access_token {
193        Some(t) => t,
194        None => {
195            log_debug("[trusted-device] No OAuth token, skipping enrollment");
196            return;
197        }
198    };
199
200    // Get base URL
201    let base_url = match BASE_URL_GETTER.get() {
202        Some(getter) => getter(),
203        None => {
204            log_debug("[trusted-device] No base URL getter registered, skipping enrollment");
205            return;
206        }
207    };
208
209    let client = reqwest::Client::new();
210
211    let hostname = hostname::get()
212        .map(|h| h.to_string_lossy().into_owned())
213        .unwrap_or_else(|_| "unknown".to_string());
214
215    let platform = std::env::consts::OS;
216    let display_name = format!("Claude Code on {} · {}", hostname, platform);
217
218    match client
219        .post(&format!("{}/api/auth/trusted_devices", base_url))
220        .header("Authorization", format!("Bearer {}", access_token))
221        .header("Content-Type", "application/json")
222        .timeout(std::time::Duration::from_millis(ENROLLMENT_TIMEOUT_MS))
223        .json(&serde_json::json!({ "display_name": display_name }))
224        .send()
225        .await
226    {
227        Ok(response) => {
228            if response.status() != 200 && response.status() != 201 {
229                log_debug(&format!(
230                    "[trusted-device] Enrollment failed {}",
231                    response.status()
232                ));
233                return;
234            }
235
236            // Parse response
237            match response.json::<serde_json::Value>().await {
238                Ok(data) => {
239                    let token = data.get("device_token").and_then(|v| v.as_str());
240                    let device_id = data.get("device_id").and_then(|v| v.as_str());
241
242                    match token {
243                        Some(token) => {
244                            // Persist to storage
245                            if let Some(storage) = STORAGE.get() {
246                                if let Some(mut data) = storage.read() {
247                                    data.trusted_device_token = Some(token.to_string());
248                                    if let Some(id) = device_id {
249                                        data.device_id = Some(id.to_string());
250                                    }
251                                    match storage.update(&data) {
252                                        Ok(_) => {
253                                            clear_trusted_device_token_cache();
254                                            log_debug(&format!(
255                                                "[trusted-device] Enrolled device_id={}",
256                                                device_id.unwrap_or("unknown")
257                                            ));
258                                        }
259                                        Err(e) => {
260                                            log_debug(&format!(
261                                                "[trusted-device] Storage write failed: {}",
262                                                e
263                                            ));
264                                        }
265                                    }
266                                }
267                            }
268                        }
269                        None => {
270                            log_debug(
271                                "[trusted-device] Enrollment response missing device_token field",
272                            );
273                        }
274                    }
275                }
276                Err(e) => {
277                    log_debug(&format!("[trusted-device] Failed to parse response: {}", e));
278                }
279            }
280        }
281        Err(e) => {
282            log_debug(&format!(
283                "[trusted-device] Enrollment request failed: {}",
284                e
285            ));
286        }
287    }
288}
289
290// =============================================================================
291// DEBUG LOGGING
292// =============================================================================
293
294fn log_debug(msg: &str) {
295    // Simple debug logging - could be replaced with proper logging
296    eprintln!("{}", msg);
297}
298
299// =============================================================================
300// TESTS
301// =============================================================================
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_token_returns_none_when_gate_disabled() {
309        // By default gate returns false, so token should be None
310        assert_eq!(get_trusted_device_token(), None);
311    }
312
313    #[test]
314    fn test_clear_token_cache() {
315        clear_trusted_device_token_cache();
316        // Should not panic
317    }
318}