Skip to main content

barbacane_lib/
control_plane.rs

1//! Control plane client for connected mode.
2//!
3//! This module handles WebSocket communication with the control plane,
4//! including registration, heartbeat, and artifact notifications.
5
6use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tokio::sync::{mpsc, watch};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use uuid::Uuid;
12
13/// Messages sent from data plane to control plane.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type", rename_all = "snake_case")]
16pub enum DataPlaneMessage {
17    /// Initial registration with authentication.
18    Register {
19        project_id: Uuid,
20        api_key: String,
21        #[serde(skip_serializing_if = "Option::is_none")]
22        name: Option<String>,
23        #[serde(skip_serializing_if = "Option::is_none")]
24        artifact_id: Option<Uuid>,
25        #[serde(default)]
26        metadata: serde_json::Value,
27    },
28    /// Periodic heartbeat.
29    Heartbeat {
30        #[serde(skip_serializing_if = "Option::is_none")]
31        artifact_id: Option<Uuid>,
32        uptime_secs: u64,
33        requests_total: u64,
34    },
35    /// Acknowledgment of artifact download.
36    ArtifactDownloaded {
37        artifact_id: Uuid,
38        success: bool,
39        #[serde(skip_serializing_if = "Option::is_none")]
40        error: Option<String>,
41    },
42}
43
44/// Messages sent from control plane to data plane.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(tag = "type", rename_all = "snake_case")]
47pub enum ControlPlaneMessage {
48    /// Registration successful.
49    Registered {
50        data_plane_id: Uuid,
51        heartbeat_interval_secs: u32,
52    },
53    /// Registration failed.
54    RegistrationFailed { reason: String },
55    /// New artifact available for download.
56    ArtifactAvailable {
57        artifact_id: Uuid,
58        download_url: String,
59        sha256: String,
60    },
61    /// Heartbeat acknowledgment.
62    HeartbeatAck,
63    /// Request disconnect.
64    Disconnect { reason: String },
65    /// Error message.
66    Error { message: String },
67}
68
69/// Configuration for the control plane client.
70#[derive(Clone)]
71pub struct ControlPlaneConfig {
72    pub control_plane_url: String,
73    pub project_id: Uuid,
74    pub api_key: String,
75    pub data_plane_name: Option<String>,
76    pub initial_artifact_id: Option<Uuid>,
77}
78
79/// Notification that a new artifact is available.
80#[derive(Debug, Clone)]
81pub struct ArtifactNotification {
82    pub artifact_id: Uuid,
83    pub download_url: String,
84    pub sha256: String,
85}
86
87/// Response to send back to the control plane after downloading an artifact.
88#[derive(Debug, Clone)]
89pub struct ArtifactDownloadedResponse {
90    pub artifact_id: Uuid,
91    pub success: bool,
92    pub error: Option<String>,
93}
94
95/// Result of a connection attempt.
96enum ConnectOutcome {
97    /// Clean shutdown via signal — exit loop.
98    Shutdown,
99    /// Successfully registered but connection was later lost — reset backoff.
100    ConnectionLost(String),
101    /// Failed before completing registration — increase backoff.
102    ConnectionFailed(String),
103}
104
105/// Control plane client that maintains connection and handles messages.
106pub struct ControlPlaneClient {
107    config: ControlPlaneConfig,
108}
109
110impl ControlPlaneClient {
111    /// Create a new control plane client.
112    pub fn new(config: ControlPlaneConfig) -> Self {
113        Self { config }
114    }
115
116    /// Start the connection loop in a background task.
117    /// Returns a receiver for artifact notifications and a sender for download responses.
118    pub fn start(
119        self,
120        shutdown_rx: watch::Receiver<bool>,
121    ) -> (
122        mpsc::Receiver<ArtifactNotification>,
123        mpsc::Sender<ArtifactDownloadedResponse>,
124    ) {
125        let (artifact_tx, artifact_rx) = mpsc::channel::<ArtifactNotification>(16);
126        let (response_tx, response_rx) = mpsc::channel::<ArtifactDownloadedResponse>(16);
127
128        tokio::spawn(async move {
129            self.connection_loop(shutdown_rx, artifact_tx, response_rx)
130                .await;
131        });
132
133        (artifact_rx, response_tx)
134    }
135
136    /// Main connection loop with reconnection logic.
137    async fn connection_loop(
138        &self,
139        mut shutdown_rx: watch::Receiver<bool>,
140        artifact_tx: mpsc::Sender<ArtifactNotification>,
141        mut response_rx: mpsc::Receiver<ArtifactDownloadedResponse>,
142    ) {
143        const INITIAL_BACKOFF_MS: u64 = 1000;
144        const MAX_BACKOFF_MS: u64 = 60000;
145        const BACKOFF_MULTIPLIER: f64 = 2.0;
146
147        let mut backoff_ms = INITIAL_BACKOFF_MS;
148
149        loop {
150            // Check for shutdown
151            if *shutdown_rx.borrow() {
152                tracing::info!("Control plane client shutting down");
153                return;
154            }
155
156            tracing::info!(url = %self.config.control_plane_url, "Connecting to control plane");
157
158            match self
159                .try_connect(&mut shutdown_rx, &artifact_tx, &mut response_rx)
160                .await
161            {
162                ConnectOutcome::Shutdown => {
163                    return;
164                }
165                ConnectOutcome::ConnectionLost(e) => {
166                    // Was registered, connection dropped — reset backoff for fast reconnect
167                    tracing::warn!(
168                        error = %e,
169                        "Control plane connection lost, reconnecting immediately"
170                    );
171                    backoff_ms = INITIAL_BACKOFF_MS;
172                }
173                ConnectOutcome::ConnectionFailed(e) => {
174                    tracing::warn!(
175                        error = %e,
176                        backoff_ms = backoff_ms,
177                        "Control plane connection failed, will retry"
178                    );
179                }
180            }
181
182            // Wait before reconnecting (or abort if shutdown)
183            tokio::select! {
184                _ = shutdown_rx.changed() => {
185                    if *shutdown_rx.borrow() {
186                        return;
187                    }
188                }
189                _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
190            }
191
192            // Increase backoff for next attempt, capped at MAX_BACKOFF_MS
193            backoff_ms =
194                ((backoff_ms as f64) * BACKOFF_MULTIPLIER).min(MAX_BACKOFF_MS as f64) as u64;
195        }
196    }
197
198    /// Attempt to connect and handle messages.
199    async fn try_connect(
200        &self,
201        shutdown_rx: &mut watch::Receiver<bool>,
202        artifact_tx: &mpsc::Sender<ArtifactNotification>,
203        response_rx: &mut mpsc::Receiver<ArtifactDownloadedResponse>,
204    ) -> ConnectOutcome {
205        // Connect to WebSocket
206        let (ws_stream, _response) = match connect_async(&self.config.control_plane_url).await {
207            Ok(conn) => conn,
208            Err(e) => {
209                return ConnectOutcome::ConnectionFailed(format!(
210                    "WebSocket connection failed: {}",
211                    e
212                ))
213            }
214        };
215
216        let (mut sender, mut receiver) = ws_stream.split();
217
218        // Send registration message
219        let register_msg = DataPlaneMessage::Register {
220            project_id: self.config.project_id,
221            api_key: self.config.api_key.clone(),
222            name: self.config.data_plane_name.clone(),
223            artifact_id: self.config.initial_artifact_id,
224            metadata: serde_json::json!({}),
225        };
226
227        let register_json = match serde_json::to_string(&register_msg) {
228            Ok(j) => j,
229            Err(e) => {
230                return ConnectOutcome::ConnectionFailed(format!(
231                    "Failed to serialize register message: {}",
232                    e
233                ))
234            }
235        };
236
237        if let Err(e) = sender.send(Message::Text(register_json.into())).await {
238            return ConnectOutcome::ConnectionFailed(format!(
239                "Failed to send register message: {}",
240                e
241            ));
242        }
243
244        // Wait for registration response
245        let registration_response =
246            match tokio::time::timeout(Duration::from_secs(30), receiver.next()).await {
247                Ok(Some(Ok(msg))) => msg,
248                Ok(Some(Err(e))) => {
249                    return ConnectOutcome::ConnectionFailed(format!("WebSocket error: {}", e))
250                }
251                Ok(None) => {
252                    return ConnectOutcome::ConnectionFailed(
253                        "Connection closed before registration".to_string(),
254                    )
255                }
256                Err(_) => {
257                    return ConnectOutcome::ConnectionFailed("Registration timeout".to_string())
258                }
259            };
260
261        let heartbeat_interval_secs = match registration_response {
262            Message::Text(text) => {
263                let msg: ControlPlaneMessage = match serde_json::from_str(&text) {
264                    Ok(m) => m,
265                    Err(e) => {
266                        return ConnectOutcome::ConnectionFailed(format!(
267                            "Failed to parse registration response: {}",
268                            e
269                        ))
270                    }
271                };
272
273                match msg {
274                    ControlPlaneMessage::Registered {
275                        data_plane_id,
276                        heartbeat_interval_secs,
277                    } => {
278                        tracing::info!(
279                            data_plane_id = %data_plane_id,
280                            heartbeat_interval_secs,
281                            "Registered with control plane"
282                        );
283                        heartbeat_interval_secs
284                    }
285                    ControlPlaneMessage::RegistrationFailed { reason } => {
286                        return ConnectOutcome::ConnectionFailed(format!(
287                            "Registration failed: {}",
288                            reason
289                        ));
290                    }
291                    other => {
292                        return ConnectOutcome::ConnectionFailed(format!(
293                            "Unexpected registration response: {:?}",
294                            other
295                        ));
296                    }
297                }
298            }
299            other => {
300                return ConnectOutcome::ConnectionFailed(format!(
301                    "Unexpected message type: {:?}",
302                    other
303                ));
304            }
305        };
306
307        // Start heartbeat timer
308        let mut heartbeat_interval =
309            tokio::time::interval(Duration::from_secs(heartbeat_interval_secs as u64));
310        let start_time = std::time::Instant::now();
311
312        // Main message loop — we are now registered, so any disconnect
313        // should trigger reconnection (ConnectionLost), not give up.
314        loop {
315            tokio::select! {
316                // Shutdown signal
317                _ = shutdown_rx.changed() => {
318                    if *shutdown_rx.borrow() {
319                        tracing::info!("Disconnecting from control plane");
320                        let _ = sender.close().await;
321                        return ConnectOutcome::Shutdown;
322                    }
323                }
324
325                // Heartbeat timer
326                _ = heartbeat_interval.tick() => {
327                    let heartbeat = DataPlaneMessage::Heartbeat {
328                        artifact_id: None, // TODO: pass current artifact ID
329                        uptime_secs: start_time.elapsed().as_secs(),
330                        requests_total: 0, // TODO: pass actual metrics
331                    };
332
333                    let json = match serde_json::to_string(&heartbeat) {
334                        Ok(j) => j,
335                        Err(e) => {
336                            tracing::error!(error = %e, "Failed to serialize heartbeat");
337                            continue;
338                        }
339                    };
340
341                    if let Err(e) = sender.send(Message::Text(json.into())).await {
342                        return ConnectOutcome::ConnectionLost(format!(
343                            "Failed to send heartbeat: {}", e
344                        ));
345                    }
346
347                    tracing::debug!("Heartbeat sent");
348                }
349
350                // Artifact download response from main loop
351                Some(response) = response_rx.recv() => {
352                    let msg = DataPlaneMessage::ArtifactDownloaded {
353                        artifact_id: response.artifact_id,
354                        success: response.success,
355                        error: response.error,
356                    };
357
358                    let json = match serde_json::to_string(&msg) {
359                        Ok(j) => j,
360                        Err(e) => {
361                            tracing::error!(error = %e, "Failed to serialize artifact downloaded");
362                            continue;
363                        }
364                    };
365
366                    if let Err(e) = sender.send(Message::Text(json.into())).await {
367                        tracing::warn!(error = %e, "Failed to send artifact downloaded response");
368                    } else {
369                        tracing::info!(
370                            artifact_id = %response.artifact_id,
371                            success = response.success,
372                            "Sent artifact downloaded response to control plane"
373                        );
374                    }
375                }
376
377                // Incoming messages
378                result = receiver.next() => {
379                    match result {
380                        Some(Ok(Message::Text(text))) => {
381                            match serde_json::from_str::<ControlPlaneMessage>(&text) {
382                                Ok(msg) => {
383                                    if let Err(e) = self.handle_message(msg, artifact_tx, &mut sender).await {
384                                        tracing::warn!(error = %e, "Error handling control plane message");
385                                    }
386                                }
387                                Err(e) => {
388                                    tracing::warn!(error = %e, "Failed to parse control plane message");
389                                }
390                            }
391                        }
392                        Some(Ok(Message::Ping(data))) => {
393                            let _ = sender.send(Message::Pong(data)).await;
394                        }
395                        Some(Ok(Message::Close(_))) | None => {
396                            return ConnectOutcome::ConnectionLost(
397                                "Connection closed by control plane".to_string()
398                            );
399                        }
400                        Some(Err(e)) => {
401                            return ConnectOutcome::ConnectionLost(format!(
402                                "WebSocket error: {}", e
403                            ));
404                        }
405                        _ => {}
406                    }
407                }
408            }
409        }
410    }
411
412    /// Handle a message from the control plane.
413    async fn handle_message(
414        &self,
415        msg: ControlPlaneMessage,
416        artifact_tx: &mpsc::Sender<ArtifactNotification>,
417        _sender: &mut futures_util::stream::SplitSink<
418            tokio_tungstenite::WebSocketStream<
419                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
420            >,
421            Message,
422        >,
423    ) -> Result<(), String> {
424        match msg {
425            ControlPlaneMessage::HeartbeatAck => {
426                tracing::debug!("Heartbeat acknowledged");
427            }
428            ControlPlaneMessage::ArtifactAvailable {
429                artifact_id,
430                download_url,
431                sha256,
432            } => {
433                tracing::info!(
434                    artifact_id = %artifact_id,
435                    download_url = %download_url,
436                    "New artifact available"
437                );
438
439                // Notify the main loop about the new artifact
440                if let Err(e) = artifact_tx
441                    .send(ArtifactNotification {
442                        artifact_id,
443                        download_url,
444                        sha256,
445                    })
446                    .await
447                {
448                    tracing::warn!(error = %e, "Failed to send artifact notification");
449                }
450            }
451            ControlPlaneMessage::Disconnect { reason } => {
452                tracing::info!(reason = %reason, "Disconnecting at control plane request");
453                return Err(format!("Disconnected by control plane: {}", reason));
454            }
455            ControlPlaneMessage::Error { message } => {
456                tracing::warn!(message = %message, "Error from control plane");
457            }
458            // These shouldn't happen after registration
459            ControlPlaneMessage::Registered { .. }
460            | ControlPlaneMessage::RegistrationFailed { .. } => {
461                tracing::warn!("Unexpected registration message after already registered");
462            }
463        }
464
465        Ok(())
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_data_plane_message_register_serialization() {
475        let msg = DataPlaneMessage::Register {
476            project_id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
477            api_key: "test-key".to_string(),
478            name: Some("my-data-plane".to_string()),
479            artifact_id: None,
480            metadata: serde_json::json!({"version": "1.0"}),
481        };
482
483        let json = serde_json::to_string(&msg).unwrap();
484        assert!(json.contains("\"type\":\"register\""));
485        assert!(json.contains("\"project_id\":"));
486        assert!(json.contains("\"api_key\":\"test-key\""));
487        assert!(json.contains("\"name\":\"my-data-plane\""));
488    }
489
490    #[test]
491    fn test_data_plane_message_heartbeat_serialization() {
492        let msg = DataPlaneMessage::Heartbeat {
493            artifact_id: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()),
494            uptime_secs: 3600,
495            requests_total: 1000,
496        };
497
498        let json = serde_json::to_string(&msg).unwrap();
499        assert!(json.contains("\"type\":\"heartbeat\""));
500        assert!(json.contains("\"uptime_secs\":3600"));
501        assert!(json.contains("\"requests_total\":1000"));
502    }
503
504    #[test]
505    fn test_data_plane_message_artifact_downloaded_success() {
506        let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
507        let msg = DataPlaneMessage::ArtifactDownloaded {
508            artifact_id,
509            success: true,
510            error: None,
511        };
512
513        let json = serde_json::to_string(&msg).unwrap();
514        assert!(json.contains("\"type\":\"artifact_downloaded\""));
515        assert!(json.contains("\"success\":true"));
516        assert!(!json.contains("\"error\":")); // None should be skipped
517    }
518
519    #[test]
520    fn test_data_plane_message_artifact_downloaded_failure() {
521        let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
522        let msg = DataPlaneMessage::ArtifactDownloaded {
523            artifact_id,
524            success: false,
525            error: Some("checksum mismatch".to_string()),
526        };
527
528        let json = serde_json::to_string(&msg).unwrap();
529        assert!(json.contains("\"type\":\"artifact_downloaded\""));
530        assert!(json.contains("\"success\":false"));
531        assert!(json.contains("\"error\":\"checksum mismatch\""));
532    }
533
534    #[test]
535    fn test_control_plane_message_registered_deserialization() {
536        let json = r#"{
537            "type": "registered",
538            "data_plane_id": "550e8400-e29b-41d4-a716-446655440000",
539            "heartbeat_interval_secs": 30
540        }"#;
541
542        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
543        match msg {
544            ControlPlaneMessage::Registered {
545                data_plane_id,
546                heartbeat_interval_secs,
547            } => {
548                assert_eq!(
549                    data_plane_id,
550                    Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
551                );
552                assert_eq!(heartbeat_interval_secs, 30);
553            }
554            _ => panic!("Expected Registered message"),
555        }
556    }
557
558    #[test]
559    fn test_control_plane_message_artifact_available_deserialization() {
560        let json = r#"{
561            "type": "artifact_available",
562            "artifact_id": "550e8400-e29b-41d4-a716-446655440000",
563            "download_url": "http://localhost:9090/artifacts/123/download",
564            "sha256": "abc123def456"
565        }"#;
566
567        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
568        match msg {
569            ControlPlaneMessage::ArtifactAvailable {
570                artifact_id,
571                download_url,
572                sha256,
573            } => {
574                assert_eq!(
575                    artifact_id,
576                    Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
577                );
578                assert_eq!(download_url, "http://localhost:9090/artifacts/123/download");
579                assert_eq!(sha256, "abc123def456");
580            }
581            _ => panic!("Expected ArtifactAvailable message"),
582        }
583    }
584
585    #[test]
586    fn test_control_plane_message_disconnect_deserialization() {
587        let json = r#"{
588            "type": "disconnect",
589            "reason": "server shutting down"
590        }"#;
591
592        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
593        match msg {
594            ControlPlaneMessage::Disconnect { reason } => {
595                assert_eq!(reason, "server shutting down");
596            }
597            _ => panic!("Expected Disconnect message"),
598        }
599    }
600
601    #[test]
602    fn test_artifact_downloaded_response_creation() {
603        let artifact_id = Uuid::new_v4();
604
605        let success_response = ArtifactDownloadedResponse {
606            artifact_id,
607            success: true,
608            error: None,
609        };
610        assert!(success_response.success);
611        assert!(success_response.error.is_none());
612
613        let failure_response = ArtifactDownloadedResponse {
614            artifact_id,
615            success: false,
616            error: Some("download failed".to_string()),
617        };
618        assert!(!failure_response.success);
619        assert_eq!(failure_response.error.as_deref(), Some("download failed"));
620    }
621
622    #[test]
623    fn test_artifact_notification_creation() {
624        let notification = ArtifactNotification {
625            artifact_id: Uuid::new_v4(),
626            download_url: "http://example.com/artifact.bca".to_string(),
627            sha256: "abc123".to_string(),
628        };
629
630        assert!(!notification.download_url.is_empty());
631        assert!(!notification.sha256.is_empty());
632    }
633
634    #[test]
635    fn test_control_plane_config_creation() {
636        let config = ControlPlaneConfig {
637            control_plane_url: "ws://localhost:9090/ws/data-plane".to_string(),
638            project_id: Uuid::new_v4(),
639            api_key: "test-api-key".to_string(),
640            data_plane_name: Some("test-plane".to_string()),
641            initial_artifact_id: None,
642        };
643
644        assert!(config.control_plane_url.starts_with("ws://"));
645        assert_eq!(config.api_key, "test-api-key");
646        assert_eq!(config.data_plane_name.as_deref(), Some("test-plane"));
647    }
648}