Skip to main content

vm_pool_client/
lib.rs

1//! Client library for communicating with the vm-pool service.
2//!
3//! Provides a high-level async API over the Unix socket protocol.
4//!
5//! # Example
6//!
7//! ```no_run
8//! # async fn example() -> Result<(), vm_pool_client::ClientError> {
9//! use vm_pool_client::Client;
10//! use vm_pool_protocol::VmConfig;
11//!
12//! let mut client = Client::connect("/tmp/vm-pool.sock").await?;
13//!
14//! let status = client.status().await?;
15//! println!("available: {}", status.available);
16//!
17//! let vm_id = client.allocate("agent:v1.0.0", VmConfig::default()).await?;
18//! println!("allocated: {}", vm_id);
19//!
20//! client.deallocate(&vm_id).await?;
21//! # Ok(())
22//! # }
23//! ```
24
25use std::path::Path;
26
27use vm_pool_protocol::{
28    LogLine, ServiceCommand, ServiceEvent, VmCommand, VmConfig, VmId,
29};
30use thiserror::Error;
31use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
32use tokio::net::UnixStream;
33use tokio::sync::mpsc;
34use tracing::debug;
35
36#[derive(Debug, Error)]
37pub enum ClientError {
38    #[error("connection failed: {0}")]
39    Connect(#[from] std::io::Error),
40    #[error("JSON error: {0}")]
41    Json(#[from] serde_json::Error),
42    #[error("connection closed")]
43    Closed,
44    #[error("service error: {0}")]
45    Service(String),
46    #[error("unexpected response: {0:?}")]
47    UnexpectedResponse(ServiceEvent),
48}
49
50/// Pool status information.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct PoolStatus {
53    pub total: usize,
54    pub available: usize,
55    pub allocated: usize,
56}
57
58/// Client for communicating with the vm-pool service.
59pub struct Client {
60    /// Channel for sending serialized commands to the writer task.
61    cmd_tx: mpsc::Sender<String>,
62    /// Channel for receiving parsed responses from the reader task.
63    resp_rx: mpsc::Receiver<ServiceEvent>,
64}
65
66impl Client {
67    /// Connect to the vm-pool service at the given Unix socket path.
68    pub async fn connect(path: impl AsRef<Path>) -> Result<Self, ClientError> {
69        let stream = UnixStream::connect(path.as_ref()).await?;
70        let (reader, writer) = stream.into_split();
71
72        let (cmd_tx, mut cmd_rx) = mpsc::channel::<String>(64);
73        let (resp_tx, resp_rx) = mpsc::channel::<ServiceEvent>(64);
74
75        // Writer task
76        tokio::spawn(async move {
77            let mut writer = writer;
78            while let Some(line) = cmd_rx.recv().await {
79                if writer.write_all(line.as_bytes()).await.is_err() {
80                    break;
81                }
82                if writer.write_all(b"\n").await.is_err() {
83                    break;
84                }
85                let _ = writer.flush().await;
86            }
87        });
88
89        // Reader task
90        tokio::spawn(async move {
91            let mut reader = BufReader::new(reader);
92            let mut line = String::new();
93            loop {
94                line.clear();
95                match reader.read_line(&mut line).await {
96                    Ok(0) => break,
97                    Ok(_) => {
98                        if let Ok(event) = serde_json::from_str::<ServiceEvent>(line.trim()) {
99                            if resp_tx.send(event).await.is_err() {
100                                break;
101                            }
102                        }
103                    }
104                    Err(_) => break,
105                }
106            }
107        });
108
109        Ok(Self { cmd_tx, resp_rx })
110    }
111
112    /// Send a command and wait for the next response.
113    async fn request(&mut self, command: ServiceCommand) -> Result<ServiceEvent, ClientError> {
114        let json = serde_json::to_string(&command)?;
115        debug!("sending: {}", json);
116        self.cmd_tx
117            .send(json)
118            .await
119            .map_err(|_| ClientError::Closed)?;
120
121        self.resp_rx.recv().await.ok_or(ClientError::Closed)
122    }
123
124    /// Convert a ServiceEvent::Error into a ClientError, or return the event.
125    fn check_error(event: ServiceEvent) -> Result<ServiceEvent, ClientError> {
126        match event {
127            ServiceEvent::Error { message } => Err(ClientError::Service(message)),
128            other => Ok(other),
129        }
130    }
131
132    /// Get pool status.
133    pub async fn status(&mut self) -> Result<PoolStatus, ClientError> {
134        let resp = self.request(ServiceCommand::Status).await?;
135        match Self::check_error(resp)? {
136            ServiceEvent::PoolStatus {
137                total,
138                available,
139                allocated,
140            } => Ok(PoolStatus {
141                total,
142                available,
143                allocated,
144            }),
145            other => Err(ClientError::UnexpectedResponse(other)),
146        }
147    }
148
149    /// Allocate a new VM. Returns the VM ID.
150    pub async fn allocate(
151        &mut self,
152        image: &str,
153        config: VmConfig,
154    ) -> Result<VmId, ClientError> {
155        let resp = self
156            .request(ServiceCommand::Allocate {
157                image: image.to_string(),
158                config,
159            })
160            .await?;
161        match Self::check_error(resp)? {
162            ServiceEvent::VmAllocated { vm_id, .. } => Ok(vm_id),
163            other => Err(ClientError::UnexpectedResponse(other)),
164        }
165    }
166
167    /// Deallocate a VM.
168    pub async fn deallocate(&mut self, vm_id: &VmId) -> Result<(), ClientError> {
169        let resp = self
170            .request(ServiceCommand::Deallocate {
171                vm_id: vm_id.clone(),
172            })
173            .await?;
174        match Self::check_error(resp)? {
175            ServiceEvent::VmStopped { .. } => Ok(()),
176            other => Err(ClientError::UnexpectedResponse(other)),
177        }
178    }
179
180    /// Send a command to a VM.
181    pub async fn send_command(
182        &mut self,
183        vm_id: &VmId,
184        command: VmCommand,
185    ) -> Result<ServiceEvent, ClientError> {
186        let resp = self
187            .request(ServiceCommand::Send {
188                vm_id: vm_id.clone(),
189                command,
190            })
191            .await?;
192        Self::check_error(resp)
193    }
194
195    /// Save a snapshot of a VM.
196    pub async fn snapshot(&mut self, vm_id: &VmId, name: &str) -> Result<(), ClientError> {
197        let resp = self
198            .request(ServiceCommand::Snapshot {
199                vm_id: vm_id.clone(),
200                name: name.to_string(),
201            })
202            .await?;
203        match Self::check_error(resp)? {
204            ServiceEvent::VmStopped { .. } => Ok(()),
205            other => Err(ClientError::UnexpectedResponse(other)),
206        }
207    }
208
209    /// Restore a VM from a snapshot.
210    pub async fn restore(&mut self, vm_id: &VmId, snapshot: &str) -> Result<(), ClientError> {
211        let resp = self
212            .request(ServiceCommand::Restore {
213                vm_id: vm_id.clone(),
214                snapshot: snapshot.to_string(),
215            })
216            .await?;
217        match Self::check_error(resp)? {
218            ServiceEvent::VmReady { .. } => Ok(()),
219            other => Err(ClientError::UnexpectedResponse(other)),
220        }
221    }
222
223    /// Tail log lines from a VM.
224    pub async fn tail_logs(
225        &mut self,
226        vm_id: &VmId,
227        lines: usize,
228    ) -> Result<Vec<LogLine>, ClientError> {
229        let resp = self
230            .request(ServiceCommand::TailLogs {
231                vm_id: vm_id.clone(),
232                lines,
233            })
234            .await?;
235        match Self::check_error(resp)? {
236            ServiceEvent::LogTail { lines, .. } => Ok(lines),
237            other => Err(ClientError::UnexpectedResponse(other)),
238        }
239    }
240
241    /// Subscribe to logs from a specific VM (or all VMs if None).
242    pub async fn subscribe_logs(
243        &mut self,
244        vm_id: Option<&VmId>,
245    ) -> Result<(), ClientError> {
246        let resp = self
247            .request(ServiceCommand::SubscribeLogs {
248                vm_id: vm_id.cloned(),
249            })
250            .await?;
251        match Self::check_error(resp)? {
252            ServiceEvent::LogsSubscribed { .. } => Ok(()),
253            other => Err(ClientError::UnexpectedResponse(other)),
254        }
255    }
256
257    /// Unsubscribe from log streaming.
258    pub async fn unsubscribe_logs(&mut self) -> Result<(), ClientError> {
259        let resp = self.request(ServiceCommand::UnsubscribeLogs).await?;
260        match Self::check_error(resp)? {
261            ServiceEvent::LogsSubscribed { .. } => Ok(()),
262            other => Err(ClientError::UnexpectedResponse(other)),
263        }
264    }
265
266    /// Receive the next event (for streaming/subscriptions).
267    /// Returns None if the connection is closed.
268    pub async fn next_event(&mut self) -> Option<ServiceEvent> {
269        self.resp_rx.recv().await
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use std::sync::Arc;
277    use vm_pool_manager::PoolConfig;
278    use vm_pool_service::{Service, ServiceConfig};
279
280    /// Start a service on a temp socket and return a connected client.
281    async fn test_client() -> (Client, Arc<Service>, tempfile::TempDir) {
282        let dir = tempfile::tempdir().unwrap();
283        let socket_path = dir.path().join("test.sock");
284
285        let config = ServiceConfig {
286            socket_path: socket_path.clone(),
287            snapshot_dir: dir.path().join("snapshots"),
288            pool: PoolConfig {
289                max_vms: 3,
290                health_check_interval: 300,
291                vm_timeout: 7200,
292            },
293        };
294
295        let service = Service::new(config).await.unwrap();
296        let svc = service.clone();
297
298        // Run service in background
299        tokio::spawn(async move { svc.run().await });
300
301        // Wait for socket to be ready
302        for _ in 0..50 {
303            if socket_path.exists() {
304                break;
305            }
306            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
307        }
308
309        let client = Client::connect(&socket_path).await.unwrap();
310        (client, service, dir)
311    }
312
313    #[tokio::test]
314    async fn client_status() {
315        let (mut client, _svc, _dir) = test_client().await;
316
317        let status = client.status().await.unwrap();
318        assert_eq!(status.total, 3);
319        assert_eq!(status.available, 3);
320        assert_eq!(status.allocated, 0);
321    }
322
323    #[tokio::test]
324    async fn client_allocate_and_deallocate() {
325        let (mut client, _svc, _dir) = test_client().await;
326
327        let vm_id = client
328            .allocate("agent:v1", VmConfig::default())
329            .await
330            .unwrap();
331
332        let status = client.status().await.unwrap();
333        assert_eq!(status.allocated, 1);
334
335        client.deallocate(&vm_id).await.unwrap();
336
337        let status = client.status().await.unwrap();
338        assert_eq!(status.allocated, 0);
339    }
340
341    #[tokio::test]
342    async fn client_allocate_error() {
343        let dir = tempfile::tempdir().unwrap();
344        let socket_path = dir.path().join("test.sock");
345
346        let config = ServiceConfig {
347            socket_path: socket_path.clone(),
348            snapshot_dir: dir.path().join("snapshots"),
349            pool: PoolConfig {
350                max_vms: 0, // No VMs allowed
351                health_check_interval: 300,
352                vm_timeout: 7200,
353            },
354        };
355
356        let service = Service::new(config).await.unwrap();
357        let svc = service.clone();
358        tokio::spawn(async move { svc.run().await });
359
360        for _ in 0..50 {
361            if socket_path.exists() {
362                break;
363            }
364            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
365        }
366
367        let mut client = Client::connect(&socket_path).await.unwrap();
368        let result = client.allocate("agent:v1", VmConfig::default()).await;
369        assert!(matches!(result, Err(ClientError::Service(_))));
370    }
371
372    #[tokio::test]
373    async fn client_tail_logs() {
374        let (mut client, _svc, _dir) = test_client().await;
375
376        let vm_id = VmId::new("vm-nonexistent");
377        let logs = client.tail_logs(&vm_id, 10).await.unwrap();
378        assert!(logs.is_empty());
379    }
380
381    #[tokio::test]
382    async fn client_subscribe_unsubscribe() {
383        let (mut client, _svc, _dir) = test_client().await;
384
385        client.subscribe_logs(None).await.unwrap();
386        client.unsubscribe_logs().await.unwrap();
387    }
388
389    #[tokio::test]
390    async fn client_full_lifecycle() {
391        let (mut client, _svc, _dir) = test_client().await;
392
393        // Check initial status
394        let status = client.status().await.unwrap();
395        assert_eq!(status.available, 3);
396
397        // Allocate two VMs
398        let vm1 = client
399            .allocate("agent:v1", VmConfig::default())
400            .await
401            .unwrap();
402        let vm2 = client
403            .allocate("automation:v1", VmConfig::default())
404            .await
405            .unwrap();
406
407        let status = client.status().await.unwrap();
408        assert_eq!(status.allocated, 2);
409        assert_eq!(status.available, 1);
410
411        // Deallocate one
412        client.deallocate(&vm1).await.unwrap();
413
414        let status = client.status().await.unwrap();
415        assert_eq!(status.allocated, 1);
416        assert_eq!(status.available, 2);
417
418        // Deallocate the other
419        client.deallocate(&vm2).await.unwrap();
420
421        let status = client.status().await.unwrap();
422        assert_eq!(status.allocated, 0);
423    }
424}