Skip to main content

codex_mobile_bridge/
server.rs

1use std::sync::Arc;
2
3use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
4use axum::extract::{Query, State};
5use axum::http::{HeaderMap, StatusCode};
6use axum::response::{IntoResponse, Response};
7use axum::routing::get;
8use axum::{Json, Router};
9use futures_util::{SinkExt, StreamExt};
10use serde::Deserialize;
11use serde_json::{Value, json};
12use tracing::{info, warn};
13
14use crate::bridge_protocol::{
15    ApiError, ClientEnvelope, RuntimeStatusSnapshot, RuntimeSummary, ServerEnvelope,
16    error_response, event_envelope, ok_response,
17};
18use crate::state::BridgeState;
19
20#[derive(Debug, Deserialize, Default)]
21struct WsQuery {
22    token: Option<String>,
23}
24
25pub fn build_router(state: Arc<BridgeState>) -> Router {
26    Router::new()
27        .route("/health", get(health_handler))
28        .route("/ws", get(ws_handler))
29        .with_state(state)
30}
31
32async fn health_handler(State(state): State<Arc<BridgeState>>) -> Json<Value> {
33    let runtime = state.runtime_snapshot_for_client().await;
34    let runtimes = state.runtime_summaries_for_client().await;
35    Json(build_health_payload(&runtime, &runtimes))
36}
37
38fn build_health_payload(runtime: &RuntimeStatusSnapshot, runtimes: &[RuntimeSummary]) -> Value {
39    let primary_runtime_id = runtimes
40        .iter()
41        .find(|item| item.is_primary)
42        .map(|item| item.runtime_id.clone());
43
44    json!({
45        "ok": true,
46        "bridgeVersion": crate::BRIDGE_VERSION,
47        "buildHash": crate::BRIDGE_BUILD_HASH,
48        "protocolVersion": crate::BRIDGE_PROTOCOL_VERSION,
49        "runtimeCount": runtimes.len(),
50        "primaryRuntimeId": primary_runtime_id,
51        "runtime": runtime,
52    })
53}
54
55async fn ws_handler(
56    State(state): State<Arc<BridgeState>>,
57    Query(query): Query<WsQuery>,
58    headers: HeaderMap,
59    ws: WebSocketUpgrade,
60) -> Response {
61    match authorize(&state, &query, &headers) {
62        Ok(()) => ws
63            .on_upgrade(move |socket| handle_socket(state, socket))
64            .into_response(),
65        Err(error) => (StatusCode::UNAUTHORIZED, error).into_response(),
66    }
67}
68
69fn authorize(
70    state: &BridgeState,
71    query: &WsQuery,
72    headers: &HeaderMap,
73) -> Result<(), &'static str> {
74    let token = query
75        .token
76        .clone()
77        .or_else(|| {
78            headers
79                .get(axum::http::header::AUTHORIZATION)
80                .and_then(|value| value.to_str().ok())
81                .and_then(|value| value.strip_prefix("Bearer "))
82                .map(ToOwned::to_owned)
83        })
84        .ok_or("missing token")?;
85
86    if token == state.config_token() {
87        Ok(())
88    } else {
89        Err("invalid token")
90    }
91}
92
93async fn handle_socket(state: Arc<BridgeState>, socket: WebSocket) {
94    let (mut sender, mut receiver) = socket.split();
95    let mut event_rx = state.subscribe_events();
96    let mut device_id: Option<String> = None;
97
98    loop {
99        tokio::select! {
100            incoming = receiver.next() => {
101                let Some(incoming) = incoming else {
102                    info!(
103                        "bridge ws 对端已断开 device_id={}",
104                        device_id.as_deref().unwrap_or("<pending>")
105                    );
106                    break;
107                };
108
109                let Ok(message) = incoming else {
110                    warn!(
111                        "bridge ws 接收失败 device_id={}: {:?}",
112                        device_id.as_deref().unwrap_or("<pending>"),
113                        incoming.err()
114                    );
115                    break;
116                };
117
118                match handle_incoming_message(&state, &mut sender, &mut device_id, message).await {
119                    Ok(should_continue) if should_continue => {}
120                    Ok(_) => break,
121                    Err(error) => {
122                        warn!(
123                            "bridge ws 处理消息失败 device_id={}: {error}",
124                            device_id.as_deref().unwrap_or("<pending>")
125                        );
126                        break;
127                    }
128                }
129            }
130            broadcast_result = event_rx.recv(), if device_id.is_some() => {
131                match broadcast_result {
132                    Ok(event) => {
133                        if send_json(&mut sender, &event_envelope(event)).await.is_err() {
134                            break;
135                        }
136                    }
137                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
138                        let envelope = ServerEnvelope::Response {
139                            request_id: "system".to_string(),
140                            success: false,
141                            data: None,
142                            error: Some(ApiError::new("lagged", "事件流丢失,请重新连接")),
143                        };
144                        let _ = send_json(&mut sender, &envelope).await;
145                        break;
146                    }
147                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
148                }
149            }
150        }
151    }
152}
153
154async fn handle_incoming_message(
155    state: &BridgeState,
156    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
157    device_id: &mut Option<String>,
158    message: Message,
159) -> anyhow::Result<bool> {
160    let text = match message {
161        Message::Text(text) => text,
162        Message::Close(frame) => {
163            let detail = frame
164                .as_ref()
165                .map(|close| format!("code={} reason={}", close.code, close.reason))
166                .unwrap_or_else(|| "no close frame".to_string());
167            info!(
168                "bridge ws 收到 close 帧 device_id={}: {detail}",
169                device_id.as_deref().unwrap_or("<pending>")
170            );
171            return Ok(false);
172        }
173        _ => return Ok(true),
174    };
175
176    let envelope = parse_client_envelope(&text).map_err(|error| {
177        anyhow::anyhow!(
178            "解析客户端消息失败: {error}; payload={}",
179            truncate_text(&text, 240)
180        )
181    })?;
182    match envelope {
183        ClientEnvelope::Hello {
184            device_id: next_device_id,
185            last_ack_seq,
186        } => {
187            info!(
188                "bridge ws 收到 hello device_id={} last_ack_seq={last_ack_seq:?}",
189                next_device_id
190            );
191            let (
192                runtime,
193                runtimes,
194                directory_bookmarks,
195                directory_history,
196                pending_requests,
197                replay_events,
198            ) = state.hello_payload(&next_device_id, last_ack_seq).await?;
199            *device_id = Some(next_device_id);
200            let connected_device_id = device_id.as_deref().unwrap_or("<pending>");
201
202            send_json(
203                sender,
204                &ServerEnvelope::Hello {
205                    bridge_version: crate::BRIDGE_VERSION.to_string(),
206                    protocol_version: crate::BRIDGE_PROTOCOL_VERSION,
207                    runtime,
208                    runtimes,
209                    directory_bookmarks,
210                    directory_history,
211                    pending_requests,
212                },
213            )
214            .await?;
215
216            info!(
217                "bridge ws hello 已完成 device_id={} replay_events={}",
218                connected_device_id,
219                replay_events.len()
220            );
221            for event in replay_events {
222                send_json(sender, &event_envelope(event)).await?;
223            }
224        }
225        ClientEnvelope::Request {
226            request_id,
227            action,
228            payload,
229        } => {
230            if device_id.is_none() {
231                send_json(
232                    sender,
233                    &error_response(
234                        request_id,
235                        ApiError::new("handshake_required", "请先发送 hello"),
236                    ),
237                )
238                .await?;
239                return Ok(true);
240            }
241
242            let response = match state.handle_request(&action, payload).await {
243                Ok(data) => ok_response(request_id, data),
244                Err(error) => error_response(
245                    request_id,
246                    ApiError::new("request_failed", error.to_string()),
247                ),
248            };
249            send_json(sender, &response).await?;
250        }
251        ClientEnvelope::AckEvents { last_seq } => {
252            if let Some(device_id) = device_id.as_deref() {
253                state.ack_events(device_id, last_seq)?;
254            }
255        }
256        ClientEnvelope::Ping => {
257            send_json(
258                sender,
259                &ServerEnvelope::Pong {
260                    server_time_ms: crate::bridge_protocol::now_millis(),
261                },
262            )
263            .await?;
264        }
265    }
266
267    Ok(true)
268}
269
270fn parse_client_envelope(text: &str) -> anyhow::Result<ClientEnvelope> {
271    match serde_json::from_str::<ClientEnvelope>(text) {
272        Ok(envelope) => Ok(envelope),
273        Err(primary_error) => {
274            let nested_payload = serde_json::from_str::<String>(text).map_err(|_| primary_error)?;
275            serde_json::from_str::<ClientEnvelope>(&nested_payload).map_err(Into::into)
276        }
277    }
278}
279
280fn truncate_text(text: &str, max_chars: usize) -> String {
281    let mut truncated = String::new();
282    for (index, character) in text.chars().enumerate() {
283        if index >= max_chars {
284            truncated.push('…');
285            return truncated;
286        }
287        truncated.push(character);
288    }
289    truncated
290}
291
292async fn send_json(
293    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
294    envelope: &ServerEnvelope,
295) -> anyhow::Result<()> {
296    let text = serde_json::to_string(envelope)?;
297    sender.send(Message::Text(text.into())).await?;
298    Ok(())
299}
300
301#[cfg(test)]
302mod tests {
303    use std::env;
304    use std::fs;
305    use std::path::PathBuf;
306    use std::sync::Arc;
307
308    use axum::extract::State;
309    use serde_json::{Value, json};
310    use tokio::time::{Duration, timeout};
311    use uuid::Uuid;
312
313    use super::build_health_payload;
314    use super::health_handler;
315    use super::parse_client_envelope;
316    use crate::bridge_protocol::{
317        ClientEnvelope, RuntimeRecord, RuntimeStatusSnapshot, RuntimeSummary,
318    };
319    use crate::config::Config;
320    use crate::state::BridgeState;
321
322    #[test]
323    fn build_health_payload_contains_bridge_metadata_and_primary_runtime() {
324        let runtime = RuntimeStatusSnapshot {
325            runtime_id: "primary".to_string(),
326            status: "running".to_string(),
327            codex_home: Some("/srv/codex-home".to_string()),
328            user_agent: Some("codex-mobile".to_string()),
329            platform_family: Some("linux".to_string()),
330            platform_os: Some("ubuntu".to_string()),
331            last_error: None,
332            pid: Some(4242),
333            updated_at_ms: 1234,
334        };
335        let runtime_record = RuntimeRecord {
336            runtime_id: "primary".to_string(),
337            display_name: "Primary".to_string(),
338            codex_home: Some("/srv/codex-home".to_string()),
339            codex_binary: "codex".to_string(),
340            is_primary: true,
341            auto_start: true,
342            created_at_ms: 1000,
343            updated_at_ms: 1000,
344        };
345        let runtimes = vec![RuntimeSummary::from_parts(&runtime_record, runtime.clone())];
346
347        let payload = build_health_payload(&runtime, &runtimes);
348
349        assert_eq!(payload["ok"], Value::Bool(true));
350        assert_eq!(
351            payload["bridgeVersion"],
352            Value::String(crate::BRIDGE_VERSION.to_string())
353        );
354        assert_eq!(
355            payload["buildHash"],
356            Value::String(crate::BRIDGE_BUILD_HASH.to_string())
357        );
358        assert_eq!(
359            payload["protocolVersion"],
360            Value::Number(crate::BRIDGE_PROTOCOL_VERSION.into())
361        );
362        assert_eq!(payload["runtimeCount"], Value::Number(1.into()));
363        assert_eq!(
364            payload["primaryRuntimeId"],
365            Value::String("primary".to_string())
366        );
367        assert_eq!(
368            payload["runtime"]["runtimeId"],
369            Value::String("primary".to_string())
370        );
371        assert_eq!(
372            payload["runtime"]["status"],
373            Value::String("running".to_string())
374        );
375    }
376
377    #[test]
378    fn parse_client_envelope_accepts_plain_hello_payload() {
379        let envelope = parse_client_envelope(
380            r#"{"kind":"hello","device_id":"device-alpha","last_ack_seq":7}"#,
381        )
382        .expect("hello payload 应可解析");
383
384        match envelope {
385            ClientEnvelope::Hello {
386                device_id,
387                last_ack_seq,
388            } => {
389                assert_eq!(device_id, "device-alpha");
390                assert_eq!(last_ack_seq, Some(7));
391            }
392            _ => panic!("应解析为 hello"),
393        }
394    }
395
396    #[test]
397    fn parse_client_envelope_accepts_double_encoded_hello_payload() {
398        let envelope = parse_client_envelope(
399            r#""{\"kind\":\"hello\",\"device_id\":\"device-beta\",\"last_ack_seq\":9}""#,
400        )
401        .expect("双重编码 hello payload 应可解析");
402
403        match envelope {
404            ClientEnvelope::Hello {
405                device_id,
406                last_ack_seq,
407            } => {
408                assert_eq!(device_id, "device-beta");
409                assert_eq!(last_ack_seq, Some(9));
410            }
411            _ => panic!("应解析为 hello"),
412        }
413    }
414
415    #[tokio::test]
416    async fn runtime_snapshot_returns_without_hanging() {
417        let state = bootstrap_test_state().await;
418
419        let snapshot = timeout(Duration::from_secs(2), state.runtime_snapshot())
420            .await
421            .expect("runtime_snapshot 超时");
422
423        assert_eq!(snapshot.runtime_id, "primary");
424    }
425
426    #[tokio::test]
427    async fn runtime_summaries_return_without_hanging() {
428        let state = bootstrap_test_state().await;
429
430        let summaries = timeout(Duration::from_secs(2), state.runtime_summaries())
431            .await
432            .expect("runtime_summaries 超时");
433
434        assert!(!summaries.is_empty());
435        assert_eq!(summaries[0].runtime_id, "primary");
436    }
437
438    #[tokio::test]
439    async fn health_handler_returns_without_hanging() {
440        let state = bootstrap_test_state().await;
441
442        let _ = timeout(
443            Duration::from_secs(2),
444            health_handler(State(Arc::clone(&state))),
445        )
446        .await
447        .expect("/health handler 超时");
448    }
449
450    #[tokio::test]
451    async fn hello_payload_returns_without_hanging() {
452        let state = bootstrap_test_state().await;
453
454        let (runtime, runtimes, ..) = timeout(
455            Duration::from_secs(2),
456            state.hello_payload("device-test", None),
457        )
458        .await
459        .expect("hello_payload 超时")
460        .expect("hello_payload 返回错误");
461
462        assert_eq!(runtime.runtime_id, "primary");
463        assert!(!runtimes.is_empty());
464        assert_eq!(runtimes[0].runtime_id, "primary");
465    }
466
467    #[tokio::test]
468    async fn list_runtimes_request_returns_without_hanging() {
469        let state = bootstrap_test_state().await;
470
471        let response = timeout(
472            Duration::from_secs(2),
473            state.handle_request("list_runtimes", json!({})),
474        )
475        .await
476        .expect("list_runtimes 超时")
477        .expect("list_runtimes 返回错误");
478
479        let runtimes = response["runtimes"].as_array().expect("runtimes 应为数组");
480        assert!(!runtimes.is_empty());
481        assert_eq!(
482            runtimes[0]["runtimeId"],
483            Value::String("primary".to_string())
484        );
485    }
486
487    #[tokio::test]
488    async fn get_runtime_status_request_returns_without_hanging() {
489        let state = bootstrap_test_state().await;
490
491        let response = timeout(
492            Duration::from_secs(2),
493            state.handle_request("get_runtime_status", json!({ "runtimeId": "primary" })),
494        )
495        .await
496        .expect("get_runtime_status 超时")
497        .expect("get_runtime_status 返回错误");
498
499        assert_eq!(
500            response["runtime"]["runtimeId"],
501            Value::String("primary".to_string())
502        );
503    }
504
505    async fn bootstrap_test_state() -> Arc<BridgeState> {
506        let base_dir = env::temp_dir().join(format!("codex-mobile-bridge-test-{}", Uuid::new_v4()));
507        fs::create_dir_all(&base_dir).expect("创建测试目录失败");
508        let db_path = base_dir.join("bridge.db");
509
510        let config = Config {
511            listen_addr: "127.0.0.1:0".to_string(),
512            token: "test-token".to_string(),
513            runtime_limit: 4,
514            db_path,
515            codex_home: None,
516            codex_binary: resolve_true_binary(),
517            directory_bookmarks: Vec::new(),
518        };
519
520        BridgeState::bootstrap(config)
521            .await
522            .expect("bootstrap 测试 BridgeState 失败")
523    }
524
525    fn resolve_true_binary() -> String {
526        for candidate in ["/usr/bin/true", "/bin/true"] {
527            if PathBuf::from(candidate).exists() {
528                return candidate.to_string();
529            }
530        }
531        "true".to_string()
532    }
533}