1use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tokio::sync::{mpsc, watch};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type", rename_all = "snake_case")]
16pub enum DataPlaneMessage {
17 Register {
19 project_id: Uuid,
20 api_key: String,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 name: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 artifact_id: Option<Uuid>,
25 #[serde(default)]
26 metadata: serde_json::Value,
27 },
28 Heartbeat {
30 #[serde(skip_serializing_if = "Option::is_none")]
31 artifact_id: Option<Uuid>,
32 uptime_secs: u64,
33 requests_total: u64,
34 },
35 ArtifactDownloaded {
37 artifact_id: Uuid,
38 success: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 error: Option<String>,
41 },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(tag = "type", rename_all = "snake_case")]
47pub enum ControlPlaneMessage {
48 Registered {
50 data_plane_id: Uuid,
51 heartbeat_interval_secs: u32,
52 },
53 RegistrationFailed { reason: String },
55 ArtifactAvailable {
57 artifact_id: Uuid,
58 download_url: String,
59 sha256: String,
60 },
61 HeartbeatAck,
63 Disconnect { reason: String },
65 Error { message: String },
67}
68
69#[derive(Clone)]
71pub struct ControlPlaneConfig {
72 pub control_plane_url: String,
73 pub project_id: Uuid,
74 pub api_key: String,
75 pub data_plane_name: Option<String>,
76 pub initial_artifact_id: Option<Uuid>,
77}
78
79#[derive(Debug, Clone)]
81pub struct ArtifactNotification {
82 pub artifact_id: Uuid,
83 pub download_url: String,
84 pub sha256: String,
85}
86
87#[derive(Debug, Clone)]
89pub struct ArtifactDownloadedResponse {
90 pub artifact_id: Uuid,
91 pub success: bool,
92 pub error: Option<String>,
93}
94
95enum ConnectOutcome {
97 Shutdown,
99 ConnectionLost(String),
101 ConnectionFailed(String),
103}
104
105pub struct ControlPlaneClient {
107 config: ControlPlaneConfig,
108}
109
110impl ControlPlaneClient {
111 pub fn new(config: ControlPlaneConfig) -> Self {
113 Self { config }
114 }
115
116 pub fn start(
119 self,
120 shutdown_rx: watch::Receiver<bool>,
121 ) -> (
122 mpsc::Receiver<ArtifactNotification>,
123 mpsc::Sender<ArtifactDownloadedResponse>,
124 ) {
125 let (artifact_tx, artifact_rx) = mpsc::channel::<ArtifactNotification>(16);
126 let (response_tx, response_rx) = mpsc::channel::<ArtifactDownloadedResponse>(16);
127
128 tokio::spawn(async move {
129 self.connection_loop(shutdown_rx, artifact_tx, response_rx)
130 .await;
131 });
132
133 (artifact_rx, response_tx)
134 }
135
136 async fn connection_loop(
138 &self,
139 mut shutdown_rx: watch::Receiver<bool>,
140 artifact_tx: mpsc::Sender<ArtifactNotification>,
141 mut response_rx: mpsc::Receiver<ArtifactDownloadedResponse>,
142 ) {
143 const INITIAL_BACKOFF_MS: u64 = 1000;
144 const MAX_BACKOFF_MS: u64 = 60000;
145 const BACKOFF_MULTIPLIER: f64 = 2.0;
146
147 let mut backoff_ms = INITIAL_BACKOFF_MS;
148
149 loop {
150 if *shutdown_rx.borrow() {
152 tracing::info!("Control plane client shutting down");
153 return;
154 }
155
156 tracing::info!(url = %self.config.control_plane_url, "Connecting to control plane");
157
158 match self
159 .try_connect(&mut shutdown_rx, &artifact_tx, &mut response_rx)
160 .await
161 {
162 ConnectOutcome::Shutdown => {
163 return;
164 }
165 ConnectOutcome::ConnectionLost(e) => {
166 tracing::warn!(
168 error = %e,
169 "Control plane connection lost, reconnecting immediately"
170 );
171 backoff_ms = INITIAL_BACKOFF_MS;
172 }
173 ConnectOutcome::ConnectionFailed(e) => {
174 tracing::warn!(
175 error = %e,
176 backoff_ms = backoff_ms,
177 "Control plane connection failed, will retry"
178 );
179 }
180 }
181
182 tokio::select! {
184 _ = shutdown_rx.changed() => {
185 if *shutdown_rx.borrow() {
186 return;
187 }
188 }
189 _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
190 }
191
192 backoff_ms =
194 ((backoff_ms as f64) * BACKOFF_MULTIPLIER).min(MAX_BACKOFF_MS as f64) as u64;
195 }
196 }
197
198 async fn try_connect(
200 &self,
201 shutdown_rx: &mut watch::Receiver<bool>,
202 artifact_tx: &mpsc::Sender<ArtifactNotification>,
203 response_rx: &mut mpsc::Receiver<ArtifactDownloadedResponse>,
204 ) -> ConnectOutcome {
205 let (ws_stream, _response) = match connect_async(&self.config.control_plane_url).await {
207 Ok(conn) => conn,
208 Err(e) => {
209 return ConnectOutcome::ConnectionFailed(format!(
210 "WebSocket connection failed: {}",
211 e
212 ))
213 }
214 };
215
216 let (mut sender, mut receiver) = ws_stream.split();
217
218 let register_msg = DataPlaneMessage::Register {
220 project_id: self.config.project_id,
221 api_key: self.config.api_key.clone(),
222 name: self.config.data_plane_name.clone(),
223 artifact_id: self.config.initial_artifact_id,
224 metadata: serde_json::json!({}),
225 };
226
227 let register_json = match serde_json::to_string(®ister_msg) {
228 Ok(j) => j,
229 Err(e) => {
230 return ConnectOutcome::ConnectionFailed(format!(
231 "Failed to serialize register message: {}",
232 e
233 ))
234 }
235 };
236
237 if let Err(e) = sender.send(Message::Text(register_json.into())).await {
238 return ConnectOutcome::ConnectionFailed(format!(
239 "Failed to send register message: {}",
240 e
241 ));
242 }
243
244 let registration_response =
246 match tokio::time::timeout(Duration::from_secs(30), receiver.next()).await {
247 Ok(Some(Ok(msg))) => msg,
248 Ok(Some(Err(e))) => {
249 return ConnectOutcome::ConnectionFailed(format!("WebSocket error: {}", e))
250 }
251 Ok(None) => {
252 return ConnectOutcome::ConnectionFailed(
253 "Connection closed before registration".to_string(),
254 )
255 }
256 Err(_) => {
257 return ConnectOutcome::ConnectionFailed("Registration timeout".to_string())
258 }
259 };
260
261 let heartbeat_interval_secs = match registration_response {
262 Message::Text(text) => {
263 let msg: ControlPlaneMessage = match serde_json::from_str(&text) {
264 Ok(m) => m,
265 Err(e) => {
266 return ConnectOutcome::ConnectionFailed(format!(
267 "Failed to parse registration response: {}",
268 e
269 ))
270 }
271 };
272
273 match msg {
274 ControlPlaneMessage::Registered {
275 data_plane_id,
276 heartbeat_interval_secs,
277 } => {
278 tracing::info!(
279 data_plane_id = %data_plane_id,
280 heartbeat_interval_secs,
281 "Registered with control plane"
282 );
283 heartbeat_interval_secs
284 }
285 ControlPlaneMessage::RegistrationFailed { reason } => {
286 return ConnectOutcome::ConnectionFailed(format!(
287 "Registration failed: {}",
288 reason
289 ));
290 }
291 other => {
292 return ConnectOutcome::ConnectionFailed(format!(
293 "Unexpected registration response: {:?}",
294 other
295 ));
296 }
297 }
298 }
299 other => {
300 return ConnectOutcome::ConnectionFailed(format!(
301 "Unexpected message type: {:?}",
302 other
303 ));
304 }
305 };
306
307 let mut heartbeat_interval =
309 tokio::time::interval(Duration::from_secs(heartbeat_interval_secs as u64));
310 let start_time = std::time::Instant::now();
311
312 loop {
315 tokio::select! {
316 _ = shutdown_rx.changed() => {
318 if *shutdown_rx.borrow() {
319 tracing::info!("Disconnecting from control plane");
320 let _ = sender.close().await;
321 return ConnectOutcome::Shutdown;
322 }
323 }
324
325 _ = heartbeat_interval.tick() => {
327 let heartbeat = DataPlaneMessage::Heartbeat {
328 artifact_id: None, uptime_secs: start_time.elapsed().as_secs(),
330 requests_total: 0, };
332
333 let json = match serde_json::to_string(&heartbeat) {
334 Ok(j) => j,
335 Err(e) => {
336 tracing::error!(error = %e, "Failed to serialize heartbeat");
337 continue;
338 }
339 };
340
341 if let Err(e) = sender.send(Message::Text(json.into())).await {
342 return ConnectOutcome::ConnectionLost(format!(
343 "Failed to send heartbeat: {}", e
344 ));
345 }
346
347 tracing::debug!("Heartbeat sent");
348 }
349
350 Some(response) = response_rx.recv() => {
352 let msg = DataPlaneMessage::ArtifactDownloaded {
353 artifact_id: response.artifact_id,
354 success: response.success,
355 error: response.error,
356 };
357
358 let json = match serde_json::to_string(&msg) {
359 Ok(j) => j,
360 Err(e) => {
361 tracing::error!(error = %e, "Failed to serialize artifact downloaded");
362 continue;
363 }
364 };
365
366 if let Err(e) = sender.send(Message::Text(json.into())).await {
367 tracing::warn!(error = %e, "Failed to send artifact downloaded response");
368 } else {
369 tracing::info!(
370 artifact_id = %response.artifact_id,
371 success = response.success,
372 "Sent artifact downloaded response to control plane"
373 );
374 }
375 }
376
377 result = receiver.next() => {
379 match result {
380 Some(Ok(Message::Text(text))) => {
381 match serde_json::from_str::<ControlPlaneMessage>(&text) {
382 Ok(msg) => {
383 if let Err(e) = self.handle_message(msg, artifact_tx, &mut sender).await {
384 tracing::warn!(error = %e, "Error handling control plane message");
385 }
386 }
387 Err(e) => {
388 tracing::warn!(error = %e, "Failed to parse control plane message");
389 }
390 }
391 }
392 Some(Ok(Message::Ping(data))) => {
393 let _ = sender.send(Message::Pong(data)).await;
394 }
395 Some(Ok(Message::Close(_))) | None => {
396 return ConnectOutcome::ConnectionLost(
397 "Connection closed by control plane".to_string()
398 );
399 }
400 Some(Err(e)) => {
401 return ConnectOutcome::ConnectionLost(format!(
402 "WebSocket error: {}", e
403 ));
404 }
405 _ => {}
406 }
407 }
408 }
409 }
410 }
411
412 async fn handle_message(
414 &self,
415 msg: ControlPlaneMessage,
416 artifact_tx: &mpsc::Sender<ArtifactNotification>,
417 _sender: &mut futures_util::stream::SplitSink<
418 tokio_tungstenite::WebSocketStream<
419 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
420 >,
421 Message,
422 >,
423 ) -> Result<(), String> {
424 match msg {
425 ControlPlaneMessage::HeartbeatAck => {
426 tracing::debug!("Heartbeat acknowledged");
427 }
428 ControlPlaneMessage::ArtifactAvailable {
429 artifact_id,
430 download_url,
431 sha256,
432 } => {
433 tracing::info!(
434 artifact_id = %artifact_id,
435 download_url = %download_url,
436 "New artifact available"
437 );
438
439 if let Err(e) = artifact_tx
441 .send(ArtifactNotification {
442 artifact_id,
443 download_url,
444 sha256,
445 })
446 .await
447 {
448 tracing::warn!(error = %e, "Failed to send artifact notification");
449 }
450 }
451 ControlPlaneMessage::Disconnect { reason } => {
452 tracing::info!(reason = %reason, "Disconnecting at control plane request");
453 return Err(format!("Disconnected by control plane: {}", reason));
454 }
455 ControlPlaneMessage::Error { message } => {
456 tracing::warn!(message = %message, "Error from control plane");
457 }
458 ControlPlaneMessage::Registered { .. }
460 | ControlPlaneMessage::RegistrationFailed { .. } => {
461 tracing::warn!("Unexpected registration message after already registered");
462 }
463 }
464
465 Ok(())
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_data_plane_message_register_serialization() {
475 let msg = DataPlaneMessage::Register {
476 project_id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
477 api_key: "test-key".to_string(),
478 name: Some("my-data-plane".to_string()),
479 artifact_id: None,
480 metadata: serde_json::json!({"version": "1.0"}),
481 };
482
483 let json = serde_json::to_string(&msg).unwrap();
484 assert!(json.contains("\"type\":\"register\""));
485 assert!(json.contains("\"project_id\":"));
486 assert!(json.contains("\"api_key\":\"test-key\""));
487 assert!(json.contains("\"name\":\"my-data-plane\""));
488 }
489
490 #[test]
491 fn test_data_plane_message_heartbeat_serialization() {
492 let msg = DataPlaneMessage::Heartbeat {
493 artifact_id: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()),
494 uptime_secs: 3600,
495 requests_total: 1000,
496 };
497
498 let json = serde_json::to_string(&msg).unwrap();
499 assert!(json.contains("\"type\":\"heartbeat\""));
500 assert!(json.contains("\"uptime_secs\":3600"));
501 assert!(json.contains("\"requests_total\":1000"));
502 }
503
504 #[test]
505 fn test_data_plane_message_artifact_downloaded_success() {
506 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
507 let msg = DataPlaneMessage::ArtifactDownloaded {
508 artifact_id,
509 success: true,
510 error: None,
511 };
512
513 let json = serde_json::to_string(&msg).unwrap();
514 assert!(json.contains("\"type\":\"artifact_downloaded\""));
515 assert!(json.contains("\"success\":true"));
516 assert!(!json.contains("\"error\":")); }
518
519 #[test]
520 fn test_data_plane_message_artifact_downloaded_failure() {
521 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
522 let msg = DataPlaneMessage::ArtifactDownloaded {
523 artifact_id,
524 success: false,
525 error: Some("checksum mismatch".to_string()),
526 };
527
528 let json = serde_json::to_string(&msg).unwrap();
529 assert!(json.contains("\"type\":\"artifact_downloaded\""));
530 assert!(json.contains("\"success\":false"));
531 assert!(json.contains("\"error\":\"checksum mismatch\""));
532 }
533
534 #[test]
535 fn test_control_plane_message_registered_deserialization() {
536 let json = r#"{
537 "type": "registered",
538 "data_plane_id": "550e8400-e29b-41d4-a716-446655440000",
539 "heartbeat_interval_secs": 30
540 }"#;
541
542 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
543 match msg {
544 ControlPlaneMessage::Registered {
545 data_plane_id,
546 heartbeat_interval_secs,
547 } => {
548 assert_eq!(
549 data_plane_id,
550 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
551 );
552 assert_eq!(heartbeat_interval_secs, 30);
553 }
554 _ => panic!("Expected Registered message"),
555 }
556 }
557
558 #[test]
559 fn test_control_plane_message_artifact_available_deserialization() {
560 let json = r#"{
561 "type": "artifact_available",
562 "artifact_id": "550e8400-e29b-41d4-a716-446655440000",
563 "download_url": "http://localhost:9090/artifacts/123/download",
564 "sha256": "abc123def456"
565 }"#;
566
567 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
568 match msg {
569 ControlPlaneMessage::ArtifactAvailable {
570 artifact_id,
571 download_url,
572 sha256,
573 } => {
574 assert_eq!(
575 artifact_id,
576 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
577 );
578 assert_eq!(download_url, "http://localhost:9090/artifacts/123/download");
579 assert_eq!(sha256, "abc123def456");
580 }
581 _ => panic!("Expected ArtifactAvailable message"),
582 }
583 }
584
585 #[test]
586 fn test_control_plane_message_disconnect_deserialization() {
587 let json = r#"{
588 "type": "disconnect",
589 "reason": "server shutting down"
590 }"#;
591
592 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
593 match msg {
594 ControlPlaneMessage::Disconnect { reason } => {
595 assert_eq!(reason, "server shutting down");
596 }
597 _ => panic!("Expected Disconnect message"),
598 }
599 }
600
601 #[test]
602 fn test_artifact_downloaded_response_creation() {
603 let artifact_id = Uuid::new_v4();
604
605 let success_response = ArtifactDownloadedResponse {
606 artifact_id,
607 success: true,
608 error: None,
609 };
610 assert!(success_response.success);
611 assert!(success_response.error.is_none());
612
613 let failure_response = ArtifactDownloadedResponse {
614 artifact_id,
615 success: false,
616 error: Some("download failed".to_string()),
617 };
618 assert!(!failure_response.success);
619 assert_eq!(failure_response.error.as_deref(), Some("download failed"));
620 }
621
622 #[test]
623 fn test_artifact_notification_creation() {
624 let notification = ArtifactNotification {
625 artifact_id: Uuid::new_v4(),
626 download_url: "http://example.com/artifact.bca".to_string(),
627 sha256: "abc123".to_string(),
628 };
629
630 assert!(!notification.download_url.is_empty());
631 assert!(!notification.sha256.is_empty());
632 }
633
634 #[test]
635 fn test_control_plane_config_creation() {
636 let config = ControlPlaneConfig {
637 control_plane_url: "ws://localhost:9090/ws/data-plane".to_string(),
638 project_id: Uuid::new_v4(),
639 api_key: "test-api-key".to_string(),
640 data_plane_name: Some("test-plane".to_string()),
641 initial_artifact_id: None,
642 };
643
644 assert!(config.control_plane_url.starts_with("ws://"));
645 assert_eq!(config.api_key, "test-api-key");
646 assert_eq!(config.data_plane_name.as_deref(), Some("test-plane"));
647 }
648}