Skip to main content

ai_agent/bridge/
trusted_device.rs

1//! Trusted device token source for bridge (remote-control) sessions.
2//!
3//! Translated from openclaudecode/src/bridge/trustedDevice.ts
4//!
5//! Bridge sessions have SecurityTier=ELEVATED on the server (CCR v2).
6//! The server gates ConnectBridgeWorker on its own flag, this CLI-side
7//! flag controls whether the CLI sends X-Trusted-Device-Token at all.
8
9use crate::constants::env::ai;
10use crate::utils::http::get_user_agent;
11use std::sync::{Arc, RwLock};
12
13use reqwest;
14
15// =============================================================================
16// CONSTANTS
17// =============================================================================
18
19const TRUSTED_DEVICE_GATE: &str = "tengu_sessions_elevated_auth_enforcement";
20const ENROLLMENT_TIMEOUT_MS: u64 = 10_000;
21
22// =============================================================================
23// STORAGE TRAIT (for dependency injection)
24// =============================================================================
25
26/// Trait for secure storage operations.
27pub trait SecureStorage: Send + Sync {
28    fn read(&self) -> Option<StorageData>;
29    fn update(&self, data: &StorageData) -> Result<(), String>;
30}
31
32#[derive(Clone, Default)]
33pub struct StorageData {
34    pub trusted_device_token: Option<String>,
35    pub device_id: Option<String>,
36    // Add other fields as needed
37}
38
39// =============================================================================
40// STATE
41// =============================================================================
42
43/// Gate check function type.
44pub type GateFn = Box<dyn Fn(&str) -> bool + Send + Sync>;
45
46/// Auth token getter function type.
47pub type AuthTokenGetterFn = Box<dyn Fn() -> Option<String> + Send + Sync>;
48
49/// Base URL getter function type.
50pub type BaseUrlGetterFn = Box<dyn Fn() -> String + Send + Sync>;
51
52static GATE_GETTER: std::sync::OnceLock<GateFn> = std::sync::OnceLock::new();
53static AUTH_TOKEN_GETTER: std::sync::OnceLock<AuthTokenGetterFn> = std::sync::OnceLock::new();
54static BASE_URL_GETTER: std::sync::OnceLock<BaseUrlGetterFn> = std::sync::OnceLock::new();
55static STORAGE: std::sync::OnceLock<Arc<dyn SecureStorage>> = std::sync::OnceLock::new();
56
57// Cached token storage
58static CACHED_TOKEN: std::sync::OnceLock<RwLock<Option<String>>> = std::sync::OnceLock::new();
59
60// =============================================================================
61// INITIALIZATION
62// =============================================================================
63
64/// Register the gate check function.
65pub fn register_gate_check(gate: impl Fn(&str) -> bool + Send + Sync + 'static) {
66    let _ = GATE_GETTER.set(Box::new(gate));
67}
68
69/// Register the auth token getter function.
70pub fn register_auth_token_getter(getter: impl Fn() -> Option<String> + Send + Sync + 'static) {
71    let _ = AUTH_TOKEN_GETTER.set(Box::new(getter));
72}
73
74/// Register the base URL getter function.
75pub fn register_base_url_getter(getter: impl Fn() -> String + Send + Sync + 'static) {
76    let _ = BASE_URL_GETTER.set(Box::new(getter));
77}
78
79/// Register the secure storage implementation.
80pub fn register_secure_storage(storage: Arc<dyn SecureStorage>) {
81    let _ = STORAGE.set(storage);
82}
83
84// =============================================================================
85// GATE CHECK
86// =============================================================================
87
88fn is_gate_enabled() -> bool {
89    GATE_GETTER
90        .get()
91        .map(|gate| gate(TRUSTED_DEVICE_GATE))
92        // Default to false if not set
93        .unwrap_or(false)
94}
95
96// =============================================================================
97// TOKEN READ/WRITE
98// =============================================================================
99
100/// Get the stored trusted device token.
101/// Uses env var override for testing/canary, falls back to secure storage.
102/// Memoized for performance.
103pub fn get_trusted_device_token() -> Option<String> {
104    // Check env var first
105    if let Ok(env_token) = std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN) {
106        if !env_token.is_empty() {
107            return Some(env_token);
108        }
109    }
110
111    if !is_gate_enabled() {
112        return None;
113    }
114
115    // Use cached token if available
116    if let Some(cached) = CACHED_TOKEN.get() {
117        if let Ok(token) = cached.read() {
118            return token.clone();
119        }
120    }
121
122    // Read from storage
123    let token = STORAGE
124        .get()
125        .and_then(|s| s.read())
126        .and_then(|data| data.trusted_device_token);
127
128    // Cache it
129    if let Some(ref t) = token {
130        if let Some(cache) = CACHED_TOKEN.get() {
131            if let Ok(mut guard) = cache.write() {
132                *guard = Some(t.clone());
133            }
134        }
135    }
136
137    token
138}
139
140/// Clear the cached trusted device token.
141pub fn clear_trusted_device_token_cache() {
142    if let Some(cache) = CACHED_TOKEN.get() {
143        if let Ok(mut guard) = cache.write() {
144            *guard = None;
145        }
146    }
147}
148
149/// Clear the stored trusted device token from secure storage and the cache.
150/// Called before enrollTrustedDevice during /login so a stale token from the
151/// previous account isn't sent while enrollment is in-flight.
152pub fn clear_trusted_device_token() {
153    if !is_gate_enabled() {
154        return;
155    }
156
157    if let Some(storage) = STORAGE.get() {
158        if let Some(mut data) = storage.read() {
159            data.trusted_device_token = None;
160            let _ = storage.update(&data);
161        }
162    }
163
164    clear_trusted_device_token_cache();
165}
166
167/// Enroll this device via POST /auth/trusted_devices and persist the token
168/// to storage. Best-effort — returns on failure so callers don't block.
169pub async fn enroll_trusted_device() {
170    // Check gate
171    if !is_gate_enabled() {
172        log_debug("[trusted-device] Gate is off, skipping enrollment");
173        return;
174    }
175
176    // Check env var override
177    if std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN).is_ok() {
178        log_debug(
179            "[trusted-device] CLAUDE_TRUSTED_DEVICE_TOKEN env var is set, skipping enrollment",
180        );
181        return;
182    }
183
184    // Get access token
185    let access_token = match AUTH_TOKEN_GETTER.get() {
186        Some(getter) => getter(),
187        None => {
188            log_debug("[trusted-device] No auth token getter registered, skipping enrollment");
189            return;
190        }
191    };
192
193    let access_token = match access_token {
194        Some(t) => t,
195        None => {
196            log_debug("[trusted-device] No OAuth token, skipping enrollment");
197            return;
198        }
199    };
200
201    // Get base URL
202    let base_url = match BASE_URL_GETTER.get() {
203        Some(getter) => getter(),
204        None => {
205            log_debug("[trusted-device] No base URL getter registered, skipping enrollment");
206            return;
207        }
208    };
209
210    let client = reqwest::Client::new();
211
212    let hostname = hostname::get()
213        .map(|h| h.to_string_lossy().into_owned())
214        .unwrap_or_else(|_| "unknown".to_string());
215
216    let platform = std::env::consts::OS;
217    let display_name = format!("Claude Code on {} · {}", hostname, platform);
218
219    match client
220        .post(&format!("{}/api/auth/trusted_devices", base_url))
221        .header("Authorization", format!("Bearer {}", access_token))
222        .header("Content-Type", "application/json")
223        .header("User-Agent", get_user_agent())
224        .timeout(std::time::Duration::from_millis(ENROLLMENT_TIMEOUT_MS))
225        .json(&serde_json::json!({ "display_name": display_name }))
226        .send()
227        .await
228    {
229        Ok(response) => {
230            if response.status() != 200 && response.status() != 201 {
231                log_debug(&format!(
232                    "[trusted-device] Enrollment failed {}",
233                    response.status()
234                ));
235                return;
236            }
237
238            // Parse response
239            match response.json::<serde_json::Value>().await {
240                Ok(data) => {
241                    let token = data.get("device_token").and_then(|v| v.as_str());
242                    let device_id = data.get("device_id").and_then(|v| v.as_str());
243
244                    match token {
245                        Some(token) => {
246                            // Persist to storage
247                            if let Some(storage) = STORAGE.get() {
248                                if let Some(mut data) = storage.read() {
249                                    data.trusted_device_token = Some(token.to_string());
250                                    if let Some(id) = device_id {
251                                        data.device_id = Some(id.to_string());
252                                    }
253                                    match storage.update(&data) {
254                                        Ok(_) => {
255                                            clear_trusted_device_token_cache();
256                                            log_debug(&format!(
257                                                "[trusted-device] Enrolled device_id={}",
258                                                device_id.unwrap_or("unknown")
259                                            ));
260                                        }
261                                        Err(e) => {
262                                            log_debug(&format!(
263                                                "[trusted-device] Storage write failed: {}",
264                                                e
265                                            ));
266                                        }
267                                    }
268                                }
269                            }
270                        }
271                        None => {
272                            log_debug(
273                                "[trusted-device] Enrollment response missing device_token field",
274                            );
275                        }
276                    }
277                }
278                Err(e) => {
279                    log_debug(&format!("[trusted-device] Failed to parse response: {}", e));
280                }
281            }
282        }
283        Err(e) => {
284            log_debug(&format!(
285                "[trusted-device] Enrollment request failed: {}",
286                e
287            ));
288        }
289    }
290}
291
292// =============================================================================
293// DEBUG LOGGING
294// =============================================================================
295
296fn log_debug(msg: &str) {
297    // Simple debug logging - could be replaced with proper logging
298    eprintln!("{}", msg);
299}
300
301// =============================================================================
302// TESTS
303// =============================================================================
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_token_returns_none_when_gate_disabled() {
311        // By default gate returns false, so token should be None
312        assert_eq!(get_trusted_device_token(), None);
313    }
314
315    #[test]
316    fn test_clear_token_cache() {
317        clear_trusted_device_token_cache();
318        // Should not panic
319    }
320}