1use super::AppState;
17use axum::{
18 extract::{
19 Query, State, WebSocketUpgrade,
20 ws::{Message, WebSocket},
21 },
22 http::{HeaderMap, header},
23 response::IntoResponse,
24};
25use futures_util::{SinkExt, StreamExt};
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use std::sync::Arc;
30use tokio::sync::{mpsc, oneshot};
31
32const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
34
35const WS_NODE_PROTOCOL: &str = "construct.nodes.v1";
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct NodeCapability {
41 pub name: String,
42 pub description: String,
43 #[serde(default = "default_capability_parameters")]
44 pub parameters: serde_json::Value,
45}
46
47fn default_capability_parameters() -> serde_json::Value {
48 serde_json::json!({
49 "type": "object",
50 "properties": {}
51 })
52}
53
54#[derive(Debug, Clone)]
56pub struct NodeInfo {
57 pub node_id: String,
58 pub capabilities: Vec<NodeCapability>,
59 pub invoke_tx: mpsc::Sender<NodeInvocation>,
61}
62
63#[derive(Debug)]
65pub struct NodeInvocation {
66 pub call_id: String,
67 pub capability: String,
68 pub args: serde_json::Value,
69 pub response_tx: oneshot::Sender<NodeInvocationResult>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct NodeInvocationResult {
75 pub success: bool,
76 pub output: String,
77 pub error: Option<String>,
78}
79
80#[derive(Debug, Default, Clone)]
82pub struct NodeRegistry {
83 nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
84 max_nodes: usize,
85}
86
87impl NodeRegistry {
88 pub fn new(max_nodes: usize) -> Self {
90 Self {
91 nodes: Arc::new(RwLock::new(HashMap::new())),
92 max_nodes,
93 }
94 }
95
96 pub fn register(&self, info: NodeInfo) -> bool {
98 let mut nodes = self.nodes.write();
99 if nodes.len() >= self.max_nodes && !nodes.contains_key(&info.node_id) {
100 return false;
101 }
102 nodes.insert(info.node_id.clone(), info);
103 true
104 }
105
106 pub fn unregister(&self, node_id: &str) {
108 self.nodes.write().remove(node_id);
109 }
110
111 pub fn node_ids(&self) -> Vec<String> {
113 self.nodes.read().keys().cloned().collect()
114 }
115
116 pub fn all_capabilities(&self) -> Vec<(String, String, NodeCapability)> {
118 let nodes = self.nodes.read();
119 let mut caps = Vec::new();
120 for info in nodes.values() {
121 for cap in &info.capabilities {
122 caps.push((info.node_id.clone(), cap.name.clone(), cap.clone()));
123 }
124 }
125 caps
126 }
127
128 pub fn invoke_tx(&self, node_id: &str) -> Option<mpsc::Sender<NodeInvocation>> {
130 self.nodes.read().get(node_id).map(|n| n.invoke_tx.clone())
131 }
132
133 pub fn contains(&self, node_id: &str) -> bool {
135 self.nodes.read().contains_key(node_id)
136 }
137
138 pub fn len(&self) -> usize {
140 self.nodes.read().len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.nodes.read().is_empty()
146 }
147}
148
149#[derive(Debug, Deserialize)]
151#[serde(tag = "type", rename_all = "snake_case")]
152enum NodeMessage {
153 Register {
154 node_id: String,
155 capabilities: Vec<NodeCapability>,
156 },
157 Result {
158 call_id: String,
159 success: bool,
160 output: String,
161 #[serde(default)]
162 error: Option<String>,
163 },
164}
165
166#[derive(Debug, Serialize)]
168#[serde(tag = "type", rename_all = "snake_case")]
169enum GatewayMessage {
170 Registered {
171 node_id: String,
172 capabilities_count: usize,
173 },
174 Error {
175 message: String,
176 },
177 Invoke {
178 call_id: String,
179 capability: String,
180 args: serde_json::Value,
181 },
182}
183
184#[derive(Deserialize)]
186pub struct NodeWsQuery {
187 pub token: Option<String>,
188}
189
190fn extract_node_ws_token<'a>(
192 headers: &'a HeaderMap,
193 query_token: Option<&'a str>,
194) -> Option<&'a str> {
195 if let Some(t) = headers
197 .get(header::AUTHORIZATION)
198 .and_then(|v| v.to_str().ok())
199 .and_then(|auth| auth.strip_prefix("Bearer "))
200 {
201 if !t.is_empty() {
202 return Some(t);
203 }
204 }
205
206 if let Some(t) = headers
208 .get("sec-websocket-protocol")
209 .and_then(|v| v.to_str().ok())
210 .and_then(|protos| {
211 protos
212 .split(',')
213 .map(|p| p.trim())
214 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
215 })
216 {
217 if !t.is_empty() {
218 return Some(t);
219 }
220 }
221
222 if let Some(t) = query_token {
224 if !t.is_empty() {
225 return Some(t);
226 }
227 }
228
229 None
230}
231
232pub async fn handle_ws_nodes(
234 State(state): State<AppState>,
235 Query(params): Query<NodeWsQuery>,
236 headers: HeaderMap,
237 ws: WebSocketUpgrade,
238) -> impl IntoResponse {
239 let nodes_config = state.config.lock().nodes.clone();
241 if let Some(ref expected_token) = nodes_config.auth_token {
242 let token = extract_node_ws_token(&headers, params.token.as_deref()).unwrap_or("");
243 if token != expected_token {
244 return (
245 axum::http::StatusCode::UNAUTHORIZED,
246 "Unauthorized — provide a valid node auth token",
247 )
248 .into_response();
249 }
250 }
251
252 if nodes_config.auth_token.is_none() && state.pairing.require_pairing() {
254 let token = extract_node_ws_token(&headers, params.token.as_deref()).unwrap_or("");
255 if !state.pairing.is_authenticated(token) {
256 return (
257 axum::http::StatusCode::UNAUTHORIZED,
258 "Unauthorized — provide Authorization header or ?token= query param",
259 )
260 .into_response();
261 }
262 }
263
264 let ws = if headers
266 .get("sec-websocket-protocol")
267 .and_then(|v| v.to_str().ok())
268 .map_or(false, |protos| {
269 protos.split(',').any(|p| p.trim() == WS_NODE_PROTOCOL)
270 }) {
271 ws.protocols([WS_NODE_PROTOCOL])
272 } else {
273 ws
274 };
275
276 let registry = state.node_registry.clone();
277 ws.on_upgrade(move |socket| handle_node_socket(socket, registry))
278 .into_response()
279}
280
281async fn handle_node_socket(socket: WebSocket, registry: Arc<NodeRegistry>) {
282 let (mut sender, mut receiver) = socket.split();
283 let mut registered_node_id: Option<String> = None;
284
285 let (invoke_tx, mut invoke_rx) = mpsc::channel::<NodeInvocation>(32);
287
288 let pending: Arc<RwLock<HashMap<String, oneshot::Sender<NodeInvocationResult>>>> =
290 Arc::new(RwLock::new(HashMap::new()));
291
292 let pending_clone = Arc::clone(&pending);
293
294 let send_task = tokio::spawn(async move {
296 while let Some(invocation) = invoke_rx.recv().await {
297 let msg = GatewayMessage::Invoke {
298 call_id: invocation.call_id.clone(),
299 capability: invocation.capability,
300 args: invocation.args,
301 };
302 if let Ok(json) = serde_json::to_string(&msg) {
303 if sender.send(Message::Text(json.into())).await.is_err() {
304 break;
305 }
306 pending_clone
307 .write()
308 .insert(invocation.call_id, invocation.response_tx);
309 }
310 }
311 });
312
313 while let Some(msg) = receiver.next().await {
315 let text = match msg {
316 Ok(Message::Text(text)) => text,
317 Ok(Message::Close(_)) | Err(_) => break,
318 _ => continue,
319 };
320
321 let parsed: serde_json::Value = match serde_json::from_str(&text) {
322 Ok(v) => v,
323 Err(_) => continue,
324 };
325
326 let node_msg: NodeMessage = match serde_json::from_value(parsed) {
328 Ok(m) => m,
329 Err(_) => continue,
330 };
331
332 match node_msg {
333 NodeMessage::Register {
334 node_id,
335 capabilities,
336 } => {
337 if node_id.is_empty() || node_id.len() > 128 {
339 tracing::warn!("Node registration rejected: invalid node_id length");
340 continue;
341 }
342
343 let caps_count = capabilities.len();
344 let info = NodeInfo {
345 node_id: node_id.clone(),
346 capabilities,
347 invoke_tx: invoke_tx.clone(),
348 };
349
350 if registry.register(info) {
351 tracing::info!("Node registered: {node_id} with {caps_count} capabilities");
352 registered_node_id = Some(node_id.clone());
353
354 } else {
361 tracing::warn!(
362 "Node registration rejected: registry at capacity for {node_id}"
363 );
364 }
365 }
366 NodeMessage::Result {
367 call_id,
368 success,
369 output,
370 error,
371 } => {
372 if let Some(tx) = pending.write().remove(&call_id) {
373 let _ = tx.send(NodeInvocationResult {
374 success,
375 output,
376 error,
377 });
378 }
379 }
380 }
381 }
382
383 if let Some(node_id) = registered_node_id {
385 registry.unregister(&node_id);
386 tracing::info!("Node disconnected and unregistered: {node_id}");
387 }
388
389 send_task.abort();
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn node_registry_register_and_unregister() {
398 let registry = NodeRegistry::new(10);
399 let (tx, _rx) = mpsc::channel(1);
400
401 let info = NodeInfo {
402 node_id: "test-node".to_string(),
403 capabilities: vec![NodeCapability {
404 name: "ping".to_string(),
405 description: "Ping test".to_string(),
406 parameters: serde_json::json!({"type": "object", "properties": {}}),
407 }],
408 invoke_tx: tx,
409 };
410
411 assert!(registry.register(info));
412 assert!(registry.contains("test-node"));
413 assert_eq!(registry.len(), 1);
414
415 registry.unregister("test-node");
416 assert!(!registry.contains("test-node"));
417 assert_eq!(registry.len(), 0);
418 }
419
420 #[test]
421 fn node_registry_capacity_limit() {
422 let registry = NodeRegistry::new(2);
423
424 for i in 0..2 {
425 let (tx, _rx) = mpsc::channel(1);
426 let info = NodeInfo {
427 node_id: format!("node-{i}"),
428 capabilities: vec![],
429 invoke_tx: tx,
430 };
431 assert!(registry.register(info));
432 }
433
434 let (tx, _rx) = mpsc::channel(1);
435 let info = NodeInfo {
436 node_id: "node-overflow".to_string(),
437 capabilities: vec![],
438 invoke_tx: tx,
439 };
440 assert!(!registry.register(info));
441 assert_eq!(registry.len(), 2);
442 }
443
444 #[test]
445 fn node_registry_re_register_same_id() {
446 let registry = NodeRegistry::new(2);
447 let (tx1, _rx1) = mpsc::channel(1);
448 let (tx2, _rx2) = mpsc::channel(1);
449
450 let info1 = NodeInfo {
451 node_id: "node-1".to_string(),
452 capabilities: vec![NodeCapability {
453 name: "old".to_string(),
454 description: "Old cap".to_string(),
455 parameters: serde_json::json!({"type": "object", "properties": {}}),
456 }],
457 invoke_tx: tx1,
458 };
459 assert!(registry.register(info1));
460
461 let info2 = NodeInfo {
462 node_id: "node-1".to_string(),
463 capabilities: vec![NodeCapability {
464 name: "new".to_string(),
465 description: "New cap".to_string(),
466 parameters: serde_json::json!({"type": "object", "properties": {}}),
467 }],
468 invoke_tx: tx2,
469 };
470 assert!(registry.register(info2));
472 assert_eq!(registry.len(), 1);
473
474 let caps = registry.all_capabilities();
475 assert_eq!(caps.len(), 1);
476 assert_eq!(caps[0].2.name, "new");
477 }
478
479 #[test]
480 fn node_registry_all_capabilities() {
481 let registry = NodeRegistry::new(10);
482 let (tx1, _rx1) = mpsc::channel(1);
483 let (tx2, _rx2) = mpsc::channel(1);
484
485 registry.register(NodeInfo {
486 node_id: "phone-1".to_string(),
487 capabilities: vec![
488 NodeCapability {
489 name: "camera.snap".to_string(),
490 description: "Take a photo".to_string(),
491 parameters: serde_json::json!({"type": "object", "properties": {}}),
492 },
493 NodeCapability {
494 name: "gps.location".to_string(),
495 description: "Get GPS location".to_string(),
496 parameters: serde_json::json!({"type": "object", "properties": {}}),
497 },
498 ],
499 invoke_tx: tx1,
500 });
501
502 registry.register(NodeInfo {
503 node_id: "sensor-1".to_string(),
504 capabilities: vec![NodeCapability {
505 name: "temp.read".to_string(),
506 description: "Read temperature".to_string(),
507 parameters: serde_json::json!({"type": "object", "properties": {}}),
508 }],
509 invoke_tx: tx2,
510 });
511
512 let caps = registry.all_capabilities();
513 assert_eq!(caps.len(), 3);
514 }
515
516 #[test]
517 fn node_registry_is_empty() {
518 let registry = NodeRegistry::new(10);
519 assert!(registry.is_empty());
520
521 let (tx, _rx) = mpsc::channel(1);
522 registry.register(NodeInfo {
523 node_id: "n".to_string(),
524 capabilities: vec![],
525 invoke_tx: tx,
526 });
527 assert!(!registry.is_empty());
528 }
529
530 #[test]
531 fn node_capability_deserialize() {
532 let json = r#"{"name":"camera.snap","description":"Take a photo"}"#;
533 let cap: NodeCapability = serde_json::from_str(json).unwrap();
534 assert_eq!(cap.name, "camera.snap");
535 assert_eq!(cap.description, "Take a photo");
536 assert_eq!(cap.parameters["type"], "object");
538 }
539
540 #[test]
541 fn node_message_register_deserialize() {
542 let json = r#"{"type":"register","node_id":"phone-1","capabilities":[{"name":"camera.snap","description":"Take a photo","parameters":{"type":"object","properties":{"resolution":{"type":"string"}}}}]}"#;
543 let msg: NodeMessage = serde_json::from_str(json).unwrap();
544 match msg {
545 NodeMessage::Register {
546 node_id,
547 capabilities,
548 } => {
549 assert_eq!(node_id, "phone-1");
550 assert_eq!(capabilities.len(), 1);
551 assert_eq!(capabilities[0].name, "camera.snap");
552 }
553 NodeMessage::Result { .. } => panic!("Expected Register message"),
554 }
555 }
556
557 #[test]
558 fn node_message_result_deserialize() {
559 let json = r#"{"type":"result","call_id":"abc-123","success":true,"output":"photo taken"}"#;
560 let msg: NodeMessage = serde_json::from_str(json).unwrap();
561 match msg {
562 NodeMessage::Result {
563 call_id,
564 success,
565 output,
566 error,
567 } => {
568 assert_eq!(call_id, "abc-123");
569 assert!(success);
570 assert_eq!(output, "photo taken");
571 assert!(error.is_none());
572 }
573 NodeMessage::Register { .. } => panic!("Expected Result message"),
574 }
575 }
576
577 #[test]
578 fn gateway_message_serialize() {
579 let msg = GatewayMessage::Registered {
580 node_id: "phone-1".to_string(),
581 capabilities_count: 3,
582 };
583 let json = serde_json::to_string(&msg).unwrap();
584 assert!(json.contains("\"type\":\"registered\""));
585 assert!(json.contains("\"node_id\":\"phone-1\""));
586 assert!(json.contains("\"capabilities_count\":3"));
587 }
588
589 #[test]
590 fn gateway_invoke_message_serialize() {
591 let msg = GatewayMessage::Invoke {
592 call_id: "call-1".to_string(),
593 capability: "camera.snap".to_string(),
594 args: serde_json::json!({"resolution": "1080p"}),
595 };
596 let json = serde_json::to_string(&msg).unwrap();
597 assert!(json.contains("\"type\":\"invoke\""));
598 assert!(json.contains("\"capability\":\"camera.snap\""));
599 }
600
601 #[test]
602 fn extract_node_ws_token_from_header() {
603 let mut headers = HeaderMap::new();
604 headers.insert("authorization", "Bearer node_tok_123".parse().unwrap());
605 assert_eq!(extract_node_ws_token(&headers, None), Some("node_tok_123"));
606 }
607
608 #[test]
609 fn extract_node_ws_token_from_query() {
610 let headers = HeaderMap::new();
611 assert_eq!(
612 extract_node_ws_token(&headers, Some("node_tok_456")),
613 Some("node_tok_456")
614 );
615 }
616
617 #[test]
618 fn extract_node_ws_token_none_when_empty() {
619 let headers = HeaderMap::new();
620 assert_eq!(extract_node_ws_token(&headers, None), None);
621 }
622}