Skip to main content

astrid_kernel/
kernel_router.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Instant;
4
5use astrid_events::ipc::{IpcMessage, IpcPayload};
6use astrid_events::kernel_api::{KernelRequest, KernelResponse};
7use tracing::{debug, info, warn};
8
9/// Spawns background tasks for the kernel management API and connection tracking.
10///
11/// Two listeners:
12/// 1. `astrid.v1.request.*` - handles management commands (list capsules, reload, etc.)
13/// 2. `client.v1.disconnect` - decrements the active connection counter on graceful disconnect.
14///
15/// Connection *increment* happens when the WASM proxy capsule accepts a socket
16/// connection (it publishes a `client.v1.connected` event). For ungraceful disconnects,
17/// the idle monitor uses `EventBus::subscriber_count()` as a secondary signal.
18#[must_use]
19pub(crate) fn spawn_kernel_router(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
20    // Spawn the connection tracker as a sibling task.
21    drop(spawn_connection_tracker(Arc::clone(&kernel)));
22
23    let mut receiver = kernel.event_bus.subscribe_topic("astrid.v1.request.*");
24
25    tokio::spawn(async move {
26        let mut rate_limiter = ManagementRateLimiter::new();
27
28        while let Some(event) = receiver.recv().await {
29            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
30                continue;
31            };
32
33            // Only process standard IPC messages that contain JSON payloads.
34            let IpcPayload::RawJson(val) = &message.payload else {
35                continue;
36            };
37
38            match serde_json::from_value::<KernelRequest>(val.clone()) {
39                Ok(req) => {
40                    let (method, limit) = rate_limit_for_request(&req);
41                    if let Some(max) = limit
42                        && !rate_limiter.check(method, max)
43                    {
44                        warn!(
45                            security_event = true,
46                            method = method,
47                            "Rate limited kernel management request"
48                        );
49                        let response_topic =
50                            message.topic.replace("kernel.request.", "kernel.response.");
51                        publish_response(
52                            &kernel,
53                            response_topic,
54                            KernelResponse::Error(format!(
55                                "Rate limited: max {max} {method} requests per minute"
56                            )),
57                        );
58                        continue;
59                    }
60                    handle_request(&kernel, message.topic.clone(), req).await;
61                },
62                Err(e) => {
63                    warn!(error = %e, topic = %message.topic, "Failed to parse KernelRequest from IPC");
64                },
65            }
66        }
67    })
68}
69
70/// Tracks client connection lifecycle events.
71///
72/// Listens on `client.v1.*` topics:
73/// - `client.v1.connected` - a new socket connection was accepted.
74/// - `client.v1.disconnect` - a client sent a graceful disconnect.
75fn spawn_connection_tracker(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
76    let mut receiver = kernel.event_bus.subscribe_topic("client.v1.*");
77
78    tokio::spawn(async move {
79        while let Some(event) = receiver.recv().await {
80            let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
81                continue;
82            };
83            match &message.payload {
84                IpcPayload::Disconnect { reason } => {
85                    kernel.connection_closed();
86                    debug!(reason = ?reason, "Client disconnected");
87                },
88                IpcPayload::Connect => {
89                    kernel.connection_opened();
90                    debug!("New client connection accepted");
91                },
92                _ => {},
93            }
94        }
95    })
96}
97
98#[expect(clippy::too_many_lines)]
99async fn handle_request(kernel: &Arc<crate::Kernel>, topic: String, req: KernelRequest) {
100    let response_topic = if let Some(suffix) = topic.strip_prefix("astrid.v1.request.") {
101        format!("astrid.v1.response.{suffix}")
102    } else {
103        topic.clone()
104    };
105
106    let res = match req {
107        KernelRequest::InstallCapsule { source, workspace } => {
108            info!(source = %source, workspace, "Kernel received install request");
109            // Here the kernel would verify identity, parse the capsule, and potentially
110            // return ApprovalRequired if it needs dangerous capabilities!
111            KernelResponse::Error(
112                "Installation logic not yet implemented in kernel router".to_string(),
113            )
114        },
115        KernelRequest::ApproveCapability {
116            request_id,
117            signature: _,
118        } => {
119            info!(request_id = %request_id, "Kernel received capability approval");
120            KernelResponse::Error("Approval logic not yet implemented in kernel router".to_string())
121        },
122        KernelRequest::ListCapsules => {
123            let reg = kernel.capsules.read().await;
124            let mut list = Vec::new();
125            for c in reg.list() {
126                list.push(c.to_string());
127            }
128            KernelResponse::Success(serde_json::json!(list))
129        },
130        KernelRequest::GetCommands => {
131            let reg = kernel.capsules.read().await;
132            let mut commands = Vec::new();
133            for c in reg.values() {
134                for cmd in &c.manifest().commands {
135                    commands.push(astrid_events::kernel_api::CommandInfo {
136                        name: cmd.name.clone(),
137                        description: cmd
138                            .description
139                            .clone()
140                            .unwrap_or_else(|| "No description".to_string()),
141                        provider_capsule: c.id().to_string(),
142                    });
143                }
144            }
145            info!(
146                count = commands.len(),
147                capsules = reg.len(),
148                "GetCommands: returning {} commands from {} capsules",
149                commands.len(),
150                reg.len()
151            );
152            KernelResponse::Commands(commands)
153        },
154        KernelRequest::ReloadCapsules => {
155            // Unregister capsules in a Failed state so they can be re-loaded
156            // with fresh configuration (e.g. after onboarding writes .env.json).
157            {
158                let reg = kernel.capsules.read().await;
159                let failed_ids: Vec<_> = reg
160                    .list()
161                    .into_iter()
162                    .filter(|id| {
163                        reg.get(id).is_some_and(|c| {
164                            matches!(c.state(), astrid_capsule::capsule::CapsuleState::Failed(_))
165                        })
166                    })
167                    .cloned()
168                    .collect();
169                drop(reg);
170
171                let mut reg = kernel.capsules.write().await;
172                for id in failed_ids {
173                    let _ = reg.unregister(&id);
174                }
175            }
176
177            kernel.load_all_capsules().await;
178            KernelResponse::Success(serde_json::json!({"status": "reloaded"}))
179        },
180        KernelRequest::GetCapsuleMetadata => {
181            let reg = kernel.capsules.read().await;
182            let mut entries = Vec::new();
183            for capsule in reg.values() {
184                let manifest = capsule.manifest();
185                entries.push(astrid_events::kernel_api::CapsuleMetadataEntry {
186                    name: manifest.package.name.clone(),
187                    llm_providers: manifest
188                        .llm_providers
189                        .iter()
190                        .map(|p| astrid_events::kernel_api::LlmProviderInfo {
191                            id: p.id.clone(),
192                            description: p.description.clone().unwrap_or_default(),
193                            capabilities: p.capabilities.clone(),
194                        })
195                        .collect(),
196                    interceptor_events: manifest
197                        .interceptors
198                        .iter()
199                        .map(|i| i.event.clone())
200                        .collect(),
201                });
202            }
203            KernelResponse::CapsuleMetadata(entries)
204        },
205    };
206
207    publish_response(kernel, response_topic, res);
208}
209
210fn publish_response(kernel: &Arc<crate::Kernel>, response_topic: String, res: KernelResponse) {
211    if let Ok(val) = serde_json::to_value(res) {
212        let msg = IpcMessage::new(
213            response_topic,
214            IpcPayload::RawJson(val),
215            kernel.session_id.0,
216        );
217        let _ = kernel.event_bus.publish(astrid_events::AstridEvent::Ipc {
218            metadata: astrid_events::EventMetadata::new("kernel_router"),
219            message: msg,
220        });
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Management API rate limiting
226// ---------------------------------------------------------------------------
227
228/// Sliding window rate limiter for management API requests.
229/// Tracks per-request timestamps and evicts entries older than 60 seconds,
230/// preventing the 2x burst possible with fixed-window designs.
231/// Single-consumer (owned by the router task), no concurrency concerns.
232struct ManagementRateLimiter {
233    buckets: HashMap<&'static str, VecDeque<Instant>>,
234}
235
236impl ManagementRateLimiter {
237    fn new() -> Self {
238        Self {
239            buckets: HashMap::new(),
240        }
241    }
242
243    /// Check if a request of the given type is within the rate limit.
244    /// Returns `true` if allowed, `false` if rate-limited.
245    fn check(&mut self, method: &'static str, max_per_minute: u32) -> bool {
246        let now = Instant::now();
247        let window = std::time::Duration::from_secs(60);
248        let timestamps = self.buckets.entry(method).or_default();
249
250        // Evict timestamps older than the 60-second sliding window.
251        while let Some(&oldest) = timestamps.front() {
252            if now.saturating_duration_since(oldest) >= window {
253                timestamps.pop_front();
254            } else {
255                break;
256            }
257        }
258
259        if timestamps.len() >= max_per_minute as usize {
260            return false;
261        }
262        timestamps.push_back(now);
263        true
264    }
265}
266
267/// Return the rate limit label and max-per-minute for a request type.
268/// Returns `None` for the limit if the request type is not rate-limited.
269fn rate_limit_for_request(req: &KernelRequest) -> (&'static str, Option<u32>) {
270    match req {
271        KernelRequest::ReloadCapsules => ("ReloadCapsules", Some(5)),
272        KernelRequest::InstallCapsule { .. } => ("InstallCapsule", Some(10)),
273        KernelRequest::ApproveCapability { .. } => ("ApproveCapability", Some(10)),
274        // Read-only operations are cheap - no rate limit.
275        KernelRequest::ListCapsules => ("ListCapsules", None),
276        KernelRequest::GetCommands => ("GetCommands", None),
277        KernelRequest::GetCapsuleMetadata => ("GetCapsuleMetadata", None),
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn rate_limiter_allows_within_limit() {
287        let mut limiter = ManagementRateLimiter::new();
288        for _ in 0..5 {
289            assert!(limiter.check("ReloadCapsules", 5));
290        }
291        // 6th should be rejected
292        assert!(!limiter.check("ReloadCapsules", 5));
293    }
294
295    #[test]
296    fn rate_limiter_independent_buckets() {
297        let mut limiter = ManagementRateLimiter::new();
298        // Fill ReloadCapsules
299        for _ in 0..5 {
300            assert!(limiter.check("ReloadCapsules", 5));
301        }
302        assert!(!limiter.check("ReloadCapsules", 5));
303
304        // InstallCapsule should still be allowed
305        assert!(limiter.check("InstallCapsule", 10));
306    }
307
308    #[test]
309    fn rate_limiter_sliding_window_eviction() {
310        let mut limiter = ManagementRateLimiter::new();
311        // Fill the bucket
312        for _ in 0..5 {
313            assert!(limiter.check("ReloadCapsules", 5));
314        }
315        assert!(!limiter.check("ReloadCapsules", 5));
316
317        // Manually set all timestamps to 61 seconds ago to simulate expiry.
318        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
319            let past = Instant::now() - std::time::Duration::from_secs(61);
320            for ts in timestamps.iter_mut() {
321                *ts = past;
322            }
323        }
324
325        // Should be allowed again after old entries are evicted
326        assert!(limiter.check("ReloadCapsules", 5));
327    }
328
329    #[test]
330    fn rate_limiter_sliding_window_prevents_boundary_burst() {
331        let mut limiter = ManagementRateLimiter::new();
332        // Fill 5 requests
333        for _ in 0..5 {
334            assert!(limiter.check("ReloadCapsules", 5));
335        }
336
337        // Move only 3 of the 5 timestamps to the past (beyond 60s window).
338        // This simulates partial window expiry - only 3 slots should free up.
339        if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
340            let past = Instant::now() - std::time::Duration::from_secs(61);
341            for ts in timestamps.iter_mut().take(3) {
342                *ts = past;
343            }
344        }
345
346        // Should allow exactly 3 more (the evicted slots), not 5
347        for _ in 0..3 {
348            assert!(limiter.check("ReloadCapsules", 5));
349        }
350        assert!(!limiter.check("ReloadCapsules", 5));
351    }
352
353    #[test]
354    fn rate_limit_for_request_returns_correct_limits() {
355        let (name, limit) = rate_limit_for_request(&KernelRequest::ReloadCapsules);
356        assert_eq!(name, "ReloadCapsules");
357        assert_eq!(limit, Some(5));
358
359        let (name, limit) = rate_limit_for_request(&KernelRequest::ListCapsules);
360        assert_eq!(name, "ListCapsules");
361        assert_eq!(limit, None);
362    }
363}