ai_agent/bridge/
trusted_device.rs1use crate::constants::env::ai;
10use std::sync::{Arc, RwLock};
11
12use reqwest;
13
14const TRUSTED_DEVICE_GATE: &str = "tengu_sessions_elevated_auth_enforcement";
19const ENROLLMENT_TIMEOUT_MS: u64 = 10_000;
20
21pub trait SecureStorage: Send + Sync {
27 fn read(&self) -> Option<StorageData>;
28 fn update(&self, data: &StorageData) -> Result<(), String>;
29}
30
31#[derive(Clone, Default)]
32pub struct StorageData {
33 pub trusted_device_token: Option<String>,
34 pub device_id: Option<String>,
35 }
37
38pub type GateFn = Box<dyn Fn(&str) -> bool + Send + Sync>;
44
45pub type AuthTokenGetterFn = Box<dyn Fn() -> Option<String> + Send + Sync>;
47
48pub type BaseUrlGetterFn = Box<dyn Fn() -> String + Send + Sync>;
50
51static GATE_GETTER: std::sync::OnceLock<GateFn> = std::sync::OnceLock::new();
52static AUTH_TOKEN_GETTER: std::sync::OnceLock<AuthTokenGetterFn> = std::sync::OnceLock::new();
53static BASE_URL_GETTER: std::sync::OnceLock<BaseUrlGetterFn> = std::sync::OnceLock::new();
54static STORAGE: std::sync::OnceLock<Arc<dyn SecureStorage>> = std::sync::OnceLock::new();
55
56static CACHED_TOKEN: std::sync::OnceLock<RwLock<Option<String>>> = std::sync::OnceLock::new();
58
59pub fn register_gate_check(gate: impl Fn(&str) -> bool + Send + Sync + 'static) {
65 let _ = GATE_GETTER.set(Box::new(gate));
66}
67
68pub fn register_auth_token_getter(getter: impl Fn() -> Option<String> + Send + Sync + 'static) {
70 let _ = AUTH_TOKEN_GETTER.set(Box::new(getter));
71}
72
73pub fn register_base_url_getter(getter: impl Fn() -> String + Send + Sync + 'static) {
75 let _ = BASE_URL_GETTER.set(Box::new(getter));
76}
77
78pub fn register_secure_storage(storage: Arc<dyn SecureStorage>) {
80 let _ = STORAGE.set(storage);
81}
82
83fn is_gate_enabled() -> bool {
88 GATE_GETTER
89 .get()
90 .map(|gate| gate(TRUSTED_DEVICE_GATE))
91 .unwrap_or(false)
93}
94
95pub fn get_trusted_device_token() -> Option<String> {
103 if let Ok(env_token) = std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN) {
105 if !env_token.is_empty() {
106 return Some(env_token);
107 }
108 }
109
110 if !is_gate_enabled() {
111 return None;
112 }
113
114 if let Some(cached) = CACHED_TOKEN.get() {
116 if let Ok(token) = cached.read() {
117 return token.clone();
118 }
119 }
120
121 let token = STORAGE
123 .get()
124 .and_then(|s| s.read())
125 .and_then(|data| data.trusted_device_token);
126
127 if let Some(ref t) = token {
129 if let Some(cache) = CACHED_TOKEN.get() {
130 if let Ok(mut guard) = cache.write() {
131 *guard = Some(t.clone());
132 }
133 }
134 }
135
136 token
137}
138
139pub fn clear_trusted_device_token_cache() {
141 if let Some(cache) = CACHED_TOKEN.get() {
142 if let Ok(mut guard) = cache.write() {
143 *guard = None;
144 }
145 }
146}
147
148pub fn clear_trusted_device_token() {
152 if !is_gate_enabled() {
153 return;
154 }
155
156 if let Some(storage) = STORAGE.get() {
157 if let Some(mut data) = storage.read() {
158 data.trusted_device_token = None;
159 let _ = storage.update(&data);
160 }
161 }
162
163 clear_trusted_device_token_cache();
164}
165
166pub async fn enroll_trusted_device() {
169 if !is_gate_enabled() {
171 log_debug("[trusted-device] Gate is off, skipping enrollment");
172 return;
173 }
174
175 if std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN).is_ok() {
177 log_debug(
178 "[trusted-device] CLAUDE_TRUSTED_DEVICE_TOKEN env var is set, skipping enrollment",
179 );
180 return;
181 }
182
183 let access_token = match AUTH_TOKEN_GETTER.get() {
185 Some(getter) => getter(),
186 None => {
187 log_debug("[trusted-device] No auth token getter registered, skipping enrollment");
188 return;
189 }
190 };
191
192 let access_token = match access_token {
193 Some(t) => t,
194 None => {
195 log_debug("[trusted-device] No OAuth token, skipping enrollment");
196 return;
197 }
198 };
199
200 let base_url = match BASE_URL_GETTER.get() {
202 Some(getter) => getter(),
203 None => {
204 log_debug("[trusted-device] No base URL getter registered, skipping enrollment");
205 return;
206 }
207 };
208
209 let client = reqwest::Client::new();
210
211 let hostname = hostname::get()
212 .map(|h| h.to_string_lossy().into_owned())
213 .unwrap_or_else(|_| "unknown".to_string());
214
215 let platform = std::env::consts::OS;
216 let display_name = format!("Claude Code on {} · {}", hostname, platform);
217
218 match client
219 .post(&format!("{}/api/auth/trusted_devices", base_url))
220 .header("Authorization", format!("Bearer {}", access_token))
221 .header("Content-Type", "application/json")
222 .timeout(std::time::Duration::from_millis(ENROLLMENT_TIMEOUT_MS))
223 .json(&serde_json::json!({ "display_name": display_name }))
224 .send()
225 .await
226 {
227 Ok(response) => {
228 if response.status() != 200 && response.status() != 201 {
229 log_debug(&format!(
230 "[trusted-device] Enrollment failed {}",
231 response.status()
232 ));
233 return;
234 }
235
236 match response.json::<serde_json::Value>().await {
238 Ok(data) => {
239 let token = data.get("device_token").and_then(|v| v.as_str());
240 let device_id = data.get("device_id").and_then(|v| v.as_str());
241
242 match token {
243 Some(token) => {
244 if let Some(storage) = STORAGE.get() {
246 if let Some(mut data) = storage.read() {
247 data.trusted_device_token = Some(token.to_string());
248 if let Some(id) = device_id {
249 data.device_id = Some(id.to_string());
250 }
251 match storage.update(&data) {
252 Ok(_) => {
253 clear_trusted_device_token_cache();
254 log_debug(&format!(
255 "[trusted-device] Enrolled device_id={}",
256 device_id.unwrap_or("unknown")
257 ));
258 }
259 Err(e) => {
260 log_debug(&format!(
261 "[trusted-device] Storage write failed: {}",
262 e
263 ));
264 }
265 }
266 }
267 }
268 }
269 None => {
270 log_debug(
271 "[trusted-device] Enrollment response missing device_token field",
272 );
273 }
274 }
275 }
276 Err(e) => {
277 log_debug(&format!("[trusted-device] Failed to parse response: {}", e));
278 }
279 }
280 }
281 Err(e) => {
282 log_debug(&format!(
283 "[trusted-device] Enrollment request failed: {}",
284 e
285 ));
286 }
287 }
288}
289
290fn log_debug(msg: &str) {
295 eprintln!("{}", msg);
297}
298
299#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_token_returns_none_when_gate_disabled() {
309 assert_eq!(get_trusted_device_token(), None);
311 }
312
313 #[test]
314 fn test_clear_token_cache() {
315 clear_trusted_device_token_cache();
316 }
318}