Skip to main content

codex_mobile_bridge/state/
mod.rs

1mod directories;
2mod events;
3mod helpers;
4mod management;
5mod runtime;
6mod threads;
7mod timeline;
8
9#[cfg(test)]
10mod tests;
11
12use std::collections::HashMap;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, Mutex};
16use std::time::{Duration as StdDuration, Instant};
17
18use anyhow::{Result, bail};
19use serde_json::{Value, json};
20use tokio::sync::{RwLock, broadcast, mpsc};
21use tokio::time::Duration;
22use tracing::warn;
23
24use self::directories::seed_directory_bookmarks;
25use self::events::run_app_server_event_loop;
26use self::runtime::ManagedRuntime;
27use crate::app_server::AppServerInbound;
28use crate::bridge_protocol::{
29    DirectoryBookmarkRecord, DirectoryHistoryRecord, PendingServerRequestRecord, PersistedEvent,
30    RuntimeStatusSnapshot, RuntimeSummary, require_payload,
31};
32use crate::config::Config;
33use crate::storage::Storage;
34
35pub struct BridgeState {
36    token: String,
37    storage: Storage,
38    runtimes: RwLock<HashMap<String, Arc<ManagedRuntime>>>,
39    primary_runtime_id: String,
40    runtime_limit: usize,
41    staging_root: PathBuf,
42    inbound_tx: mpsc::UnboundedSender<AppServerInbound>,
43    events_tx: broadcast::Sender<PersistedEvent>,
44    staged_turn_inputs: Mutex<HashMap<String, Vec<PathBuf>>>,
45    timeout_warning_tracker: Mutex<HashMap<String, Instant>>,
46    external_event_cursor: Mutex<i64>,
47}
48
49impl BridgeState {
50    pub async fn bootstrap(config: Config) -> Result<Arc<Self>> {
51        let storage = Storage::open(config.db_path.clone())?;
52        seed_directory_bookmarks(&storage, &config.directory_bookmarks)?;
53        let staging_root = staging_root_from_db_path(&config.db_path);
54        prepare_staging_root(&staging_root)?;
55
56        let primary_runtime = storage.ensure_primary_runtime(
57            config
58                .codex_home
59                .as_ref()
60                .map(|path| path.to_string_lossy().to_string()),
61            config.codex_binary.clone(),
62        )?;
63
64        let (events_tx, _) = broadcast::channel(512);
65        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
66
67        let mut runtime_map = HashMap::new();
68        for record in storage.list_runtimes()? {
69            let runtime = Arc::new(Self::build_runtime(record, inbound_tx.clone()));
70            runtime_map.insert(runtime.record.runtime_id.clone(), runtime);
71        }
72
73        let state = Arc::new(Self {
74            token: config.token,
75            storage,
76            runtimes: RwLock::new(runtime_map),
77            primary_runtime_id: primary_runtime.runtime_id,
78            runtime_limit: config.runtime_limit.max(1),
79            staging_root,
80            inbound_tx: inbound_tx.clone(),
81            events_tx,
82            staged_turn_inputs: Mutex::new(HashMap::new()),
83            timeout_warning_tracker: Mutex::new(HashMap::new()),
84            external_event_cursor: Mutex::new(0),
85        });
86
87        *state
88            .external_event_cursor
89            .lock()
90            .expect("external event cursor poisoned") = state.storage.latest_event_seq()?;
91
92        tokio::spawn(run_app_server_event_loop(Arc::clone(&state), inbound_rx));
93        tokio::spawn(management::run_external_event_relay(Arc::clone(&state)));
94
95        for summary in state.runtime_summaries().await {
96            if summary.auto_start {
97                let runtime_id = summary.runtime_id.clone();
98                let state_ref = Arc::clone(&state);
99                tokio::spawn(async move {
100                    if let Err(error) = state_ref.start_existing_runtime(&runtime_id).await {
101                        let _ = state_ref
102                            .emit_runtime_degraded(
103                                &runtime_id,
104                                format!("自动启动 runtime 失败: {error}"),
105                            )
106                            .await;
107                    }
108                });
109            }
110        }
111
112        Ok(state)
113    }
114
115    pub fn subscribe_events(&self) -> broadcast::Receiver<PersistedEvent> {
116        self.events_tx.subscribe()
117    }
118
119    pub fn config_token(&self) -> &str {
120        &self.token
121    }
122
123    pub async fn hello_payload(
124        &self,
125        device_id: &str,
126        provided_ack_seq: Option<i64>,
127    ) -> Result<(
128        RuntimeStatusSnapshot,
129        Vec<RuntimeSummary>,
130        Vec<DirectoryBookmarkRecord>,
131        Vec<DirectoryHistoryRecord>,
132        Vec<PendingServerRequestRecord>,
133        Vec<PersistedEvent>,
134    )> {
135        let fallback_ack = self.storage.get_mobile_session_ack(device_id)?.unwrap_or(0);
136        let last_ack_seq = provided_ack_seq.unwrap_or(fallback_ack);
137        self.storage
138            .save_mobile_session_ack(device_id, last_ack_seq)?;
139
140        let runtime = self.runtime_snapshot_for_client().await;
141        let runtimes = self.runtime_summaries_for_client().await;
142        let directory_bookmarks = self.storage.list_directory_bookmarks()?;
143        let directory_history = self.storage.list_directory_history(20)?;
144        let pending_requests = self.storage.list_pending_requests()?;
145        let replay_events = self.storage.replay_events_after(last_ack_seq)?;
146
147        Ok((
148            runtime,
149            runtimes,
150            directory_bookmarks,
151            directory_history,
152            pending_requests,
153            replay_events,
154        ))
155    }
156
157    pub fn ack_events(&self, device_id: &str, last_seq: i64) -> Result<()> {
158        self.storage.save_mobile_session_ack(device_id, last_seq)
159    }
160
161    pub async fn handle_request(&self, action: &str, payload: Value) -> Result<Value> {
162        match action {
163            "get_runtime_status" => self.get_runtime_status(require_payload(payload)?).await,
164            "list_runtimes" => Ok(json!({ "runtimes": self.runtime_summaries_for_client().await })),
165            "start_runtime" => self.start_runtime(require_payload(payload)?).await,
166            "stop_runtime" => self.stop_runtime(require_payload(payload)?).await,
167            "restart_runtime" => self.restart_runtime(require_payload(payload)?).await,
168            "prune_runtime" => self.prune_runtime(require_payload(payload)?).await,
169            "read_directory" => self.read_directory(require_payload(payload)?).await,
170            "create_directory_bookmark" => {
171                self.create_directory_bookmark(require_payload(payload)?)
172                    .await
173            }
174            "remove_directory_bookmark" => {
175                self.remove_directory_bookmark(require_payload(payload)?)
176                    .await
177            }
178            "list_threads" => self.list_threads(require_payload(payload)?).await,
179            "start_thread" => self.start_thread(require_payload(payload)?).await,
180            "read_thread" => self.read_thread(require_payload(payload)?).await,
181            "resume_thread" => self.resume_thread(require_payload(payload)?).await,
182            "update_thread" => self.update_thread(require_payload(payload)?).await,
183            "archive_thread" => self.archive_thread(require_payload(payload)?).await,
184            "unarchive_thread" => self.unarchive_thread(require_payload(payload)?).await,
185            "stage_input_image" => self.stage_input_image(require_payload(payload)?).await,
186            "send_turn" => self.send_turn(require_payload(payload)?).await,
187            "interrupt_turn" => self.interrupt_turn(require_payload(payload)?).await,
188            "respond_pending_request" => {
189                self.respond_pending_request(require_payload(payload)?)
190                    .await
191            }
192            "start_bridge_management" => {
193                self.start_bridge_management(require_payload(payload)?)
194                    .await
195            }
196            "read_bridge_management" => {
197                self.read_bridge_management(require_payload(payload)?).await
198            }
199            "inspect_remote_state" => self.inspect_remote_state().await,
200            _ => bail!("未知 action: {action}"),
201        }
202    }
203
204    fn log_timeout_warning(&self, key: &str, message: &str) {
205        if self.should_emit_rate_limited_notice(key) {
206            warn!("{message}");
207        }
208    }
209
210    pub(super) fn should_emit_rate_limited_notice(&self, key: &str) -> bool {
211        let now = Instant::now();
212        let mut tracker = self
213            .timeout_warning_tracker
214            .lock()
215            .expect("timeout warning tracker poisoned");
216        let should_emit = tracker
217            .get(key)
218            .map(|last| now.duration_since(*last) >= CLIENT_TIMEOUT_WARN_COOLDOWN)
219            .unwrap_or(true);
220        if should_emit {
221            tracker.insert(key.to_string(), now);
222        }
223        should_emit
224    }
225
226    fn emit_event(
227        &self,
228        event_type: &str,
229        runtime_id: Option<&str>,
230        thread_id: Option<&str>,
231        payload: Value,
232    ) -> Result<()> {
233        let event = self
234            .storage
235            .append_event(event_type, runtime_id, thread_id, &payload)?;
236        *self
237            .external_event_cursor
238            .lock()
239            .expect("external event cursor poisoned") = event.seq;
240        let _ = self.events_tx.send(event);
241        Ok(())
242    }
243
244    pub(super) fn staging_root(&self) -> &Path {
245        &self.staging_root
246    }
247
248    pub(super) fn register_staged_turn_inputs(&self, turn_id: &str, paths: Vec<PathBuf>) {
249        if paths.is_empty() {
250            return;
251        }
252        let mut staged_turn_inputs = self
253            .staged_turn_inputs
254            .lock()
255            .expect("staged turn inputs poisoned");
256        staged_turn_inputs.insert(turn_id.to_string(), paths);
257    }
258
259    pub(super) fn cleanup_staged_turn_inputs(&self, turn_id: &str) -> Result<()> {
260        let paths = self
261            .staged_turn_inputs
262            .lock()
263            .expect("staged turn inputs poisoned")
264            .remove(turn_id)
265            .unwrap_or_default();
266        self.cleanup_staged_paths(paths)
267    }
268
269    pub(super) fn cleanup_staged_paths<I>(&self, paths: I) -> Result<()>
270    where
271        I: IntoIterator<Item = PathBuf>,
272    {
273        for path in paths {
274            self.remove_staged_path(&path)?;
275        }
276        Ok(())
277    }
278
279    fn remove_staged_path(&self, path: &Path) -> Result<()> {
280        if !path.starts_with(&self.staging_root) {
281            bail!("拒绝清理 staging 目录之外的文件: {}", path.display());
282        }
283        if path.exists() {
284            fs::remove_file(path)?;
285        }
286        Ok(())
287    }
288}
289
290fn staging_root_from_db_path(db_path: &Path) -> PathBuf {
291    db_path
292        .parent()
293        .unwrap_or_else(|| Path::new("."))
294        .join("staged-inputs")
295}
296
297fn prepare_staging_root(staging_root: &Path) -> Result<()> {
298    fs::create_dir_all(staging_root)?;
299    for entry in fs::read_dir(staging_root)? {
300        let path = entry?.path();
301        if path.is_dir() {
302            fs::remove_dir_all(&path)?;
303        } else {
304            fs::remove_file(&path)?;
305        }
306    }
307    Ok(())
308}
309
310pub(super) const CLIENT_STATUS_TIMEOUT: Duration = Duration::from_millis(400);
311pub(super) const CLIENT_TIMEOUT_WARN_COOLDOWN: StdDuration = StdDuration::from_secs(30);