Skip to main content

barbacane_lib/
control_plane.rs

1//! Control plane client for connected mode.
2//!
3//! This module handles WebSocket communication with the control plane,
4//! including registration, heartbeat, and artifact notifications.
5
6use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tokio::sync::{mpsc, watch};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use uuid::Uuid;
12
13/// Messages sent from data plane to control plane.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type", rename_all = "snake_case")]
16pub enum DataPlaneMessage {
17    /// Initial registration with authentication.
18    Register {
19        project_id: Uuid,
20        api_key: String,
21        #[serde(skip_serializing_if = "Option::is_none")]
22        name: Option<String>,
23        #[serde(skip_serializing_if = "Option::is_none")]
24        artifact_id: Option<Uuid>,
25        #[serde(default)]
26        metadata: serde_json::Value,
27    },
28    /// Periodic heartbeat.
29    Heartbeat {
30        #[serde(skip_serializing_if = "Option::is_none")]
31        artifact_id: Option<Uuid>,
32        uptime_secs: u64,
33        requests_total: u64,
34    },
35    /// Acknowledgment of artifact download.
36    ArtifactDownloaded {
37        artifact_id: Uuid,
38        success: bool,
39        #[serde(skip_serializing_if = "Option::is_none")]
40        error: Option<String>,
41    },
42}
43
44/// Messages sent from control plane to data plane.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(tag = "type", rename_all = "snake_case")]
47pub enum ControlPlaneMessage {
48    /// Registration successful.
49    Registered {
50        data_plane_id: Uuid,
51        heartbeat_interval_secs: u32,
52    },
53    /// Registration failed.
54    RegistrationFailed { reason: String },
55    /// New artifact available for download.
56    ArtifactAvailable {
57        artifact_id: Uuid,
58        download_url: String,
59        sha256: String,
60    },
61    /// Heartbeat acknowledgment.
62    HeartbeatAck,
63    /// Request disconnect.
64    Disconnect { reason: String },
65    /// Error message.
66    Error { message: String },
67}
68
69/// Configuration for the control plane client.
70#[derive(Clone)]
71pub struct ControlPlaneConfig {
72    pub control_plane_url: String,
73    pub project_id: Uuid,
74    pub api_key: String,
75    pub data_plane_name: Option<String>,
76}
77
78/// Notification that a new artifact is available.
79#[derive(Debug, Clone)]
80pub struct ArtifactNotification {
81    pub artifact_id: Uuid,
82    pub download_url: String,
83    pub sha256: String,
84}
85
86/// Response to send back to the control plane after downloading an artifact.
87#[derive(Debug, Clone)]
88pub struct ArtifactDownloadedResponse {
89    pub artifact_id: Uuid,
90    pub success: bool,
91    pub error: Option<String>,
92}
93
94/// Control plane client that maintains connection and handles messages.
95pub struct ControlPlaneClient {
96    config: ControlPlaneConfig,
97}
98
99impl ControlPlaneClient {
100    /// Create a new control plane client.
101    pub fn new(config: ControlPlaneConfig) -> Self {
102        Self { config }
103    }
104
105    /// Start the connection loop in a background task.
106    /// Returns a receiver for artifact notifications and a sender for download responses.
107    pub fn start(
108        self,
109        shutdown_rx: watch::Receiver<bool>,
110    ) -> (
111        mpsc::Receiver<ArtifactNotification>,
112        mpsc::Sender<ArtifactDownloadedResponse>,
113    ) {
114        let (artifact_tx, artifact_rx) = mpsc::channel::<ArtifactNotification>(16);
115        let (response_tx, response_rx) = mpsc::channel::<ArtifactDownloadedResponse>(16);
116
117        tokio::spawn(async move {
118            self.connection_loop(shutdown_rx, artifact_tx, response_rx)
119                .await;
120        });
121
122        (artifact_rx, response_tx)
123    }
124
125    /// Main connection loop with reconnection logic.
126    async fn connection_loop(
127        &self,
128        mut shutdown_rx: watch::Receiver<bool>,
129        artifact_tx: mpsc::Sender<ArtifactNotification>,
130        mut response_rx: mpsc::Receiver<ArtifactDownloadedResponse>,
131    ) {
132        const INITIAL_BACKOFF_MS: u64 = 1000;
133        const MAX_BACKOFF_MS: u64 = 60000;
134        const BACKOFF_MULTIPLIER: f64 = 2.0;
135
136        let mut backoff_ms = INITIAL_BACKOFF_MS;
137
138        loop {
139            // Check for shutdown
140            if *shutdown_rx.borrow() {
141                tracing::info!("Control plane client shutting down");
142                return;
143            }
144
145            tracing::info!(url = %self.config.control_plane_url, "Connecting to control plane");
146
147            match self
148                .try_connect(&mut shutdown_rx, &artifact_tx, &mut response_rx)
149                .await
150            {
151                Ok(()) => {
152                    // Connection was cleanly closed (e.g., shutdown)
153                    return;
154                }
155                Err(e) => {
156                    tracing::warn!(
157                        error = %e,
158                        backoff_ms = backoff_ms,
159                        "Control plane connection failed, will retry"
160                    );
161                }
162            }
163
164            // Wait before reconnecting (or abort if shutdown)
165            tokio::select! {
166                _ = shutdown_rx.changed() => {
167                    if *shutdown_rx.borrow() {
168                        return;
169                    }
170                }
171                _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
172            }
173
174            // Increase backoff for next attempt
175            backoff_ms =
176                ((backoff_ms as f64) * BACKOFF_MULTIPLIER).min(MAX_BACKOFF_MS as f64) as u64;
177        }
178    }
179
180    /// Attempt to connect and handle messages.
181    async fn try_connect(
182        &self,
183        shutdown_rx: &mut watch::Receiver<bool>,
184        artifact_tx: &mpsc::Sender<ArtifactNotification>,
185        response_rx: &mut mpsc::Receiver<ArtifactDownloadedResponse>,
186    ) -> Result<(), String> {
187        // Connect to WebSocket
188        let (ws_stream, _response) = connect_async(&self.config.control_plane_url)
189            .await
190            .map_err(|e| format!("WebSocket connection failed: {}", e))?;
191
192        let (mut sender, mut receiver) = ws_stream.split();
193
194        // Send registration message
195        let register_msg = DataPlaneMessage::Register {
196            project_id: self.config.project_id,
197            api_key: self.config.api_key.clone(),
198            name: self.config.data_plane_name.clone(),
199            artifact_id: None, // TODO: pass current artifact ID
200            metadata: serde_json::json!({}),
201        };
202
203        let register_json = serde_json::to_string(&register_msg)
204            .map_err(|e| format!("Failed to serialize register message: {}", e))?;
205
206        sender
207            .send(Message::Text(register_json.into()))
208            .await
209            .map_err(|e| format!("Failed to send register message: {}", e))?;
210
211        // Wait for registration response
212        let registration_response = tokio::time::timeout(Duration::from_secs(30), receiver.next())
213            .await
214            .map_err(|_| "Registration timeout")?
215            .ok_or("Connection closed before registration")?
216            .map_err(|e| format!("WebSocket error: {}", e))?;
217
218        let heartbeat_interval_secs = match registration_response {
219            Message::Text(text) => {
220                let msg: ControlPlaneMessage = serde_json::from_str(&text)
221                    .map_err(|e| format!("Failed to parse registration response: {}", e))?;
222
223                match msg {
224                    ControlPlaneMessage::Registered {
225                        data_plane_id,
226                        heartbeat_interval_secs,
227                    } => {
228                        tracing::info!(
229                            data_plane_id = %data_plane_id,
230                            heartbeat_interval_secs,
231                            "Registered with control plane"
232                        );
233                        heartbeat_interval_secs
234                    }
235                    ControlPlaneMessage::RegistrationFailed { reason } => {
236                        return Err(format!("Registration failed: {}", reason));
237                    }
238                    other => {
239                        return Err(format!("Unexpected registration response: {:?}", other));
240                    }
241                }
242            }
243            other => {
244                return Err(format!("Unexpected message type: {:?}", other));
245            }
246        };
247
248        // Start heartbeat timer
249        let mut heartbeat_interval =
250            tokio::time::interval(Duration::from_secs(heartbeat_interval_secs as u64));
251        let start_time = std::time::Instant::now();
252
253        // Main message loop
254        loop {
255            tokio::select! {
256                // Shutdown signal
257                _ = shutdown_rx.changed() => {
258                    if *shutdown_rx.borrow() {
259                        tracing::info!("Disconnecting from control plane");
260                        let _ = sender.close().await;
261                        return Ok(());
262                    }
263                }
264
265                // Heartbeat timer
266                _ = heartbeat_interval.tick() => {
267                    let heartbeat = DataPlaneMessage::Heartbeat {
268                        artifact_id: None, // TODO: pass current artifact ID
269                        uptime_secs: start_time.elapsed().as_secs(),
270                        requests_total: 0, // TODO: pass actual metrics
271                    };
272
273                    let json = serde_json::to_string(&heartbeat)
274                        .map_err(|e| format!("Failed to serialize heartbeat: {}", e))?;
275
276                    if let Err(e) = sender.send(Message::Text(json.into())).await {
277                        return Err(format!("Failed to send heartbeat: {}", e));
278                    }
279
280                    tracing::debug!("Heartbeat sent");
281                }
282
283                // Artifact download response from main loop
284                Some(response) = response_rx.recv() => {
285                    let msg = DataPlaneMessage::ArtifactDownloaded {
286                        artifact_id: response.artifact_id,
287                        success: response.success,
288                        error: response.error,
289                    };
290
291                    let json = serde_json::to_string(&msg)
292                        .map_err(|e| format!("Failed to serialize artifact downloaded: {}", e))?;
293
294                    if let Err(e) = sender.send(Message::Text(json.into())).await {
295                        tracing::warn!(error = %e, "Failed to send artifact downloaded response");
296                    } else {
297                        tracing::info!(
298                            artifact_id = %response.artifact_id,
299                            success = response.success,
300                            "Sent artifact downloaded response to control plane"
301                        );
302                    }
303                }
304
305                // Incoming messages
306                result = receiver.next() => {
307                    match result {
308                        Some(Ok(Message::Text(text))) => {
309                            match serde_json::from_str::<ControlPlaneMessage>(&text) {
310                                Ok(msg) => {
311                                    if let Err(e) = self.handle_message(msg, artifact_tx, &mut sender).await {
312                                        tracing::warn!(error = %e, "Error handling control plane message");
313                                    }
314                                }
315                                Err(e) => {
316                                    tracing::warn!(error = %e, "Failed to parse control plane message");
317                                }
318                            }
319                        }
320                        Some(Ok(Message::Ping(data))) => {
321                            let _ = sender.send(Message::Pong(data)).await;
322                        }
323                        Some(Ok(Message::Close(_))) | None => {
324                            return Err("Connection closed by control plane".to_string());
325                        }
326                        Some(Err(e)) => {
327                            return Err(format!("WebSocket error: {}", e));
328                        }
329                        _ => {}
330                    }
331                }
332            }
333        }
334    }
335
336    /// Handle a message from the control plane.
337    async fn handle_message(
338        &self,
339        msg: ControlPlaneMessage,
340        artifact_tx: &mpsc::Sender<ArtifactNotification>,
341        _sender: &mut futures_util::stream::SplitSink<
342            tokio_tungstenite::WebSocketStream<
343                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
344            >,
345            Message,
346        >,
347    ) -> Result<(), String> {
348        match msg {
349            ControlPlaneMessage::HeartbeatAck => {
350                tracing::debug!("Heartbeat acknowledged");
351            }
352            ControlPlaneMessage::ArtifactAvailable {
353                artifact_id,
354                download_url,
355                sha256,
356            } => {
357                tracing::info!(
358                    artifact_id = %artifact_id,
359                    download_url = %download_url,
360                    "New artifact available"
361                );
362
363                // Notify the main loop about the new artifact
364                if let Err(e) = artifact_tx
365                    .send(ArtifactNotification {
366                        artifact_id,
367                        download_url,
368                        sha256,
369                    })
370                    .await
371                {
372                    tracing::warn!(error = %e, "Failed to send artifact notification");
373                }
374            }
375            ControlPlaneMessage::Disconnect { reason } => {
376                tracing::info!(reason = %reason, "Disconnecting at control plane request");
377                return Err(format!("Disconnected by control plane: {}", reason));
378            }
379            ControlPlaneMessage::Error { message } => {
380                tracing::warn!(message = %message, "Error from control plane");
381            }
382            // These shouldn't happen after registration
383            ControlPlaneMessage::Registered { .. }
384            | ControlPlaneMessage::RegistrationFailed { .. } => {
385                tracing::warn!("Unexpected registration message after already registered");
386            }
387        }
388
389        Ok(())
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_data_plane_message_register_serialization() {
399        let msg = DataPlaneMessage::Register {
400            project_id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
401            api_key: "test-key".to_string(),
402            name: Some("my-data-plane".to_string()),
403            artifact_id: None,
404            metadata: serde_json::json!({"version": "1.0"}),
405        };
406
407        let json = serde_json::to_string(&msg).unwrap();
408        assert!(json.contains("\"type\":\"register\""));
409        assert!(json.contains("\"project_id\":"));
410        assert!(json.contains("\"api_key\":\"test-key\""));
411        assert!(json.contains("\"name\":\"my-data-plane\""));
412    }
413
414    #[test]
415    fn test_data_plane_message_heartbeat_serialization() {
416        let msg = DataPlaneMessage::Heartbeat {
417            artifact_id: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()),
418            uptime_secs: 3600,
419            requests_total: 1000,
420        };
421
422        let json = serde_json::to_string(&msg).unwrap();
423        assert!(json.contains("\"type\":\"heartbeat\""));
424        assert!(json.contains("\"uptime_secs\":3600"));
425        assert!(json.contains("\"requests_total\":1000"));
426    }
427
428    #[test]
429    fn test_data_plane_message_artifact_downloaded_success() {
430        let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
431        let msg = DataPlaneMessage::ArtifactDownloaded {
432            artifact_id,
433            success: true,
434            error: None,
435        };
436
437        let json = serde_json::to_string(&msg).unwrap();
438        assert!(json.contains("\"type\":\"artifact_downloaded\""));
439        assert!(json.contains("\"success\":true"));
440        assert!(!json.contains("\"error\":")); // None should be skipped
441    }
442
443    #[test]
444    fn test_data_plane_message_artifact_downloaded_failure() {
445        let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
446        let msg = DataPlaneMessage::ArtifactDownloaded {
447            artifact_id,
448            success: false,
449            error: Some("checksum mismatch".to_string()),
450        };
451
452        let json = serde_json::to_string(&msg).unwrap();
453        assert!(json.contains("\"type\":\"artifact_downloaded\""));
454        assert!(json.contains("\"success\":false"));
455        assert!(json.contains("\"error\":\"checksum mismatch\""));
456    }
457
458    #[test]
459    fn test_control_plane_message_registered_deserialization() {
460        let json = r#"{
461            "type": "registered",
462            "data_plane_id": "550e8400-e29b-41d4-a716-446655440000",
463            "heartbeat_interval_secs": 30
464        }"#;
465
466        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
467        match msg {
468            ControlPlaneMessage::Registered {
469                data_plane_id,
470                heartbeat_interval_secs,
471            } => {
472                assert_eq!(
473                    data_plane_id,
474                    Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
475                );
476                assert_eq!(heartbeat_interval_secs, 30);
477            }
478            _ => panic!("Expected Registered message"),
479        }
480    }
481
482    #[test]
483    fn test_control_plane_message_artifact_available_deserialization() {
484        let json = r#"{
485            "type": "artifact_available",
486            "artifact_id": "550e8400-e29b-41d4-a716-446655440000",
487            "download_url": "http://localhost:9090/artifacts/123/download",
488            "sha256": "abc123def456"
489        }"#;
490
491        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
492        match msg {
493            ControlPlaneMessage::ArtifactAvailable {
494                artifact_id,
495                download_url,
496                sha256,
497            } => {
498                assert_eq!(
499                    artifact_id,
500                    Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
501                );
502                assert_eq!(download_url, "http://localhost:9090/artifacts/123/download");
503                assert_eq!(sha256, "abc123def456");
504            }
505            _ => panic!("Expected ArtifactAvailable message"),
506        }
507    }
508
509    #[test]
510    fn test_control_plane_message_disconnect_deserialization() {
511        let json = r#"{
512            "type": "disconnect",
513            "reason": "server shutting down"
514        }"#;
515
516        let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
517        match msg {
518            ControlPlaneMessage::Disconnect { reason } => {
519                assert_eq!(reason, "server shutting down");
520            }
521            _ => panic!("Expected Disconnect message"),
522        }
523    }
524
525    #[test]
526    fn test_artifact_downloaded_response_creation() {
527        let artifact_id = Uuid::new_v4();
528
529        let success_response = ArtifactDownloadedResponse {
530            artifact_id,
531            success: true,
532            error: None,
533        };
534        assert!(success_response.success);
535        assert!(success_response.error.is_none());
536
537        let failure_response = ArtifactDownloadedResponse {
538            artifact_id,
539            success: false,
540            error: Some("download failed".to_string()),
541        };
542        assert!(!failure_response.success);
543        assert_eq!(failure_response.error.as_deref(), Some("download failed"));
544    }
545
546    #[test]
547    fn test_artifact_notification_creation() {
548        let notification = ArtifactNotification {
549            artifact_id: Uuid::new_v4(),
550            download_url: "http://example.com/artifact.bca".to_string(),
551            sha256: "abc123".to_string(),
552        };
553
554        assert!(!notification.download_url.is_empty());
555        assert!(!notification.sha256.is_empty());
556    }
557
558    #[test]
559    fn test_control_plane_config_creation() {
560        let config = ControlPlaneConfig {
561            control_plane_url: "ws://localhost:9090/ws/data-plane".to_string(),
562            project_id: Uuid::new_v4(),
563            api_key: "test-api-key".to_string(),
564            data_plane_name: Some("test-plane".to_string()),
565        };
566
567        assert!(config.control_plane_url.starts_with("ws://"));
568        assert_eq!(config.api_key, "test-api-key");
569        assert_eq!(config.data_plane_name.as_deref(), Some("test-plane"));
570    }
571}