1use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tokio::sync::{mpsc, watch};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type", rename_all = "snake_case")]
16pub enum DataPlaneMessage {
17 Register {
19 project_id: Uuid,
20 api_key: String,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 name: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 artifact_id: Option<Uuid>,
25 #[serde(default)]
26 metadata: serde_json::Value,
27 },
28 Heartbeat {
30 #[serde(skip_serializing_if = "Option::is_none")]
31 artifact_id: Option<Uuid>,
32 uptime_secs: u64,
33 requests_total: u64,
34 },
35 ArtifactDownloaded {
37 artifact_id: Uuid,
38 success: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 error: Option<String>,
41 },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(tag = "type", rename_all = "snake_case")]
47pub enum ControlPlaneMessage {
48 Registered {
50 data_plane_id: Uuid,
51 heartbeat_interval_secs: u32,
52 },
53 RegistrationFailed { reason: String },
55 ArtifactAvailable {
57 artifact_id: Uuid,
58 download_url: String,
59 sha256: String,
60 },
61 HeartbeatAck,
63 Disconnect { reason: String },
65 Error { message: String },
67}
68
69#[derive(Clone)]
71pub struct ControlPlaneConfig {
72 pub control_plane_url: String,
73 pub project_id: Uuid,
74 pub api_key: String,
75 pub data_plane_name: Option<String>,
76}
77
78#[derive(Debug, Clone)]
80pub struct ArtifactNotification {
81 pub artifact_id: Uuid,
82 pub download_url: String,
83 pub sha256: String,
84}
85
86#[derive(Debug, Clone)]
88pub struct ArtifactDownloadedResponse {
89 pub artifact_id: Uuid,
90 pub success: bool,
91 pub error: Option<String>,
92}
93
94pub struct ControlPlaneClient {
96 config: ControlPlaneConfig,
97}
98
99impl ControlPlaneClient {
100 pub fn new(config: ControlPlaneConfig) -> Self {
102 Self { config }
103 }
104
105 pub fn start(
108 self,
109 shutdown_rx: watch::Receiver<bool>,
110 ) -> (
111 mpsc::Receiver<ArtifactNotification>,
112 mpsc::Sender<ArtifactDownloadedResponse>,
113 ) {
114 let (artifact_tx, artifact_rx) = mpsc::channel::<ArtifactNotification>(16);
115 let (response_tx, response_rx) = mpsc::channel::<ArtifactDownloadedResponse>(16);
116
117 tokio::spawn(async move {
118 self.connection_loop(shutdown_rx, artifact_tx, response_rx)
119 .await;
120 });
121
122 (artifact_rx, response_tx)
123 }
124
125 async fn connection_loop(
127 &self,
128 mut shutdown_rx: watch::Receiver<bool>,
129 artifact_tx: mpsc::Sender<ArtifactNotification>,
130 mut response_rx: mpsc::Receiver<ArtifactDownloadedResponse>,
131 ) {
132 const INITIAL_BACKOFF_MS: u64 = 1000;
133 const MAX_BACKOFF_MS: u64 = 60000;
134 const BACKOFF_MULTIPLIER: f64 = 2.0;
135
136 let mut backoff_ms = INITIAL_BACKOFF_MS;
137
138 loop {
139 if *shutdown_rx.borrow() {
141 tracing::info!("Control plane client shutting down");
142 return;
143 }
144
145 tracing::info!(url = %self.config.control_plane_url, "Connecting to control plane");
146
147 match self
148 .try_connect(&mut shutdown_rx, &artifact_tx, &mut response_rx)
149 .await
150 {
151 Ok(()) => {
152 return;
154 }
155 Err(e) => {
156 tracing::warn!(
157 error = %e,
158 backoff_ms = backoff_ms,
159 "Control plane connection failed, will retry"
160 );
161 }
162 }
163
164 tokio::select! {
166 _ = shutdown_rx.changed() => {
167 if *shutdown_rx.borrow() {
168 return;
169 }
170 }
171 _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
172 }
173
174 backoff_ms =
176 ((backoff_ms as f64) * BACKOFF_MULTIPLIER).min(MAX_BACKOFF_MS as f64) as u64;
177 }
178 }
179
180 async fn try_connect(
182 &self,
183 shutdown_rx: &mut watch::Receiver<bool>,
184 artifact_tx: &mpsc::Sender<ArtifactNotification>,
185 response_rx: &mut mpsc::Receiver<ArtifactDownloadedResponse>,
186 ) -> Result<(), String> {
187 let (ws_stream, _response) = connect_async(&self.config.control_plane_url)
189 .await
190 .map_err(|e| format!("WebSocket connection failed: {}", e))?;
191
192 let (mut sender, mut receiver) = ws_stream.split();
193
194 let register_msg = DataPlaneMessage::Register {
196 project_id: self.config.project_id,
197 api_key: self.config.api_key.clone(),
198 name: self.config.data_plane_name.clone(),
199 artifact_id: None, metadata: serde_json::json!({}),
201 };
202
203 let register_json = serde_json::to_string(®ister_msg)
204 .map_err(|e| format!("Failed to serialize register message: {}", e))?;
205
206 sender
207 .send(Message::Text(register_json.into()))
208 .await
209 .map_err(|e| format!("Failed to send register message: {}", e))?;
210
211 let registration_response = tokio::time::timeout(Duration::from_secs(30), receiver.next())
213 .await
214 .map_err(|_| "Registration timeout")?
215 .ok_or("Connection closed before registration")?
216 .map_err(|e| format!("WebSocket error: {}", e))?;
217
218 let heartbeat_interval_secs = match registration_response {
219 Message::Text(text) => {
220 let msg: ControlPlaneMessage = serde_json::from_str(&text)
221 .map_err(|e| format!("Failed to parse registration response: {}", e))?;
222
223 match msg {
224 ControlPlaneMessage::Registered {
225 data_plane_id,
226 heartbeat_interval_secs,
227 } => {
228 tracing::info!(
229 data_plane_id = %data_plane_id,
230 heartbeat_interval_secs,
231 "Registered with control plane"
232 );
233 heartbeat_interval_secs
234 }
235 ControlPlaneMessage::RegistrationFailed { reason } => {
236 return Err(format!("Registration failed: {}", reason));
237 }
238 other => {
239 return Err(format!("Unexpected registration response: {:?}", other));
240 }
241 }
242 }
243 other => {
244 return Err(format!("Unexpected message type: {:?}", other));
245 }
246 };
247
248 let mut heartbeat_interval =
250 tokio::time::interval(Duration::from_secs(heartbeat_interval_secs as u64));
251 let start_time = std::time::Instant::now();
252
253 loop {
255 tokio::select! {
256 _ = shutdown_rx.changed() => {
258 if *shutdown_rx.borrow() {
259 tracing::info!("Disconnecting from control plane");
260 let _ = sender.close().await;
261 return Ok(());
262 }
263 }
264
265 _ = heartbeat_interval.tick() => {
267 let heartbeat = DataPlaneMessage::Heartbeat {
268 artifact_id: None, uptime_secs: start_time.elapsed().as_secs(),
270 requests_total: 0, };
272
273 let json = serde_json::to_string(&heartbeat)
274 .map_err(|e| format!("Failed to serialize heartbeat: {}", e))?;
275
276 if let Err(e) = sender.send(Message::Text(json.into())).await {
277 return Err(format!("Failed to send heartbeat: {}", e));
278 }
279
280 tracing::debug!("Heartbeat sent");
281 }
282
283 Some(response) = response_rx.recv() => {
285 let msg = DataPlaneMessage::ArtifactDownloaded {
286 artifact_id: response.artifact_id,
287 success: response.success,
288 error: response.error,
289 };
290
291 let json = serde_json::to_string(&msg)
292 .map_err(|e| format!("Failed to serialize artifact downloaded: {}", e))?;
293
294 if let Err(e) = sender.send(Message::Text(json.into())).await {
295 tracing::warn!(error = %e, "Failed to send artifact downloaded response");
296 } else {
297 tracing::info!(
298 artifact_id = %response.artifact_id,
299 success = response.success,
300 "Sent artifact downloaded response to control plane"
301 );
302 }
303 }
304
305 result = receiver.next() => {
307 match result {
308 Some(Ok(Message::Text(text))) => {
309 match serde_json::from_str::<ControlPlaneMessage>(&text) {
310 Ok(msg) => {
311 if let Err(e) = self.handle_message(msg, artifact_tx, &mut sender).await {
312 tracing::warn!(error = %e, "Error handling control plane message");
313 }
314 }
315 Err(e) => {
316 tracing::warn!(error = %e, "Failed to parse control plane message");
317 }
318 }
319 }
320 Some(Ok(Message::Ping(data))) => {
321 let _ = sender.send(Message::Pong(data)).await;
322 }
323 Some(Ok(Message::Close(_))) | None => {
324 return Err("Connection closed by control plane".to_string());
325 }
326 Some(Err(e)) => {
327 return Err(format!("WebSocket error: {}", e));
328 }
329 _ => {}
330 }
331 }
332 }
333 }
334 }
335
336 async fn handle_message(
338 &self,
339 msg: ControlPlaneMessage,
340 artifact_tx: &mpsc::Sender<ArtifactNotification>,
341 _sender: &mut futures_util::stream::SplitSink<
342 tokio_tungstenite::WebSocketStream<
343 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
344 >,
345 Message,
346 >,
347 ) -> Result<(), String> {
348 match msg {
349 ControlPlaneMessage::HeartbeatAck => {
350 tracing::debug!("Heartbeat acknowledged");
351 }
352 ControlPlaneMessage::ArtifactAvailable {
353 artifact_id,
354 download_url,
355 sha256,
356 } => {
357 tracing::info!(
358 artifact_id = %artifact_id,
359 download_url = %download_url,
360 "New artifact available"
361 );
362
363 if let Err(e) = artifact_tx
365 .send(ArtifactNotification {
366 artifact_id,
367 download_url,
368 sha256,
369 })
370 .await
371 {
372 tracing::warn!(error = %e, "Failed to send artifact notification");
373 }
374 }
375 ControlPlaneMessage::Disconnect { reason } => {
376 tracing::info!(reason = %reason, "Disconnecting at control plane request");
377 return Err(format!("Disconnected by control plane: {}", reason));
378 }
379 ControlPlaneMessage::Error { message } => {
380 tracing::warn!(message = %message, "Error from control plane");
381 }
382 ControlPlaneMessage::Registered { .. }
384 | ControlPlaneMessage::RegistrationFailed { .. } => {
385 tracing::warn!("Unexpected registration message after already registered");
386 }
387 }
388
389 Ok(())
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_data_plane_message_register_serialization() {
399 let msg = DataPlaneMessage::Register {
400 project_id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
401 api_key: "test-key".to_string(),
402 name: Some("my-data-plane".to_string()),
403 artifact_id: None,
404 metadata: serde_json::json!({"version": "1.0"}),
405 };
406
407 let json = serde_json::to_string(&msg).unwrap();
408 assert!(json.contains("\"type\":\"register\""));
409 assert!(json.contains("\"project_id\":"));
410 assert!(json.contains("\"api_key\":\"test-key\""));
411 assert!(json.contains("\"name\":\"my-data-plane\""));
412 }
413
414 #[test]
415 fn test_data_plane_message_heartbeat_serialization() {
416 let msg = DataPlaneMessage::Heartbeat {
417 artifact_id: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()),
418 uptime_secs: 3600,
419 requests_total: 1000,
420 };
421
422 let json = serde_json::to_string(&msg).unwrap();
423 assert!(json.contains("\"type\":\"heartbeat\""));
424 assert!(json.contains("\"uptime_secs\":3600"));
425 assert!(json.contains("\"requests_total\":1000"));
426 }
427
428 #[test]
429 fn test_data_plane_message_artifact_downloaded_success() {
430 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
431 let msg = DataPlaneMessage::ArtifactDownloaded {
432 artifact_id,
433 success: true,
434 error: None,
435 };
436
437 let json = serde_json::to_string(&msg).unwrap();
438 assert!(json.contains("\"type\":\"artifact_downloaded\""));
439 assert!(json.contains("\"success\":true"));
440 assert!(!json.contains("\"error\":")); }
442
443 #[test]
444 fn test_data_plane_message_artifact_downloaded_failure() {
445 let artifact_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
446 let msg = DataPlaneMessage::ArtifactDownloaded {
447 artifact_id,
448 success: false,
449 error: Some("checksum mismatch".to_string()),
450 };
451
452 let json = serde_json::to_string(&msg).unwrap();
453 assert!(json.contains("\"type\":\"artifact_downloaded\""));
454 assert!(json.contains("\"success\":false"));
455 assert!(json.contains("\"error\":\"checksum mismatch\""));
456 }
457
458 #[test]
459 fn test_control_plane_message_registered_deserialization() {
460 let json = r#"{
461 "type": "registered",
462 "data_plane_id": "550e8400-e29b-41d4-a716-446655440000",
463 "heartbeat_interval_secs": 30
464 }"#;
465
466 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
467 match msg {
468 ControlPlaneMessage::Registered {
469 data_plane_id,
470 heartbeat_interval_secs,
471 } => {
472 assert_eq!(
473 data_plane_id,
474 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
475 );
476 assert_eq!(heartbeat_interval_secs, 30);
477 }
478 _ => panic!("Expected Registered message"),
479 }
480 }
481
482 #[test]
483 fn test_control_plane_message_artifact_available_deserialization() {
484 let json = r#"{
485 "type": "artifact_available",
486 "artifact_id": "550e8400-e29b-41d4-a716-446655440000",
487 "download_url": "http://localhost:9090/artifacts/123/download",
488 "sha256": "abc123def456"
489 }"#;
490
491 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
492 match msg {
493 ControlPlaneMessage::ArtifactAvailable {
494 artifact_id,
495 download_url,
496 sha256,
497 } => {
498 assert_eq!(
499 artifact_id,
500 Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()
501 );
502 assert_eq!(download_url, "http://localhost:9090/artifacts/123/download");
503 assert_eq!(sha256, "abc123def456");
504 }
505 _ => panic!("Expected ArtifactAvailable message"),
506 }
507 }
508
509 #[test]
510 fn test_control_plane_message_disconnect_deserialization() {
511 let json = r#"{
512 "type": "disconnect",
513 "reason": "server shutting down"
514 }"#;
515
516 let msg: ControlPlaneMessage = serde_json::from_str(json).unwrap();
517 match msg {
518 ControlPlaneMessage::Disconnect { reason } => {
519 assert_eq!(reason, "server shutting down");
520 }
521 _ => panic!("Expected Disconnect message"),
522 }
523 }
524
525 #[test]
526 fn test_artifact_downloaded_response_creation() {
527 let artifact_id = Uuid::new_v4();
528
529 let success_response = ArtifactDownloadedResponse {
530 artifact_id,
531 success: true,
532 error: None,
533 };
534 assert!(success_response.success);
535 assert!(success_response.error.is_none());
536
537 let failure_response = ArtifactDownloadedResponse {
538 artifact_id,
539 success: false,
540 error: Some("download failed".to_string()),
541 };
542 assert!(!failure_response.success);
543 assert_eq!(failure_response.error.as_deref(), Some("download failed"));
544 }
545
546 #[test]
547 fn test_artifact_notification_creation() {
548 let notification = ArtifactNotification {
549 artifact_id: Uuid::new_v4(),
550 download_url: "http://example.com/artifact.bca".to_string(),
551 sha256: "abc123".to_string(),
552 };
553
554 assert!(!notification.download_url.is_empty());
555 assert!(!notification.sha256.is_empty());
556 }
557
558 #[test]
559 fn test_control_plane_config_creation() {
560 let config = ControlPlaneConfig {
561 control_plane_url: "ws://localhost:9090/ws/data-plane".to_string(),
562 project_id: Uuid::new_v4(),
563 api_key: "test-api-key".to_string(),
564 data_plane_name: Some("test-plane".to_string()),
565 };
566
567 assert!(config.control_plane_url.starts_with("ws://"));
568 assert_eq!(config.api_key, "test-api-key");
569 assert_eq!(config.data_plane_name.as_deref(), Some("test-plane"));
570 }
571}