ai_agent/bridge/
trusted_device.rs1use crate::constants::env::ai;
10use crate::utils::http::get_user_agent;
11use std::sync::{Arc, RwLock};
12
13use reqwest;
14
15const TRUSTED_DEVICE_GATE: &str = "tengu_sessions_elevated_auth_enforcement";
20const ENROLLMENT_TIMEOUT_MS: u64 = 10_000;
21
22pub trait SecureStorage: Send + Sync {
28 fn read(&self) -> Option<StorageData>;
29 fn update(&self, data: &StorageData) -> Result<(), String>;
30}
31
32#[derive(Clone, Default)]
33pub struct StorageData {
34 pub trusted_device_token: Option<String>,
35 pub device_id: Option<String>,
36 }
38
39pub type GateFn = Box<dyn Fn(&str) -> bool + Send + Sync>;
45
46pub type AuthTokenGetterFn = Box<dyn Fn() -> Option<String> + Send + Sync>;
48
49pub type BaseUrlGetterFn = Box<dyn Fn() -> String + Send + Sync>;
51
52static GATE_GETTER: std::sync::OnceLock<GateFn> = std::sync::OnceLock::new();
53static AUTH_TOKEN_GETTER: std::sync::OnceLock<AuthTokenGetterFn> = std::sync::OnceLock::new();
54static BASE_URL_GETTER: std::sync::OnceLock<BaseUrlGetterFn> = std::sync::OnceLock::new();
55static STORAGE: std::sync::OnceLock<Arc<dyn SecureStorage>> = std::sync::OnceLock::new();
56
57static CACHED_TOKEN: std::sync::OnceLock<RwLock<Option<String>>> = std::sync::OnceLock::new();
59
60pub fn register_gate_check(gate: impl Fn(&str) -> bool + Send + Sync + 'static) {
66 let _ = GATE_GETTER.set(Box::new(gate));
67}
68
69pub fn register_auth_token_getter(getter: impl Fn() -> Option<String> + Send + Sync + 'static) {
71 let _ = AUTH_TOKEN_GETTER.set(Box::new(getter));
72}
73
74pub fn register_base_url_getter(getter: impl Fn() -> String + Send + Sync + 'static) {
76 let _ = BASE_URL_GETTER.set(Box::new(getter));
77}
78
79pub fn register_secure_storage(storage: Arc<dyn SecureStorage>) {
81 let _ = STORAGE.set(storage);
82}
83
84fn is_gate_enabled() -> bool {
89 GATE_GETTER
90 .get()
91 .map(|gate| gate(TRUSTED_DEVICE_GATE))
92 .unwrap_or(false)
94}
95
96pub fn get_trusted_device_token() -> Option<String> {
104 if let Ok(env_token) = std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN) {
106 if !env_token.is_empty() {
107 return Some(env_token);
108 }
109 }
110
111 if !is_gate_enabled() {
112 return None;
113 }
114
115 if let Some(cached) = CACHED_TOKEN.get() {
117 if let Ok(token) = cached.read() {
118 return token.clone();
119 }
120 }
121
122 let token = STORAGE
124 .get()
125 .and_then(|s| s.read())
126 .and_then(|data| data.trusted_device_token);
127
128 if let Some(ref t) = token {
130 if let Some(cache) = CACHED_TOKEN.get() {
131 if let Ok(mut guard) = cache.write() {
132 *guard = Some(t.clone());
133 }
134 }
135 }
136
137 token
138}
139
140pub fn clear_trusted_device_token_cache() {
142 if let Some(cache) = CACHED_TOKEN.get() {
143 if let Ok(mut guard) = cache.write() {
144 *guard = None;
145 }
146 }
147}
148
149pub fn clear_trusted_device_token() {
153 if !is_gate_enabled() {
154 return;
155 }
156
157 if let Some(storage) = STORAGE.get() {
158 if let Some(mut data) = storage.read() {
159 data.trusted_device_token = None;
160 let _ = storage.update(&data);
161 }
162 }
163
164 clear_trusted_device_token_cache();
165}
166
167pub async fn enroll_trusted_device() {
170 if !is_gate_enabled() {
172 log_debug("[trusted-device] Gate is off, skipping enrollment");
173 return;
174 }
175
176 if std::env::var(ai::CLAUDE_TRUSTED_DEVICE_TOKEN).is_ok() {
178 log_debug(
179 "[trusted-device] CLAUDE_TRUSTED_DEVICE_TOKEN env var is set, skipping enrollment",
180 );
181 return;
182 }
183
184 let access_token = match AUTH_TOKEN_GETTER.get() {
186 Some(getter) => getter(),
187 None => {
188 log_debug("[trusted-device] No auth token getter registered, skipping enrollment");
189 return;
190 }
191 };
192
193 let access_token = match access_token {
194 Some(t) => t,
195 None => {
196 log_debug("[trusted-device] No OAuth token, skipping enrollment");
197 return;
198 }
199 };
200
201 let base_url = match BASE_URL_GETTER.get() {
203 Some(getter) => getter(),
204 None => {
205 log_debug("[trusted-device] No base URL getter registered, skipping enrollment");
206 return;
207 }
208 };
209
210 let client = reqwest::Client::new();
211
212 let hostname = hostname::get()
213 .map(|h| h.to_string_lossy().into_owned())
214 .unwrap_or_else(|_| "unknown".to_string());
215
216 let platform = std::env::consts::OS;
217 let display_name = format!("Claude Code on {} · {}", hostname, platform);
218
219 match client
220 .post(&format!("{}/api/auth/trusted_devices", base_url))
221 .header("Authorization", format!("Bearer {}", access_token))
222 .header("Content-Type", "application/json")
223 .header("User-Agent", get_user_agent())
224 .timeout(std::time::Duration::from_millis(ENROLLMENT_TIMEOUT_MS))
225 .json(&serde_json::json!({ "display_name": display_name }))
226 .send()
227 .await
228 {
229 Ok(response) => {
230 if response.status() != 200 && response.status() != 201 {
231 log_debug(&format!(
232 "[trusted-device] Enrollment failed {}",
233 response.status()
234 ));
235 return;
236 }
237
238 match response.json::<serde_json::Value>().await {
240 Ok(data) => {
241 let token = data.get("device_token").and_then(|v| v.as_str());
242 let device_id = data.get("device_id").and_then(|v| v.as_str());
243
244 match token {
245 Some(token) => {
246 if let Some(storage) = STORAGE.get() {
248 if let Some(mut data) = storage.read() {
249 data.trusted_device_token = Some(token.to_string());
250 if let Some(id) = device_id {
251 data.device_id = Some(id.to_string());
252 }
253 match storage.update(&data) {
254 Ok(_) => {
255 clear_trusted_device_token_cache();
256 log_debug(&format!(
257 "[trusted-device] Enrolled device_id={}",
258 device_id.unwrap_or("unknown")
259 ));
260 }
261 Err(e) => {
262 log_debug(&format!(
263 "[trusted-device] Storage write failed: {}",
264 e
265 ));
266 }
267 }
268 }
269 }
270 }
271 None => {
272 log_debug(
273 "[trusted-device] Enrollment response missing device_token field",
274 );
275 }
276 }
277 }
278 Err(e) => {
279 log_debug(&format!("[trusted-device] Failed to parse response: {}", e));
280 }
281 }
282 }
283 Err(e) => {
284 log_debug(&format!(
285 "[trusted-device] Enrollment request failed: {}",
286 e
287 ));
288 }
289 }
290}
291
292fn log_debug(msg: &str) {
297 eprintln!("{}", msg);
299}
300
301#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_token_returns_none_when_gate_disabled() {
311 assert_eq!(get_trusted_device_token(), None);
313 }
314
315 #[test]
316 fn test_clear_token_cache() {
317 clear_trusted_device_token_cache();
318 }
320}