Skip to main content

alien_bindings/grpc/
control_service.rs

1//! Control service for runtime-application communication.
2//!
3//! This service handles:
4//! - HTTP server registration
5//! - Event handler registration  
6//! - Task streaming from runtime to app
7//! - Task result submission
8
9use std::{collections::HashMap, pin::Pin, sync::Arc};
10use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status};
13use tracing::{debug, info, warn};
14
15pub mod alien_bindings {
16    pub mod control {
17        tonic::include_proto!("alien_bindings.control");
18
19        pub const FILE_DESCRIPTOR_SET: &[u8] =
20            tonic::include_file_descriptor_set!("alien_bindings.control_descriptor");
21    }
22}
23
24use alien_bindings::control::{
25    control_service_server::{ControlService, ControlServiceServer},
26    RegisterEventHandlerRequest, RegisterEventHandlerResponse, RegisterHttpServerRequest,
27    RegisterHttpServerResponse, SendTaskResultRequest, SendTaskResultResponse, Task,
28    WaitForTasksRequest,
29};
30
31/// Handler registration info
32#[derive(Debug, Clone)]
33pub struct HandlerRegistration {
34    pub handler_type: String,
35    pub resource_name: String,
36}
37
38/// Tracks registered handlers and HTTP server port
39#[derive(Debug)]
40pub struct ControlState {
41    /// Registered HTTP server port (if any)
42    http_port: Option<u16>,
43    /// Registered event handlers: (handler_type, resource_name) -> registration
44    handlers: HashMap<(String, String), HandlerRegistration>,
45    /// Sender for notifying when HTTP server is registered
46    http_ready_tx: Option<tokio::sync::oneshot::Sender<u16>>,
47}
48
49impl Default for ControlState {
50    fn default() -> Self {
51        Self {
52            http_port: None,
53            handlers: HashMap::new(),
54            http_ready_tx: None,
55        }
56    }
57}
58
59/// Control gRPC server implementation
60#[derive(Clone)]
61pub struct ControlGrpcServer {
62    /// Shared state
63    state: Arc<RwLock<ControlState>>,
64    /// Task sender - runtime sends tasks here
65    task_tx: broadcast::Sender<Task>,
66    /// Result channels - keyed by task_id
67    result_channels: Arc<Mutex<HashMap<String, mpsc::Sender<Result<TaskResult, String>>>>>,
68}
69
70/// Result for a task
71#[derive(Debug, Clone)]
72pub struct TaskResult {
73    /// Whether the task was processed successfully
74    pub success: bool,
75    /// Response data (for successful processing)
76    pub response_data: Vec<u8>,
77    /// Error code (for failed processing)
78    pub error_code: Option<String>,
79    /// Error message (for failed processing)
80    pub error_message: Option<String>,
81}
82
83impl TaskResult {
84    /// Create a success response
85    pub fn success(data: Vec<u8>) -> Self {
86        Self {
87            success: true,
88            response_data: data,
89            error_code: None,
90            error_message: None,
91        }
92    }
93
94    /// Create an error response
95    pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
96        Self {
97            success: false,
98            response_data: Vec::new(),
99            error_code: Some(code.into()),
100            error_message: Some(message.into()),
101        }
102    }
103}
104
105impl ControlGrpcServer {
106    pub fn new() -> Self {
107        let (task_tx, _) = broadcast::channel(1024);
108        Self {
109            state: Arc::new(RwLock::new(ControlState::default())),
110            task_tx,
111            result_channels: Arc::new(Mutex::new(HashMap::new())),
112        }
113    }
114
115    /// Get the registered HTTP port (if any)
116    pub async fn get_http_port(&self) -> Option<u16> {
117        self.state.read().await.http_port
118    }
119
120    /// Check if a handler is registered
121    pub async fn has_handler(&self, handler_type: &str, resource_name: &str) -> bool {
122        let state = self.state.read().await;
123        state
124            .handlers
125            .contains_key(&(handler_type.to_string(), resource_name.to_string()))
126    }
127
128    /// Get all registered handlers
129    pub async fn get_handlers(&self) -> Vec<HandlerRegistration> {
130        let state = self.state.read().await;
131        state.handlers.values().cloned().collect()
132    }
133
134    /// Wait for HTTP server to be registered
135    pub async fn wait_for_http_server(&self) -> Option<u16> {
136        // Check if already registered
137        {
138            let state = self.state.read().await;
139            if let Some(port) = state.http_port {
140                return Some(port);
141            }
142        }
143
144        // Create a oneshot channel and store sender
145        let (tx, rx) = tokio::sync::oneshot::channel();
146        {
147            let mut state = self.state.write().await;
148            // Double-check in case it was registered while we were waiting for write lock
149            if let Some(port) = state.http_port {
150                return Some(port);
151            }
152            state.http_ready_tx = Some(tx);
153        }
154
155        // Wait for registration
156        rx.await.ok()
157    }
158
159    /// Send a task to the application and wait for the result.
160    /// This is used for all task types - the runtime must wait for the app to process
161    /// before acknowledging to the platform (storage/cron/queue) or submitting responses (commands).
162    pub async fn send_task(
163        &self,
164        task: Task,
165        timeout: std::time::Duration,
166    ) -> Result<TaskResult, String> {
167        let task_id = task.task_id.clone();
168
169        // Create result channel
170        let (result_tx, mut result_rx) = mpsc::channel(1);
171        {
172            let mut channels = self.result_channels.lock().await;
173            channels.insert(task_id.clone(), result_tx);
174        }
175
176        // Send the task
177        self.task_tx
178            .send(task)
179            .map_err(|e| format!("Failed to send task: {}", e))?;
180
181        // Wait for result with timeout
182        let result = tokio::time::timeout(timeout, result_rx.recv())
183            .await
184            .map_err(|_| "Task result timeout".to_string())?
185            .ok_or_else(|| "Result channel closed".to_string())?;
186
187        // Clean up channel
188        {
189            let mut channels = self.result_channels.lock().await;
190            channels.remove(&task_id);
191        }
192
193        result
194    }
195
196    /// Convert to tonic service
197    pub fn into_service(self) -> ControlServiceServer<Self> {
198        ControlServiceServer::new(self)
199    }
200}
201
202impl Default for ControlGrpcServer {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[tonic::async_trait]
209impl ControlService for ControlGrpcServer {
210    async fn register_http_server(
211        &self,
212        request: Request<RegisterHttpServerRequest>,
213    ) -> Result<Response<RegisterHttpServerResponse>, Status> {
214        let req = request.into_inner();
215        let port = req.port as u16;
216
217        info!(port = port, "Application registered HTTP server");
218
219        let mut state = self.state.write().await;
220        state.http_port = Some(port);
221
222        // Notify any waiters
223        if let Some(tx) = state.http_ready_tx.take() {
224            let _ = tx.send(port);
225        }
226
227        Ok(Response::new(RegisterHttpServerResponse { success: true }))
228    }
229
230    async fn register_event_handler(
231        &self,
232        request: Request<RegisterEventHandlerRequest>,
233    ) -> Result<Response<RegisterEventHandlerResponse>, Status> {
234        let req = request.into_inner();
235
236        info!(
237            handler_type = %req.handler_type,
238            resource_name = %req.resource_name,
239            "Application registered event handler"
240        );
241
242        let registration = HandlerRegistration {
243            handler_type: req.handler_type.clone(),
244            resource_name: req.resource_name.clone(),
245        };
246
247        let mut state = self.state.write().await;
248        state
249            .handlers
250            .insert((req.handler_type, req.resource_name), registration);
251
252        Ok(Response::new(RegisterEventHandlerResponse {
253            success: true,
254        }))
255    }
256
257    type WaitForTasksStream = Pin<Box<dyn Stream<Item = Result<Task, Status>> + Send>>;
258
259    async fn wait_for_tasks(
260        &self,
261        request: Request<WaitForTasksRequest>,
262    ) -> Result<Response<Self::WaitForTasksStream>, Status> {
263        let req = request.into_inner();
264        debug!(application_id = %req.application_id, "Application waiting for tasks");
265
266        let mut task_rx = self.task_tx.subscribe();
267
268        let stream = async_stream::stream! {
269            loop {
270                match task_rx.recv().await {
271                    Ok(task) => {
272                        yield Ok(task);
273                    }
274                    Err(broadcast::error::RecvError::Lagged(n)) => {
275                        warn!(skipped = n, "Task stream lagged, some tasks may have been dropped");
276                        continue;
277                    }
278                    Err(broadcast::error::RecvError::Closed) => {
279                        debug!("Task channel closed, ending stream");
280                        break;
281                    }
282                }
283            }
284        };
285
286        Ok(Response::new(Box::pin(stream)))
287    }
288
289    async fn send_task_result(
290        &self,
291        request: Request<SendTaskResultRequest>,
292    ) -> Result<Response<SendTaskResultResponse>, Status> {
293        let req = request.into_inner();
294        let task_id = req.task_id;
295
296        debug!(task_id = %task_id, "Received task result");
297
298        let result = match req.result {
299            Some(alien_bindings::control::send_task_result_request::Result::Success(s)) => {
300                Ok(TaskResult::success(s.response_data))
301            }
302            Some(alien_bindings::control::send_task_result_request::Result::Error(e)) => {
303                Ok(TaskResult::error(e.code, e.message))
304            }
305            None => Err("No result in response".to_string()),
306        };
307
308        // Send to waiting channel if any
309        let channels = self.result_channels.lock().await;
310        if let Some(tx) = channels.get(&task_id) {
311            let _ = tx.send(result).await;
312        }
313
314        Ok(Response::new(SendTaskResultResponse { acknowledged: true }))
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[tokio::test]
323    async fn test_register_http_server() {
324        let server = ControlGrpcServer::new();
325
326        assert!(server.get_http_port().await.is_none());
327
328        let req = Request::new(RegisterHttpServerRequest { port: 8080 });
329        let resp = server.register_http_server(req).await.unwrap();
330
331        assert!(resp.into_inner().success);
332        assert_eq!(server.get_http_port().await, Some(8080));
333    }
334
335    #[tokio::test]
336    async fn test_register_event_handler() {
337        let server = ControlGrpcServer::new();
338
339        assert!(!server.has_handler("storage", "uploads").await);
340
341        let req = Request::new(RegisterEventHandlerRequest {
342            handler_type: "storage".to_string(),
343            resource_name: "uploads".to_string(),
344        });
345        let resp = server.register_event_handler(req).await.unwrap();
346
347        assert!(resp.into_inner().success);
348        assert!(server.has_handler("storage", "uploads").await);
349    }
350
351    #[tokio::test]
352    async fn test_wait_for_http_server() {
353        let server = ControlGrpcServer::new();
354        let server_clone = server.clone();
355
356        // Spawn a task to wait for HTTP server
357        let wait_task = tokio::spawn(async move { server_clone.wait_for_http_server().await });
358
359        // Give the wait task time to start
360        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361
362        // Register HTTP server
363        let req = Request::new(RegisterHttpServerRequest { port: 3000 });
364        server.register_http_server(req).await.unwrap();
365
366        // Wait task should complete with the port
367        let port = wait_task.await.unwrap();
368        assert_eq!(port, Some(3000));
369    }
370}