Skip to main content

astrid_kernel/
kernel_router.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Instant;
4
5use astrid_events::ipc::{IpcMessage, IpcPayload};
6use astrid_events::kernel_api::{KernelRequest, KernelResponse};
7use tracing::{debug, info, warn};
8
9/// Spawns background tasks for the kernel management API and connection tracking.
10///
11/// Two listeners:
12/// 1. `astrid.v1.request.*` - handles management commands (list capsules, reload, etc.)
13/// 2. `client.v1.disconnect` - decrements the active connection counter on graceful disconnect.
14///
15/// Connection *increment* happens when the WASM proxy capsule accepts a socket
16/// connection (it publishes a `client.v1.connected` event). For ungraceful disconnects,
17/// the idle monitor uses `EventBus::subscriber_count()` as a secondary signal.
18#[must_use]
19pub(crate) fn spawn_kernel_router(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
20    // Spawn the connection tracker as a sibling task.
21    drop(spawn_connection_tracker(Arc::clone(&kernel)));
22
23    let mut receiver = kernel.event_bus.subscribe_topic("astrid.v1.request.*");
24
25    tokio::spawn(async move {
26        let mut rate_limiter = ManagementRateLimiter::new();
27
28        while let Some(event) = receiver.recv().await {
29            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
30                continue;
31            };
32
33            // Only process standard IPC messages that contain JSON payloads.
34            let IpcPayload::RawJson(val) = &message.payload else {
35                continue;
36            };
37
38            match serde_json::from_value::<KernelRequest>(val.clone()) {
39                Ok(req) => {
40                    let (method, limit) = rate_limit_for_request(&req);
41                    if let Some(max) = limit
42                        && !rate_limiter.check(method, max)
43                    {
44                        warn!(
45                            security_event = true,
46                            method = method,
47                            "Rate limited kernel management request"
48                        );
49                        let response_topic =
50                            message.topic.replace("kernel.request.", "kernel.response.");
51                        publish_response(
52                            &kernel,
53                            response_topic,
54                            KernelResponse::Error(format!(
55                                "Rate limited: max {max} {method} requests per minute"
56                            )),
57                        );
58                        continue;
59                    }
60                    handle_request(&kernel, message.topic.clone(), req).await;
61                },
62                Err(e) => {
63                    warn!(error = %e, topic = %message.topic, "Failed to parse KernelRequest from IPC");
64                },
65            }
66        }
67    })
68}
69
70/// Tracks client connection lifecycle events.
71///
72/// Listens on `client.v1.*` topics:
73/// - `client.v1.connected` - a new socket connection was accepted.
74/// - `client.v1.disconnect` - a client sent a graceful disconnect.
75fn spawn_connection_tracker(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
76    let mut receiver = kernel.event_bus.subscribe_topic("client.v1.*");
77
78    tokio::spawn(async move {
79        while let Some(event) = receiver.recv().await {
80            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
81                continue;
82            };
83            match &message.payload {
84                IpcPayload::Disconnect { reason } => {
85                    kernel.connection_closed();
86                    debug!(reason = ?reason, "Client disconnected");
87                },
88                IpcPayload::Connect => {
89                    kernel.connection_opened();
90                    debug!("New client connection accepted");
91                },
92                _ => {},
93            }
94        }
95    })
96}
97
98#[expect(clippy::too_many_lines)]
99async fn handle_request(kernel: &Arc<crate::Kernel>, topic: String, req: KernelRequest) {
100    let response_topic = if let Some(suffix) = topic.strip_prefix("astrid.v1.request.") {
101        format!("astrid.v1.response.{suffix}")
102    } else {
103        topic.clone()
104    };
105
106    let res = match req {
107        KernelRequest::InstallCapsule { source, workspace } => {
108            info!(source = %source, workspace, "Kernel received install request");
109            // Here the kernel would verify identity, parse the capsule, and potentially
110            // return ApprovalRequired if it needs dangerous capabilities!
111            KernelResponse::Error(
112                "Installation logic not yet implemented in kernel router".to_string(),
113            )
114        },
115        KernelRequest::ApproveCapability {
116            request_id,
117            signature: _,
118        } => {
119            info!(request_id = %request_id, "Kernel received capability approval");
120            KernelResponse::Error("Approval logic not yet implemented in kernel router".to_string())
121        },
122        KernelRequest::ListCapsules => {
123            let reg = kernel.capsules.read().await;
124            let mut list = Vec::new();
125            for c in reg.list() {
126                list.push(c.to_string());
127            }
128            KernelResponse::Success(serde_json::json!(list))
129        },
130        KernelRequest::GetCommands => {
131            let reg = kernel.capsules.read().await;
132            let mut commands = Vec::new();
133            for c in reg.values() {
134                for cmd in &c.manifest().commands {
135                    commands.push(astrid_events::kernel_api::CommandInfo {
136                        name: cmd.name.clone(),
137                        description: cmd
138                            .description
139                            .clone()
140                            .unwrap_or_else(|| "No description".to_string()),
141                        provider_capsule: c.id().to_string(),
142                    });
143                }
144            }
145            info!(
146                count = commands.len(),
147                capsules = reg.len(),
148                "GetCommands: returning {} commands from {} capsules",
149                commands.len(),
150                reg.len()
151            );
152            KernelResponse::Commands(commands)
153        },
154        KernelRequest::ReloadCapsules => {
155            // Unregister capsules in a Failed state so they can be re-loaded
156            // with fresh configuration (e.g. after onboarding writes .env.json).
157            {
158                let reg = kernel.capsules.read().await;
159                let failed_ids: Vec<_> = reg
160                    .list()
161                    .into_iter()
162                    .filter(|id| {
163                        reg.get(id).is_some_and(|c| {
164                            matches!(c.state(), astrid_capsule::capsule::CapsuleState::Failed(_))
165                        })
166                    })
167                    .cloned()
168                    .collect();
169                drop(reg);
170
171                let mut reg = kernel.capsules.write().await;
172                for id in failed_ids {
173                    let _ = reg.unregister(&id);
174                }
175            }
176
177            kernel.load_all_capsules().await;
178            KernelResponse::Success(serde_json::json!({"status": "reloaded"}))
179        },
180        KernelRequest::Shutdown { reason } => {
181            info!(
182                reason = reason.as_deref().unwrap_or("none"),
183                "Kernel received shutdown request via management API"
184            );
185            // Publish response before signaling shutdown so the client gets confirmation.
186            publish_response(
187                kernel,
188                response_topic.clone(),
189                KernelResponse::Success(serde_json::json!({"status": "shutting_down"})),
190            );
191            // Signal the daemon's main loop to exit gracefully.
192            let _ = kernel.shutdown_tx.send(true);
193            // Return early — the daemon will call kernel.shutdown() from its main loop.
194            return;
195        },
196        KernelRequest::GetStatus => {
197            let uptime = kernel.boot_time.elapsed().as_secs();
198            let reg = kernel.capsules.read().await;
199            let loaded: Vec<String> = reg.list().iter().map(ToString::to_string).collect();
200            let status = astrid_events::kernel_api::DaemonStatus {
201                pid: std::process::id(),
202                uptime_secs: uptime,
203                version: env!("CARGO_PKG_VERSION").to_string(),
204                ephemeral: false, // The kernel doesn't know; daemon sets this via response override if needed
205                connected_clients: u32::try_from(kernel.connection_count()).unwrap_or(u32::MAX),
206                loaded_capsules: loaded,
207            };
208            KernelResponse::Status(status)
209        },
210        KernelRequest::GetCapsuleMetadata => {
211            let reg = kernel.capsules.read().await;
212            let mut entries = Vec::new();
213            for capsule in reg.values() {
214                let manifest = capsule.manifest();
215                entries.push(astrid_events::kernel_api::CapsuleMetadataEntry {
216                    name: manifest.package.name.clone(),
217                    interceptor_events: manifest
218                        .interceptors
219                        .iter()
220                        .map(|i| i.event.clone())
221                        .collect(),
222                });
223            }
224            KernelResponse::CapsuleMetadata(entries)
225        },
226    };
227
228    publish_response(kernel, response_topic, res);
229}
230
231fn publish_response(kernel: &Arc<crate::Kernel>, response_topic: String, res: KernelResponse) {
232    if let Ok(val) = serde_json::to_value(res) {
233        let msg = IpcMessage::new(
234            response_topic,
235            IpcPayload::RawJson(val),
236            kernel.session_id.0,
237        );
238        let _ = kernel.event_bus.publish(astrid_events::AstridEvent::Ipc {
239            metadata: astrid_events::EventMetadata::new("kernel_router"),
240            message: msg,
241        });
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Management API rate limiting
247// ---------------------------------------------------------------------------
248
249/// Sliding window rate limiter for management API requests.
250/// Tracks per-request timestamps and evicts entries older than 60 seconds,
251/// preventing the 2x burst possible with fixed-window designs.
252/// Single-consumer (owned by the router task), no concurrency concerns.
253struct ManagementRateLimiter {
254    buckets: HashMap<&'static str, VecDeque<Instant>>,
255}
256
257impl ManagementRateLimiter {
258    fn new() -> Self {
259        Self {
260            buckets: HashMap::new(),
261        }
262    }
263
264    /// Check if a request of the given type is within the rate limit.
265    /// Returns `true` if allowed, `false` if rate-limited.
266    fn check(&mut self, method: &'static str, max_per_minute: u32) -> bool {
267        let now = Instant::now();
268        let window = std::time::Duration::from_secs(60);
269        let timestamps = self.buckets.entry(method).or_default();
270
271        // Evict timestamps older than the 60-second sliding window.
272        while let Some(&oldest) = timestamps.front() {
273            if now.saturating_duration_since(oldest) >= window {
274                timestamps.pop_front();
275            } else {
276                break;
277            }
278        }
279
280        if timestamps.len() >= max_per_minute as usize {
281            return false;
282        }
283        timestamps.push_back(now);
284        true
285    }
286}
287
288/// Return the rate limit label and max-per-minute for a request type.
289/// Returns `None` for the limit if the request type is not rate-limited.
290fn rate_limit_for_request(req: &KernelRequest) -> (&'static str, Option<u32>) {
291    match req {
292        KernelRequest::ReloadCapsules => ("ReloadCapsules", Some(5)),
293        KernelRequest::InstallCapsule { .. } => ("InstallCapsule", Some(10)),
294        KernelRequest::ApproveCapability { .. } => ("ApproveCapability", Some(10)),
295        // Read-only operations are cheap - no rate limit.
296        KernelRequest::ListCapsules => ("ListCapsules", None),
297        KernelRequest::GetCommands => ("GetCommands", None),
298        KernelRequest::GetCapsuleMetadata => ("GetCapsuleMetadata", None),
299        KernelRequest::Shutdown { .. } => ("Shutdown", Some(1)),
300        KernelRequest::GetStatus => ("GetStatus", None),
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn rate_limiter_allows_within_limit() {
310        let mut limiter = ManagementRateLimiter::new();
311        for _ in 0..5 {
312            assert!(limiter.check("ReloadCapsules", 5));
313        }
314        // 6th should be rejected
315        assert!(!limiter.check("ReloadCapsules", 5));
316    }
317
318    #[test]
319    fn rate_limiter_independent_buckets() {
320        let mut limiter = ManagementRateLimiter::new();
321        // Fill ReloadCapsules
322        for _ in 0..5 {
323            assert!(limiter.check("ReloadCapsules", 5));
324        }
325        assert!(!limiter.check("ReloadCapsules", 5));
326
327        // InstallCapsule should still be allowed
328        assert!(limiter.check("InstallCapsule", 10));
329    }
330
331    #[test]
332    fn rate_limiter_sliding_window_eviction() {
333        let mut limiter = ManagementRateLimiter::new();
334        // Fill the bucket
335        for _ in 0..5 {
336            assert!(limiter.check("ReloadCapsules", 5));
337        }
338        assert!(!limiter.check("ReloadCapsules", 5));
339
340        // Manually set all timestamps to 61 seconds ago to simulate expiry.
341        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
342            let past = Instant::now() - std::time::Duration::from_secs(61);
343            for ts in timestamps.iter_mut() {
344                *ts = past;
345            }
346        }
347
348        // Should be allowed again after old entries are evicted
349        assert!(limiter.check("ReloadCapsules", 5));
350    }
351
352    #[test]
353    fn rate_limiter_sliding_window_prevents_boundary_burst() {
354        let mut limiter = ManagementRateLimiter::new();
355        // Fill 5 requests
356        for _ in 0..5 {
357            assert!(limiter.check("ReloadCapsules", 5));
358        }
359
360        // Move only 3 of the 5 timestamps to the past (beyond 60s window).
361        // This simulates partial window expiry - only 3 slots should free up.
362        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
363            let past = Instant::now() - std::time::Duration::from_secs(61);
364            for ts in timestamps.iter_mut().take(3) {
365                *ts = past;
366            }
367        }
368
369        // Should allow exactly 3 more (the evicted slots), not 5
370        for _ in 0..3 {
371            assert!(limiter.check("ReloadCapsules", 5));
372        }
373        assert!(!limiter.check("ReloadCapsules", 5));
374    }
375
376    #[test]
377    fn rate_limit_for_request_returns_correct_limits() {
378        let (name, limit) = rate_limit_for_request(&KernelRequest::ReloadCapsules);
379        assert_eq!(name, "ReloadCapsules");
380        assert_eq!(limit, Some(5));
381
382        let (name, limit) = rate_limit_for_request(&KernelRequest::ListCapsules);
383        assert_eq!(name, "ListCapsules");
384        assert_eq!(limit, None);
385    }
386}