1use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, watch};
12use tokio_tungstenite::{connect_async, tungstenite::Message};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum DataPlaneMessage {
19 Register {
21 project_id: Uuid,
22 api_key: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 name: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 artifact_id: Option<Uuid>,
27 #[serde(default)]
28 metadata: serde_json::Value,
29 },
30 Heartbeat {
32 #[serde(skip_serializing_if = "Option::is_none")]
33 artifact_id: Option<Uuid>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 artifact_hash: Option<String>,
36 uptime_secs: u64,
37 requests_total: u64,
38 },
39 ArtifactDownloaded {
41 artifact_id: Uuid,
42 success: bool,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 error: Option<String>,
45 },
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type", rename_all = "snake_case")]
51pub enum ControlPlaneMessage {
52 Registered {
54 data_plane_id: Uuid,
55 heartbeat_interval_secs: u32,
56 },
57 RegistrationFailed { reason: String },
59 ArtifactAvailable {
61 artifact_id: Uuid,
62 download_url: String,
63 sha256: String,
64 },
65 HeartbeatAck { drift_detected: bool },
67 Disconnect { reason: String },
69 Error { message: String },
71}
72
73#[derive(Clone)]
75pub struct ControlPlaneConfig {
76 pub control_plane_url: String,
77 pub project_id: Uuid,
78 pub api_key: String,
79 pub data_plane_name: Option<String>,
80 pub initial_artifact_id: Option<Uuid>,
81}
82
83#[derive(Debug, Clone)]
85pub struct ArtifactNotification {
86 pub artifact_id: Uuid,
87 pub download_url: String,
88 pub sha256: String,
89}
90
91#[derive(Debug, Clone)]
93pub struct ArtifactDownloadedResponse {
94 pub artifact_id: Uuid,
95 pub success: bool,
96 pub error: Option<String>,
97}
98
99enum ConnectOutcome {
101 Shutdown,
103 ConnectionLost(String),
105 ConnectionFailed(String),
107}
108
109pub struct ControlPlaneClient {
111 config: ControlPlaneConfig,
112}
113
114impl ControlPlaneClient {
115 pub fn new(config: ControlPlaneConfig) -> Self {
117 Self { config }
118 }
119
120 pub fn start(
123 self,
124 shutdown_rx: watch::Receiver<bool>,
125 artifact_hash_rx: watch::Receiver<Option<String>>,
126 drift_flag: Arc<AtomicBool>,
127 ) -> (
128 mpsc::Receiver<ArtifactNotification>,
129 mpsc::Sender<ArtifactDownloadedResponse>,
130 ) {
131 let (artifact_tx, artifact_rx) = mpsc::channel::<ArtifactNotification>(16);
132 let (response_tx, response_rx) = mpsc::channel::<ArtifactDownloadedResponse>(16);
133
134 tokio::spawn(async move {
135 self.connection_loop(
136 shutdown_rx,
137 artifact_tx,
138 response_rx,
139 artifact_hash_rx,
140 drift_flag,
141 )
142 .await;
143 });
144
145 (artifact_rx, response_tx)
146 }
147
148 async fn connection_loop(
150 &self,
151 mut shutdown_rx: watch::Receiver<bool>,
152 artifact_tx: mpsc::Sender<ArtifactNotification>,
153 mut response_rx: mpsc::Receiver<ArtifactDownloadedResponse>,
154 artifact_hash_rx: watch::Receiver<Option<String>>,
155 drift_flag: Arc<AtomicBool>,
156 ) {
157 const INITIAL_BACKOFF_MS: u64 = 1000;
158 const MAX_BACKOFF_MS: u64 = 60000;
159 const BACKOFF_MULTIPLIER: f64 = 2.0;
160
161 let mut backoff_ms = INITIAL_BACKOFF_MS;
162
163 loop {
164 if *shutdown_rx.borrow() {
166 tracing::info!("Control plane client shutting down");
167 return;
168 }
169
170 tracing::info!(url = %self.config.control_plane_url, "Connecting to control plane");
171
172 match self
173 .try_connect(
174 &mut shutdown_rx,
175 &artifact_tx,
176 &mut response_rx,
177 &artifact_hash_rx,
178 &drift_flag,
179 )
180 .await
181 {
182 ConnectOutcome::Shutdown => {
183 return;
184 }
185 ConnectOutcome::ConnectionLost(e) => {
186 tracing::warn!(
188 error = %e,
189 "Control plane connection lost, reconnecting immediately"
190 );
191 backoff_ms = INITIAL_BACKOFF_MS;
192 }
193 ConnectOutcome::ConnectionFailed(e) => {
194 tracing::warn!(
195 error = %e,
196 backoff_ms = backoff_ms,
197 "Control plane connection failed, will retry"
198 );
199 }
200 }
201
202 tokio::select! {
204 _ = shutdown_rx.changed() => {
205 if *shutdown_rx.borrow() {
206 return;
207 }
208 }
209 _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
210 }
211
212 backoff_ms =
214 ((backoff_ms as f64) * BACKOFF_MULTIPLIER).min(MAX_BACKOFF_MS as f64) as u64;
215 }
216 }
217
218 async fn try_connect(
220 &self,
221 shutdown_rx: &mut watch::Receiver<bool>,
222 artifact_tx: &mpsc::Sender<ArtifactNotification>,
223 response_rx: &mut mpsc::Receiver<ArtifactDownloadedResponse>,
224 artifact_hash_rx: &watch::Receiver<Option<String>>,
225 drift_flag: &Arc<AtomicBool>,
226 ) -> ConnectOutcome {
227 let (ws_stream, _response) = match connect_async(&self.config.control_plane_url).await {
229 Ok(conn) => conn,
230 Err(e) => {
231 return ConnectOutcome::ConnectionFailed(format!(
232 "WebSocket connection failed: {}",
233 e
234 ))
235 }
236 };
237
238 let (mut sender, mut receiver) = ws_stream.split();
239
240 let register_msg = DataPlaneMessage::Register {
242 project_id: self.config.project_id,
243 api_key: self.config.api_key.clone(),
244 name: self.config.data_plane_name.clone(),
245 artifact_id: self.config.initial_artifact_id,
246 metadata: serde_json::json!({}),
247 };
248
249 let register_json = match serde_json::to_string(®ister_msg) {
250 Ok(j) => j,
251 Err(e) => {
252 return ConnectOutcome::ConnectionFailed(format!(
253 "Failed to serialize register message: {}",
254 e
255 ))
256 }
257 };
258
259 if let Err(e) = sender.send(Message::Text(register_json.into())).await {
260 return ConnectOutcome::ConnectionFailed(format!(
261 "Failed to send register message: {}",
262 e
263 ));
264 }
265
266 let registration_response =
268 match tokio::time::timeout(Duration::from_secs(30), receiver.next()).await {
269 Ok(Some(Ok(msg))) => msg,
270 Ok(Some(Err(e))) => {
271 return ConnectOutcome::ConnectionFailed(format!("WebSocket error: {}", e))
272 }
273 Ok(None) => {
274 return ConnectOutcome::ConnectionFailed(
275 "Connection closed before registration".to_string(),
276 )
277 }
278 Err(_) => {
279 return ConnectOutcome::ConnectionFailed("Registration timeout".to_string())
280 }
281 };
282
283 let heartbeat_interval_secs = match registration_response {
284 Message::Text(text) => {
285 let msg: ControlPlaneMessage = match serde_json::from_str(&text) {
286 Ok(m) => m,
287 Err(e) => {
288 return ConnectOutcome::ConnectionFailed(format!(
289 "Failed to parse registration response: {}",
290 e
291 ))
292 }
293 };
294
295 match msg {
296 ControlPlaneMessage::Registered {
297 data_plane_id,
298 heartbeat_interval_secs,
299 } => {
300 tracing::info!(
301 data_plane_id = %data_plane_id,
302 heartbeat_interval_secs,
303 "Registered with control plane"
304 );
305 heartbeat_interval_secs
306 }
307 ControlPlaneMessage::RegistrationFailed { reason } => {
308 return ConnectOutcome::ConnectionFailed(format!(
309 "Registration failed: {}",
310 reason
311 ));
312 }
313 other => {
314 return ConnectOutcome::ConnectionFailed(format!(
315 "Unexpected registration response: {:?}",
316 other
317 ));
318 }
319 }
320 }
321 other => {
322 return ConnectOutcome::ConnectionFailed(format!(
323 "Unexpected message type: {:?}",
324 other
325 ));
326 }
327 };
328
329 let mut heartbeat_interval =
331 tokio::time::interval(Duration::from_secs(heartbeat_interval_secs as u64));
332 let start_time = std::time::Instant::now();
333
334 loop {
337 tokio::select! {
338 _ = shutdown_rx.changed() => {
340 if *shutdown_rx.borrow() {
341 tracing::info!("Disconnecting from control plane");
342 let _ = sender.close().await;
343 return ConnectOutcome::Shutdown;
344 }
345 }
346
347 _ = heartbeat_interval.tick() => {
349 let heartbeat = DataPlaneMessage::Heartbeat {
350 artifact_id: None, artifact_hash: artifact_hash_rx.borrow().clone(),
352 uptime_secs: start_time.elapsed().as_secs(),
353 requests_total: 0, };
355
356 let json = match serde_json::to_string(&heartbeat) {
357 Ok(j) => j,
358 Err(e) => {
359 tracing::error!(error = %e, "Failed to serialize heartbeat");
360 continue;
361 }
362 };
363
364 if let Err(e) = sender.send(Message::Text(json.into())).await {
365 return ConnectOutcome::ConnectionLost(format!(
366 "Failed to send heartbeat: {}", e
367 ));
368 }
369
370 tracing::debug!("Heartbeat sent");
371 }
372
373 Some(response) = response_rx.recv() => {
375 let msg = DataPlaneMessage::ArtifactDownloaded {
376 artifact_id: response.artifact_id,
377 success: response.success,
378 error: response.error,
379 };
380
381 let json = match serde_json::to_string(&msg) {
382 Ok(j) => j,
383 Err(e) => {
384 tracing::error!(error = %e, "Failed to serialize artifact downloaded");
385 continue;
386 }
387 };
388
389 if let Err(e) = sender.send(Message::Text(json.into())).await {
390 tracing::warn!(error = %e, "Failed to send artifact downloaded response");
391 } else {
392 tracing::info!(
393 artifact_id = %response.artifact_id,
394 success = response.success,
395 "Sent artifact downloaded response to control plane"
396 );
397 }
398 }
399
400 result = receiver.next() => {
402 match result {
403 Some(Ok(Message::Text(text))) => {
404 match serde_json::from_str::<ControlPlaneMessage>(&text) {
405 Ok(msg) => {
406 if let Err(e) = self.handle_message(msg, artifact_tx, &mut sender, drift_flag).await {
407 tracing::warn!(error = %e, "Error handling control plane message");
408 }
409 }
410 Err(e) => {
411 tracing::warn!(error = %e, "Failed to parse control plane message");
412 }
413 }
414 }
415 Some(Ok(Message::Ping(data))) => {
416 let _ = sender.send(Message::Pong(data)).await;
417 }
418 Some(Ok(Message::Close(_))) | None => {
419 return ConnectOutcome::ConnectionLost(
420 "Connection closed by control plane".to_string()
421 );
422 }
423 Some(Err(e)) => {
424 return ConnectOutcome::ConnectionLost(format!(
425 "WebSocket error: {}", e
426 ));
427 }
428 _ => {}
429 }
430 }
431 }
432 }
433 }
434
435 async fn handle_message(
437 &self,
438 msg: ControlPlaneMessage,
439 artifact_tx: &mpsc::Sender<ArtifactNotification>,
440 _sender: &mut futures_util::stream::SplitSink<
441 tokio_tungstenite::WebSocketStream<
442 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
443 >,
444 Message,
445 >,
446 drift_flag: &Arc<AtomicBool>,
447 ) -> Result<(), String> {
448 match msg {
449 ControlPlaneMessage::HeartbeatAck { drift_detected } => {
450 drift_flag.store(drift_detected, Ordering::Relaxed);
451 if drift_detected {
452 tracing::warn!("Control plane detected configuration drift");
453 }
454 tracing::debug!(drift_detected, "Heartbeat acknowledged");
455 }
456 ControlPlaneMessage::ArtifactAvailable {
457 artifact_id,
458 download_url,
459 sha256,
460 } => {
461 tracing::info!(
462 artifact_id = %artifact_id,
463 download_url = %download_url,
464 "New artifact available"
465 );
466
467 if let Err(e) = artifact_tx
469 .send(ArtifactNotification {
470 artifact_id,
471 download_url,
472 sha256,
473 })
474 .await
475 {
476 tracing::warn!(error = %e, "Failed to send artifact notification");
477 }
478 }
479 ControlPlaneMessage::Disconnect { reason } => {
480 tracing::info!(reason = %reason, "Disconnecting at control plane request");
481 return Err(format!("Disconnected by control plane: {}", reason));
482 }
483 ControlPlaneMessage::Error { message } => {
484 tracing::warn!(message = %message, "Error from control plane");
485 }
486 ControlPlaneMessage::Registered { .. }
488 | ControlPlaneMessage::RegistrationFailed { .. } => {
489 tracing::warn!("Unexpected registration message after already registered");
490 }
491 }
492
493 Ok(())
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_data_plane_message_register_serialization() {
503 let msg = DataPlaneMessage::Register {
504 project_id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
505 api_key: "test-key".to_string(),
506 name: Some("my-data-plane".to_string()),
507 artifact_id: None,
508 metadata: serde_json::json!({"version": "1.0"}),
509 };
510
511 let json = serde_json::to_string(&msg).unwrap();
512 assert!(json.contains("\"type\":\"register\""));
513 assert!(json.contains("\"project_id\":"));
514 assert!(json.contains("\"api_key\":\"test-key\""));
515 assert!(json.contains("\"name\":\"my-data-plane\""));
516 }
517
518 #[test]
519 fn test_data_plane_message_heartbeat_serialization() {
520 let msg = DataPlaneMessage::Heartbeat {
521 artifact_id: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()),
522 artifact_hash: Some("sha256:abc123".to_string()),
523 uptime_secs: 3600,
524 requests_total: 1000,
525 };
526
527 let json = serde_json::to_string(&msg).unwrap();
528 assert!(json.contains("\"type\":\"heartbeat\""));
529 assert!(json.contains("\"uptime_secs\":3600"));
530 assert!(json.contains("\"requests_total\":1000"));
531 assert!(json.contains("\"artifact_hash\":\"sha256:abc123\""));
532 }
533
534 #[test]
535 fn test_data_plane_message_artifact_downloaded_success() {
536 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
537 let msg = DataPlaneMessage::ArtifactDownloaded {
538 artifact_id,
539 success: true,
540 error: None,
541 };
542
543 let json = serde_json::to_string(&msg).unwrap();
544 assert!(json.contains("\"type\":\"artifact_downloaded\""));
545 assert!(json.contains("\"success\":true"));
546 assert!(!json.contains("\"error\":")); }
548
549 #[test]
550 fn test_data_plane_message_artifact_downloaded_failure() {
551 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
552 let msg = DataPlaneMessage::ArtifactDownloaded {
553 artifact_id,
554 success: false,
555 error: Some("checksum mismatch".to_string()),
556 };
557
558 let json = serde_json::to_string(&msg).unwrap();
559 assert!(json.contains("\"type\":\"artifact_downloaded\""));
560 assert!(json.contains("\"success\":false"));
561 assert!(json.contains("\"error\":\"checksum mismatch\""));
562 }
563
564 #[test]
565 fn test_control_plane_message_registered_deserialization() {
566 let json = r#"{
567 "type": "registered",
568 "data_plane_id": "550e8400-e29b-41d4-a716-446655440000",
569 "heartbeat_interval_secs": 30
570 }"#;
571
572 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
573 match msg {
574 ControlPlaneMessage::Registered {
575 data_plane_id,
576 heartbeat_interval_secs,
577 } => {
578 assert_eq!(
579 data_plane_id,
580 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
581 );
582 assert_eq!(heartbeat_interval_secs, 30);
583 }
584 _ => panic!("Expected Registered message"),
585 }
586 }
587
588 #[test]
589 fn test_control_plane_message_artifact_available_deserialization() {
590 let json = r#"{
591 "type": "artifact_available",
592 "artifact_id": "550e8400-e29b-41d4-a716-446655440000",
593 "download_url": "http://localhost:9090/artifacts/123/download",
594 "sha256": "abc123def456"
595 }"#;
596
597 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
598 match msg {
599 ControlPlaneMessage::ArtifactAvailable {
600 artifact_id,
601 download_url,
602 sha256,
603 } => {
604 assert_eq!(
605 artifact_id,
606 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
607 );
608 assert_eq!(download_url, "http://localhost:9090/artifacts/123/download");
609 assert_eq!(sha256, "abc123def456");
610 }
611 _ => panic!("Expected ArtifactAvailable message"),
612 }
613 }
614
615 #[test]
616 fn test_control_plane_message_disconnect_deserialization() {
617 let json = r#"{
618 "type": "disconnect",
619 "reason": "server shutting down"
620 }"#;
621
622 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
623 match msg {
624 ControlPlaneMessage::Disconnect { reason } => {
625 assert_eq!(reason, "server shutting down");
626 }
627 _ => panic!("Expected Disconnect message"),
628 }
629 }
630
631 #[test]
632 fn test_artifact_downloaded_response_creation() {
633 let artifact_id = Uuid::new_v4();
634
635 let success_response = ArtifactDownloadedResponse {
636 artifact_id,
637 success: true,
638 error: None,
639 };
640 assert!(success_response.success);
641 assert!(success_response.error.is_none());
642
643 let failure_response = ArtifactDownloadedResponse {
644 artifact_id,
645 success: false,
646 error: Some("download failed".to_string()),
647 };
648 assert!(!failure_response.success);
649 assert_eq!(failure_response.error.as_deref(), Some("download failed"));
650 }
651
652 #[test]
653 fn test_artifact_notification_creation() {
654 let notification = ArtifactNotification {
655 artifact_id: Uuid::new_v4(),
656 download_url: "http://example.com/artifact.bca".to_string(),
657 sha256: "abc123".to_string(),
658 };
659
660 assert!(!notification.download_url.is_empty());
661 assert!(!notification.sha256.is_empty());
662 }
663
664 #[test]
665 fn test_control_plane_config_creation() {
666 let config = ControlPlaneConfig {
667 control_plane_url: "ws://localhost:9090/ws/data-plane".to_string(),
668 project_id: Uuid::new_v4(),
669 api_key: "test-api-key".to_string(),
670 data_plane_name: Some("test-plane".to_string()),
671 initial_artifact_id: None,
672 };
673
674 assert!(config.control_plane_url.starts_with("ws://"));
675 assert_eq!(config.api_key, "test-api-key");
676 assert_eq!(config.data_plane_name.as_deref(), Some("test-plane"));
677 }
678
679 #[test]
680 fn test_heartbeat_ack_with_drift_serialization() {
681 let json = r#"{"type":"heartbeat_ack","drift_detected":true}"#;
682 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
683 match msg {
684 ControlPlaneMessage::HeartbeatAck { drift_detected } => {
685 assert!(drift_detected);
686 }
687 _ => panic!("Expected HeartbeatAck message"),
688 }
689 }
690
691 #[test]
692 fn test_heartbeat_ack_without_drift_serialization() {
693 let json = r#"{"type":"heartbeat_ack","drift_detected":false}"#;
694 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
695 match msg {
696 ControlPlaneMessage::HeartbeatAck { drift_detected } => {
697 assert!(!drift_detected);
698 }
699 _ => panic!("Expected HeartbeatAck message"),
700 }
701 }
702
703 #[test]
704 fn test_heartbeat_with_artifact_hash_serialization() {
705 let msg = DataPlaneMessage::Heartbeat {
706 artifact_id: None,
707 artifact_hash: Some("sha256:abc123def".to_string()),
708 uptime_secs: 120,
709 requests_total: 50,
710 };
711
712 let json = serde_json::to_string(&msg).unwrap();
713 assert!(json.contains("\"artifact_hash\":\"sha256:abc123def\""));
714
715 let deserialized: DataPlaneMessage = serde_json::from_str(&json).unwrap();
717 match deserialized {
718 DataPlaneMessage::Heartbeat {
719 artifact_hash,
720 uptime_secs,
721 ..
722 } => {
723 assert_eq!(artifact_hash, Some("sha256:abc123def".to_string()));
724 assert_eq!(uptime_secs, 120);
725 }
726 _ => panic!("Expected Heartbeat message"),
727 }
728 }
729
730 #[test]
731 fn test_heartbeat_without_artifact_hash() {
732 let msg = DataPlaneMessage::Heartbeat {
733 artifact_id: None,
734 artifact_hash: None,
735 uptime_secs: 0,
736 requests_total: 0,
737 };
738
739 let json = serde_json::to_string(&msg).unwrap();
740 assert!(
742 !json.contains("artifact_hash"),
743 "artifact_hash should be omitted when None"
744 );
745 }
746
747 #[test]
748 fn test_drift_flag_updated_by_heartbeat_ack() {
749 let drift_flag = Arc::new(AtomicBool::new(false));
750
751 drift_flag.store(true, Ordering::Relaxed);
753 assert!(drift_flag.load(Ordering::Relaxed));
754
755 drift_flag.store(false, Ordering::Relaxed);
757 assert!(!drift_flag.load(Ordering::Relaxed));
758 }
759}