Skip to main content

codex_mobile_bridge/state/
mod.rs

1mod events;
2mod helpers;
3mod runtime;
4mod threads;
5mod timeline;
6
7#[cfg(test)]
8mod tests;
9
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use std::time::{Duration as StdDuration, Instant};
13
14use anyhow::{Result, bail};
15use serde_json::{Value, json};
16use tokio::sync::{RwLock, broadcast, mpsc};
17use tokio::time::Duration;
18use tracing::warn;
19
20use self::events::run_app_server_event_loop;
21use self::runtime::ManagedRuntime;
22use self::threads::seed_workspaces;
23use crate::app_server::AppServerInbound;
24use crate::bridge_protocol::{
25    PendingServerRequestRecord, PersistedEvent, RuntimeStatusSnapshot, RuntimeSummary,
26    WorkspaceRecord, require_payload,
27};
28use crate::config::Config;
29use crate::storage::Storage;
30
31pub struct BridgeState {
32    token: String,
33    storage: Storage,
34    runtimes: RwLock<HashMap<String, Arc<ManagedRuntime>>>,
35    primary_runtime_id: String,
36    runtime_limit: usize,
37    inbound_tx: mpsc::UnboundedSender<AppServerInbound>,
38    events_tx: broadcast::Sender<PersistedEvent>,
39    timeout_warning_tracker: Mutex<HashMap<String, Instant>>,
40}
41
42impl BridgeState {
43    pub async fn bootstrap(config: Config) -> Result<Arc<Self>> {
44        let storage = Storage::open(config.db_path.clone())?;
45        seed_workspaces(&storage, &config.workspace_roots)?;
46
47        let primary_runtime = storage.ensure_primary_runtime(
48            config
49                .codex_home
50                .as_ref()
51                .map(|path| path.to_string_lossy().to_string()),
52            config.codex_binary.clone(),
53        )?;
54
55        let (events_tx, _) = broadcast::channel(512);
56        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
57
58        let mut runtime_map = HashMap::new();
59        for record in storage.list_runtimes()? {
60            let runtime = Arc::new(Self::build_runtime(record, inbound_tx.clone()));
61            runtime_map.insert(runtime.record.runtime_id.clone(), runtime);
62        }
63
64        let state = Arc::new(Self {
65            token: config.token,
66            storage,
67            runtimes: RwLock::new(runtime_map),
68            primary_runtime_id: primary_runtime.runtime_id,
69            runtime_limit: config.runtime_limit.max(1),
70            inbound_tx: inbound_tx.clone(),
71            events_tx,
72            timeout_warning_tracker: Mutex::new(HashMap::new()),
73        });
74
75        tokio::spawn(run_app_server_event_loop(Arc::clone(&state), inbound_rx));
76
77        for summary in state.runtime_summaries().await {
78            if summary.auto_start {
79                let runtime_id = summary.runtime_id.clone();
80                let state_ref = Arc::clone(&state);
81                tokio::spawn(async move {
82                    if let Err(error) = state_ref.start_existing_runtime(&runtime_id).await {
83                        let _ = state_ref
84                            .emit_runtime_degraded(
85                                &runtime_id,
86                                format!("自动启动 runtime 失败: {error}"),
87                            )
88                            .await;
89                    }
90                });
91            }
92        }
93
94        Ok(state)
95    }
96
97    pub fn subscribe_events(&self) -> broadcast::Receiver<PersistedEvent> {
98        self.events_tx.subscribe()
99    }
100
101    pub fn config_token(&self) -> &str {
102        &self.token
103    }
104
105    pub async fn hello_payload(
106        &self,
107        device_id: &str,
108        provided_ack_seq: Option<i64>,
109    ) -> Result<(
110        RuntimeStatusSnapshot,
111        Vec<RuntimeSummary>,
112        Vec<WorkspaceRecord>,
113        Vec<PendingServerRequestRecord>,
114        Vec<PersistedEvent>,
115    )> {
116        let fallback_ack = self.storage.get_mobile_session_ack(device_id)?.unwrap_or(0);
117        let last_ack_seq = provided_ack_seq.unwrap_or(fallback_ack);
118        self.storage
119            .save_mobile_session_ack(device_id, last_ack_seq)?;
120
121        let runtime = self.runtime_snapshot_for_client().await;
122        let runtimes = self.runtime_summaries_for_client().await;
123        let workspaces = self.storage.list_workspaces()?;
124        let pending_requests = self.storage.list_pending_requests()?;
125        let replay_events = self.storage.replay_events_after(last_ack_seq)?;
126
127        Ok((
128            runtime,
129            runtimes,
130            workspaces,
131            pending_requests,
132            replay_events,
133        ))
134    }
135
136    pub fn ack_events(&self, device_id: &str, last_seq: i64) -> Result<()> {
137        self.storage.save_mobile_session_ack(device_id, last_seq)
138    }
139
140    pub async fn handle_request(&self, action: &str, payload: Value) -> Result<Value> {
141        match action {
142            "get_runtime_status" => self.get_runtime_status(require_payload(payload)?).await,
143            "list_runtimes" => Ok(json!({ "runtimes": self.runtime_summaries_for_client().await })),
144            "start_runtime" => self.start_runtime(require_payload(payload)?).await,
145            "stop_runtime" => self.stop_runtime(require_payload(payload)?).await,
146            "restart_runtime" => self.restart_runtime(require_payload(payload)?).await,
147            "prune_runtime" => self.prune_runtime(require_payload(payload)?).await,
148            "list_workspaces" => Ok(json!({ "workspaces": self.storage.list_workspaces()? })),
149            "upsert_workspace" => self.upsert_workspace(require_payload(payload)?).await,
150            "list_threads" => self.list_threads(require_payload(payload)?).await,
151            "start_thread" => self.start_thread(require_payload(payload)?).await,
152            "read_thread" => self.read_thread(require_payload(payload)?).await,
153            "resume_thread" => self.resume_thread(require_payload(payload)?).await,
154            "update_thread" => self.update_thread(require_payload(payload)?).await,
155            "archive_thread" => self.archive_thread(require_payload(payload)?).await,
156            "unarchive_thread" => self.unarchive_thread(require_payload(payload)?).await,
157            "send_turn" => self.send_turn(require_payload(payload)?).await,
158            "interrupt_turn" => self.interrupt_turn(require_payload(payload)?).await,
159            "respond_pending_request" => {
160                self.respond_pending_request(require_payload(payload)?)
161                    .await
162            }
163            _ => bail!("未知 action: {action}"),
164        }
165    }
166
167    fn log_timeout_warning(&self, key: &str, message: &str) {
168        let now = Instant::now();
169        let mut tracker = self
170            .timeout_warning_tracker
171            .lock()
172            .expect("timeout warning tracker poisoned");
173        let should_log = tracker
174            .get(key)
175            .map(|last| now.duration_since(*last) >= CLIENT_TIMEOUT_WARN_COOLDOWN)
176            .unwrap_or(true);
177        if should_log {
178            tracker.insert(key.to_string(), now);
179            warn!("{message}");
180        }
181    }
182
183    fn emit_event(
184        &self,
185        event_type: &str,
186        runtime_id: Option<&str>,
187        thread_id: Option<&str>,
188        payload: Value,
189    ) -> Result<()> {
190        let event = self
191            .storage
192            .append_event(event_type, runtime_id, thread_id, &payload)?;
193        let _ = self.events_tx.send(event);
194        Ok(())
195    }
196}
197
198pub(super) const CLIENT_STATUS_TIMEOUT: Duration = Duration::from_millis(400);
199pub(super) const CLIENT_TIMEOUT_WARN_COOLDOWN: StdDuration = StdDuration::from_secs(30);