1use std::{collections::HashMap, pin::Pin, sync::Arc};
10use tokio::sync::{broadcast, mpsc, Mutex, Notify, RwLock};
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status};
13use tracing::{debug, info, warn};
14
15pub mod alien_bindings {
16 pub mod control {
17 tonic::include_proto!("alien_bindings.control");
18
19 pub const FILE_DESCRIPTOR_SET: &[u8] =
20 tonic::include_file_descriptor_set!("alien_bindings.control_descriptor");
21 }
22}
23
24use alien_bindings::control::{
25 control_service_server::{ControlService, ControlServiceServer},
26 RegisterEventHandlerRequest, RegisterEventHandlerResponse, RegisterHttpServerRequest,
27 RegisterHttpServerResponse, SendTaskResultRequest, SendTaskResultResponse, Task,
28 WaitForTasksRequest,
29};
30
31#[derive(Debug, Clone)]
33pub struct HandlerRegistration {
34 pub handler_type: String,
35 pub resource_name: String,
36}
37
38#[derive(Debug)]
40pub struct ControlState {
41 http_port: Option<u16>,
43 handlers: HashMap<(String, String), HandlerRegistration>,
45 http_ready_tx: Option<tokio::sync::oneshot::Sender<u16>>,
47}
48
49impl Default for ControlState {
50 fn default() -> Self {
51 Self {
52 http_port: None,
53 handlers: HashMap::new(),
54 http_ready_tx: None,
55 }
56 }
57}
58
59#[derive(Clone)]
61pub struct ControlGrpcServer {
62 state: Arc<RwLock<ControlState>>,
64 task_tx: broadcast::Sender<Task>,
66 result_channels: Arc<Mutex<HashMap<String, mpsc::Sender<Result<TaskResult, String>>>>>,
68 task_subscriber_notify: Arc<Notify>,
70}
71
72#[derive(Debug, Clone)]
74pub struct TaskResult {
75 pub success: bool,
77 pub response_data: Vec<u8>,
79 pub error_code: Option<String>,
81 pub error_message: Option<String>,
83}
84
85impl TaskResult {
86 pub fn success(data: Vec<u8>) -> Self {
88 Self {
89 success: true,
90 response_data: data,
91 error_code: None,
92 error_message: None,
93 }
94 }
95
96 pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
98 Self {
99 success: false,
100 response_data: Vec::new(),
101 error_code: Some(code.into()),
102 error_message: Some(message.into()),
103 }
104 }
105}
106
107impl ControlGrpcServer {
108 pub fn new() -> Self {
109 let (task_tx, _) = broadcast::channel(1024);
110 Self {
111 state: Arc::new(RwLock::new(ControlState::default())),
112 task_tx,
113 task_subscriber_notify: Arc::new(Notify::new()),
114 result_channels: Arc::new(Mutex::new(HashMap::new())),
115 }
116 }
117
118 pub async fn get_http_port(&self) -> Option<u16> {
120 self.state.read().await.http_port
121 }
122
123 pub async fn has_registered_handlers(&self) -> bool {
125 !self.state.read().await.handlers.is_empty()
126 }
127
128 pub async fn has_handler(&self, handler_type: &str, resource_name: &str) -> bool {
130 let state = self.state.read().await;
131 state
132 .handlers
133 .contains_key(&(handler_type.to_string(), resource_name.to_string()))
134 }
135
136 pub async fn get_handlers(&self) -> Vec<HandlerRegistration> {
138 let state = self.state.read().await;
139 state.handlers.values().cloned().collect()
140 }
141
142 pub async fn wait_for_http_server(&self) -> Option<u16> {
144 {
146 let state = self.state.read().await;
147 if let Some(port) = state.http_port {
148 return Some(port);
149 }
150 }
151
152 let (tx, rx) = tokio::sync::oneshot::channel();
154 {
155 let mut state = self.state.write().await;
156 if let Some(port) = state.http_port {
158 return Some(port);
159 }
160 state.http_ready_tx = Some(tx);
161 }
162
163 rx.await.ok()
165 }
166
167 pub async fn wait_for_task_subscriber(&self) {
170 if self.task_tx.receiver_count() > 0 {
171 return;
172 }
173 self.task_subscriber_notify.notified().await;
177 }
178
179 pub async fn send_task(
183 &self,
184 task: Task,
185 timeout: std::time::Duration,
186 ) -> Result<TaskResult, String> {
187 let task_id = task.task_id.clone();
188
189 let (result_tx, mut result_rx) = mpsc::channel(1);
191 {
192 let mut channels = self.result_channels.lock().await;
193 channels.insert(task_id.clone(), result_tx);
194 }
195
196 let receiver_count = self
198 .task_tx
199 .send(task)
200 .map_err(|e| format!("Failed to send task: {}", e))?;
201
202 debug!(task_id = %task_id, receiver_count = receiver_count, "Task broadcast to subscribers, waiting for result");
203
204 let result = tokio::time::timeout(timeout, result_rx.recv())
206 .await
207 .map_err(|_| {
208 warn!(task_id = %task_id, timeout_secs = timeout.as_secs(), "Task result timeout — app never sent result");
209 "Task result timeout".to_string()
210 })?
211 .ok_or_else(|| {
212 warn!(task_id = %task_id, "Result channel closed without sending result");
213 "Result channel closed".to_string()
214 })?;
215
216 debug!(task_id = %task_id, success = result.as_ref().map(|r| r.success).unwrap_or(false), "Received task result from app");
217
218 {
220 let mut channels = self.result_channels.lock().await;
221 channels.remove(&task_id);
222 }
223
224 result
225 }
226
227 pub fn into_service(self) -> ControlServiceServer<Self> {
229 ControlServiceServer::new(self)
230 }
231}
232
233impl Default for ControlGrpcServer {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239#[tonic::async_trait]
240impl ControlService for ControlGrpcServer {
241 async fn register_http_server(
242 &self,
243 request: Request<RegisterHttpServerRequest>,
244 ) -> Result<Response<RegisterHttpServerResponse>, Status> {
245 let req = request.into_inner();
246 let port = req.port as u16;
247
248 info!(port = port, "Application registered HTTP server");
249
250 let mut state = self.state.write().await;
251 state.http_port = Some(port);
252
253 if let Some(tx) = state.http_ready_tx.take() {
255 let _ = tx.send(port);
256 }
257
258 Ok(Response::new(RegisterHttpServerResponse { success: true }))
259 }
260
261 async fn register_event_handler(
262 &self,
263 request: Request<RegisterEventHandlerRequest>,
264 ) -> Result<Response<RegisterEventHandlerResponse>, Status> {
265 let req = request.into_inner();
266
267 info!(
268 handler_type = %req.handler_type,
269 resource_name = %req.resource_name,
270 "Application registered event handler"
271 );
272
273 let registration = HandlerRegistration {
274 handler_type: req.handler_type.clone(),
275 resource_name: req.resource_name.clone(),
276 };
277
278 let mut state = self.state.write().await;
279 state
280 .handlers
281 .insert((req.handler_type, req.resource_name), registration);
282
283 Ok(Response::new(RegisterEventHandlerResponse {
284 success: true,
285 }))
286 }
287
288 type WaitForTasksStream = Pin<Box<dyn Stream<Item = Result<Task, Status>> + Send>>;
289
290 async fn wait_for_tasks(
291 &self,
292 request: Request<WaitForTasksRequest>,
293 ) -> Result<Response<Self::WaitForTasksStream>, Status> {
294 let req = request.into_inner();
295 debug!(application_id = %req.application_id, "Application waiting for tasks");
296
297 let mut task_rx = self.task_tx.subscribe();
298 self.task_subscriber_notify.notify_one();
299
300 let stream = async_stream::stream! {
301 loop {
302 match task_rx.recv().await {
303 Ok(task) => {
304 yield Ok(task);
305 }
306 Err(broadcast::error::RecvError::Lagged(n)) => {
307 warn!(skipped = n, "Task stream lagged, some tasks may have been dropped");
308 continue;
309 }
310 Err(broadcast::error::RecvError::Closed) => {
311 debug!("Task channel closed, ending stream");
312 break;
313 }
314 }
315 }
316 };
317
318 Ok(Response::new(Box::pin(stream)))
319 }
320
321 async fn send_task_result(
322 &self,
323 request: Request<SendTaskResultRequest>,
324 ) -> Result<Response<SendTaskResultResponse>, Status> {
325 let req = request.into_inner();
326 let task_id = req.task_id;
327
328 let (result, result_desc) = match req.result {
329 Some(alien_bindings::control::send_task_result_request::Result::Success(ref s)) => {
330 let desc = format!("success, response_data_len={}", s.response_data.len());
331 (Ok(TaskResult::success(s.response_data.clone())), desc)
332 }
333 Some(alien_bindings::control::send_task_result_request::Result::Error(ref e)) => {
334 let desc = format!("error, code={}, message={}", e.code, e.message);
335 (
336 Ok(TaskResult::error(e.code.clone(), e.message.clone())),
337 desc,
338 )
339 }
340 None => (Err("No result in response".to_string()), "none".to_string()),
341 };
342
343 debug!(task_id = %task_id, result = %result_desc, "Received task result from app via gRPC");
344
345 let channels = self.result_channels.lock().await;
347 if let Some(tx) = channels.get(&task_id) {
348 if let Err(e) = tx.send(result).await {
349 warn!(task_id = %task_id, "Failed to send result to waiting channel: {:?}", e);
350 } else {
351 debug!(task_id = %task_id, "Result forwarded to send_task caller");
352 }
353 } else {
354 warn!(task_id = %task_id, "No waiting channel found for task result (task may have already timed out)");
355 }
356
357 Ok(Response::new(SendTaskResultResponse { acknowledged: true }))
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[tokio::test]
366 async fn test_register_http_server() {
367 let server = ControlGrpcServer::new();
368
369 assert!(server.get_http_port().await.is_none());
370
371 let req = Request::new(RegisterHttpServerRequest { port: 8080 });
372 let resp = server.register_http_server(req).await.unwrap();
373
374 assert!(resp.into_inner().success);
375 assert_eq!(server.get_http_port().await, Some(8080));
376 }
377
378 #[tokio::test]
379 async fn test_register_event_handler() {
380 let server = ControlGrpcServer::new();
381
382 assert!(!server.has_handler("storage", "uploads").await);
383
384 let req = Request::new(RegisterEventHandlerRequest {
385 handler_type: "storage".to_string(),
386 resource_name: "uploads".to_string(),
387 });
388 let resp = server.register_event_handler(req).await.unwrap();
389
390 assert!(resp.into_inner().success);
391 assert!(server.has_handler("storage", "uploads").await);
392 }
393
394 #[tokio::test]
395 async fn test_wait_for_http_server() {
396 let server = ControlGrpcServer::new();
397 let server_clone = server.clone();
398
399 let wait_task = tokio::spawn(async move { server_clone.wait_for_http_server().await });
401
402 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
404
405 let req = Request::new(RegisterHttpServerRequest { port: 3000 });
407 server.register_http_server(req).await.unwrap();
408
409 let port = wait_task.await.unwrap();
411 assert_eq!(port, Some(3000));
412 }
413}