Skip to main content

alien_bindings/grpc/
control_service.rs

1//! Control service for runtime-application communication.
2//!
3//! This service handles:
4//! - HTTP server registration
5//! - Event handler registration  
6//! - Task streaming from runtime to app
7//! - Task result submission
8
9use std::{collections::HashMap, pin::Pin, sync::Arc};
10use tokio::sync::{broadcast, mpsc, Mutex, Notify, RwLock};
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status};
13use tracing::{debug, info, warn};
14
15pub mod alien_bindings {
16    pub mod control {
17        tonic::include_proto!("alien_bindings.control");
18
19        pub const FILE_DESCRIPTOR_SET: &[u8] =
20            tonic::include_file_descriptor_set!("alien_bindings.control_descriptor");
21    }
22}
23
24use alien_bindings::control::{
25    control_service_server::{ControlService, ControlServiceServer},
26    RegisterEventHandlerRequest, RegisterEventHandlerResponse, RegisterHttpServerRequest,
27    RegisterHttpServerResponse, SendTaskResultRequest, SendTaskResultResponse, Task,
28    WaitForTasksRequest,
29};
30
31/// Handler registration info
32#[derive(Debug, Clone)]
33pub struct HandlerRegistration {
34    pub handler_type: String,
35    pub resource_name: String,
36}
37
38/// Tracks registered handlers and HTTP server port
39#[derive(Debug)]
40pub struct ControlState {
41    /// Registered HTTP server port (if any)
42    http_port: Option<u16>,
43    /// Registered event handlers: (handler_type, resource_name) -> registration
44    handlers: HashMap<(String, String), HandlerRegistration>,
45    /// Sender for notifying when HTTP server is registered
46    http_ready_tx: Option<tokio::sync::oneshot::Sender<u16>>,
47}
48
49impl Default for ControlState {
50    fn default() -> Self {
51        Self {
52            http_port: None,
53            handlers: HashMap::new(),
54            http_ready_tx: None,
55        }
56    }
57}
58
59/// Control gRPC server implementation
60#[derive(Clone)]
61pub struct ControlGrpcServer {
62    /// Shared state
63    state: Arc<RwLock<ControlState>>,
64    /// Task sender - runtime sends tasks here
65    task_tx: broadcast::Sender<Task>,
66    /// Result channels - keyed by task_id
67    result_channels: Arc<Mutex<HashMap<String, mpsc::Sender<Result<TaskResult, String>>>>>,
68    /// Notified when the first task stream subscriber connects
69    task_subscriber_notify: Arc<Notify>,
70}
71
72/// Result for a task
73#[derive(Debug, Clone)]
74pub struct TaskResult {
75    /// Whether the task was processed successfully
76    pub success: bool,
77    /// Response data (for successful processing)
78    pub response_data: Vec<u8>,
79    /// Error code (for failed processing)
80    pub error_code: Option<String>,
81    /// Error message (for failed processing)
82    pub error_message: Option<String>,
83}
84
85impl TaskResult {
86    /// Create a success response
87    pub fn success(data: Vec<u8>) -> Self {
88        Self {
89            success: true,
90            response_data: data,
91            error_code: None,
92            error_message: None,
93        }
94    }
95
96    /// Create an error response
97    pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
98        Self {
99            success: false,
100            response_data: Vec::new(),
101            error_code: Some(code.into()),
102            error_message: Some(message.into()),
103        }
104    }
105}
106
107impl ControlGrpcServer {
108    pub fn new() -> Self {
109        let (task_tx, _) = broadcast::channel(1024);
110        Self {
111            state: Arc::new(RwLock::new(ControlState::default())),
112            task_tx,
113            task_subscriber_notify: Arc::new(Notify::new()),
114            result_channels: Arc::new(Mutex::new(HashMap::new())),
115        }
116    }
117
118    /// Get the registered HTTP port (if any)
119    pub async fn get_http_port(&self) -> Option<u16> {
120        self.state.read().await.http_port
121    }
122
123    /// Check if any event handlers have been registered by the application.
124    pub async fn has_registered_handlers(&self) -> bool {
125        !self.state.read().await.handlers.is_empty()
126    }
127
128    /// Check if a handler is registered
129    pub async fn has_handler(&self, handler_type: &str, resource_name: &str) -> bool {
130        let state = self.state.read().await;
131        state
132            .handlers
133            .contains_key(&(handler_type.to_string(), resource_name.to_string()))
134    }
135
136    /// Get all registered handlers
137    pub async fn get_handlers(&self) -> Vec<HandlerRegistration> {
138        let state = self.state.read().await;
139        state.handlers.values().cloned().collect()
140    }
141
142    /// Wait for HTTP server to be registered
143    pub async fn wait_for_http_server(&self) -> Option<u16> {
144        // Check if already registered
145        {
146            let state = self.state.read().await;
147            if let Some(port) = state.http_port {
148                return Some(port);
149            }
150        }
151
152        // Create a oneshot channel and store sender
153        let (tx, rx) = tokio::sync::oneshot::channel();
154        {
155            let mut state = self.state.write().await;
156            // Double-check in case it was registered while we were waiting for write lock
157            if let Some(port) = state.http_port {
158                return Some(port);
159            }
160            state.http_ready_tx = Some(tx);
161        }
162
163        // Wait for registration
164        rx.await.ok()
165    }
166
167    /// Wait for at least one application to subscribe to the task stream.
168    /// Returns immediately if there's already a subscriber.
169    pub async fn wait_for_task_subscriber(&self) {
170        if self.task_tx.receiver_count() > 0 {
171            return;
172        }
173        // notify_one() stores a permit when no one is waiting, so even if
174        // the app subscribes between our check above and this await, the
175        // stored permit makes notified() return immediately.
176        self.task_subscriber_notify.notified().await;
177    }
178
179    /// Send a task to the application and wait for the result.
180    /// This is used for all task types - the runtime must wait for the app to process
181    /// before acknowledging to the platform (storage/cron/queue) or submitting responses (commands).
182    pub async fn send_task(
183        &self,
184        task: Task,
185        timeout: std::time::Duration,
186    ) -> Result<TaskResult, String> {
187        let task_id = task.task_id.clone();
188
189        // Create result channel
190        let (result_tx, mut result_rx) = mpsc::channel(1);
191        {
192            let mut channels = self.result_channels.lock().await;
193            channels.insert(task_id.clone(), result_tx);
194        }
195
196        // Send the task
197        let receiver_count = self
198            .task_tx
199            .send(task)
200            .map_err(|e| format!("Failed to send task: {}", e))?;
201
202        debug!(task_id = %task_id, receiver_count = receiver_count, "Task broadcast to subscribers, waiting for result");
203
204        // Wait for result with timeout
205        let result = tokio::time::timeout(timeout, result_rx.recv())
206            .await
207            .map_err(|_| {
208                warn!(task_id = %task_id, timeout_secs = timeout.as_secs(), "Task result timeout — app never sent result");
209                "Task result timeout".to_string()
210            })?
211            .ok_or_else(|| {
212                warn!(task_id = %task_id, "Result channel closed without sending result");
213                "Result channel closed".to_string()
214            })?;
215
216        debug!(task_id = %task_id, success = result.as_ref().map(|r| r.success).unwrap_or(false), "Received task result from app");
217
218        // Clean up channel
219        {
220            let mut channels = self.result_channels.lock().await;
221            channels.remove(&task_id);
222        }
223
224        result
225    }
226
227    /// Convert to tonic service
228    pub fn into_service(self) -> ControlServiceServer<Self> {
229        ControlServiceServer::new(self)
230    }
231}
232
233impl Default for ControlGrpcServer {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239#[tonic::async_trait]
240impl ControlService for ControlGrpcServer {
241    async fn register_http_server(
242        &self,
243        request: Request<RegisterHttpServerRequest>,
244    ) -> Result<Response<RegisterHttpServerResponse>, Status> {
245        let req = request.into_inner();
246        let port = req.port as u16;
247
248        info!(port = port, "Application registered HTTP server");
249
250        let mut state = self.state.write().await;
251        state.http_port = Some(port);
252
253        // Notify any waiters
254        if let Some(tx) = state.http_ready_tx.take() {
255            let _ = tx.send(port);
256        }
257
258        Ok(Response::new(RegisterHttpServerResponse { success: true }))
259    }
260
261    async fn register_event_handler(
262        &self,
263        request: Request<RegisterEventHandlerRequest>,
264    ) -> Result<Response<RegisterEventHandlerResponse>, Status> {
265        let req = request.into_inner();
266
267        info!(
268            handler_type = %req.handler_type,
269            resource_name = %req.resource_name,
270            "Application registered event handler"
271        );
272
273        let registration = HandlerRegistration {
274            handler_type: req.handler_type.clone(),
275            resource_name: req.resource_name.clone(),
276        };
277
278        let mut state = self.state.write().await;
279        state
280            .handlers
281            .insert((req.handler_type, req.resource_name), registration);
282
283        Ok(Response::new(RegisterEventHandlerResponse {
284            success: true,
285        }))
286    }
287
288    type WaitForTasksStream = Pin<Box<dyn Stream<Item = Result<Task, Status>> + Send>>;
289
290    async fn wait_for_tasks(
291        &self,
292        request: Request<WaitForTasksRequest>,
293    ) -> Result<Response<Self::WaitForTasksStream>, Status> {
294        let req = request.into_inner();
295        debug!(application_id = %req.application_id, "Application waiting for tasks");
296
297        let mut task_rx = self.task_tx.subscribe();
298        self.task_subscriber_notify.notify_one();
299
300        let stream = async_stream::stream! {
301            loop {
302                match task_rx.recv().await {
303                    Ok(task) => {
304                        yield Ok(task);
305                    }
306                    Err(broadcast::error::RecvError::Lagged(n)) => {
307                        warn!(skipped = n, "Task stream lagged, some tasks may have been dropped");
308                        continue;
309                    }
310                    Err(broadcast::error::RecvError::Closed) => {
311                        debug!("Task channel closed, ending stream");
312                        break;
313                    }
314                }
315            }
316        };
317
318        Ok(Response::new(Box::pin(stream)))
319    }
320
321    async fn send_task_result(
322        &self,
323        request: Request<SendTaskResultRequest>,
324    ) -> Result<Response<SendTaskResultResponse>, Status> {
325        let req = request.into_inner();
326        let task_id = req.task_id;
327
328        let (result, result_desc) = match req.result {
329            Some(alien_bindings::control::send_task_result_request::Result::Success(ref s)) => {
330                let desc = format!("success, response_data_len={}", s.response_data.len());
331                (Ok(TaskResult::success(s.response_data.clone())), desc)
332            }
333            Some(alien_bindings::control::send_task_result_request::Result::Error(ref e)) => {
334                let desc = format!("error, code={}, message={}", e.code, e.message);
335                (
336                    Ok(TaskResult::error(e.code.clone(), e.message.clone())),
337                    desc,
338                )
339            }
340            None => (Err("No result in response".to_string()), "none".to_string()),
341        };
342
343        debug!(task_id = %task_id, result = %result_desc, "Received task result from app via gRPC");
344
345        // Send to waiting channel if any
346        let channels = self.result_channels.lock().await;
347        if let Some(tx) = channels.get(&task_id) {
348            if let Err(e) = tx.send(result).await {
349                warn!(task_id = %task_id, "Failed to send result to waiting channel: {:?}", e);
350            } else {
351                debug!(task_id = %task_id, "Result forwarded to send_task caller");
352            }
353        } else {
354            warn!(task_id = %task_id, "No waiting channel found for task result (task may have already timed out)");
355        }
356
357        Ok(Response::new(SendTaskResultResponse { acknowledged: true }))
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[tokio::test]
366    async fn test_register_http_server() {
367        let server = ControlGrpcServer::new();
368
369        assert!(server.get_http_port().await.is_none());
370
371        let req = Request::new(RegisterHttpServerRequest { port: 8080 });
372        let resp = server.register_http_server(req).await.unwrap();
373
374        assert!(resp.into_inner().success);
375        assert_eq!(server.get_http_port().await, Some(8080));
376    }
377
378    #[tokio::test]
379    async fn test_register_event_handler() {
380        let server = ControlGrpcServer::new();
381
382        assert!(!server.has_handler("storage", "uploads").await);
383
384        let req = Request::new(RegisterEventHandlerRequest {
385            handler_type: "storage".to_string(),
386            resource_name: "uploads".to_string(),
387        });
388        let resp = server.register_event_handler(req).await.unwrap();
389
390        assert!(resp.into_inner().success);
391        assert!(server.has_handler("storage", "uploads").await);
392    }
393
394    #[tokio::test]
395    async fn test_wait_for_http_server() {
396        let server = ControlGrpcServer::new();
397        let server_clone = server.clone();
398
399        // Spawn a task to wait for HTTP server
400        let wait_task = tokio::spawn(async move { server_clone.wait_for_http_server().await });
401
402        // Give the wait task time to start
403        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
404
405        // Register HTTP server
406        let req = Request::new(RegisterHttpServerRequest { port: 3000 });
407        server.register_http_server(req).await.unwrap();
408
409        // Wait task should complete with the port
410        let port = wait_task.await.unwrap();
411        assert_eq!(port, Some(3000));
412    }
413}