Skip to main content

construct/gateway/
nodes.rs

1//! WebSocket endpoint for dynamic node discovery and capability advertisement.
2//!
3//! External processes/devices connect to `/ws/nodes` and advertise their
4//! capabilities at runtime. The gateway exposes these as dynamically available
5//! tools to the agent.
6//!
7//! ## Protocol
8//!
9//! ```text
10//! Node -> Gateway: {"type":"register","node_id":"phone-1","capabilities":[{"name":"camera.snap","description":"Take a photo","parameters":{...}}]}
11//! Gateway -> Node: {"type":"registered","node_id":"phone-1","capabilities_count":1}
12//! Gateway -> Node: {"type":"invoke","call_id":"uuid","capability":"camera.snap","args":{...}}
13//! Node -> Gateway: {"type":"result","call_id":"uuid","success":true,"output":"..."}
14//! ```
15
16use super::AppState;
17use axum::{
18    extract::{
19        Query, State, WebSocketUpgrade,
20        ws::{Message, WebSocket},
21    },
22    http::{HeaderMap, header},
23    response::IntoResponse,
24};
25use futures_util::{SinkExt, StreamExt};
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use std::sync::Arc;
30use tokio::sync::{mpsc, oneshot};
31
32/// Prefix used in `Sec-WebSocket-Protocol` to carry a bearer token.
33const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
34
35/// The sub-protocol we support for node connections.
36const WS_NODE_PROTOCOL: &str = "construct.nodes.v1";
37
38/// A single capability advertised by a node.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct NodeCapability {
41    pub name: String,
42    pub description: String,
43    #[serde(default = "default_capability_parameters")]
44    pub parameters: serde_json::Value,
45}
46
47fn default_capability_parameters() -> serde_json::Value {
48    serde_json::json!({
49        "type": "object",
50        "properties": {}
51    })
52}
53
54/// Tracks a connected node and its capabilities.
55#[derive(Debug, Clone)]
56pub struct NodeInfo {
57    pub node_id: String,
58    pub capabilities: Vec<NodeCapability>,
59    /// Channel to send invocation requests to the node's WebSocket handler.
60    pub invoke_tx: mpsc::Sender<NodeInvocation>,
61}
62
63/// An invocation request sent to a node.
64#[derive(Debug)]
65pub struct NodeInvocation {
66    pub call_id: String,
67    pub capability: String,
68    pub args: serde_json::Value,
69    pub response_tx: oneshot::Sender<NodeInvocationResult>,
70}
71
72/// The result of a node invocation.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct NodeInvocationResult {
75    pub success: bool,
76    pub output: String,
77    pub error: Option<String>,
78}
79
80/// Registry of all connected nodes and their capabilities.
81#[derive(Debug, Default, Clone)]
82pub struct NodeRegistry {
83    nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
84    max_nodes: usize,
85}
86
87impl NodeRegistry {
88    /// Create a new registry with the given capacity limit.
89    pub fn new(max_nodes: usize) -> Self {
90        Self {
91            nodes: Arc::new(RwLock::new(HashMap::new())),
92            max_nodes,
93        }
94    }
95
96    /// Register a node with its capabilities. Returns false if at capacity.
97    pub fn register(&self, info: NodeInfo) -> bool {
98        let mut nodes = self.nodes.write();
99        if nodes.len() >= self.max_nodes && !nodes.contains_key(&info.node_id) {
100            return false;
101        }
102        nodes.insert(info.node_id.clone(), info);
103        true
104    }
105
106    /// Remove a node from the registry.
107    pub fn unregister(&self, node_id: &str) {
108        self.nodes.write().remove(node_id);
109    }
110
111    /// List all registered node IDs.
112    pub fn node_ids(&self) -> Vec<String> {
113        self.nodes.read().keys().cloned().collect()
114    }
115
116    /// Get all capabilities across all nodes, keyed by prefixed tool name.
117    pub fn all_capabilities(&self) -> Vec<(String, String, NodeCapability)> {
118        let nodes = self.nodes.read();
119        let mut caps = Vec::new();
120        for info in nodes.values() {
121            for cap in &info.capabilities {
122                caps.push((info.node_id.clone(), cap.name.clone(), cap.clone()));
123            }
124        }
125        caps
126    }
127
128    /// Get the invocation sender for a specific node.
129    pub fn invoke_tx(&self, node_id: &str) -> Option<mpsc::Sender<NodeInvocation>> {
130        self.nodes.read().get(node_id).map(|n| n.invoke_tx.clone())
131    }
132
133    /// Check if a node is registered.
134    pub fn contains(&self, node_id: &str) -> bool {
135        self.nodes.read().contains_key(node_id)
136    }
137
138    /// Number of registered nodes.
139    pub fn len(&self) -> usize {
140        self.nodes.read().len()
141    }
142
143    /// Whether the registry is empty.
144    pub fn is_empty(&self) -> bool {
145        self.nodes.read().is_empty()
146    }
147}
148
149/// Messages received from a node.
150#[derive(Debug, Deserialize)]
151#[serde(tag = "type", rename_all = "snake_case")]
152enum NodeMessage {
153    Register {
154        node_id: String,
155        capabilities: Vec<NodeCapability>,
156    },
157    Result {
158        call_id: String,
159        success: bool,
160        output: String,
161        #[serde(default)]
162        error: Option<String>,
163    },
164}
165
166/// Messages sent to a node.
167#[derive(Debug, Serialize)]
168#[serde(tag = "type", rename_all = "snake_case")]
169enum GatewayMessage {
170    Registered {
171        node_id: String,
172        capabilities_count: usize,
173    },
174    Error {
175        message: String,
176    },
177    Invoke {
178        call_id: String,
179        capability: String,
180        args: serde_json::Value,
181    },
182}
183
184/// Query parameters for the `/ws/nodes` endpoint.
185#[derive(Deserialize)]
186pub struct NodeWsQuery {
187    pub token: Option<String>,
188}
189
190/// Extract a bearer token from WebSocket-compatible sources.
191fn extract_node_ws_token<'a>(
192    headers: &'a HeaderMap,
193    query_token: Option<&'a str>,
194) -> Option<&'a str> {
195    // 1. Authorization header
196    if let Some(t) = headers
197        .get(header::AUTHORIZATION)
198        .and_then(|v| v.to_str().ok())
199        .and_then(|auth| auth.strip_prefix("Bearer "))
200    {
201        if !t.is_empty() {
202            return Some(t);
203        }
204    }
205
206    // 2. Sec-WebSocket-Protocol: bearer.<token>
207    if let Some(t) = headers
208        .get("sec-websocket-protocol")
209        .and_then(|v| v.to_str().ok())
210        .and_then(|protos| {
211            protos
212                .split(',')
213                .map(|p| p.trim())
214                .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
215        })
216    {
217        if !t.is_empty() {
218            return Some(t);
219        }
220    }
221
222    // 3. ?token= query parameter
223    if let Some(t) = query_token {
224        if !t.is_empty() {
225            return Some(t);
226        }
227    }
228
229    None
230}
231
232/// GET /ws/nodes — WebSocket upgrade for node connections
233pub async fn handle_ws_nodes(
234    State(state): State<AppState>,
235    Query(params): Query<NodeWsQuery>,
236    headers: HeaderMap,
237    ws: WebSocketUpgrade,
238) -> impl IntoResponse {
239    // Auth: check node auth token if configured
240    let nodes_config = state.config.lock().nodes.clone();
241    if let Some(ref expected_token) = nodes_config.auth_token {
242        let token = extract_node_ws_token(&headers, params.token.as_deref()).unwrap_or("");
243        if token != expected_token {
244            return (
245                axum::http::StatusCode::UNAUTHORIZED,
246                "Unauthorized — provide a valid node auth token",
247            )
248                .into_response();
249        }
250    }
251
252    // Fall back to pairing auth if no node-specific token
253    if nodes_config.auth_token.is_none() && state.pairing.require_pairing() {
254        let token = extract_node_ws_token(&headers, params.token.as_deref()).unwrap_or("");
255        if !state.pairing.is_authenticated(token) {
256            return (
257                axum::http::StatusCode::UNAUTHORIZED,
258                "Unauthorized — provide Authorization header or ?token= query param",
259            )
260                .into_response();
261        }
262    }
263
264    // Echo sub-protocol if client requests it
265    let ws = if headers
266        .get("sec-websocket-protocol")
267        .and_then(|v| v.to_str().ok())
268        .map_or(false, |protos| {
269            protos.split(',').any(|p| p.trim() == WS_NODE_PROTOCOL)
270        }) {
271        ws.protocols([WS_NODE_PROTOCOL])
272    } else {
273        ws
274    };
275
276    let registry = state.node_registry.clone();
277    ws.on_upgrade(move |socket| handle_node_socket(socket, registry))
278        .into_response()
279}
280
281async fn handle_node_socket(socket: WebSocket, registry: Arc<NodeRegistry>) {
282    let (mut sender, mut receiver) = socket.split();
283    let mut registered_node_id: Option<String> = None;
284
285    // Channel for forwarding invocations to this node
286    let (invoke_tx, mut invoke_rx) = mpsc::channel::<NodeInvocation>(32);
287
288    // Pending invocation responses keyed by call_id
289    let pending: Arc<RwLock<HashMap<String, oneshot::Sender<NodeInvocationResult>>>> =
290        Arc::new(RwLock::new(HashMap::new()));
291
292    let pending_clone = Arc::clone(&pending);
293
294    // Task to forward invocations to the node via WebSocket
295    let send_task = tokio::spawn(async move {
296        while let Some(invocation) = invoke_rx.recv().await {
297            let msg = GatewayMessage::Invoke {
298                call_id: invocation.call_id.clone(),
299                capability: invocation.capability,
300                args: invocation.args,
301            };
302            if let Ok(json) = serde_json::to_string(&msg) {
303                if sender.send(Message::Text(json.into())).await.is_err() {
304                    break;
305                }
306                pending_clone
307                    .write()
308                    .insert(invocation.call_id, invocation.response_tx);
309            }
310        }
311    });
312
313    // Process incoming messages from node
314    while let Some(msg) = receiver.next().await {
315        let text = match msg {
316            Ok(Message::Text(text)) => text,
317            Ok(Message::Close(_)) | Err(_) => break,
318            _ => continue,
319        };
320
321        let parsed: serde_json::Value = match serde_json::from_str(&text) {
322            Ok(v) => v,
323            Err(_) => continue,
324        };
325
326        // Try to parse as NodeMessage
327        let node_msg: NodeMessage = match serde_json::from_value(parsed) {
328            Ok(m) => m,
329            Err(_) => continue,
330        };
331
332        match node_msg {
333            NodeMessage::Register {
334                node_id,
335                capabilities,
336            } => {
337                // Validate node_id
338                if node_id.is_empty() || node_id.len() > 128 {
339                    tracing::warn!("Node registration rejected: invalid node_id length");
340                    continue;
341                }
342
343                let caps_count = capabilities.len();
344                let info = NodeInfo {
345                    node_id: node_id.clone(),
346                    capabilities,
347                    invoke_tx: invoke_tx.clone(),
348                };
349
350                if registry.register(info) {
351                    tracing::info!("Node registered: {node_id} with {caps_count} capabilities");
352                    registered_node_id = Some(node_id.clone());
353
354                    // Send ack — we can't use `sender` here since it's moved
355                    // into the send task. Instead, send ack via the invoke channel
356                    // pattern isn't ideal. We'll use a workaround: send the ack
357                    // through a special invocation that the send task converts to
358                    // a registered message. For simplicity, we just log and the
359                    // ack is implicit in the protocol.
360                } else {
361                    tracing::warn!(
362                        "Node registration rejected: registry at capacity for {node_id}"
363                    );
364                }
365            }
366            NodeMessage::Result {
367                call_id,
368                success,
369                output,
370                error,
371            } => {
372                if let Some(tx) = pending.write().remove(&call_id) {
373                    let _ = tx.send(NodeInvocationResult {
374                        success,
375                        output,
376                        error,
377                    });
378                }
379            }
380        }
381    }
382
383    // Cleanup: unregister node on disconnect
384    if let Some(node_id) = registered_node_id {
385        registry.unregister(&node_id);
386        tracing::info!("Node disconnected and unregistered: {node_id}");
387    }
388
389    send_task.abort();
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn node_registry_register_and_unregister() {
398        let registry = NodeRegistry::new(10);
399        let (tx, _rx) = mpsc::channel(1);
400
401        let info = NodeInfo {
402            node_id: "test-node".to_string(),
403            capabilities: vec![NodeCapability {
404                name: "ping".to_string(),
405                description: "Ping test".to_string(),
406                parameters: serde_json::json!({"type": "object", "properties": {}}),
407            }],
408            invoke_tx: tx,
409        };
410
411        assert!(registry.register(info));
412        assert!(registry.contains("test-node"));
413        assert_eq!(registry.len(), 1);
414
415        registry.unregister("test-node");
416        assert!(!registry.contains("test-node"));
417        assert_eq!(registry.len(), 0);
418    }
419
420    #[test]
421    fn node_registry_capacity_limit() {
422        let registry = NodeRegistry::new(2);
423
424        for i in 0..2 {
425            let (tx, _rx) = mpsc::channel(1);
426            let info = NodeInfo {
427                node_id: format!("node-{i}"),
428                capabilities: vec![],
429                invoke_tx: tx,
430            };
431            assert!(registry.register(info));
432        }
433
434        let (tx, _rx) = mpsc::channel(1);
435        let info = NodeInfo {
436            node_id: "node-overflow".to_string(),
437            capabilities: vec![],
438            invoke_tx: tx,
439        };
440        assert!(!registry.register(info));
441        assert_eq!(registry.len(), 2);
442    }
443
444    #[test]
445    fn node_registry_re_register_same_id() {
446        let registry = NodeRegistry::new(2);
447        let (tx1, _rx1) = mpsc::channel(1);
448        let (tx2, _rx2) = mpsc::channel(1);
449
450        let info1 = NodeInfo {
451            node_id: "node-1".to_string(),
452            capabilities: vec![NodeCapability {
453                name: "old".to_string(),
454                description: "Old cap".to_string(),
455                parameters: serde_json::json!({"type": "object", "properties": {}}),
456            }],
457            invoke_tx: tx1,
458        };
459        assert!(registry.register(info1));
460
461        let info2 = NodeInfo {
462            node_id: "node-1".to_string(),
463            capabilities: vec![NodeCapability {
464                name: "new".to_string(),
465                description: "New cap".to_string(),
466                parameters: serde_json::json!({"type": "object", "properties": {}}),
467            }],
468            invoke_tx: tx2,
469        };
470        // Re-registering same node_id should succeed (update)
471        assert!(registry.register(info2));
472        assert_eq!(registry.len(), 1);
473
474        let caps = registry.all_capabilities();
475        assert_eq!(caps.len(), 1);
476        assert_eq!(caps[0].2.name, "new");
477    }
478
479    #[test]
480    fn node_registry_all_capabilities() {
481        let registry = NodeRegistry::new(10);
482        let (tx1, _rx1) = mpsc::channel(1);
483        let (tx2, _rx2) = mpsc::channel(1);
484
485        registry.register(NodeInfo {
486            node_id: "phone-1".to_string(),
487            capabilities: vec![
488                NodeCapability {
489                    name: "camera.snap".to_string(),
490                    description: "Take a photo".to_string(),
491                    parameters: serde_json::json!({"type": "object", "properties": {}}),
492                },
493                NodeCapability {
494                    name: "gps.location".to_string(),
495                    description: "Get GPS location".to_string(),
496                    parameters: serde_json::json!({"type": "object", "properties": {}}),
497                },
498            ],
499            invoke_tx: tx1,
500        });
501
502        registry.register(NodeInfo {
503            node_id: "sensor-1".to_string(),
504            capabilities: vec![NodeCapability {
505                name: "temp.read".to_string(),
506                description: "Read temperature".to_string(),
507                parameters: serde_json::json!({"type": "object", "properties": {}}),
508            }],
509            invoke_tx: tx2,
510        });
511
512        let caps = registry.all_capabilities();
513        assert_eq!(caps.len(), 3);
514    }
515
516    #[test]
517    fn node_registry_is_empty() {
518        let registry = NodeRegistry::new(10);
519        assert!(registry.is_empty());
520
521        let (tx, _rx) = mpsc::channel(1);
522        registry.register(NodeInfo {
523            node_id: "n".to_string(),
524            capabilities: vec![],
525            invoke_tx: tx,
526        });
527        assert!(!registry.is_empty());
528    }
529
530    #[test]
531    fn node_capability_deserialize() {
532        let json = r#"{"name":"camera.snap","description":"Take a photo"}"#;
533        let cap: NodeCapability = serde_json::from_str(json).unwrap();
534        assert_eq!(cap.name, "camera.snap");
535        assert_eq!(cap.description, "Take a photo");
536        // Default parameters
537        assert_eq!(cap.parameters["type"], "object");
538    }
539
540    #[test]
541    fn node_message_register_deserialize() {
542        let json = r#"{"type":"register","node_id":"phone-1","capabilities":[{"name":"camera.snap","description":"Take a photo","parameters":{"type":"object","properties":{"resolution":{"type":"string"}}}}]}"#;
543        let msg: NodeMessage = serde_json::from_str(json).unwrap();
544        match msg {
545            NodeMessage::Register {
546                node_id,
547                capabilities,
548            } => {
549                assert_eq!(node_id, "phone-1");
550                assert_eq!(capabilities.len(), 1);
551                assert_eq!(capabilities[0].name, "camera.snap");
552            }
553            NodeMessage::Result { .. } => panic!("Expected Register message"),
554        }
555    }
556
557    #[test]
558    fn node_message_result_deserialize() {
559        let json = r#"{"type":"result","call_id":"abc-123","success":true,"output":"photo taken"}"#;
560        let msg: NodeMessage = serde_json::from_str(json).unwrap();
561        match msg {
562            NodeMessage::Result {
563                call_id,
564                success,
565                output,
566                error,
567            } => {
568                assert_eq!(call_id, "abc-123");
569                assert!(success);
570                assert_eq!(output, "photo taken");
571                assert!(error.is_none());
572            }
573            NodeMessage::Register { .. } => panic!("Expected Result message"),
574        }
575    }
576
577    #[test]
578    fn gateway_message_serialize() {
579        let msg = GatewayMessage::Registered {
580            node_id: "phone-1".to_string(),
581            capabilities_count: 3,
582        };
583        let json = serde_json::to_string(&msg).unwrap();
584        assert!(json.contains("\"type\":\"registered\""));
585        assert!(json.contains("\"node_id\":\"phone-1\""));
586        assert!(json.contains("\"capabilities_count\":3"));
587    }
588
589    #[test]
590    fn gateway_invoke_message_serialize() {
591        let msg = GatewayMessage::Invoke {
592            call_id: "call-1".to_string(),
593            capability: "camera.snap".to_string(),
594            args: serde_json::json!({"resolution": "1080p"}),
595        };
596        let json = serde_json::to_string(&msg).unwrap();
597        assert!(json.contains("\"type\":\"invoke\""));
598        assert!(json.contains("\"capability\":\"camera.snap\""));
599    }
600
601    #[test]
602    fn extract_node_ws_token_from_header() {
603        let mut headers = HeaderMap::new();
604        headers.insert("authorization", "Bearer node_tok_123".parse().unwrap());
605        assert_eq!(extract_node_ws_token(&headers, None), Some("node_tok_123"));
606    }
607
608    #[test]
609    fn extract_node_ws_token_from_query() {
610        let headers = HeaderMap::new();
611        assert_eq!(
612            extract_node_ws_token(&headers, Some("node_tok_456")),
613            Some("node_tok_456")
614        );
615    }
616
617    #[test]
618    fn extract_node_ws_token_none_when_empty() {
619        let headers = HeaderMap::new();
620        assert_eq!(extract_node_ws_token(&headers, None), None);
621    }
622}