1use std::{collections::HashMap, pin::Pin, sync::Arc};
10use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status};
13use tracing::{debug, info, warn};
14
15pub mod alien_bindings {
16 pub mod control {
17 tonic::include_proto!("alien_bindings.control");
18
19 pub const FILE_DESCRIPTOR_SET: &[u8] =
20 tonic::include_file_descriptor_set!("alien_bindings.control_descriptor");
21 }
22}
23
24use alien_bindings::control::{
25 control_service_server::{ControlService, ControlServiceServer},
26 RegisterEventHandlerRequest, RegisterEventHandlerResponse, RegisterHttpServerRequest,
27 RegisterHttpServerResponse, SendTaskResultRequest, SendTaskResultResponse, Task,
28 WaitForTasksRequest,
29};
30
31#[derive(Debug, Clone)]
33pub struct HandlerRegistration {
34 pub handler_type: String,
35 pub resource_name: String,
36}
37
38#[derive(Debug)]
40pub struct ControlState {
41 http_port: Option<u16>,
43 handlers: HashMap<(String, String), HandlerRegistration>,
45 http_ready_tx: Option<tokio::sync::oneshot::Sender<u16>>,
47}
48
49impl Default for ControlState {
50 fn default() -> Self {
51 Self {
52 http_port: None,
53 handlers: HashMap::new(),
54 http_ready_tx: None,
55 }
56 }
57}
58
59#[derive(Clone)]
61pub struct ControlGrpcServer {
62 state: Arc<RwLock<ControlState>>,
64 task_tx: broadcast::Sender<Task>,
66 result_channels: Arc<Mutex<HashMap<String, mpsc::Sender<Result<TaskResult, String>>>>>,
68}
69
70#[derive(Debug, Clone)]
72pub struct TaskResult {
73 pub success: bool,
75 pub response_data: Vec<u8>,
77 pub error_code: Option<String>,
79 pub error_message: Option<String>,
81}
82
83impl TaskResult {
84 pub fn success(data: Vec<u8>) -> Self {
86 Self {
87 success: true,
88 response_data: data,
89 error_code: None,
90 error_message: None,
91 }
92 }
93
94 pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
96 Self {
97 success: false,
98 response_data: Vec::new(),
99 error_code: Some(code.into()),
100 error_message: Some(message.into()),
101 }
102 }
103}
104
105impl ControlGrpcServer {
106 pub fn new() -> Self {
107 let (task_tx, _) = broadcast::channel(1024);
108 Self {
109 state: Arc::new(RwLock::new(ControlState::default())),
110 task_tx,
111 result_channels: Arc::new(Mutex::new(HashMap::new())),
112 }
113 }
114
115 pub async fn get_http_port(&self) -> Option<u16> {
117 self.state.read().await.http_port
118 }
119
120 pub async fn has_handler(&self, handler_type: &str, resource_name: &str) -> bool {
122 let state = self.state.read().await;
123 state
124 .handlers
125 .contains_key(&(handler_type.to_string(), resource_name.to_string()))
126 }
127
128 pub async fn get_handlers(&self) -> Vec<HandlerRegistration> {
130 let state = self.state.read().await;
131 state.handlers.values().cloned().collect()
132 }
133
134 pub async fn wait_for_http_server(&self) -> Option<u16> {
136 {
138 let state = self.state.read().await;
139 if let Some(port) = state.http_port {
140 return Some(port);
141 }
142 }
143
144 let (tx, rx) = tokio::sync::oneshot::channel();
146 {
147 let mut state = self.state.write().await;
148 if let Some(port) = state.http_port {
150 return Some(port);
151 }
152 state.http_ready_tx = Some(tx);
153 }
154
155 rx.await.ok()
157 }
158
159 pub async fn send_task(
163 &self,
164 task: Task,
165 timeout: std::time::Duration,
166 ) -> Result<TaskResult, String> {
167 let task_id = task.task_id.clone();
168
169 let (result_tx, mut result_rx) = mpsc::channel(1);
171 {
172 let mut channels = self.result_channels.lock().await;
173 channels.insert(task_id.clone(), result_tx);
174 }
175
176 self.task_tx
178 .send(task)
179 .map_err(|e| format!("Failed to send task: {}", e))?;
180
181 let result = tokio::time::timeout(timeout, result_rx.recv())
183 .await
184 .map_err(|_| "Task result timeout".to_string())?
185 .ok_or_else(|| "Result channel closed".to_string())?;
186
187 {
189 let mut channels = self.result_channels.lock().await;
190 channels.remove(&task_id);
191 }
192
193 result
194 }
195
196 pub fn into_service(self) -> ControlServiceServer<Self> {
198 ControlServiceServer::new(self)
199 }
200}
201
202impl Default for ControlGrpcServer {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[tonic::async_trait]
209impl ControlService for ControlGrpcServer {
210 async fn register_http_server(
211 &self,
212 request: Request<RegisterHttpServerRequest>,
213 ) -> Result<Response<RegisterHttpServerResponse>, Status> {
214 let req = request.into_inner();
215 let port = req.port as u16;
216
217 info!(port = port, "Application registered HTTP server");
218
219 let mut state = self.state.write().await;
220 state.http_port = Some(port);
221
222 if let Some(tx) = state.http_ready_tx.take() {
224 let _ = tx.send(port);
225 }
226
227 Ok(Response::new(RegisterHttpServerResponse { success: true }))
228 }
229
230 async fn register_event_handler(
231 &self,
232 request: Request<RegisterEventHandlerRequest>,
233 ) -> Result<Response<RegisterEventHandlerResponse>, Status> {
234 let req = request.into_inner();
235
236 info!(
237 handler_type = %req.handler_type,
238 resource_name = %req.resource_name,
239 "Application registered event handler"
240 );
241
242 let registration = HandlerRegistration {
243 handler_type: req.handler_type.clone(),
244 resource_name: req.resource_name.clone(),
245 };
246
247 let mut state = self.state.write().await;
248 state
249 .handlers
250 .insert((req.handler_type, req.resource_name), registration);
251
252 Ok(Response::new(RegisterEventHandlerResponse {
253 success: true,
254 }))
255 }
256
257 type WaitForTasksStream = Pin<Box<dyn Stream<Item = Result<Task, Status>> + Send>>;
258
259 async fn wait_for_tasks(
260 &self,
261 request: Request<WaitForTasksRequest>,
262 ) -> Result<Response<Self::WaitForTasksStream>, Status> {
263 let req = request.into_inner();
264 debug!(application_id = %req.application_id, "Application waiting for tasks");
265
266 let mut task_rx = self.task_tx.subscribe();
267
268 let stream = async_stream::stream! {
269 loop {
270 match task_rx.recv().await {
271 Ok(task) => {
272 yield Ok(task);
273 }
274 Err(broadcast::error::RecvError::Lagged(n)) => {
275 warn!(skipped = n, "Task stream lagged, some tasks may have been dropped");
276 continue;
277 }
278 Err(broadcast::error::RecvError::Closed) => {
279 debug!("Task channel closed, ending stream");
280 break;
281 }
282 }
283 }
284 };
285
286 Ok(Response::new(Box::pin(stream)))
287 }
288
289 async fn send_task_result(
290 &self,
291 request: Request<SendTaskResultRequest>,
292 ) -> Result<Response<SendTaskResultResponse>, Status> {
293 let req = request.into_inner();
294 let task_id = req.task_id;
295
296 debug!(task_id = %task_id, "Received task result");
297
298 let result = match req.result {
299 Some(alien_bindings::control::send_task_result_request::Result::Success(s)) => {
300 Ok(TaskResult::success(s.response_data))
301 }
302 Some(alien_bindings::control::send_task_result_request::Result::Error(e)) => {
303 Ok(TaskResult::error(e.code, e.message))
304 }
305 None => Err("No result in response".to_string()),
306 };
307
308 let channels = self.result_channels.lock().await;
310 if let Some(tx) = channels.get(&task_id) {
311 let _ = tx.send(result).await;
312 }
313
314 Ok(Response::new(SendTaskResultResponse { acknowledged: true }))
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[tokio::test]
323 async fn test_register_http_server() {
324 let server = ControlGrpcServer::new();
325
326 assert!(server.get_http_port().await.is_none());
327
328 let req = Request::new(RegisterHttpServerRequest { port: 8080 });
329 let resp = server.register_http_server(req).await.unwrap();
330
331 assert!(resp.into_inner().success);
332 assert_eq!(server.get_http_port().await, Some(8080));
333 }
334
335 #[tokio::test]
336 async fn test_register_event_handler() {
337 let server = ControlGrpcServer::new();
338
339 assert!(!server.has_handler("storage", "uploads").await);
340
341 let req = Request::new(RegisterEventHandlerRequest {
342 handler_type: "storage".to_string(),
343 resource_name: "uploads".to_string(),
344 });
345 let resp = server.register_event_handler(req).await.unwrap();
346
347 assert!(resp.into_inner().success);
348 assert!(server.has_handler("storage", "uploads").await);
349 }
350
351 #[tokio::test]
352 async fn test_wait_for_http_server() {
353 let server = ControlGrpcServer::new();
354 let server_clone = server.clone();
355
356 let wait_task = tokio::spawn(async move { server_clone.wait_for_http_server().await });
358
359 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361
362 let req = Request::new(RegisterHttpServerRequest { port: 3000 });
364 server.register_http_server(req).await.unwrap();
365
366 let port = wait_task.await.unwrap();
368 assert_eq!(port, Some(3000));
369 }
370}