1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::Instant;
4
5use astrid_events::ipc::{IpcMessage, IpcPayload};
6use astrid_events::kernel_api::{KernelRequest, KernelResponse};
7use tracing::{debug, info, warn};
8
9#[must_use]
19pub(crate) fn spawn_kernel_router(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
20 drop(spawn_connection_tracker(Arc::clone(&kernel)));
22
23 let mut receiver = kernel.event_bus.subscribe_topic("astrid.v1.request.*");
24
25 tokio::spawn(async move {
26 let mut rate_limiter = ManagementRateLimiter::new();
27
28 while let Some(event) = receiver.recv().await {
29 let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
30 continue;
31 };
32
33 let IpcPayload::RawJson(val) = &message.payload else {
35 continue;
36 };
37
38 match serde_json::from_value::<KernelRequest>(val.clone()) {
39 Ok(req) => {
40 let (method, limit) = rate_limit_for_request(&req);
41 if let Some(max) = limit
42 && !rate_limiter.check(method, max)
43 {
44 warn!(
45 security_event = true,
46 method = method,
47 "Rate limited kernel management request"
48 );
49 let response_topic =
50 message.topic.replace("kernel.request.", "kernel.response.");
51 publish_response(
52 &kernel,
53 response_topic,
54 KernelResponse::Error(format!(
55 "Rate limited: max {max} {method} requests per minute"
56 )),
57 );
58 continue;
59 }
60 handle_request(&kernel, message.topic.clone(), req).await;
61 },
62 Err(e) => {
63 warn!(error = %e, topic = %message.topic, "Failed to parse KernelRequest from IPC");
64 },
65 }
66 }
67 })
68}
69
70fn spawn_connection_tracker(kernel: Arc<crate::Kernel>) -> tokio::task::JoinHandle<()> {
76 let mut receiver = kernel.event_bus.subscribe_topic("client.v1.*");
77
78 tokio::spawn(async move {
79 while let Some(event) = receiver.recv().await {
80 let astrid_events::AstridEvent::Ipc { message, .. } = &*event else {
81 continue;
82 };
83 match &message.payload {
84 IpcPayload::Disconnect { reason } => {
85 kernel.connection_closed();
86 debug!(reason = ?reason, "Client disconnected");
87 },
88 IpcPayload::Connect => {
89 kernel.connection_opened();
90 debug!("New client connection accepted");
91 },
92 _ => {},
93 }
94 }
95 })
96}
97
98#[expect(clippy::too_many_lines)]
99async fn handle_request(kernel: &Arc<crate::Kernel>, topic: String, req: KernelRequest) {
100 let response_topic = if let Some(suffix) = topic.strip_prefix("astrid.v1.request.") {
101 format!("astrid.v1.response.{suffix}")
102 } else {
103 topic.clone()
104 };
105
106 let res = match req {
107 KernelRequest::InstallCapsule { source, workspace } => {
108 info!(source = %source, workspace, "Kernel received install request");
109 KernelResponse::Error(
112 "Installation logic not yet implemented in kernel router".to_string(),
113 )
114 },
115 KernelRequest::ApproveCapability {
116 request_id,
117 signature: _,
118 } => {
119 info!(request_id = %request_id, "Kernel received capability approval");
120 KernelResponse::Error("Approval logic not yet implemented in kernel router".to_string())
121 },
122 KernelRequest::ListCapsules => {
123 let reg = kernel.capsules.read().await;
124 let mut list = Vec::new();
125 for c in reg.list() {
126 list.push(c.to_string());
127 }
128 KernelResponse::Success(serde_json::json!(list))
129 },
130 KernelRequest::GetCommands => {
131 let reg = kernel.capsules.read().await;
132 let mut commands = Vec::new();
133 for c in reg.values() {
134 for cmd in &c.manifest().commands {
135 commands.push(astrid_events::kernel_api::CommandInfo {
136 name: cmd.name.clone(),
137 description: cmd
138 .description
139 .clone()
140 .unwrap_or_else(|| "No description".to_string()),
141 provider_capsule: c.id().to_string(),
142 });
143 }
144 }
145 info!(
146 count = commands.len(),
147 capsules = reg.len(),
148 "GetCommands: returning {} commands from {} capsules",
149 commands.len(),
150 reg.len()
151 );
152 KernelResponse::Commands(commands)
153 },
154 KernelRequest::ReloadCapsules => {
155 {
158 let reg = kernel.capsules.read().await;
159 let failed_ids: Vec<_> = reg
160 .list()
161 .into_iter()
162 .filter(|id| {
163 reg.get(id).is_some_and(|c| {
164 matches!(c.state(), astrid_capsule::capsule::CapsuleState::Failed(_))
165 })
166 })
167 .cloned()
168 .collect();
169 drop(reg);
170
171 let mut reg = kernel.capsules.write().await;
172 for id in failed_ids {
173 let _ = reg.unregister(&id);
174 }
175 }
176
177 kernel.load_all_capsules().await;
178 KernelResponse::Success(serde_json::json!({"status": "reloaded"}))
179 },
180 KernelRequest::GetCapsuleMetadata => {
181 let reg = kernel.capsules.read().await;
182 let mut entries = Vec::new();
183 for capsule in reg.values() {
184 let manifest = capsule.manifest();
185 entries.push(astrid_events::kernel_api::CapsuleMetadataEntry {
186 name: manifest.package.name.clone(),
187 llm_providers: manifest
188 .llm_providers
189 .iter()
190 .map(|p| astrid_events::kernel_api::LlmProviderInfo {
191 id: p.id.clone(),
192 description: p.description.clone().unwrap_or_default(),
193 capabilities: p.capabilities.clone(),
194 })
195 .collect(),
196 interceptor_events: manifest
197 .interceptors
198 .iter()
199 .map(|i| i.event.clone())
200 .collect(),
201 });
202 }
203 KernelResponse::CapsuleMetadata(entries)
204 },
205 };
206
207 publish_response(kernel, response_topic, res);
208}
209
210fn publish_response(kernel: &Arc<crate::Kernel>, response_topic: String, res: KernelResponse) {
211 if let Ok(val) = serde_json::to_value(res) {
212 let msg = IpcMessage::new(
213 response_topic,
214 IpcPayload::RawJson(val),
215 kernel.session_id.0,
216 );
217 let _ = kernel.event_bus.publish(astrid_events::AstridEvent::Ipc {
218 metadata: astrid_events::EventMetadata::new("kernel_router"),
219 message: msg,
220 });
221 }
222}
223
224struct ManagementRateLimiter {
233 buckets: HashMap<&'static str, VecDeque<Instant>>,
234}
235
236impl ManagementRateLimiter {
237 fn new() -> Self {
238 Self {
239 buckets: HashMap::new(),
240 }
241 }
242
243 fn check(&mut self, method: &'static str, max_per_minute: u32) -> bool {
246 let now = Instant::now();
247 let window = std::time::Duration::from_secs(60);
248 let timestamps = self.buckets.entry(method).or_default();
249
250 while let Some(&oldest) = timestamps.front() {
252 if now.saturating_duration_since(oldest) >= window {
253 timestamps.pop_front();
254 } else {
255 break;
256 }
257 }
258
259 if timestamps.len() >= max_per_minute as usize {
260 return false;
261 }
262 timestamps.push_back(now);
263 true
264 }
265}
266
267fn rate_limit_for_request(req: &KernelRequest) -> (&'static str, Option<u32>) {
270 match req {
271 KernelRequest::ReloadCapsules => ("ReloadCapsules", Some(5)),
272 KernelRequest::InstallCapsule { .. } => ("InstallCapsule", Some(10)),
273 KernelRequest::ApproveCapability { .. } => ("ApproveCapability", Some(10)),
274 KernelRequest::ListCapsules => ("ListCapsules", None),
276 KernelRequest::GetCommands => ("GetCommands", None),
277 KernelRequest::GetCapsuleMetadata => ("GetCapsuleMetadata", None),
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn rate_limiter_allows_within_limit() {
287 let mut limiter = ManagementRateLimiter::new();
288 for _ in 0..5 {
289 assert!(limiter.check("ReloadCapsules", 5));
290 }
291 assert!(!limiter.check("ReloadCapsules", 5));
293 }
294
295 #[test]
296 fn rate_limiter_independent_buckets() {
297 let mut limiter = ManagementRateLimiter::new();
298 for _ in 0..5 {
300 assert!(limiter.check("ReloadCapsules", 5));
301 }
302 assert!(!limiter.check("ReloadCapsules", 5));
303
304 assert!(limiter.check("InstallCapsule", 10));
306 }
307
308 #[test]
309 fn rate_limiter_sliding_window_eviction() {
310 let mut limiter = ManagementRateLimiter::new();
311 for _ in 0..5 {
313 assert!(limiter.check("ReloadCapsules", 5));
314 }
315 assert!(!limiter.check("ReloadCapsules", 5));
316
317 if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
319 let past = Instant::now() - std::time::Duration::from_secs(61);
320 for ts in timestamps.iter_mut() {
321 *ts = past;
322 }
323 }
324
325 assert!(limiter.check("ReloadCapsules", 5));
327 }
328
329 #[test]
330 fn rate_limiter_sliding_window_prevents_boundary_burst() {
331 let mut limiter = ManagementRateLimiter::new();
332 for _ in 0..5 {
334 assert!(limiter.check("ReloadCapsules", 5));
335 }
336
337 if let Some(timestamps) = limiter.buckets.get_mut("ReloadCapsules") {
340 let past = Instant::now() - std::time::Duration::from_secs(61);
341 for ts in timestamps.iter_mut().take(3) {
342 *ts = past;
343 }
344 }
345
346 for _ in 0..3 {
348 assert!(limiter.check("ReloadCapsules", 5));
349 }
350 assert!(!limiter.check("ReloadCapsules", 5));
351 }
352
353 #[test]
354 fn rate_limit_for_request_returns_correct_limits() {
355 let (name, limit) = rate_limit_for_request(&KernelRequest::ReloadCapsules);
356 assert_eq!(name, "ReloadCapsules");
357 assert_eq!(limit, Some(5));
358
359 let (name, limit) = rate_limit_for_request(&KernelRequest::ListCapsules);
360 assert_eq!(name, "ListCapsules");
361 assert_eq!(limit, None);
362 }
363}