Skip to main content

astrid_kernel/
kernel_router.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Instant;
4
5use astrid_events::ipc::{IpcMessage, IpcPayload};
6use astrid_events::kernel_api::{KernelRequest, KernelResponse};
7use tracing::{debug, info, warn};
8
9/// Spawns background tasks for the kernel management API and connection tracking.
10///
11/// Two listeners:
12/// 1. `astrid.v1.request.*` - handles management commands (list capsules, reload, etc.)
13/// 2. `client.v1.disconnect` - decrements the active connection counter on graceful disconnect.
14///
15/// Connection *increment* happens when the WASM proxy capsule accepts a socket
16/// connection (it publishes a `client.v1.connected` event). For ungraceful disconnects,
17/// the idle monitor uses `EventBus::subscriber_count()` as a secondary signal.
18#[must_use]
19pub(crate) fn spawn_kernel_router(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
20    // Spawn the connection tracker as a sibling task.
21    drop(spawn_connection_tracker(Arc::clone(&kernel)));
22
23    let mut receiver = kernel.event_bus.subscribe_topic("astrid.v1.request.*");
24
25    tokio::spawn(async move {
26        let mut rate_limiter = ManagementRateLimiter::new();
27
28        while let Some(event) = receiver.recv().await {
29            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
30                continue;
31            };
32
33            // Only process standard IPC messages that contain JSON payloads.
34            let IpcPayload::RawJson(val) = &message.payload else {
35                continue;
36            };
37
38            match serde_json::from_value::<KernelRequest>(val.clone()) {
39                Ok(req) => {
40                    let (method, limit) = rate_limit_for_request(&req);
41                    if let Some(max) = limit
42                        && !rate_limiter.check(method, max)
43                    {
44                        warn!(
45                            security_event = true,
46                            method = method,
47                            "Rate limited kernel management request"
48                        );
49                        let response_topic =
50                            message.topic.replace("kernel.request.", "kernel.response.");
51                        publish_response(
52                            &kernel,
53                            response_topic,
54                            KernelResponse::Error(format!(
55                                "Rate limited: max {max} {method} requests per minute"
56                            )),
57                        );
58                        continue;
59                    }
60                    handle_request(&kernel, message.topic.clone(), req).await;
61                },
62                Err(e) => {
63                    warn!(error = %e, topic = %message.topic, "Failed to parse KernelRequest from IPC");
64                },
65            }
66        }
67    })
68}
69
70/// Tracks client connection lifecycle events.
71///
72/// Listens on `client.v1.*` topics:
73/// - `client.v1.connected` - a new socket connection was accepted.
74/// - `client.v1.disconnect` - a client sent a graceful disconnect.
75fn spawn_connection_tracker(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
76    let mut receiver = kernel.event_bus.subscribe_topic("client.v1.*");
77
78    tokio::spawn(async move {
79        while let Some(event) = receiver.recv().await {
80            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
81                continue;
82            };
83            match &message.payload {
84                IpcPayload::Disconnect { reason } => {
85                    kernel.connection_closed();
86                    debug!(reason = ?reason, "Client disconnected");
87                },
88                IpcPayload::Connect => {
89                    kernel.connection_opened();
90                    debug!("New client connection accepted");
91                },
92                _ => {},
93            }
94        }
95    })
96}
97
98#[expect(clippy::too_many_lines)]
99async fn handle_request(kernel: &Arc<crate::Kernel>, topic: String, req: KernelRequest) {
100    let response_topic = if let Some(suffix) = topic.strip_prefix("astrid.v1.request.") {
101        format!("astrid.v1.response.{suffix}")
102    } else {
103        topic.clone()
104    };
105
106    let res = match req {
107        KernelRequest::InstallCapsule { source, workspace } => {
108            info!(source = %source, workspace, "Kernel received install request");
109            // Here the kernel would verify identity, parse the capsule, and potentially
110            // return ApprovalRequired if it needs dangerous capabilities!
111            KernelResponse::Error(
112                "Installation logic not yet implemented in kernel router".to_string(),
113            )
114        },
115        KernelRequest::ApproveCapability {
116            request_id,
117            signature: _,
118        } => {
119            info!(request_id = %request_id, "Kernel received capability approval");
120            KernelResponse::Error("Approval logic not yet implemented in kernel router".to_string())
121        },
122        KernelRequest::ListCapsules => {
123            let reg = kernel.capsules.read().await;
124            let mut list = Vec::new();
125            for c in reg.list() {
126                list.push(c.to_string());
127            }
128            KernelResponse::Success(serde_json::json!(list))
129        },
130        KernelRequest::GetCommands => {
131            let reg = kernel.capsules.read().await;
132            let mut commands = Vec::new();
133            for c in reg.values() {
134                for cmd in &c.manifest().commands {
135                    commands.push(astrid_events::kernel_api::CommandInfo {
136                        name: cmd.name.clone(),
137                        description: cmd
138                            .description
139                            .clone()
140                            .unwrap_or_else(|| "No description".to_string()),
141                        provider_capsule: c.id().to_string(),
142                    });
143                }
144            }
145            info!(
146                count = commands.len(),
147                capsules = reg.len(),
148                "GetCommands: returning {} commands from {} capsules",
149                commands.len(),
150                reg.len()
151            );
152            KernelResponse::Commands(commands)
153        },
154        KernelRequest::ReloadCapsules => {
155            // Unregister capsules in a Failed state so they can be re-loaded
156            // with fresh configuration (e.g. after onboarding writes .env.json).
157            {
158                let reg = kernel.capsules.read().await;
159                let failed_ids: Vec<_> = reg
160                    .list()
161                    .into_iter()
162                    .filter(|id| {
163                        reg.get(id).is_some_and(|c| {
164                            matches!(c.state(), astrid_capsule::capsule::CapsuleState::Failed(_))
165                        })
166                    })
167                    .cloned()
168                    .collect();
169                drop(reg);
170
171                let mut reg = kernel.capsules.write().await;
172                for id in failed_ids {
173                    let _ = reg.unregister(&id);
174                }
175            }
176
177            kernel.load_all_capsules().await;
178            KernelResponse::Success(serde_json::json!({"status": "reloaded"}))
179        },
180        KernelRequest::Shutdown { reason } => {
181            info!(
182                reason = reason.as_deref().unwrap_or("none"),
183                "Kernel received shutdown request via management API"
184            );
185            // Publish response before signaling shutdown so the client gets confirmation.
186            publish_response(
187                kernel,
188                response_topic.clone(),
189                KernelResponse::Success(serde_json::json!({"status": "shutting_down"})),
190            );
191            // Signal the daemon's main loop to exit gracefully.
192            let _ = kernel.shutdown_tx.send(true);
193            // Return early — the daemon will call kernel.shutdown() from its main loop.
194            return;
195        },
196        KernelRequest::GetStatus => {
197            let uptime = kernel.boot_time.elapsed().as_secs();
198            let reg = kernel.capsules.read().await;
199            let loaded: Vec<String> = reg.list().iter().map(ToString::to_string).collect();
200            let status = astrid_events::kernel_api::DaemonStatus {
201                pid: std::process::id(),
202                uptime_secs: uptime,
203                version: env!("CARGO_PKG_VERSION").to_string(),
204                ephemeral: false, // The kernel doesn't know; daemon sets this via response override if needed
205                connected_clients: u32::try_from(kernel.connection_count()).unwrap_or(u32::MAX),
206                loaded_capsules: loaded,
207            };
208            KernelResponse::Status(status)
209        },
210        KernelRequest::GetCapsuleMetadata => {
211            let reg = kernel.capsules.read().await;
212            let mut entries = Vec::new();
213            for capsule in reg.values() {
214                let manifest = capsule.manifest();
215                entries.push(astrid_events::kernel_api::CapsuleMetadataEntry {
216                    name: manifest.package.name.clone(),
217                    llm_providers: manifest
218                        .llm_providers
219                        .iter()
220                        .map(|p| astrid_events::kernel_api::LlmProviderInfo {
221                            id: p.id.clone(),
222                            description: p.description.clone().unwrap_or_default(),
223                            capabilities: p.capabilities.clone(),
224                        })
225                        .collect(),
226                    interceptor_events: manifest
227                        .interceptors
228                        .iter()
229                        .map(|i| i.event.clone())
230                        .collect(),
231                });
232            }
233            KernelResponse::CapsuleMetadata(entries)
234        },
235    };
236
237    publish_response(kernel, response_topic, res);
238}
239
240fn publish_response(kernel: &Arc<crate::Kernel>, response_topic: String, res: KernelResponse) {
241    if let Ok(val) = serde_json::to_value(res) {
242        let msg = IpcMessage::new(
243            response_topic,
244            IpcPayload::RawJson(val),
245            kernel.session_id.0,
246        );
247        let _ = kernel.event_bus.publish(astrid_events::AstridEvent::Ipc {
248            metadata: astrid_events::EventMetadata::new("kernel_router"),
249            message: msg,
250        });
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Management API rate limiting
256// ---------------------------------------------------------------------------
257
258/// Sliding window rate limiter for management API requests.
259/// Tracks per-request timestamps and evicts entries older than 60 seconds,
260/// preventing the 2x burst possible with fixed-window designs.
261/// Single-consumer (owned by the router task), no concurrency concerns.
262struct ManagementRateLimiter {
263    buckets: HashMap<&'static str, VecDeque<Instant>>,
264}
265
266impl ManagementRateLimiter {
267    fn new() -> Self {
268        Self {
269            buckets: HashMap::new(),
270        }
271    }
272
273    /// Check if a request of the given type is within the rate limit.
274    /// Returns `true` if allowed, `false` if rate-limited.
275    fn check(&mut self, method: &'static str, max_per_minute: u32) -> bool {
276        let now = Instant::now();
277        let window = std::time::Duration::from_secs(60);
278        let timestamps = self.buckets.entry(method).or_default();
279
280        // Evict timestamps older than the 60-second sliding window.
281        while let Some(&oldest) = timestamps.front() {
282            if now.saturating_duration_since(oldest) >= window {
283                timestamps.pop_front();
284            } else {
285                break;
286            }
287        }
288
289        if timestamps.len() >= max_per_minute as usize {
290            return false;
291        }
292        timestamps.push_back(now);
293        true
294    }
295}
296
297/// Return the rate limit label and max-per-minute for a request type.
298/// Returns `None` for the limit if the request type is not rate-limited.
299fn rate_limit_for_request(req: &KernelRequest) -> (&'static str, Option<u32>) {
300    match req {
301        KernelRequest::ReloadCapsules => ("ReloadCapsules", Some(5)),
302        KernelRequest::InstallCapsule { .. } => ("InstallCapsule", Some(10)),
303        KernelRequest::ApproveCapability { .. } => ("ApproveCapability", Some(10)),
304        // Read-only operations are cheap - no rate limit.
305        KernelRequest::ListCapsules => ("ListCapsules", None),
306        KernelRequest::GetCommands => ("GetCommands", None),
307        KernelRequest::GetCapsuleMetadata => ("GetCapsuleMetadata", None),
308        KernelRequest::Shutdown { .. } => ("Shutdown", Some(1)),
309        KernelRequest::GetStatus => ("GetStatus", None),
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn rate_limiter_allows_within_limit() {
319        let mut limiter = ManagementRateLimiter::new();
320        for _ in 0..5 {
321            assert!(limiter.check("ReloadCapsules", 5));
322        }
323        // 6th should be rejected
324        assert!(!limiter.check("ReloadCapsules", 5));
325    }
326
327    #[test]
328    fn rate_limiter_independent_buckets() {
329        let mut limiter = ManagementRateLimiter::new();
330        // Fill ReloadCapsules
331        for _ in 0..5 {
332            assert!(limiter.check("ReloadCapsules", 5));
333        }
334        assert!(!limiter.check("ReloadCapsules", 5));
335
336        // InstallCapsule should still be allowed
337        assert!(limiter.check("InstallCapsule", 10));
338    }
339
340    #[test]
341    fn rate_limiter_sliding_window_eviction() {
342        let mut limiter = ManagementRateLimiter::new();
343        // Fill the bucket
344        for _ in 0..5 {
345            assert!(limiter.check("ReloadCapsules", 5));
346        }
347        assert!(!limiter.check("ReloadCapsules", 5));
348
349        // Manually set all timestamps to 61 seconds ago to simulate expiry.
350        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
351            let past = Instant::now() - std::time::Duration::from_secs(61);
352            for ts in timestamps.iter_mut() {
353                *ts = past;
354            }
355        }
356
357        // Should be allowed again after old entries are evicted
358        assert!(limiter.check("ReloadCapsules", 5));
359    }
360
361    #[test]
362    fn rate_limiter_sliding_window_prevents_boundary_burst() {
363        let mut limiter = ManagementRateLimiter::new();
364        // Fill 5 requests
365        for _ in 0..5 {
366            assert!(limiter.check("ReloadCapsules", 5));
367        }
368
369        // Move only 3 of the 5 timestamps to the past (beyond 60s window).
370        // This simulates partial window expiry - only 3 slots should free up.
371        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
372            let past = Instant::now() - std::time::Duration::from_secs(61);
373            for ts in timestamps.iter_mut().take(3) {
374                *ts = past;
375            }
376        }
377
378        // Should allow exactly 3 more (the evicted slots), not 5
379        for _ in 0..3 {
380            assert!(limiter.check("ReloadCapsules", 5));
381        }
382        assert!(!limiter.check("ReloadCapsules", 5));
383    }
384
385    #[test]
386    fn rate_limit_for_request_returns_correct_limits() {
387        let (name, limit) = rate_limit_for_request(&KernelRequest::ReloadCapsules);
388        assert_eq!(name, "ReloadCapsules");
389        assert_eq!(limit, Some(5));
390
391        let (name, limit) = rate_limit_for_request(&KernelRequest::ListCapsules);
392        assert_eq!(name, "ListCapsules");
393        assert_eq!(limit, None);
394    }
395}