Skip to main content

mvm_core/
protocol.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3
4use crate::tenant::TenantNet;
5
6/// Default Unix domain socket path for hostd.
7pub const HOSTD_SOCKET_PATH: &str = "/run/mvm/hostd.sock";
8
9/// Maximum frame size for hostd IPC (1 MiB).
10const MAX_FRAME_SIZE: usize = 1024 * 1024;
11
12// ============================================================================
13// Request/Response types
14// ============================================================================
15
16/// Request from agentd to hostd (privileged executor).
17///
18/// Each variant maps to exactly one privileged operation. The agentd
19/// (unprivileged) decides WHAT to do; hostd (privileged) decides HOW.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum HostdRequest {
22    /// Start an existing instance (TAP, cgroup, jailer, FC launch).
23    StartInstance {
24        tenant_id: String,
25        pool_id: String,
26        instance_id: String,
27    },
28    /// Stop a running instance (kill FC, teardown cgroup, TAP).
29    StopInstance {
30        tenant_id: String,
31        pool_id: String,
32        instance_id: String,
33    },
34    /// Snapshot and suspend an instance.
35    SleepInstance {
36        tenant_id: String,
37        pool_id: String,
38        instance_id: String,
39        force: bool,
40        #[serde(default)]
41        drain_timeout_secs: Option<u64>,
42    },
43    /// Restore an instance from snapshot.
44    WakeInstance {
45        tenant_id: String,
46        pool_id: String,
47        instance_id: String,
48    },
49    /// Destroy an instance and optionally wipe volumes.
50    DestroyInstance {
51        tenant_id: String,
52        pool_id: String,
53        instance_id: String,
54        wipe_volumes: bool,
55    },
56    /// Create per-tenant bridge and NAT rules.
57    SetupNetwork { tenant_id: String, net: TenantNet },
58    /// Tear down per-tenant bridge and NAT rules.
59    TeardownNetwork { tenant_id: String, net: TenantNet },
60    /// Health check.
61    Ping,
62}
63
64/// Response from hostd to agentd.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum HostdResponse {
67    /// Operation succeeded.
68    Ok,
69    /// Error with description.
70    Error { message: String },
71    /// Pong response to Ping.
72    Pong,
73}
74
75// ============================================================================
76// Frame protocol (length-prefixed JSON over Unix socket)
77// ============================================================================
78
79/// Read a length-prefixed JSON frame from a tokio AsyncRead.
80pub async fn read_frame<R: tokio::io::AsyncReadExt + Unpin>(reader: &mut R) -> Result<Vec<u8>> {
81    let mut len_buf = [0u8; 4];
82    reader
83        .read_exact(&mut len_buf)
84        .await
85        .with_context(|| "Failed to read frame length")?;
86    let len = u32::from_be_bytes(len_buf) as usize;
87
88    if len > MAX_FRAME_SIZE {
89        anyhow::bail!("Frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
90    }
91
92    let mut buf = vec![0u8; len];
93    reader
94        .read_exact(&mut buf)
95        .await
96        .with_context(|| "Failed to read frame body")?;
97
98    Ok(buf)
99}
100
101/// Write a length-prefixed JSON frame to a tokio AsyncWrite.
102pub async fn write_frame<W: tokio::io::AsyncWriteExt + Unpin>(
103    writer: &mut W,
104    data: &[u8],
105) -> Result<()> {
106    let len = (data.len() as u32).to_be_bytes();
107    writer
108        .write_all(&len)
109        .await
110        .with_context(|| "Failed to write frame length")?;
111    writer
112        .write_all(data)
113        .await
114        .with_context(|| "Failed to write frame body")?;
115    writer
116        .flush()
117        .await
118        .with_context(|| "Failed to flush frame")?;
119    Ok(())
120}
121
122/// Serialize and send a request.
123pub async fn send_request<W: tokio::io::AsyncWriteExt + Unpin>(
124    writer: &mut W,
125    req: &HostdRequest,
126) -> Result<()> {
127    let data = serde_json::to_vec(req).with_context(|| "Failed to serialize request")?;
128    write_frame(writer, &data).await
129}
130
131/// Read and deserialize a request.
132pub async fn recv_request<R: tokio::io::AsyncReadExt + Unpin>(
133    reader: &mut R,
134) -> Result<HostdRequest> {
135    let data = read_frame(reader).await?;
136    serde_json::from_slice(&data).with_context(|| "Failed to deserialize request")
137}
138
139/// Serialize and send a response.
140pub async fn send_response<W: tokio::io::AsyncWriteExt + Unpin>(
141    writer: &mut W,
142    resp: &HostdResponse,
143) -> Result<()> {
144    let data = serde_json::to_vec(resp).with_context(|| "Failed to serialize response")?;
145    write_frame(writer, &data).await
146}
147
148/// Read and deserialize a response.
149pub async fn recv_response<R: tokio::io::AsyncReadExt + Unpin>(
150    reader: &mut R,
151) -> Result<HostdResponse> {
152    let data = read_frame(reader).await?;
153    serde_json::from_slice(&data).with_context(|| "Failed to deserialize response")
154}
155
156// ============================================================================
157// Tests
158// ============================================================================
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::tenant::TenantNet;
164
165    #[test]
166    fn test_hostd_request_start_roundtrip() {
167        let req = HostdRequest::StartInstance {
168            tenant_id: "acme".to_string(),
169            pool_id: "workers".to_string(),
170            instance_id: "i-abc123".to_string(),
171        };
172        let json = serde_json::to_string(&req).unwrap();
173        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
174        match parsed {
175            HostdRequest::StartInstance {
176                tenant_id,
177                pool_id,
178                instance_id,
179            } => {
180                assert_eq!(tenant_id, "acme");
181                assert_eq!(pool_id, "workers");
182                assert_eq!(instance_id, "i-abc123");
183            }
184            _ => panic!("Wrong variant"),
185        }
186    }
187
188    #[test]
189    fn test_hostd_request_stop_roundtrip() {
190        let req = HostdRequest::StopInstance {
191            tenant_id: "acme".to_string(),
192            pool_id: "workers".to_string(),
193            instance_id: "i-abc123".to_string(),
194        };
195        let json = serde_json::to_string(&req).unwrap();
196        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
197        assert!(matches!(parsed, HostdRequest::StopInstance { .. }));
198    }
199
200    #[test]
201    fn test_hostd_request_sleep_roundtrip() {
202        let req = HostdRequest::SleepInstance {
203            tenant_id: "acme".to_string(),
204            pool_id: "workers".to_string(),
205            instance_id: "i-abc123".to_string(),
206            force: true,
207            drain_timeout_secs: Some(30),
208        };
209        let json = serde_json::to_string(&req).unwrap();
210        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
211        match parsed {
212            HostdRequest::SleepInstance {
213                force,
214                drain_timeout_secs,
215                ..
216            } => {
217                assert!(force);
218                assert_eq!(drain_timeout_secs, Some(30));
219            }
220            _ => panic!("Wrong variant"),
221        }
222    }
223
224    #[test]
225    fn test_hostd_request_wake_roundtrip() {
226        let req = HostdRequest::WakeInstance {
227            tenant_id: "acme".to_string(),
228            pool_id: "workers".to_string(),
229            instance_id: "i-abc123".to_string(),
230        };
231        let json = serde_json::to_string(&req).unwrap();
232        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
233        assert!(matches!(parsed, HostdRequest::WakeInstance { .. }));
234    }
235
236    #[test]
237    fn test_hostd_request_destroy_roundtrip() {
238        let req = HostdRequest::DestroyInstance {
239            tenant_id: "acme".to_string(),
240            pool_id: "workers".to_string(),
241            instance_id: "i-abc123".to_string(),
242            wipe_volumes: true,
243        };
244        let json = serde_json::to_string(&req).unwrap();
245        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
246        match parsed {
247            HostdRequest::DestroyInstance { wipe_volumes, .. } => assert!(wipe_volumes),
248            _ => panic!("Wrong variant"),
249        }
250    }
251
252    #[test]
253    fn test_hostd_request_setup_network_roundtrip() {
254        let net = TenantNet::new(3, "10.240.3.0/24", "10.240.3.1");
255        let req = HostdRequest::SetupNetwork {
256            tenant_id: "acme".to_string(),
257            net: net.clone(),
258        };
259        let json = serde_json::to_string(&req).unwrap();
260        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
261        match parsed {
262            HostdRequest::SetupNetwork { tenant_id, net: n } => {
263                assert_eq!(tenant_id, "acme");
264                assert_eq!(n.tenant_net_id, 3);
265                assert_eq!(n.ipv4_subnet, "10.240.3.0/24");
266            }
267            _ => panic!("Wrong variant"),
268        }
269    }
270
271    #[test]
272    fn test_hostd_request_teardown_network_roundtrip() {
273        let net = TenantNet::new(3, "10.240.3.0/24", "10.240.3.1");
274        let req = HostdRequest::TeardownNetwork {
275            tenant_id: "acme".to_string(),
276            net,
277        };
278        let json = serde_json::to_string(&req).unwrap();
279        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
280        assert!(matches!(parsed, HostdRequest::TeardownNetwork { .. }));
281    }
282
283    #[test]
284    fn test_hostd_request_ping_roundtrip() {
285        let req = HostdRequest::Ping;
286        let json = serde_json::to_string(&req).unwrap();
287        let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
288        assert!(matches!(parsed, HostdRequest::Ping));
289    }
290
291    #[test]
292    fn test_hostd_response_ok_roundtrip() {
293        let resp = HostdResponse::Ok;
294        let json = serde_json::to_string(&resp).unwrap();
295        let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
296        assert!(matches!(parsed, HostdResponse::Ok));
297    }
298
299    #[test]
300    fn test_hostd_response_error_roundtrip() {
301        let resp = HostdResponse::Error {
302            message: "instance not found".to_string(),
303        };
304        let json = serde_json::to_string(&resp).unwrap();
305        let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
306        match parsed {
307            HostdResponse::Error { message } => assert_eq!(message, "instance not found"),
308            _ => panic!("Wrong variant"),
309        }
310    }
311
312    #[test]
313    fn test_hostd_response_pong_roundtrip() {
314        let resp = HostdResponse::Pong;
315        let json = serde_json::to_string(&resp).unwrap();
316        let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
317        assert!(matches!(parsed, HostdResponse::Pong));
318    }
319
320    #[test]
321    fn test_all_request_variants_serialize() {
322        let net = TenantNet::new(1, "10.240.1.0/24", "10.240.1.1");
323        let variants: Vec<HostdRequest> = vec![
324            HostdRequest::StartInstance {
325                tenant_id: "t".to_string(),
326                pool_id: "p".to_string(),
327                instance_id: "i".to_string(),
328            },
329            HostdRequest::StopInstance {
330                tenant_id: "t".to_string(),
331                pool_id: "p".to_string(),
332                instance_id: "i".to_string(),
333            },
334            HostdRequest::SleepInstance {
335                tenant_id: "t".to_string(),
336                pool_id: "p".to_string(),
337                instance_id: "i".to_string(),
338                force: false,
339                drain_timeout_secs: None,
340            },
341            HostdRequest::WakeInstance {
342                tenant_id: "t".to_string(),
343                pool_id: "p".to_string(),
344                instance_id: "i".to_string(),
345            },
346            HostdRequest::DestroyInstance {
347                tenant_id: "t".to_string(),
348                pool_id: "p".to_string(),
349                instance_id: "i".to_string(),
350                wipe_volumes: false,
351            },
352            HostdRequest::SetupNetwork {
353                tenant_id: "t".to_string(),
354                net: net.clone(),
355            },
356            HostdRequest::TeardownNetwork {
357                tenant_id: "t".to_string(),
358                net,
359            },
360            HostdRequest::Ping,
361        ];
362
363        for req in &variants {
364            let json = serde_json::to_string(req).unwrap();
365            let _: HostdRequest = serde_json::from_str(&json).unwrap();
366        }
367    }
368
369    #[test]
370    fn test_all_response_variants_serialize() {
371        let variants: Vec<HostdResponse> = vec![
372            HostdResponse::Ok,
373            HostdResponse::Error {
374                message: "err".to_string(),
375            },
376            HostdResponse::Pong,
377        ];
378
379        for resp in &variants {
380            let json = serde_json::to_string(resp).unwrap();
381            let _: HostdResponse = serde_json::from_str(&json).unwrap();
382        }
383    }
384
385    #[test]
386    fn test_socket_path_constant() {
387        assert_eq!(HOSTD_SOCKET_PATH, "/run/mvm/hostd.sock");
388    }
389
390    #[tokio::test]
391    async fn test_frame_roundtrip() {
392        let data = b"hello hostd";
393        let mut buf = Vec::new();
394        write_frame(&mut buf, data).await.unwrap();
395
396        let mut cursor = std::io::Cursor::new(buf);
397        let read_back = read_frame(&mut cursor).await.unwrap();
398        assert_eq!(read_back, data);
399    }
400
401    #[tokio::test]
402    async fn test_request_send_recv_roundtrip() {
403        let req = HostdRequest::Ping;
404        let mut buf = Vec::new();
405        send_request(&mut buf, &req).await.unwrap();
406
407        let mut cursor = std::io::Cursor::new(buf);
408        let parsed = recv_request(&mut cursor).await.unwrap();
409        assert!(matches!(parsed, HostdRequest::Ping));
410    }
411
412    #[tokio::test]
413    async fn test_response_send_recv_roundtrip() {
414        let resp = HostdResponse::Ok;
415        let mut buf = Vec::new();
416        send_response(&mut buf, &resp).await.unwrap();
417
418        let mut cursor = std::io::Cursor::new(buf);
419        let parsed = recv_response(&mut cursor).await.unwrap();
420        assert!(matches!(parsed, HostdResponse::Ok));
421    }
422}