Skip to main content

codex_mobile_bridge/state/
mod.rs

1mod directories;
2mod events;
3mod helpers;
4mod runtime;
5mod threads;
6mod timeline;
7
8#[cfg(test)]
9mod tests;
10
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex};
15use std::time::{Duration as StdDuration, Instant};
16
17use anyhow::{Result, bail};
18use serde_json::{Value, json};
19use tokio::sync::{RwLock, broadcast, mpsc};
20use tokio::time::Duration;
21use tracing::warn;
22
23use self::directories::seed_directory_bookmarks;
24use self::events::run_app_server_event_loop;
25use self::runtime::ManagedRuntime;
26use crate::app_server::AppServerInbound;
27use crate::bridge_protocol::{
28    DirectoryBookmarkRecord, DirectoryHistoryRecord, PendingServerRequestRecord, PersistedEvent,
29    RuntimeStatusSnapshot, RuntimeSummary, require_payload,
30};
31use crate::config::Config;
32use crate::storage::Storage;
33
34pub struct BridgeState {
35    token: String,
36    storage: Storage,
37    runtimes: RwLock<HashMap<String, Arc<ManagedRuntime>>>,
38    primary_runtime_id: String,
39    runtime_limit: usize,
40    staging_root: PathBuf,
41    inbound_tx: mpsc::UnboundedSender<AppServerInbound>,
42    events_tx: broadcast::Sender<PersistedEvent>,
43    staged_turn_inputs: Mutex<HashMap<String, Vec<PathBuf>>>,
44    timeout_warning_tracker: Mutex<HashMap<String, Instant>>,
45}
46
47impl BridgeState {
48    pub async fn bootstrap(config: Config) -> Result<Arc<Self>> {
49        let storage = Storage::open(config.db_path.clone())?;
50        seed_directory_bookmarks(&storage, &config.directory_bookmarks)?;
51        let staging_root = staging_root_from_db_path(&config.db_path);
52        prepare_staging_root(&staging_root)?;
53
54        let primary_runtime = storage.ensure_primary_runtime(
55            config
56                .codex_home
57                .as_ref()
58                .map(|path| path.to_string_lossy().to_string()),
59            config.codex_binary.clone(),
60        )?;
61
62        let (events_tx, _) = broadcast::channel(512);
63        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
64
65        let mut runtime_map = HashMap::new();
66        for record in storage.list_runtimes()? {
67            let runtime = Arc::new(Self::build_runtime(record, inbound_tx.clone()));
68            runtime_map.insert(runtime.record.runtime_id.clone(), runtime);
69        }
70
71        let state = Arc::new(Self {
72            token: config.token,
73            storage,
74            runtimes: RwLock::new(runtime_map),
75            primary_runtime_id: primary_runtime.runtime_id,
76            runtime_limit: config.runtime_limit.max(1),
77            staging_root,
78            inbound_tx: inbound_tx.clone(),
79            events_tx,
80            staged_turn_inputs: Mutex::new(HashMap::new()),
81            timeout_warning_tracker: Mutex::new(HashMap::new()),
82        });
83
84        tokio::spawn(run_app_server_event_loop(Arc::clone(&state), inbound_rx));
85
86        for summary in state.runtime_summaries().await {
87            if summary.auto_start {
88                let runtime_id = summary.runtime_id.clone();
89                let state_ref = Arc::clone(&state);
90                tokio::spawn(async move {
91                    if let Err(error) = state_ref.start_existing_runtime(&runtime_id).await {
92                        let _ = state_ref
93                            .emit_runtime_degraded(
94                                &runtime_id,
95                                format!("自动启动 runtime 失败: {error}"),
96                            )
97                            .await;
98                    }
99                });
100            }
101        }
102
103        Ok(state)
104    }
105
106    pub fn subscribe_events(&self) -> broadcast::Receiver<PersistedEvent> {
107        self.events_tx.subscribe()
108    }
109
110    pub fn config_token(&self) -> &str {
111        &self.token
112    }
113
114    pub async fn hello_payload(
115        &self,
116        device_id: &str,
117        provided_ack_seq: Option<i64>,
118    ) -> Result<(
119        RuntimeStatusSnapshot,
120        Vec<RuntimeSummary>,
121        Vec<DirectoryBookmarkRecord>,
122        Vec<DirectoryHistoryRecord>,
123        Vec<PendingServerRequestRecord>,
124        Vec<PersistedEvent>,
125    )> {
126        let fallback_ack = self.storage.get_mobile_session_ack(device_id)?.unwrap_or(0);
127        let last_ack_seq = provided_ack_seq.unwrap_or(fallback_ack);
128        self.storage
129            .save_mobile_session_ack(device_id, last_ack_seq)?;
130
131        let runtime = self.runtime_snapshot_for_client().await;
132        let runtimes = self.runtime_summaries_for_client().await;
133        let directory_bookmarks = self.storage.list_directory_bookmarks()?;
134        let directory_history = self.storage.list_directory_history(20)?;
135        let pending_requests = self.storage.list_pending_requests()?;
136        let replay_events = self.storage.replay_events_after(last_ack_seq)?;
137
138        Ok((
139            runtime,
140            runtimes,
141            directory_bookmarks,
142            directory_history,
143            pending_requests,
144            replay_events,
145        ))
146    }
147
148    pub fn ack_events(&self, device_id: &str, last_seq: i64) -> Result<()> {
149        self.storage.save_mobile_session_ack(device_id, last_seq)
150    }
151
152    pub async fn handle_request(&self, action: &str, payload: Value) -> Result<Value> {
153        match action {
154            "get_runtime_status" => self.get_runtime_status(require_payload(payload)?).await,
155            "list_runtimes" => Ok(json!({ "runtimes": self.runtime_summaries_for_client().await })),
156            "start_runtime" => self.start_runtime(require_payload(payload)?).await,
157            "stop_runtime" => self.stop_runtime(require_payload(payload)?).await,
158            "restart_runtime" => self.restart_runtime(require_payload(payload)?).await,
159            "prune_runtime" => self.prune_runtime(require_payload(payload)?).await,
160            "read_directory" => self.read_directory(require_payload(payload)?).await,
161            "create_directory_bookmark" => {
162                self.create_directory_bookmark(require_payload(payload)?)
163                    .await
164            }
165            "remove_directory_bookmark" => {
166                self.remove_directory_bookmark(require_payload(payload)?)
167                    .await
168            }
169            "list_threads" => self.list_threads(require_payload(payload)?).await,
170            "start_thread" => self.start_thread(require_payload(payload)?).await,
171            "read_thread" => self.read_thread(require_payload(payload)?).await,
172            "resume_thread" => self.resume_thread(require_payload(payload)?).await,
173            "update_thread" => self.update_thread(require_payload(payload)?).await,
174            "archive_thread" => self.archive_thread(require_payload(payload)?).await,
175            "unarchive_thread" => self.unarchive_thread(require_payload(payload)?).await,
176            "stage_input_image" => self.stage_input_image(require_payload(payload)?).await,
177            "send_turn" => self.send_turn(require_payload(payload)?).await,
178            "interrupt_turn" => self.interrupt_turn(require_payload(payload)?).await,
179            "respond_pending_request" => {
180                self.respond_pending_request(require_payload(payload)?)
181                    .await
182            }
183            _ => bail!("未知 action: {action}"),
184        }
185    }
186
187    fn log_timeout_warning(&self, key: &str, message: &str) {
188        if self.should_emit_rate_limited_notice(key) {
189            warn!("{message}");
190        }
191    }
192
193    pub(super) fn should_emit_rate_limited_notice(&self, key: &str) -> bool {
194        let now = Instant::now();
195        let mut tracker = self
196            .timeout_warning_tracker
197            .lock()
198            .expect("timeout warning tracker poisoned");
199        let should_emit = tracker
200            .get(key)
201            .map(|last| now.duration_since(*last) >= CLIENT_TIMEOUT_WARN_COOLDOWN)
202            .unwrap_or(true);
203        if should_emit {
204            tracker.insert(key.to_string(), now);
205        }
206        should_emit
207    }
208
209    fn emit_event(
210        &self,
211        event_type: &str,
212        runtime_id: Option<&str>,
213        thread_id: Option<&str>,
214        payload: Value,
215    ) -> Result<()> {
216        let event = self
217            .storage
218            .append_event(event_type, runtime_id, thread_id, &payload)?;
219        let _ = self.events_tx.send(event);
220        Ok(())
221    }
222
223    pub(super) fn staging_root(&self) -> &Path {
224        &self.staging_root
225    }
226
227    pub(super) fn register_staged_turn_inputs(&self, turn_id: &str, paths: Vec<PathBuf>) {
228        if paths.is_empty() {
229            return;
230        }
231        let mut staged_turn_inputs = self
232            .staged_turn_inputs
233            .lock()
234            .expect("staged turn inputs poisoned");
235        staged_turn_inputs.insert(turn_id.to_string(), paths);
236    }
237
238    pub(super) fn cleanup_staged_turn_inputs(&self, turn_id: &str) -> Result<()> {
239        let paths = self
240            .staged_turn_inputs
241            .lock()
242            .expect("staged turn inputs poisoned")
243            .remove(turn_id)
244            .unwrap_or_default();
245        self.cleanup_staged_paths(paths)
246    }
247
248    pub(super) fn cleanup_staged_paths<I>(&self, paths: I) -> Result<()>
249    where
250        I: IntoIterator<Item = PathBuf>,
251    {
252        for path in paths {
253            self.remove_staged_path(&path)?;
254        }
255        Ok(())
256    }
257
258    fn remove_staged_path(&self, path: &Path) -> Result<()> {
259        if !path.starts_with(&self.staging_root) {
260            bail!("拒绝清理 staging 目录之外的文件: {}", path.display());
261        }
262        if path.exists() {
263            fs::remove_file(path)?;
264        }
265        Ok(())
266    }
267}
268
269fn staging_root_from_db_path(db_path: &Path) -> PathBuf {
270    db_path
271        .parent()
272        .unwrap_or_else(|| Path::new("."))
273        .join("staged-inputs")
274}
275
276fn prepare_staging_root(staging_root: &Path) -> Result<()> {
277    fs::create_dir_all(staging_root)?;
278    for entry in fs::read_dir(staging_root)? {
279        let path = entry?.path();
280        if path.is_dir() {
281            fs::remove_dir_all(&path)?;
282        } else {
283            fs::remove_file(&path)?;
284        }
285    }
286    Ok(())
287}
288
289pub(super) const CLIENT_STATUS_TIMEOUT: Duration = Duration::from_millis(400);
290pub(super) const CLIENT_TIMEOUT_WARN_COOLDOWN: StdDuration = StdDuration::from_secs(30);