Skip to main content

codex_mobile_bridge/state/
mod.rs

1mod events;
2mod helpers;
3mod runtime;
4mod threads;
5mod timeline;
6
7#[cfg(test)]
8mod tests;
9
10use std::collections::HashMap;
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration as StdDuration, Instant};
15
16use anyhow::{Result, bail};
17use serde_json::{Value, json};
18use tokio::sync::{RwLock, broadcast, mpsc};
19use tokio::time::Duration;
20use tracing::warn;
21
22use self::events::run_app_server_event_loop;
23use self::runtime::ManagedRuntime;
24use self::threads::seed_workspaces;
25use crate::app_server::AppServerInbound;
26use crate::bridge_protocol::{
27    PendingServerRequestRecord, PersistedEvent, RuntimeStatusSnapshot, RuntimeSummary,
28    WorkspaceRecord, require_payload,
29};
30use crate::config::Config;
31use crate::storage::Storage;
32
33pub struct BridgeState {
34    token: String,
35    storage: Storage,
36    runtimes: RwLock<HashMap<String, Arc<ManagedRuntime>>>,
37    primary_runtime_id: String,
38    runtime_limit: usize,
39    staging_root: PathBuf,
40    inbound_tx: mpsc::UnboundedSender<AppServerInbound>,
41    events_tx: broadcast::Sender<PersistedEvent>,
42    staged_turn_inputs: Mutex<HashMap<String, Vec<PathBuf>>>,
43    timeout_warning_tracker: Mutex<HashMap<String, Instant>>,
44}
45
46impl BridgeState {
47    pub async fn bootstrap(config: Config) -> Result<Arc<Self>> {
48        let storage = Storage::open(config.db_path.clone())?;
49        seed_workspaces(&storage, &config.workspace_roots)?;
50        let staging_root = staging_root_from_db_path(&config.db_path);
51        prepare_staging_root(&staging_root)?;
52
53        let primary_runtime = storage.ensure_primary_runtime(
54            config
55                .codex_home
56                .as_ref()
57                .map(|path| path.to_string_lossy().to_string()),
58            config.codex_binary.clone(),
59        )?;
60
61        let (events_tx, _) = broadcast::channel(512);
62        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
63
64        let mut runtime_map = HashMap::new();
65        for record in storage.list_runtimes()? {
66            let runtime = Arc::new(Self::build_runtime(record, inbound_tx.clone()));
67            runtime_map.insert(runtime.record.runtime_id.clone(), runtime);
68        }
69
70        let state = Arc::new(Self {
71            token: config.token,
72            storage,
73            runtimes: RwLock::new(runtime_map),
74            primary_runtime_id: primary_runtime.runtime_id,
75            runtime_limit: config.runtime_limit.max(1),
76            staging_root,
77            inbound_tx: inbound_tx.clone(),
78            events_tx,
79            staged_turn_inputs: Mutex::new(HashMap::new()),
80            timeout_warning_tracker: Mutex::new(HashMap::new()),
81        });
82
83        tokio::spawn(run_app_server_event_loop(Arc::clone(&state), inbound_rx));
84
85        for summary in state.runtime_summaries().await {
86            if summary.auto_start {
87                let runtime_id = summary.runtime_id.clone();
88                let state_ref = Arc::clone(&state);
89                tokio::spawn(async move {
90                    if let Err(error) = state_ref.start_existing_runtime(&runtime_id).await {
91                        let _ = state_ref
92                            .emit_runtime_degraded(
93                                &runtime_id,
94                                format!("自动启动 runtime 失败: {error}"),
95                            )
96                            .await;
97                    }
98                });
99            }
100        }
101
102        Ok(state)
103    }
104
105    pub fn subscribe_events(&self) -> broadcast::Receiver<PersistedEvent> {
106        self.events_tx.subscribe()
107    }
108
109    pub fn config_token(&self) -> &str {
110        &self.token
111    }
112
113    pub async fn hello_payload(
114        &self,
115        device_id: &str,
116        provided_ack_seq: Option<i64>,
117    ) -> Result<(
118        RuntimeStatusSnapshot,
119        Vec<RuntimeSummary>,
120        Vec<WorkspaceRecord>,
121        Vec<PendingServerRequestRecord>,
122        Vec<PersistedEvent>,
123    )> {
124        let fallback_ack = self.storage.get_mobile_session_ack(device_id)?.unwrap_or(0);
125        let last_ack_seq = provided_ack_seq.unwrap_or(fallback_ack);
126        self.storage
127            .save_mobile_session_ack(device_id, last_ack_seq)?;
128
129        let runtime = self.runtime_snapshot_for_client().await;
130        let runtimes = self.runtime_summaries_for_client().await;
131        let workspaces = self.storage.list_workspaces()?;
132        let pending_requests = self.storage.list_pending_requests()?;
133        let replay_events = self.storage.replay_events_after(last_ack_seq)?;
134
135        Ok((
136            runtime,
137            runtimes,
138            workspaces,
139            pending_requests,
140            replay_events,
141        ))
142    }
143
144    pub fn ack_events(&self, device_id: &str, last_seq: i64) -> Result<()> {
145        self.storage.save_mobile_session_ack(device_id, last_seq)
146    }
147
148    pub async fn handle_request(&self, action: &str, payload: Value) -> Result<Value> {
149        match action {
150            "get_runtime_status" => self.get_runtime_status(require_payload(payload)?).await,
151            "list_runtimes" => Ok(json!({ "runtimes": self.runtime_summaries_for_client().await })),
152            "start_runtime" => self.start_runtime(require_payload(payload)?).await,
153            "stop_runtime" => self.stop_runtime(require_payload(payload)?).await,
154            "restart_runtime" => self.restart_runtime(require_payload(payload)?).await,
155            "prune_runtime" => self.prune_runtime(require_payload(payload)?).await,
156            "list_workspaces" => Ok(json!({ "workspaces": self.storage.list_workspaces()? })),
157            "upsert_workspace" => self.upsert_workspace(require_payload(payload)?).await,
158            "list_threads" => self.list_threads(require_payload(payload)?).await,
159            "start_thread" => self.start_thread(require_payload(payload)?).await,
160            "read_thread" => self.read_thread(require_payload(payload)?).await,
161            "resume_thread" => self.resume_thread(require_payload(payload)?).await,
162            "update_thread" => self.update_thread(require_payload(payload)?).await,
163            "archive_thread" => self.archive_thread(require_payload(payload)?).await,
164            "unarchive_thread" => self.unarchive_thread(require_payload(payload)?).await,
165            "stage_input_image" => self.stage_input_image(require_payload(payload)?).await,
166            "send_turn" => self.send_turn(require_payload(payload)?).await,
167            "interrupt_turn" => self.interrupt_turn(require_payload(payload)?).await,
168            "respond_pending_request" => {
169                self.respond_pending_request(require_payload(payload)?)
170                    .await
171            }
172            _ => bail!("未知 action: {action}"),
173        }
174    }
175
176    fn log_timeout_warning(&self, key: &str, message: &str) {
177        let now = Instant::now();
178        let mut tracker = self
179            .timeout_warning_tracker
180            .lock()
181            .expect("timeout warning tracker poisoned");
182        let should_log = tracker
183            .get(key)
184            .map(|last| now.duration_since(*last) >= CLIENT_TIMEOUT_WARN_COOLDOWN)
185            .unwrap_or(true);
186        if should_log {
187            tracker.insert(key.to_string(), now);
188            warn!("{message}");
189        }
190    }
191
192    fn emit_event(
193        &self,
194        event_type: &str,
195        runtime_id: Option<&str>,
196        thread_id: Option<&str>,
197        payload: Value,
198    ) -> Result<()> {
199        let event = self
200            .storage
201            .append_event(event_type, runtime_id, thread_id, &payload)?;
202        let _ = self.events_tx.send(event);
203        Ok(())
204    }
205
206    pub(super) fn staging_root(&self) -> &Path {
207        &self.staging_root
208    }
209
210    pub(super) fn register_staged_turn_inputs(&self, turn_id: &str, paths: Vec<PathBuf>) {
211        if paths.is_empty() {
212            return;
213        }
214        let mut staged_turn_inputs = self
215            .staged_turn_inputs
216            .lock()
217            .expect("staged turn inputs poisoned");
218        staged_turn_inputs.insert(turn_id.to_string(), paths);
219    }
220
221    pub(super) fn cleanup_staged_turn_inputs(&self, turn_id: &str) -> Result<()> {
222        let paths = self
223            .staged_turn_inputs
224            .lock()
225            .expect("staged turn inputs poisoned")
226            .remove(turn_id)
227            .unwrap_or_default();
228        self.cleanup_staged_paths(paths)
229    }
230
231    pub(super) fn cleanup_staged_paths<I>(&self, paths: I) -> Result<()>
232    where
233        I: IntoIterator<Item = PathBuf>,
234    {
235        for path in paths {
236            self.remove_staged_path(&path)?;
237        }
238        Ok(())
239    }
240
241    fn remove_staged_path(&self, path: &Path) -> Result<()> {
242        if !path.starts_with(&self.staging_root) {
243            bail!("拒绝清理 staging 目录之外的文件: {}", path.display());
244        }
245        if path.exists() {
246            fs::remove_file(path)?;
247        }
248        Ok(())
249    }
250}
251
252fn staging_root_from_db_path(db_path: &Path) -> PathBuf {
253    db_path
254        .parent()
255        .unwrap_or_else(|| Path::new("."))
256        .join("staged-inputs")
257}
258
259fn prepare_staging_root(staging_root: &Path) -> Result<()> {
260    fs::create_dir_all(staging_root)?;
261    for entry in fs::read_dir(staging_root)? {
262        let path = entry?.path();
263        if path.is_dir() {
264            fs::remove_dir_all(&path)?;
265        } else {
266            fs::remove_file(&path)?;
267        }
268    }
269    Ok(())
270}
271
272pub(super) const CLIENT_STATUS_TIMEOUT: Duration = Duration::from_millis(400);
273pub(super) const CLIENT_TIMEOUT_WARN_COOLDOWN: StdDuration = StdDuration::from_secs(30);