codex-mobile-bridge 0.3.3

Remote bridge and service manager for codex-mobile.
Documentation
use super::*;

pub(super) async fn ws_handler(
    State(state): State<Arc<BridgeState>>,
    Query(query): Query<WsQuery>,
    headers: HeaderMap,
    ws: WebSocketUpgrade,
) -> Response {
    match authorize(&state, &query, &headers) {
        Ok(()) => ws
            .on_upgrade(move |socket| handle_socket(state, socket))
            .into_response(),
        Err(error) => (StatusCode::UNAUTHORIZED, error).into_response(),
    }
}

fn authorize(
    state: &BridgeState,
    query: &WsQuery,
    headers: &HeaderMap,
) -> Result<(), &'static str> {
    let token = query
        .token
        .clone()
        .or_else(|| {
            headers
                .get(axum::http::header::AUTHORIZATION)
                .and_then(|value| value.to_str().ok())
                .and_then(|value| value.strip_prefix("Bearer "))
                .map(ToOwned::to_owned)
        })
        .ok_or("missing token")?;

    if token == state.config_token() {
        Ok(())
    } else {
        Err("invalid token")
    }
}

async fn handle_socket(state: Arc<BridgeState>, socket: WebSocket) {
    let (mut sender, mut receiver) = socket.split();
    let mut event_rx = state.subscribe_events();
    let mut device_id: Option<String> = None;

    loop {
        tokio::select! {
            incoming = receiver.next() => {
                let Some(incoming) = incoming else {
                    info!(
                        "bridge ws 对端已断开 device_id={}",
                        device_id.as_deref().unwrap_or("<pending>")
                    );
                    break;
                };

                let Ok(message) = incoming else {
                    warn!(
                        "bridge ws 接收失败 device_id={}: {:?}",
                        device_id.as_deref().unwrap_or("<pending>"),
                        incoming.err()
                    );
                    break;
                };

                match handle_incoming_message(&state, &mut sender, &mut device_id, message).await {
                    Ok(should_continue) if should_continue => {}
                    Ok(_) => break,
                    Err(error) => {
                        warn!(
                            "bridge ws 处理消息失败 device_id={}: {error}",
                            device_id.as_deref().unwrap_or("<pending>")
                        );
                        break;
                    }
                }
            }
            broadcast_result = event_rx.recv(), if device_id.is_some() => {
                match broadcast_result {
                    Ok(event) => {
                        if send_json(&mut sender, &event_envelope(event)).await.is_err() {
                            break;
                        }
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
                        let envelope = ServerEnvelope::Response {
                            request_id: "system".to_string(),
                            success: false,
                            data: None,
                            error: Some(ApiError::new("lagged", "事件流丢失,请重新连接")),
                        };
                        let _ = send_json(&mut sender, &envelope).await;
                        break;
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
                }
            }
        }
    }
}

async fn handle_incoming_message(
    state: &BridgeState,
    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
    device_id: &mut Option<String>,
    message: Message,
) -> anyhow::Result<bool> {
    let text = match message {
        Message::Text(text) => text,
        Message::Close(frame) => {
            let detail = frame
                .as_ref()
                .map(|close| format!("code={} reason={}", close.code, close.reason))
                .unwrap_or_else(|| "no close frame".to_string());
            info!(
                "bridge ws 收到 close 帧 device_id={}: {detail}",
                device_id.as_deref().unwrap_or("<pending>")
            );
            return Ok(false);
        }
        _ => return Ok(true),
    };

    let envelope = parse_client_envelope(&text).map_err(|error| {
        anyhow::anyhow!(
            "解析客户端消息失败: {error}; payload={}",
            truncate_text(&text, 240)
        )
    })?;
    match envelope {
        ClientEnvelope::Hello {
            device_id: next_device_id,
            last_ack_seq,
        } => {
            info!(
                "bridge ws 收到 hello device_id={} last_ack_seq={last_ack_seq:?}",
                next_device_id
            );
            let (
                runtime,
                runtimes,
                directory_bookmarks,
                directory_history,
                pending_requests,
                replay_events,
            ) = state.hello_payload(&next_device_id, last_ack_seq).await?;
            *device_id = Some(next_device_id);
            let connected_device_id = device_id.as_deref().unwrap_or("<pending>");

            send_json(
                sender,
                &ServerEnvelope::Hello {
                    bridge_version: crate::BRIDGE_VERSION.to_string(),
                    protocol_version: crate::BRIDGE_PROTOCOL_VERSION,
                    runtime,
                    runtimes,
                    directory_bookmarks,
                    directory_history,
                    pending_requests,
                },
            )
            .await?;

            info!(
                "bridge ws hello 已完成 device_id={} replay_events={}",
                connected_device_id,
                replay_events.len()
            );
            for event in replay_events {
                send_json(sender, &event_envelope(event)).await?;
            }
        }
        ClientEnvelope::Request {
            request_id,
            action,
            payload,
        } => {
            if device_id.is_none() {
                send_json(
                    sender,
                    &error_response(
                        request_id,
                        ApiError::new("handshake_required", "请先发送 hello"),
                    ),
                )
                .await?;
                return Ok(true);
            }

            let response = match state.handle_request(&action, payload).await {
                Ok(data) => ok_response(request_id, data),
                Err(error) => error_response(
                    request_id,
                    ApiError::new("request_failed", error.to_string()),
                ),
            };
            send_json(sender, &response).await?;
        }
        ClientEnvelope::AckEvents { last_seq } => {
            if let Some(device_id) = device_id.as_deref() {
                state.ack_events(device_id, last_seq)?;
            }
        }
        ClientEnvelope::Ping => {
            send_json(
                sender,
                &ServerEnvelope::Pong {
                    server_time_ms: crate::bridge_protocol::now_millis(),
                },
            )
            .await?;
        }
    }

    Ok(true)
}

pub(super) fn parse_client_envelope(text: &str) -> anyhow::Result<ClientEnvelope> {
    match serde_json::from_str::<ClientEnvelope>(text) {
        Ok(envelope) => Ok(envelope),
        Err(primary_error) => {
            let nested_payload = serde_json::from_str::<String>(text).map_err(|_| primary_error)?;
            serde_json::from_str::<ClientEnvelope>(&nested_payload).map_err(Into::into)
        }
    }
}

fn truncate_text(text: &str, max_chars: usize) -> String {
    let mut truncated = String::new();
    for (index, character) in text.chars().enumerate() {
        if index >= max_chars {
            truncated.push('');
            return truncated;
        }
        truncated.push(character);
    }
    truncated
}

async fn send_json(
    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
    envelope: &ServerEnvelope,
) -> anyhow::Result<()> {
    let text = serde_json::to_string(envelope)?;
    sender.send(Message::Text(text.into())).await?;
    Ok(())
}