Skip to main content

codex_runtime/runtime/core/
mod.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
3use std::sync::Arc;
4use std::sync::RwLock;
5
6use crate::plugin::{BlockReason, HookContext, HookReport};
7use arc_swap::ArcSwapOption;
8use serde_json::Value;
9use tokio::sync::{broadcast, mpsc, oneshot, Mutex, Notify};
10use tokio::task::JoinHandle;
11use tokio::time::Duration;
12
13#[cfg(test)]
14use crate::runtime::approvals::TimeoutAction;
15use crate::runtime::approvals::{ServerRequest, ServerRequestConfig};
16use crate::runtime::errors::{RpcError, RuntimeError};
17use crate::runtime::events::{Envelope, JsonRpcId};
18use crate::runtime::hooks::{HookKernel, PreHookDecision, RuntimeHookConfig};
19use crate::runtime::metrics::{RuntimeMetrics, RuntimeMetricsSnapshot};
20use crate::runtime::runtime_validation::validate_runtime_capacities;
21#[cfg(test)]
22use crate::runtime::state::ConnectionState;
23use crate::runtime::state::{RuntimeState, StateProjectionLimits};
24use crate::runtime::transport::{StdioProcessSpec, StdioTransport, StdioTransportConfig};
25
26type PendingResult = Result<Value, RpcError>;
27
28mod approval;
29mod config;
30mod dispatch;
31pub(crate) mod io_policy;
32mod lifecycle;
33mod rpc;
34mod rpc_io;
35mod state_projection;
36mod supervisor;
37
38pub use config::{InitializeCapabilities, RestartPolicy, RuntimeConfig, SupervisorConfig};
39use dispatch::event_sink_loop;
40use lifecycle::{shutdown_runtime, spawn_connection_generation};
41use state_projection::state_snapshot_arc;
42use supervisor::start_supervisor_task;
43
44#[derive(Clone, Debug, PartialEq, Eq)]
45struct PendingServerRequestEntry {
46    rpc_id: JsonRpcId,
47    rpc_key: String,
48    method: String,
49    created_at_millis: i64,
50    deadline_millis: i64,
51}
52
53struct RuntimeCounters {
54    initialized: AtomicBool,
55    shutting_down: AtomicBool,
56    generation: AtomicU64,
57    next_rpc_id: AtomicU64,
58    next_seq: AtomicU64,
59}
60
61struct RuntimeSpec {
62    process: StdioProcessSpec,
63    transport_cfg: StdioTransportConfig,
64    initialize_params: Value,
65    supervisor_cfg: SupervisorConfig,
66    rpc_response_timeout: Duration,
67    server_request_cfg: ServerRequestConfig,
68    state_projection_limits: StateProjectionLimits,
69}
70
71struct RuntimeIo {
72    pending: Mutex<HashMap<u64, oneshot::Sender<PendingResult>>>,
73    outbound_tx: ArcSwapOption<mpsc::Sender<Value>>,
74    live_tx: broadcast::Sender<Envelope>,
75    pending_server_requests: Mutex<HashMap<String, PendingServerRequestEntry>>,
76    server_request_tx: mpsc::Sender<ServerRequest>,
77    server_request_rx: Mutex<Option<mpsc::Receiver<ServerRequest>>>,
78    event_sink_tx: Option<mpsc::Sender<Envelope>>,
79    transport_closed_signal: Notify,
80    shutdown_signal: Notify,
81}
82
83struct RuntimeTasks {
84    event_sink_task: Mutex<Option<JoinHandle<()>>>,
85    supervisor_task: Mutex<Option<JoinHandle<()>>>,
86    dispatcher_task: Mutex<Option<JoinHandle<()>>>,
87    transport: Mutex<Option<StdioTransport>>,
88}
89
90struct RuntimeSnapshots {
91    state: RwLock<Arc<RuntimeState>>,
92    initialize_result: RwLock<Option<Value>>,
93}
94
95#[derive(Clone)]
96pub struct Runtime {
97    inner: Arc<RuntimeInner>,
98}
99
100struct RuntimeInner {
101    counters: RuntimeCounters,
102    spec: RuntimeSpec,
103    io: RuntimeIo,
104    tasks: RuntimeTasks,
105    snapshots: RuntimeSnapshots,
106    metrics: Arc<RuntimeMetrics>,
107    hooks: HookKernel,
108}
109
110impl Runtime {
111    pub async fn spawn_local(cfg: RuntimeConfig) -> Result<Self, RuntimeError> {
112        let RuntimeConfig {
113            process,
114            hooks,
115            transport,
116            supervisor,
117            rpc_response_timeout,
118            server_requests,
119            initialize_params,
120            live_channel_capacity,
121            server_request_channel_capacity,
122            event_sink,
123            event_sink_channel_capacity,
124            state_projection_limits,
125        } = cfg;
126
127        validate_runtime_capacities(
128            live_channel_capacity,
129            server_request_channel_capacity,
130            event_sink.is_some(),
131            event_sink_channel_capacity,
132            rpc_response_timeout,
133        )?;
134        crate::runtime::runtime_validation::validate_state_projection_limits(
135            &state_projection_limits,
136        )?;
137
138        let (live_tx, _) = broadcast::channel(live_channel_capacity);
139        let (server_request_tx, server_request_rx) = mpsc::channel(server_request_channel_capacity);
140        let metrics = Arc::new(RuntimeMetrics::new(now_millis()));
141        let (event_sink_tx, event_sink_task) = match event_sink {
142            Some(sink) => {
143                let (tx, rx) = mpsc::channel(event_sink_channel_capacity);
144                let task = tokio::spawn(event_sink_loop(sink, Arc::clone(&metrics), rx));
145                (Some(tx), Some(task))
146            }
147            None => (None, None),
148        };
149
150        let runtime = Self {
151            inner: Arc::new(RuntimeInner {
152                counters: RuntimeCounters {
153                    initialized: AtomicBool::new(false),
154                    shutting_down: AtomicBool::new(false),
155                    generation: AtomicU64::new(0),
156                    next_rpc_id: AtomicU64::new(1),
157                    next_seq: AtomicU64::new(0),
158                },
159                spec: RuntimeSpec {
160                    process,
161                    transport_cfg: transport,
162                    initialize_params,
163                    supervisor_cfg: supervisor,
164                    rpc_response_timeout,
165                    server_request_cfg: server_requests,
166                    state_projection_limits,
167                },
168                io: RuntimeIo {
169                    pending: Mutex::new(HashMap::new()),
170                    outbound_tx: ArcSwapOption::new(None),
171                    live_tx,
172                    pending_server_requests: Mutex::new(HashMap::new()),
173                    server_request_tx,
174                    server_request_rx: Mutex::new(Some(server_request_rx)),
175                    event_sink_tx,
176                    transport_closed_signal: Notify::new(),
177                    shutdown_signal: Notify::new(),
178                },
179                tasks: RuntimeTasks {
180                    event_sink_task: Mutex::new(event_sink_task),
181                    supervisor_task: Mutex::new(None),
182                    dispatcher_task: Mutex::new(None),
183                    transport: Mutex::new(None),
184                },
185                snapshots: RuntimeSnapshots {
186                    state: RwLock::new(Arc::new(RuntimeState::default())),
187                    initialize_result: RwLock::new(None),
188                },
189                metrics,
190                hooks: HookKernel::new(hooks),
191            }),
192        };
193
194        spawn_connection_generation(&runtime.inner, 0).await?;
195        start_supervisor_task(&runtime.inner).await;
196
197        Ok(runtime)
198    }
199
200    pub fn subscribe_live(&self) -> broadcast::Receiver<Envelope> {
201        self.inner.io.live_tx.subscribe()
202    }
203
204    pub fn is_initialized(&self) -> bool {
205        self.inner.counters.initialized.load(Ordering::Acquire)
206    }
207
208    pub fn state_snapshot(&self) -> Arc<RuntimeState> {
209        state_snapshot_arc(&self.inner)
210    }
211
212    pub fn initialize_result_snapshot(&self) -> Option<Value> {
213        match self.inner.snapshots.initialize_result.read() {
214            Ok(guard) => guard.clone(),
215            Err(poisoned) => poisoned.into_inner().clone(),
216        }
217    }
218
219    pub fn server_user_agent(&self) -> Option<String> {
220        self.initialize_result_snapshot()
221            .and_then(|value| value.get("userAgent").cloned())
222            .and_then(|value| value.as_str().map(ToOwned::to_owned))
223    }
224
225    pub fn metrics_snapshot(&self) -> RuntimeMetricsSnapshot {
226        self.inner.metrics.snapshot(now_millis())
227    }
228
229    pub(crate) fn record_detached_task_init_failed(&self) {
230        self.inner.metrics.record_detached_task_init_failed();
231    }
232
233    /// Return latest hook report snapshot (last completed hook-enabled call wins).
234    /// Allocation: clones report payload. Complexity: O(i), i = issue count.
235    pub fn hook_report_snapshot(&self) -> HookReport {
236        self.inner.hooks.report_snapshot()
237    }
238
239    /// Register additional lifecycle hooks into running runtime.
240    /// Duplicate hook names are ignored.
241    /// Allocation: O(n) for dedup snapshot. Complexity: O(n + m), n=existing, m=incoming.
242    pub fn register_hooks(&self, hooks: RuntimeHookConfig) {
243        self.inner.hooks.register(hooks);
244    }
245
246    pub(crate) fn hooks_enabled(&self) -> bool {
247        self.inner.hooks.is_enabled()
248    }
249
250    /// True when at least one pre-tool-use hook is registered.
251    /// Allocation: one Vec clone. Complexity: O(n), n = hook count.
252    pub(crate) fn has_pre_tool_use_hooks(&self) -> bool {
253        self.inner.hooks.has_pre_tool_use_hooks()
254    }
255
256    pub(crate) fn has_pre_tool_use_hooks_with(
257        &self,
258        scoped_hooks: Option<&RuntimeHookConfig>,
259    ) -> bool {
260        self.has_pre_tool_use_hooks()
261            || scoped_hooks.is_some_and(|hooks| hooks.has_pre_tool_use_hooks())
262    }
263
264    pub(crate) fn register_thread_scoped_pre_tool_use_hooks(
265        &self,
266        thread_id: &str,
267        scoped_hooks: Option<&RuntimeHookConfig>,
268    ) {
269        let Some(scoped_hooks) = scoped_hooks else {
270            return;
271        };
272        self.inner
273            .hooks
274            .register_thread_scoped_pre_tool_use_hooks(thread_id, &scoped_hooks.pre_tool_use_hooks);
275    }
276
277    pub(crate) fn clear_thread_scoped_pre_tool_use_hooks(&self, thread_id: &str) {
278        self.inner
279            .hooks
280            .clear_thread_scoped_pre_tool_use_hooks(thread_id);
281    }
282
283    pub(crate) fn hooks_enabled_with(&self, scoped_hooks: Option<&RuntimeHookConfig>) -> bool {
284        self.hooks_enabled() || scoped_hooks.is_some_and(|hooks| !hooks.is_empty())
285    }
286
287    pub(crate) fn next_hook_correlation_id(&self) -> String {
288        let seq = self.inner.counters.next_seq.fetch_add(1, Ordering::AcqRel) + 1;
289        format!("hk-{seq}")
290    }
291
292    pub(crate) fn publish_hook_report(&self, report: HookReport) {
293        self.inner.hooks.set_latest_report(report);
294    }
295
296    /// Run pre-hooks. Returns `Err(BlockReason)` if any hook blocks.
297    /// Allocation: O(n) decisions vec, n = hook count.
298    pub(crate) async fn run_pre_hooks_with(
299        &self,
300        ctx: &HookContext,
301        report: &mut HookReport,
302        scoped_hooks: Option<&RuntimeHookConfig>,
303    ) -> Result<Vec<PreHookDecision>, BlockReason> {
304        self.inner
305            .hooks
306            .run_pre_with(ctx, report, scoped_hooks)
307            .await
308    }
309
310    pub(crate) async fn run_post_hooks_with(
311        &self,
312        ctx: &HookContext,
313        report: &mut HookReport,
314        scoped_hooks: Option<&RuntimeHookConfig>,
315    ) {
316        self.inner
317            .hooks
318            .run_post_with(ctx, report, scoped_hooks)
319            .await;
320    }
321
322    pub async fn shutdown(&self) -> Result<(), RuntimeError> {
323        shutdown_runtime(&self.inner).await
324    }
325}
326
327use super::now_millis;
328
329#[cfg(test)]
330mod tests;