1use std::sync::Arc;
2
3use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
4use axum::extract::{Query, State};
5use axum::http::{HeaderMap, StatusCode};
6use axum::response::{IntoResponse, Response};
7use axum::routing::get;
8use axum::{Json, Router};
9use futures_util::{SinkExt, StreamExt};
10use serde::Deserialize;
11use serde_json::{Value, json};
12use tracing::{info, warn};
13
14use crate::bridge_protocol::{
15 ApiError, ClientEnvelope, RuntimeStatusSnapshot, RuntimeSummary, ServerEnvelope,
16 error_response, event_envelope, ok_response,
17};
18use crate::state::BridgeState;
19
20#[derive(Debug, Deserialize, Default)]
21struct WsQuery {
22 token: Option<String>,
23}
24
25pub fn build_router(state: Arc<BridgeState>) -> Router {
26 Router::new()
27 .route("/health", get(health_handler))
28 .route("/ws", get(ws_handler))
29 .with_state(state)
30}
31
32async fn health_handler(State(state): State<Arc<BridgeState>>) -> Json<Value> {
33 let runtime = state.runtime_snapshot_for_client().await;
34 let runtimes = state.runtime_summaries_for_client().await;
35 Json(build_health_payload(&runtime, &runtimes))
36}
37
38fn build_health_payload(runtime: &RuntimeStatusSnapshot, runtimes: &[RuntimeSummary]) -> Value {
39 let primary_runtime_id = runtimes
40 .iter()
41 .find(|item| item.is_primary)
42 .map(|item| item.runtime_id.clone());
43
44 json!({
45 "ok": true,
46 "bridgeVersion": crate::BRIDGE_VERSION,
47 "buildHash": crate::BRIDGE_BUILD_HASH,
48 "protocolVersion": crate::BRIDGE_PROTOCOL_VERSION,
49 "runtimeCount": runtimes.len(),
50 "primaryRuntimeId": primary_runtime_id,
51 "runtime": runtime,
52 })
53}
54
55async fn ws_handler(
56 State(state): State<Arc<BridgeState>>,
57 Query(query): Query<WsQuery>,
58 headers: HeaderMap,
59 ws: WebSocketUpgrade,
60) -> Response {
61 match authorize(&state, &query, &headers) {
62 Ok(()) => ws
63 .on_upgrade(move |socket| handle_socket(state, socket))
64 .into_response(),
65 Err(error) => (StatusCode::UNAUTHORIZED, error).into_response(),
66 }
67}
68
69fn authorize(
70 state: &BridgeState,
71 query: &WsQuery,
72 headers: &HeaderMap,
73) -> Result<(), &'static str> {
74 let token = query
75 .token
76 .clone()
77 .or_else(|| {
78 headers
79 .get(axum::http::header::AUTHORIZATION)
80 .and_then(|value| value.to_str().ok())
81 .and_then(|value| value.strip_prefix("Bearer "))
82 .map(ToOwned::to_owned)
83 })
84 .ok_or("missing token")?;
85
86 if token == state.config_token() {
87 Ok(())
88 } else {
89 Err("invalid token")
90 }
91}
92
93async fn handle_socket(state: Arc<BridgeState>, socket: WebSocket) {
94 let (mut sender, mut receiver) = socket.split();
95 let mut event_rx = state.subscribe_events();
96 let mut device_id: Option<String> = None;
97
98 loop {
99 tokio::select! {
100 incoming = receiver.next() => {
101 let Some(incoming) = incoming else {
102 info!(
103 "bridge ws 对端已断开 device_id={}",
104 device_id.as_deref().unwrap_or("<pending>")
105 );
106 break;
107 };
108
109 let Ok(message) = incoming else {
110 warn!(
111 "bridge ws 接收失败 device_id={}: {:?}",
112 device_id.as_deref().unwrap_or("<pending>"),
113 incoming.err()
114 );
115 break;
116 };
117
118 match handle_incoming_message(&state, &mut sender, &mut device_id, message).await {
119 Ok(should_continue) if should_continue => {}
120 Ok(_) => break,
121 Err(error) => {
122 warn!(
123 "bridge ws 处理消息失败 device_id={}: {error}",
124 device_id.as_deref().unwrap_or("<pending>")
125 );
126 break;
127 }
128 }
129 }
130 broadcast_result = event_rx.recv(), if device_id.is_some() => {
131 match broadcast_result {
132 Ok(event) => {
133 if send_json(&mut sender, &event_envelope(event)).await.is_err() {
134 break;
135 }
136 }
137 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
138 let envelope = ServerEnvelope::Response {
139 request_id: "system".to_string(),
140 success: false,
141 data: None,
142 error: Some(ApiError::new("lagged", "事件流丢失,请重新连接")),
143 };
144 let _ = send_json(&mut sender, &envelope).await;
145 break;
146 }
147 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
148 }
149 }
150 }
151 }
152}
153
154async fn handle_incoming_message(
155 state: &BridgeState,
156 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
157 device_id: &mut Option<String>,
158 message: Message,
159) -> anyhow::Result<bool> {
160 let text = match message {
161 Message::Text(text) => text,
162 Message::Close(frame) => {
163 let detail = frame
164 .as_ref()
165 .map(|close| format!("code={} reason={}", close.code, close.reason))
166 .unwrap_or_else(|| "no close frame".to_string());
167 info!(
168 "bridge ws 收到 close 帧 device_id={}: {detail}",
169 device_id.as_deref().unwrap_or("<pending>")
170 );
171 return Ok(false);
172 }
173 _ => return Ok(true),
174 };
175
176 let envelope = parse_client_envelope(&text).map_err(|error| {
177 anyhow::anyhow!(
178 "解析客户端消息失败: {error}; payload={}",
179 truncate_text(&text, 240)
180 )
181 })?;
182 match envelope {
183 ClientEnvelope::Hello {
184 device_id: next_device_id,
185 last_ack_seq,
186 } => {
187 info!(
188 "bridge ws 收到 hello device_id={} last_ack_seq={last_ack_seq:?}",
189 next_device_id
190 );
191 let (
192 runtime,
193 runtimes,
194 directory_bookmarks,
195 directory_history,
196 pending_requests,
197 replay_events,
198 ) = state.hello_payload(&next_device_id, last_ack_seq).await?;
199 *device_id = Some(next_device_id);
200 let connected_device_id = device_id.as_deref().unwrap_or("<pending>");
201
202 send_json(
203 sender,
204 &ServerEnvelope::Hello {
205 bridge_version: crate::BRIDGE_VERSION.to_string(),
206 protocol_version: crate::BRIDGE_PROTOCOL_VERSION,
207 runtime,
208 runtimes,
209 directory_bookmarks,
210 directory_history,
211 pending_requests,
212 },
213 )
214 .await?;
215
216 info!(
217 "bridge ws hello 已完成 device_id={} replay_events={}",
218 connected_device_id,
219 replay_events.len()
220 );
221 for event in replay_events {
222 send_json(sender, &event_envelope(event)).await?;
223 }
224 }
225 ClientEnvelope::Request {
226 request_id,
227 action,
228 payload,
229 } => {
230 if device_id.is_none() {
231 send_json(
232 sender,
233 &error_response(
234 request_id,
235 ApiError::new("handshake_required", "请先发送 hello"),
236 ),
237 )
238 .await?;
239 return Ok(true);
240 }
241
242 let response = match state.handle_request(&action, payload).await {
243 Ok(data) => ok_response(request_id, data),
244 Err(error) => error_response(
245 request_id,
246 ApiError::new("request_failed", error.to_string()),
247 ),
248 };
249 send_json(sender, &response).await?;
250 }
251 ClientEnvelope::AckEvents { last_seq } => {
252 if let Some(device_id) = device_id.as_deref() {
253 state.ack_events(device_id, last_seq)?;
254 }
255 }
256 ClientEnvelope::Ping => {
257 send_json(
258 sender,
259 &ServerEnvelope::Pong {
260 server_time_ms: crate::bridge_protocol::now_millis(),
261 },
262 )
263 .await?;
264 }
265 }
266
267 Ok(true)
268}
269
270fn parse_client_envelope(text: &str) -> anyhow::Result<ClientEnvelope> {
271 match serde_json::from_str::<ClientEnvelope>(text) {
272 Ok(envelope) => Ok(envelope),
273 Err(primary_error) => {
274 let nested_payload = serde_json::from_str::<String>(text).map_err(|_| primary_error)?;
275 serde_json::from_str::<ClientEnvelope>(&nested_payload).map_err(Into::into)
276 }
277 }
278}
279
280fn truncate_text(text: &str, max_chars: usize) -> String {
281 let mut truncated = String::new();
282 for (index, character) in text.chars().enumerate() {
283 if index >= max_chars {
284 truncated.push('…');
285 return truncated;
286 }
287 truncated.push(character);
288 }
289 truncated
290}
291
292async fn send_json(
293 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
294 envelope: &ServerEnvelope,
295) -> anyhow::Result<()> {
296 let text = serde_json::to_string(envelope)?;
297 sender.send(Message::Text(text.into())).await?;
298 Ok(())
299}
300
301#[cfg(test)]
302mod tests {
303 use std::env;
304 use std::fs;
305 use std::path::PathBuf;
306 use std::sync::Arc;
307
308 use axum::extract::State;
309 use serde_json::{Value, json};
310 use tokio::time::{Duration, timeout};
311 use uuid::Uuid;
312
313 use super::build_health_payload;
314 use super::health_handler;
315 use super::parse_client_envelope;
316 use crate::bridge_protocol::{
317 AppServerHandshakeSummary, ClientEnvelope, RuntimeRecord, RuntimeStatusSnapshot,
318 RuntimeSummary,
319 };
320 use crate::config::Config;
321 use crate::state::BridgeState;
322
323 #[test]
324 fn build_health_payload_contains_bridge_metadata_and_primary_runtime() {
325 let runtime = RuntimeStatusSnapshot {
326 runtime_id: "primary".to_string(),
327 status: "running".to_string(),
328 codex_home: Some("/srv/codex-home".to_string()),
329 user_agent: Some("codex-mobile".to_string()),
330 platform_family: Some("linux".to_string()),
331 platform_os: Some("ubuntu".to_string()),
332 last_error: None,
333 pid: Some(4242),
334 app_server_handshake: AppServerHandshakeSummary::new(
335 "ready",
336 true,
337 vec!["fs/changed".to_string()],
338 Some("握手完成,initialized 已发送".to_string()),
339 ),
340 updated_at_ms: 1234,
341 };
342 let runtime_record = RuntimeRecord {
343 runtime_id: "primary".to_string(),
344 display_name: "Primary".to_string(),
345 codex_home: Some("/srv/codex-home".to_string()),
346 codex_binary: "codex".to_string(),
347 is_primary: true,
348 auto_start: true,
349 created_at_ms: 1000,
350 updated_at_ms: 1000,
351 };
352 let runtimes = vec![RuntimeSummary::from_parts(&runtime_record, runtime.clone())];
353
354 let payload = build_health_payload(&runtime, &runtimes);
355
356 assert_eq!(payload["ok"], Value::Bool(true));
357 assert_eq!(
358 payload["bridgeVersion"],
359 Value::String(crate::BRIDGE_VERSION.to_string())
360 );
361 assert_eq!(
362 payload["buildHash"],
363 Value::String(crate::BRIDGE_BUILD_HASH.to_string())
364 );
365 assert_eq!(
366 payload["protocolVersion"],
367 Value::Number(crate::BRIDGE_PROTOCOL_VERSION.into())
368 );
369 assert_eq!(payload["runtimeCount"], Value::Number(1.into()));
370 assert_eq!(
371 payload["primaryRuntimeId"],
372 Value::String("primary".to_string())
373 );
374 assert_eq!(
375 payload["runtime"]["runtimeId"],
376 Value::String("primary".to_string())
377 );
378 assert_eq!(
379 payload["runtime"]["status"],
380 Value::String("running".to_string())
381 );
382 }
383
384 #[test]
385 fn parse_client_envelope_accepts_plain_hello_payload() {
386 let envelope = parse_client_envelope(
387 r#"{"kind":"hello","device_id":"device-alpha","last_ack_seq":7}"#,
388 )
389 .expect("hello payload 应可解析");
390
391 match envelope {
392 ClientEnvelope::Hello {
393 device_id,
394 last_ack_seq,
395 } => {
396 assert_eq!(device_id, "device-alpha");
397 assert_eq!(last_ack_seq, Some(7));
398 }
399 _ => panic!("应解析为 hello"),
400 }
401 }
402
403 #[test]
404 fn parse_client_envelope_accepts_double_encoded_hello_payload() {
405 let envelope = parse_client_envelope(
406 r#""{\"kind\":\"hello\",\"device_id\":\"device-beta\",\"last_ack_seq\":9}""#,
407 )
408 .expect("双重编码 hello payload 应可解析");
409
410 match envelope {
411 ClientEnvelope::Hello {
412 device_id,
413 last_ack_seq,
414 } => {
415 assert_eq!(device_id, "device-beta");
416 assert_eq!(last_ack_seq, Some(9));
417 }
418 _ => panic!("应解析为 hello"),
419 }
420 }
421
422 #[tokio::test]
423 async fn runtime_snapshot_returns_without_hanging() {
424 let state = bootstrap_test_state().await;
425
426 let snapshot = timeout(Duration::from_secs(2), state.runtime_snapshot())
427 .await
428 .expect("runtime_snapshot 超时");
429
430 assert_eq!(snapshot.runtime_id, "primary");
431 }
432
433 #[tokio::test]
434 async fn runtime_summaries_return_without_hanging() {
435 let state = bootstrap_test_state().await;
436
437 let summaries = timeout(Duration::from_secs(2), state.runtime_summaries())
438 .await
439 .expect("runtime_summaries 超时");
440
441 assert!(!summaries.is_empty());
442 assert_eq!(summaries[0].runtime_id, "primary");
443 }
444
445 #[tokio::test]
446 async fn health_handler_returns_without_hanging() {
447 let state = bootstrap_test_state().await;
448
449 let _ = timeout(
450 Duration::from_secs(2),
451 health_handler(State(Arc::clone(&state))),
452 )
453 .await
454 .expect("/health handler 超时");
455 }
456
457 #[tokio::test]
458 async fn hello_payload_returns_without_hanging() {
459 let state = bootstrap_test_state().await;
460
461 let (runtime, runtimes, ..) = timeout(
462 Duration::from_secs(2),
463 state.hello_payload("device-test", None),
464 )
465 .await
466 .expect("hello_payload 超时")
467 .expect("hello_payload 返回错误");
468
469 assert_eq!(runtime.runtime_id, "primary");
470 assert!(!runtimes.is_empty());
471 assert_eq!(runtimes[0].runtime_id, "primary");
472 }
473
474 #[tokio::test]
475 async fn list_runtimes_request_returns_without_hanging() {
476 let state = bootstrap_test_state().await;
477
478 let response = timeout(
479 Duration::from_secs(2),
480 state.handle_request("list_runtimes", json!({})),
481 )
482 .await
483 .expect("list_runtimes 超时")
484 .expect("list_runtimes 返回错误");
485
486 let runtimes = response["runtimes"].as_array().expect("runtimes 应为数组");
487 assert!(!runtimes.is_empty());
488 assert_eq!(
489 runtimes[0]["runtimeId"],
490 Value::String("primary".to_string())
491 );
492 }
493
494 #[tokio::test]
495 async fn get_runtime_status_request_returns_without_hanging() {
496 let state = bootstrap_test_state().await;
497
498 let response = timeout(
499 Duration::from_secs(2),
500 state.handle_request("get_runtime_status", json!({ "runtimeId": "primary" })),
501 )
502 .await
503 .expect("get_runtime_status 超时")
504 .expect("get_runtime_status 返回错误");
505
506 assert_eq!(
507 response["runtime"]["runtimeId"],
508 Value::String("primary".to_string())
509 );
510 }
511
512 async fn bootstrap_test_state() -> Arc<BridgeState> {
513 let base_dir = env::temp_dir().join(format!("codex-mobile-bridge-test-{}", Uuid::new_v4()));
514 fs::create_dir_all(&base_dir).expect("创建测试目录失败");
515 let db_path = base_dir.join("bridge.db");
516
517 let config = Config {
518 listen_addr: "127.0.0.1:0".to_string(),
519 token: "test-token".to_string(),
520 runtime_limit: 4,
521 db_path,
522 codex_home: None,
523 codex_binary: resolve_true_binary(),
524 directory_bookmarks: Vec::new(),
525 };
526
527 BridgeState::bootstrap(config)
528 .await
529 .expect("bootstrap 测试 BridgeState 失败")
530 }
531
532 fn resolve_true_binary() -> String {
533 for candidate in ["/usr/bin/true", "/bin/true"] {
534 if PathBuf::from(candidate).exists() {
535 return candidate.to_string();
536 }
537 }
538 "true".to_string()
539 }
540}