1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3
4use crate::tenant::TenantNet;
5
6pub const HOSTD_SOCKET_PATH: &str = "/run/mvm/hostd.sock";
8
9const MAX_FRAME_SIZE: usize = 1024 * 1024;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum HostdRequest {
22 StartInstance {
24 tenant_id: String,
25 pool_id: String,
26 instance_id: String,
27 },
28 StopInstance {
30 tenant_id: String,
31 pool_id: String,
32 instance_id: String,
33 },
34 SleepInstance {
36 tenant_id: String,
37 pool_id: String,
38 instance_id: String,
39 force: bool,
40 #[serde(default)]
41 drain_timeout_secs: Option<u64>,
42 },
43 WakeInstance {
45 tenant_id: String,
46 pool_id: String,
47 instance_id: String,
48 },
49 DestroyInstance {
51 tenant_id: String,
52 pool_id: String,
53 instance_id: String,
54 wipe_volumes: bool,
55 },
56 SetupNetwork { tenant_id: String, net: TenantNet },
58 TeardownNetwork { tenant_id: String, net: TenantNet },
60 Ping,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum HostdResponse {
67 Ok,
69 Error { message: String },
71 Pong,
73}
74
75pub async fn read_frame<R: tokio::io::AsyncReadExt + Unpin>(reader: &mut R) -> Result<Vec<u8>> {
81 let mut len_buf = [0u8; 4];
82 reader
83 .read_exact(&mut len_buf)
84 .await
85 .with_context(|| "Failed to read frame length")?;
86 let len = u32::from_be_bytes(len_buf) as usize;
87
88 if len > MAX_FRAME_SIZE {
89 anyhow::bail!("Frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
90 }
91
92 let mut buf = vec![0u8; len];
93 reader
94 .read_exact(&mut buf)
95 .await
96 .with_context(|| "Failed to read frame body")?;
97
98 Ok(buf)
99}
100
101pub async fn write_frame<W: tokio::io::AsyncWriteExt + Unpin>(
103 writer: &mut W,
104 data: &[u8],
105) -> Result<()> {
106 let len = (data.len() as u32).to_be_bytes();
107 writer
108 .write_all(&len)
109 .await
110 .with_context(|| "Failed to write frame length")?;
111 writer
112 .write_all(data)
113 .await
114 .with_context(|| "Failed to write frame body")?;
115 writer
116 .flush()
117 .await
118 .with_context(|| "Failed to flush frame")?;
119 Ok(())
120}
121
122pub async fn send_request<W: tokio::io::AsyncWriteExt + Unpin>(
124 writer: &mut W,
125 req: &HostdRequest,
126) -> Result<()> {
127 let data = serde_json::to_vec(req).with_context(|| "Failed to serialize request")?;
128 write_frame(writer, &data).await
129}
130
131pub async fn recv_request<R: tokio::io::AsyncReadExt + Unpin>(
133 reader: &mut R,
134) -> Result<HostdRequest> {
135 let data = read_frame(reader).await?;
136 serde_json::from_slice(&data).with_context(|| "Failed to deserialize request")
137}
138
139pub async fn send_response<W: tokio::io::AsyncWriteExt + Unpin>(
141 writer: &mut W,
142 resp: &HostdResponse,
143) -> Result<()> {
144 let data = serde_json::to_vec(resp).with_context(|| "Failed to serialize response")?;
145 write_frame(writer, &data).await
146}
147
148pub async fn recv_response<R: tokio::io::AsyncReadExt + Unpin>(
150 reader: &mut R,
151) -> Result<HostdResponse> {
152 let data = read_frame(reader).await?;
153 serde_json::from_slice(&data).with_context(|| "Failed to deserialize response")
154}
155
156#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::tenant::TenantNet;
164
165 #[test]
166 fn test_hostd_request_start_roundtrip() {
167 let req = HostdRequest::StartInstance {
168 tenant_id: "acme".to_string(),
169 pool_id: "workers".to_string(),
170 instance_id: "i-abc123".to_string(),
171 };
172 let json = serde_json::to_string(&req).unwrap();
173 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
174 match parsed {
175 HostdRequest::StartInstance {
176 tenant_id,
177 pool_id,
178 instance_id,
179 } => {
180 assert_eq!(tenant_id, "acme");
181 assert_eq!(pool_id, "workers");
182 assert_eq!(instance_id, "i-abc123");
183 }
184 _ => panic!("Wrong variant"),
185 }
186 }
187
188 #[test]
189 fn test_hostd_request_stop_roundtrip() {
190 let req = HostdRequest::StopInstance {
191 tenant_id: "acme".to_string(),
192 pool_id: "workers".to_string(),
193 instance_id: "i-abc123".to_string(),
194 };
195 let json = serde_json::to_string(&req).unwrap();
196 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
197 assert!(matches!(parsed, HostdRequest::StopInstance { .. }));
198 }
199
200 #[test]
201 fn test_hostd_request_sleep_roundtrip() {
202 let req = HostdRequest::SleepInstance {
203 tenant_id: "acme".to_string(),
204 pool_id: "workers".to_string(),
205 instance_id: "i-abc123".to_string(),
206 force: true,
207 drain_timeout_secs: Some(30),
208 };
209 let json = serde_json::to_string(&req).unwrap();
210 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
211 match parsed {
212 HostdRequest::SleepInstance {
213 force,
214 drain_timeout_secs,
215 ..
216 } => {
217 assert!(force);
218 assert_eq!(drain_timeout_secs, Some(30));
219 }
220 _ => panic!("Wrong variant"),
221 }
222 }
223
224 #[test]
225 fn test_hostd_request_wake_roundtrip() {
226 let req = HostdRequest::WakeInstance {
227 tenant_id: "acme".to_string(),
228 pool_id: "workers".to_string(),
229 instance_id: "i-abc123".to_string(),
230 };
231 let json = serde_json::to_string(&req).unwrap();
232 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
233 assert!(matches!(parsed, HostdRequest::WakeInstance { .. }));
234 }
235
236 #[test]
237 fn test_hostd_request_destroy_roundtrip() {
238 let req = HostdRequest::DestroyInstance {
239 tenant_id: "acme".to_string(),
240 pool_id: "workers".to_string(),
241 instance_id: "i-abc123".to_string(),
242 wipe_volumes: true,
243 };
244 let json = serde_json::to_string(&req).unwrap();
245 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
246 match parsed {
247 HostdRequest::DestroyInstance { wipe_volumes, .. } => assert!(wipe_volumes),
248 _ => panic!("Wrong variant"),
249 }
250 }
251
252 #[test]
253 fn test_hostd_request_setup_network_roundtrip() {
254 let net = TenantNet::new(3, "10.240.3.0/24", "10.240.3.1");
255 let req = HostdRequest::SetupNetwork {
256 tenant_id: "acme".to_string(),
257 net: net.clone(),
258 };
259 let json = serde_json::to_string(&req).unwrap();
260 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
261 match parsed {
262 HostdRequest::SetupNetwork { tenant_id, net: n } => {
263 assert_eq!(tenant_id, "acme");
264 assert_eq!(n.tenant_net_id, 3);
265 assert_eq!(n.ipv4_subnet, "10.240.3.0/24");
266 }
267 _ => panic!("Wrong variant"),
268 }
269 }
270
271 #[test]
272 fn test_hostd_request_teardown_network_roundtrip() {
273 let net = TenantNet::new(3, "10.240.3.0/24", "10.240.3.1");
274 let req = HostdRequest::TeardownNetwork {
275 tenant_id: "acme".to_string(),
276 net,
277 };
278 let json = serde_json::to_string(&req).unwrap();
279 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
280 assert!(matches!(parsed, HostdRequest::TeardownNetwork { .. }));
281 }
282
283 #[test]
284 fn test_hostd_request_ping_roundtrip() {
285 let req = HostdRequest::Ping;
286 let json = serde_json::to_string(&req).unwrap();
287 let parsed: HostdRequest = serde_json::from_str(&json).unwrap();
288 assert!(matches!(parsed, HostdRequest::Ping));
289 }
290
291 #[test]
292 fn test_hostd_response_ok_roundtrip() {
293 let resp = HostdResponse::Ok;
294 let json = serde_json::to_string(&resp).unwrap();
295 let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
296 assert!(matches!(parsed, HostdResponse::Ok));
297 }
298
299 #[test]
300 fn test_hostd_response_error_roundtrip() {
301 let resp = HostdResponse::Error {
302 message: "instance not found".to_string(),
303 };
304 let json = serde_json::to_string(&resp).unwrap();
305 let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
306 match parsed {
307 HostdResponse::Error { message } => assert_eq!(message, "instance not found"),
308 _ => panic!("Wrong variant"),
309 }
310 }
311
312 #[test]
313 fn test_hostd_response_pong_roundtrip() {
314 let resp = HostdResponse::Pong;
315 let json = serde_json::to_string(&resp).unwrap();
316 let parsed: HostdResponse = serde_json::from_str(&json).unwrap();
317 assert!(matches!(parsed, HostdResponse::Pong));
318 }
319
320 #[test]
321 fn test_all_request_variants_serialize() {
322 let net = TenantNet::new(1, "10.240.1.0/24", "10.240.1.1");
323 let variants: Vec<HostdRequest> = vec![
324 HostdRequest::StartInstance {
325 tenant_id: "t".to_string(),
326 pool_id: "p".to_string(),
327 instance_id: "i".to_string(),
328 },
329 HostdRequest::StopInstance {
330 tenant_id: "t".to_string(),
331 pool_id: "p".to_string(),
332 instance_id: "i".to_string(),
333 },
334 HostdRequest::SleepInstance {
335 tenant_id: "t".to_string(),
336 pool_id: "p".to_string(),
337 instance_id: "i".to_string(),
338 force: false,
339 drain_timeout_secs: None,
340 },
341 HostdRequest::WakeInstance {
342 tenant_id: "t".to_string(),
343 pool_id: "p".to_string(),
344 instance_id: "i".to_string(),
345 },
346 HostdRequest::DestroyInstance {
347 tenant_id: "t".to_string(),
348 pool_id: "p".to_string(),
349 instance_id: "i".to_string(),
350 wipe_volumes: false,
351 },
352 HostdRequest::SetupNetwork {
353 tenant_id: "t".to_string(),
354 net: net.clone(),
355 },
356 HostdRequest::TeardownNetwork {
357 tenant_id: "t".to_string(),
358 net,
359 },
360 HostdRequest::Ping,
361 ];
362
363 for req in &variants {
364 let json = serde_json::to_string(req).unwrap();
365 let _: HostdRequest = serde_json::from_str(&json).unwrap();
366 }
367 }
368
369 #[test]
370 fn test_all_response_variants_serialize() {
371 let variants: Vec<HostdResponse> = vec![
372 HostdResponse::Ok,
373 HostdResponse::Error {
374 message: "err".to_string(),
375 },
376 HostdResponse::Pong,
377 ];
378
379 for resp in &variants {
380 let json = serde_json::to_string(resp).unwrap();
381 let _: HostdResponse = serde_json::from_str(&json).unwrap();
382 }
383 }
384
385 #[test]
386 fn test_socket_path_constant() {
387 assert_eq!(HOSTD_SOCKET_PATH, "/run/mvm/hostd.sock");
388 }
389
390 #[tokio::test]
391 async fn test_frame_roundtrip() {
392 let data = b"hello hostd";
393 let mut buf = Vec::new();
394 write_frame(&mut buf, data).await.unwrap();
395
396 let mut cursor = std::io::Cursor::new(buf);
397 let read_back = read_frame(&mut cursor).await.unwrap();
398 assert_eq!(read_back, data);
399 }
400
401 #[tokio::test]
402 async fn test_request_send_recv_roundtrip() {
403 let req = HostdRequest::Ping;
404 let mut buf = Vec::new();
405 send_request(&mut buf, &req).await.unwrap();
406
407 let mut cursor = std::io::Cursor::new(buf);
408 let parsed = recv_request(&mut cursor).await.unwrap();
409 assert!(matches!(parsed, HostdRequest::Ping));
410 }
411
412 #[tokio::test]
413 async fn test_response_send_recv_roundtrip() {
414 let resp = HostdResponse::Ok;
415 let mut buf = Vec::new();
416 send_response(&mut buf, &resp).await.unwrap();
417
418 let mut cursor = std::io::Cursor::new(buf);
419 let parsed = recv_response(&mut cursor).await.unwrap();
420 assert!(matches!(parsed, HostdResponse::Ok));
421 }
422}