use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum::{Json, Router};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use serde_json::{Value, json};
use tracing::{info, warn};
use crate::bridge_protocol::{
ApiError, ClientEnvelope, RuntimeStatusSnapshot, RuntimeSummary, ServerEnvelope,
error_response, event_envelope, ok_response,
};
use crate::state::BridgeState;
#[derive(Debug, Deserialize, Default)]
struct WsQuery {
token: Option<String>,
}
pub fn build_router(state: Arc<BridgeState>) -> Router {
Router::new()
.route("/health", get(health_handler))
.route("/ws", get(ws_handler))
.with_state(state)
}
async fn health_handler(State(state): State<Arc<BridgeState>>) -> Json<Value> {
let runtime = state.runtime_snapshot_for_client().await;
let runtimes = state.runtime_summaries_for_client().await;
Json(build_health_payload(&runtime, &runtimes))
}
fn build_health_payload(runtime: &RuntimeStatusSnapshot, runtimes: &[RuntimeSummary]) -> Value {
let primary_runtime_id = runtimes
.iter()
.find(|item| item.is_primary)
.map(|item| item.runtime_id.clone());
json!({
"ok": true,
"bridgeVersion": crate::BRIDGE_VERSION,
"buildHash": crate::BRIDGE_BUILD_HASH,
"protocolVersion": crate::BRIDGE_PROTOCOL_VERSION,
"runtimeCount": runtimes.len(),
"primaryRuntimeId": primary_runtime_id,
"runtime": runtime,
})
}
async fn ws_handler(
State(state): State<Arc<BridgeState>>,
Query(query): Query<WsQuery>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
match authorize(&state, &query, &headers) {
Ok(()) => ws
.on_upgrade(move |socket| handle_socket(state, socket))
.into_response(),
Err(error) => (StatusCode::UNAUTHORIZED, error).into_response(),
}
}
fn authorize(
state: &BridgeState,
query: &WsQuery,
headers: &HeaderMap,
) -> Result<(), &'static str> {
let token = query
.token
.clone()
.or_else(|| {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.strip_prefix("Bearer "))
.map(ToOwned::to_owned)
})
.ok_or("missing token")?;
if token == state.config_token() {
Ok(())
} else {
Err("invalid token")
}
}
async fn handle_socket(state: Arc<BridgeState>, socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
let mut event_rx = state.subscribe_events();
let mut device_id: Option<String> = None;
loop {
tokio::select! {
incoming = receiver.next() => {
let Some(incoming) = incoming else {
info!(
"bridge ws 对端已断开 device_id={}",
device_id.as_deref().unwrap_or("<pending>")
);
break;
};
let Ok(message) = incoming else {
warn!(
"bridge ws 接收失败 device_id={}: {:?}",
device_id.as_deref().unwrap_or("<pending>"),
incoming.err()
);
break;
};
match handle_incoming_message(&state, &mut sender, &mut device_id, message).await {
Ok(should_continue) if should_continue => {}
Ok(_) => break,
Err(error) => {
warn!(
"bridge ws 处理消息失败 device_id={}: {error}",
device_id.as_deref().unwrap_or("<pending>")
);
break;
}
}
}
broadcast_result = event_rx.recv(), if device_id.is_some() => {
match broadcast_result {
Ok(event) => {
if send_json(&mut sender, &event_envelope(event)).await.is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let envelope = ServerEnvelope::Response {
request_id: "system".to_string(),
success: false,
data: None,
error: Some(ApiError::new("lagged", "事件流丢失,请重新连接")),
};
let _ = send_json(&mut sender, &envelope).await;
break;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}
}
}
async fn handle_incoming_message(
state: &BridgeState,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
device_id: &mut Option<String>,
message: Message,
) -> anyhow::Result<bool> {
let text = match message {
Message::Text(text) => text,
Message::Close(frame) => {
let detail = frame
.as_ref()
.map(|close| format!("code={} reason={}", close.code, close.reason))
.unwrap_or_else(|| "no close frame".to_string());
info!(
"bridge ws 收到 close 帧 device_id={}: {detail}",
device_id.as_deref().unwrap_or("<pending>")
);
return Ok(false);
}
_ => return Ok(true),
};
let envelope = parse_client_envelope(&text).map_err(|error| {
anyhow::anyhow!(
"解析客户端消息失败: {error}; payload={}",
truncate_text(&text, 240)
)
})?;
match envelope {
ClientEnvelope::Hello {
device_id: next_device_id,
last_ack_seq,
} => {
info!(
"bridge ws 收到 hello device_id={} last_ack_seq={last_ack_seq:?}",
next_device_id
);
let (runtime, runtimes, workspaces, pending_requests, replay_events) =
state.hello_payload(&next_device_id, last_ack_seq).await?;
*device_id = Some(next_device_id);
let connected_device_id = device_id.as_deref().unwrap_or("<pending>");
send_json(
sender,
&ServerEnvelope::Hello {
runtime,
runtimes,
workspaces,
pending_requests,
},
)
.await?;
info!(
"bridge ws hello 已完成 device_id={} replay_events={}",
connected_device_id,
replay_events.len()
);
for event in replay_events {
send_json(sender, &event_envelope(event)).await?;
}
}
ClientEnvelope::Request {
request_id,
action,
payload,
} => {
if device_id.is_none() {
send_json(
sender,
&error_response(
request_id,
ApiError::new("handshake_required", "请先发送 hello"),
),
)
.await?;
return Ok(true);
}
let response = match state.handle_request(&action, payload).await {
Ok(data) => ok_response(request_id, data),
Err(error) => error_response(
request_id,
ApiError::new("request_failed", error.to_string()),
),
};
send_json(sender, &response).await?;
}
ClientEnvelope::AckEvents { last_seq } => {
if let Some(device_id) = device_id.as_deref() {
state.ack_events(device_id, last_seq)?;
}
}
ClientEnvelope::Ping => {
send_json(
sender,
&ServerEnvelope::Pong {
server_time_ms: crate::bridge_protocol::now_millis(),
},
)
.await?;
}
}
Ok(true)
}
fn parse_client_envelope(text: &str) -> anyhow::Result<ClientEnvelope> {
match serde_json::from_str::<ClientEnvelope>(text) {
Ok(envelope) => Ok(envelope),
Err(primary_error) => {
let nested_payload = serde_json::from_str::<String>(text).map_err(|_| primary_error)?;
serde_json::from_str::<ClientEnvelope>(&nested_payload).map_err(Into::into)
}
}
}
fn truncate_text(text: &str, max_chars: usize) -> String {
let mut truncated = String::new();
for (index, character) in text.chars().enumerate() {
if index >= max_chars {
truncated.push('…');
return truncated;
}
truncated.push(character);
}
truncated
}
async fn send_json(
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
envelope: &ServerEnvelope,
) -> anyhow::Result<()> {
let text = serde_json::to_string(envelope)?;
sender.send(Message::Text(text.into())).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::env;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use axum::extract::State;
use serde_json::{Value, json};
use tokio::time::{Duration, timeout};
use uuid::Uuid;
use super::build_health_payload;
use super::health_handler;
use super::parse_client_envelope;
use crate::bridge_protocol::{
ClientEnvelope, RuntimeRecord, RuntimeStatusSnapshot, RuntimeSummary,
};
use crate::config::Config;
use crate::state::BridgeState;
#[test]
fn build_health_payload_contains_bridge_metadata_and_primary_runtime() {
let runtime = RuntimeStatusSnapshot {
runtime_id: "primary".to_string(),
status: "running".to_string(),
codex_home: Some("/srv/codex-home".to_string()),
user_agent: Some("codex-mobile".to_string()),
platform_family: Some("linux".to_string()),
platform_os: Some("ubuntu".to_string()),
last_error: None,
pid: Some(4242),
updated_at_ms: 1234,
};
let runtime_record = RuntimeRecord {
runtime_id: "primary".to_string(),
display_name: "Primary".to_string(),
codex_home: Some("/srv/codex-home".to_string()),
codex_binary: "codex".to_string(),
is_primary: true,
auto_start: true,
created_at_ms: 1000,
updated_at_ms: 1000,
};
let runtimes = vec![RuntimeSummary::from_parts(&runtime_record, runtime.clone())];
let payload = build_health_payload(&runtime, &runtimes);
assert_eq!(payload["ok"], Value::Bool(true));
assert_eq!(
payload["bridgeVersion"],
Value::String(crate::BRIDGE_VERSION.to_string())
);
assert_eq!(
payload["buildHash"],
Value::String(crate::BRIDGE_BUILD_HASH.to_string())
);
assert_eq!(
payload["protocolVersion"],
Value::Number(crate::BRIDGE_PROTOCOL_VERSION.into())
);
assert_eq!(payload["runtimeCount"], Value::Number(1.into()));
assert_eq!(
payload["primaryRuntimeId"],
Value::String("primary".to_string())
);
assert_eq!(
payload["runtime"]["runtimeId"],
Value::String("primary".to_string())
);
assert_eq!(
payload["runtime"]["status"],
Value::String("running".to_string())
);
}
#[test]
fn parse_client_envelope_accepts_plain_hello_payload() {
let envelope = parse_client_envelope(
r#"{"kind":"hello","device_id":"device-alpha","last_ack_seq":7}"#,
)
.expect("hello payload 应可解析");
match envelope {
ClientEnvelope::Hello {
device_id,
last_ack_seq,
} => {
assert_eq!(device_id, "device-alpha");
assert_eq!(last_ack_seq, Some(7));
}
_ => panic!("应解析为 hello"),
}
}
#[test]
fn parse_client_envelope_accepts_double_encoded_hello_payload() {
let envelope = parse_client_envelope(
r#""{\"kind\":\"hello\",\"device_id\":\"device-beta\",\"last_ack_seq\":9}""#,
)
.expect("双重编码 hello payload 应可解析");
match envelope {
ClientEnvelope::Hello {
device_id,
last_ack_seq,
} => {
assert_eq!(device_id, "device-beta");
assert_eq!(last_ack_seq, Some(9));
}
_ => panic!("应解析为 hello"),
}
}
#[tokio::test]
async fn runtime_snapshot_returns_without_hanging() {
let state = bootstrap_test_state().await;
let snapshot = timeout(Duration::from_secs(2), state.runtime_snapshot())
.await
.expect("runtime_snapshot 超时");
assert_eq!(snapshot.runtime_id, "primary");
}
#[tokio::test]
async fn runtime_summaries_return_without_hanging() {
let state = bootstrap_test_state().await;
let summaries = timeout(Duration::from_secs(2), state.runtime_summaries())
.await
.expect("runtime_summaries 超时");
assert!(!summaries.is_empty());
assert_eq!(summaries[0].runtime_id, "primary");
}
#[tokio::test]
async fn health_handler_returns_without_hanging() {
let state = bootstrap_test_state().await;
let _ = timeout(
Duration::from_secs(2),
health_handler(State(Arc::clone(&state))),
)
.await
.expect("/health handler 超时");
}
#[tokio::test]
async fn hello_payload_returns_without_hanging() {
let state = bootstrap_test_state().await;
let (runtime, runtimes, ..) = timeout(
Duration::from_secs(2),
state.hello_payload("device-test", None),
)
.await
.expect("hello_payload 超时")
.expect("hello_payload 返回错误");
assert_eq!(runtime.runtime_id, "primary");
assert!(!runtimes.is_empty());
assert_eq!(runtimes[0].runtime_id, "primary");
}
#[tokio::test]
async fn list_runtimes_request_returns_without_hanging() {
let state = bootstrap_test_state().await;
let response = timeout(
Duration::from_secs(2),
state.handle_request("list_runtimes", json!({})),
)
.await
.expect("list_runtimes 超时")
.expect("list_runtimes 返回错误");
let runtimes = response["runtimes"].as_array().expect("runtimes 应为数组");
assert!(!runtimes.is_empty());
assert_eq!(
runtimes[0]["runtimeId"],
Value::String("primary".to_string())
);
}
#[tokio::test]
async fn get_runtime_status_request_returns_without_hanging() {
let state = bootstrap_test_state().await;
let response = timeout(
Duration::from_secs(2),
state.handle_request("get_runtime_status", json!({ "runtimeId": "primary" })),
)
.await
.expect("get_runtime_status 超时")
.expect("get_runtime_status 返回错误");
assert_eq!(
response["runtime"]["runtimeId"],
Value::String("primary".to_string())
);
}
async fn bootstrap_test_state() -> Arc<BridgeState> {
let base_dir = env::temp_dir().join(format!("codex-mobile-bridge-test-{}", Uuid::new_v4()));
fs::create_dir_all(&base_dir).expect("创建测试目录失败");
let db_path = base_dir.join("bridge.db");
let config = Config {
listen_addr: "127.0.0.1:0".to_string(),
token: "test-token".to_string(),
runtime_limit: 4,
db_path,
codex_home: None,
codex_binary: resolve_true_binary(),
workspace_roots: Vec::new(),
};
BridgeState::bootstrap(config)
.await
.expect("bootstrap 测试 BridgeState 失败")
}
fn resolve_true_binary() -> String {
for candidate in ["/usr/bin/true", "/bin/true"] {
if PathBuf::from(candidate).exists() {
return candidate.to_string();
}
}
"true".to_string()
}
}