1use std::path::Path;
26
27use vm_pool_protocol::{
28 LogLine, ServiceCommand, ServiceEvent, VmCommand, VmConfig, VmId,
29};
30use thiserror::Error;
31use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
32use tokio::net::UnixStream;
33use tokio::sync::mpsc;
34use tracing::debug;
35
36#[derive(Debug, Error)]
37pub enum ClientError {
38 #[error("connection failed: {0}")]
39 Connect(#[from] std::io::Error),
40 #[error("JSON error: {0}")]
41 Json(#[from] serde_json::Error),
42 #[error("connection closed")]
43 Closed,
44 #[error("service error: {0}")]
45 Service(String),
46 #[error("unexpected response: {0:?}")]
47 UnexpectedResponse(ServiceEvent),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct PoolStatus {
53 pub total: usize,
54 pub available: usize,
55 pub allocated: usize,
56}
57
58pub struct Client {
60 cmd_tx: mpsc::Sender<String>,
62 resp_rx: mpsc::Receiver<ServiceEvent>,
64}
65
66impl Client {
67 pub async fn connect(path: impl AsRef<Path>) -> Result<Self, ClientError> {
69 let stream = UnixStream::connect(path.as_ref()).await?;
70 let (reader, writer) = stream.into_split();
71
72 let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(64);
73 let (resp_tx, resp_rx) = mpsc::channel::<ServiceEvent>(64);
74
75 tokio::spawn(async move {
77 let mut writer = writer;
78 while let Some(line) = cmd_rx.recv().await {
79 if writer.write_all(line.as_bytes()).await.is_err() {
80 break;
81 }
82 if writer.write_all(b"\n").await.is_err() {
83 break;
84 }
85 let _ = writer.flush().await;
86 }
87 });
88
89 tokio::spawn(async move {
91 let mut reader = BufReader::new(reader);
92 let mut line = String::new();
93 loop {
94 line.clear();
95 match reader.read_line(&mut line).await {
96 Ok(0) => break,
97 Ok(_) => {
98 if let Ok(event) = serde_json::from_str::<ServiceEvent>(line.trim()) {
99 if resp_tx.send(event).await.is_err() {
100 break;
101 }
102 }
103 }
104 Err(_) => break,
105 }
106 }
107 });
108
109 Ok(Self { cmd_tx, resp_rx })
110 }
111
112 async fn request(&mut self, command: ServiceCommand) -> Result<ServiceEvent, ClientError> {
114 let json = serde_json::to_string(&command)?;
115 debug!("sending: {}", json);
116 self.cmd_tx
117 .send(json)
118 .await
119 .map_err(|_| ClientError::Closed)?;
120
121 self.resp_rx.recv().await.ok_or(ClientError::Closed)
122 }
123
124 fn check_error(event: ServiceEvent) -> Result<ServiceEvent, ClientError> {
126 match event {
127 ServiceEvent::Error { message } => Err(ClientError::Service(message)),
128 other => Ok(other),
129 }
130 }
131
132 pub async fn status(&mut self) -> Result<PoolStatus, ClientError> {
134 let resp = self.request(ServiceCommand::Status).await?;
135 match Self::check_error(resp)? {
136 ServiceEvent::PoolStatus {
137 total,
138 available,
139 allocated,
140 } => Ok(PoolStatus {
141 total,
142 available,
143 allocated,
144 }),
145 other => Err(ClientError::UnexpectedResponse(other)),
146 }
147 }
148
149 pub async fn allocate(
151 &mut self,
152 image: &str,
153 config: VmConfig,
154 ) -> Result<VmId, ClientError> {
155 let resp = self
156 .request(ServiceCommand::Allocate {
157 image: image.to_string(),
158 config,
159 })
160 .await?;
161 match Self::check_error(resp)? {
162 ServiceEvent::VmAllocated { vm_id, .. } => Ok(vm_id),
163 other => Err(ClientError::UnexpectedResponse(other)),
164 }
165 }
166
167 pub async fn deallocate(&mut self, vm_id: &VmId) -> Result<(), ClientError> {
169 let resp = self
170 .request(ServiceCommand::Deallocate {
171 vm_id: vm_id.clone(),
172 })
173 .await?;
174 match Self::check_error(resp)? {
175 ServiceEvent::VmStopped { .. } => Ok(()),
176 other => Err(ClientError::UnexpectedResponse(other)),
177 }
178 }
179
180 pub async fn send_command(
182 &mut self,
183 vm_id: &VmId,
184 command: VmCommand,
185 ) -> Result<ServiceEvent, ClientError> {
186 let resp = self
187 .request(ServiceCommand::Send {
188 vm_id: vm_id.clone(),
189 command,
190 })
191 .await?;
192 Self::check_error(resp)
193 }
194
195 pub async fn snapshot(&mut self, vm_id: &VmId, name: &str) -> Result<(), ClientError> {
197 let resp = self
198 .request(ServiceCommand::Snapshot {
199 vm_id: vm_id.clone(),
200 name: name.to_string(),
201 })
202 .await?;
203 match Self::check_error(resp)? {
204 ServiceEvent::VmStopped { .. } => Ok(()),
205 other => Err(ClientError::UnexpectedResponse(other)),
206 }
207 }
208
209 pub async fn restore(&mut self, vm_id: &VmId, snapshot: &str) -> Result<(), ClientError> {
211 let resp = self
212 .request(ServiceCommand::Restore {
213 vm_id: vm_id.clone(),
214 snapshot: snapshot.to_string(),
215 })
216 .await?;
217 match Self::check_error(resp)? {
218 ServiceEvent::VmReady { .. } => Ok(()),
219 other => Err(ClientError::UnexpectedResponse(other)),
220 }
221 }
222
223 pub async fn tail_logs(
225 &mut self,
226 vm_id: &VmId,
227 lines: usize,
228 ) -> Result<Vec<LogLine>, ClientError> {
229 let resp = self
230 .request(ServiceCommand::TailLogs {
231 vm_id: vm_id.clone(),
232 lines,
233 })
234 .await?;
235 match Self::check_error(resp)? {
236 ServiceEvent::LogTail { lines, .. } => Ok(lines),
237 other => Err(ClientError::UnexpectedResponse(other)),
238 }
239 }
240
241 pub async fn subscribe_logs(
243 &mut self,
244 vm_id: Option<&VmId>,
245 ) -> Result<(), ClientError> {
246 let resp = self
247 .request(ServiceCommand::SubscribeLogs {
248 vm_id: vm_id.cloned(),
249 })
250 .await?;
251 match Self::check_error(resp)? {
252 ServiceEvent::LogsSubscribed { .. } => Ok(()),
253 other => Err(ClientError::UnexpectedResponse(other)),
254 }
255 }
256
257 pub async fn unsubscribe_logs(&mut self) -> Result<(), ClientError> {
259 let resp = self.request(ServiceCommand::UnsubscribeLogs).await?;
260 match Self::check_error(resp)? {
261 ServiceEvent::LogsSubscribed { .. } => Ok(()),
262 other => Err(ClientError::UnexpectedResponse(other)),
263 }
264 }
265
266 pub async fn next_event(&mut self) -> Option<ServiceEvent> {
269 self.resp_rx.recv().await
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use std::sync::Arc;
277 use vm_pool_manager::PoolConfig;
278 use vm_pool_service::{Service, ServiceConfig};
279
280 async fn test_client() -> (Client, Arc<Service>, tempfile::TempDir) {
282 let dir = tempfile::tempdir().unwrap();
283 let socket_path = dir.path().join("test.sock");
284
285 let config = ServiceConfig {
286 socket_path: socket_path.clone(),
287 snapshot_dir: dir.path().join("snapshots"),
288 pool: PoolConfig {
289 max_vms: 3,
290 health_check_interval: 300,
291 vm_timeout: 7200,
292 },
293 };
294
295 let service = Service::new(config).await.unwrap();
296 let svc = service.clone();
297
298 tokio::spawn(async move { svc.run().await });
300
301 for _ in 0..50 {
303 if socket_path.exists() {
304 break;
305 }
306 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
307 }
308
309 let client = Client::connect(&socket_path).await.unwrap();
310 (client, service, dir)
311 }
312
313 #[tokio::test]
314 async fn client_status() {
315 let (mut client, _svc, _dir) = test_client().await;
316
317 let status = client.status().await.unwrap();
318 assert_eq!(status.total, 3);
319 assert_eq!(status.available, 3);
320 assert_eq!(status.allocated, 0);
321 }
322
323 #[tokio::test]
324 async fn client_allocate_and_deallocate() {
325 let (mut client, _svc, _dir) = test_client().await;
326
327 let vm_id = client
328 .allocate("agent:v1", VmConfig::default())
329 .await
330 .unwrap();
331
332 let status = client.status().await.unwrap();
333 assert_eq!(status.allocated, 1);
334
335 client.deallocate(&vm_id).await.unwrap();
336
337 let status = client.status().await.unwrap();
338 assert_eq!(status.allocated, 0);
339 }
340
341 #[tokio::test]
342 async fn client_allocate_error() {
343 let dir = tempfile::tempdir().unwrap();
344 let socket_path = dir.path().join("test.sock");
345
346 let config = ServiceConfig {
347 socket_path: socket_path.clone(),
348 snapshot_dir: dir.path().join("snapshots"),
349 pool: PoolConfig {
350 max_vms: 0, health_check_interval: 300,
352 vm_timeout: 7200,
353 },
354 };
355
356 let service = Service::new(config).await.unwrap();
357 let svc = service.clone();
358 tokio::spawn(async move { svc.run().await });
359
360 for _ in 0..50 {
361 if socket_path.exists() {
362 break;
363 }
364 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
365 }
366
367 let mut client = Client::connect(&socket_path).await.unwrap();
368 let result = client.allocate("agent:v1", VmConfig::default()).await;
369 assert!(matches!(result, Err(ClientError::Service(_))));
370 }
371
372 #[tokio::test]
373 async fn client_tail_logs() {
374 let (mut client, _svc, _dir) = test_client().await;
375
376 let vm_id = VmId::new("vm-nonexistent");
377 let logs = client.tail_logs(&vm_id, 10).await.unwrap();
378 assert!(logs.is_empty());
379 }
380
381 #[tokio::test]
382 async fn client_subscribe_unsubscribe() {
383 let (mut client, _svc, _dir) = test_client().await;
384
385 client.subscribe_logs(None).await.unwrap();
386 client.unsubscribe_logs().await.unwrap();
387 }
388
389 #[tokio::test]
390 async fn client_full_lifecycle() {
391 let (mut client, _svc, _dir) = test_client().await;
392
393 let status = client.status().await.unwrap();
395 assert_eq!(status.available, 3);
396
397 let vm1 = client
399 .allocate("agent:v1", VmConfig::default())
400 .await
401 .unwrap();
402 let vm2 = client
403 .allocate("automation:v1", VmConfig::default())
404 .await
405 .unwrap();
406
407 let status = client.status().await.unwrap();
408 assert_eq!(status.allocated, 2);
409 assert_eq!(status.available, 1);
410
411 client.deallocate(&vm1).await.unwrap();
413
414 let status = client.status().await.unwrap();
415 assert_eq!(status.allocated, 1);
416 assert_eq!(status.available, 2);
417
418 client.deallocate(&vm2).await.unwrap();
420
421 let status = client.status().await.unwrap();
422 assert_eq!(status.allocated, 0);
423 }
424}