Skip to main content

viiper_client/
async_client.rs

1// This file is auto-generated by VIIPER codegen. DO NOT EDIT.
2
3use crate::error::{ProblemJson, ViiperError};
4use crate::types::*;
5use std::net::SocketAddr;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9/// Stream wrapper that can be either plain or encrypted
10#[cfg(feature = "async")]
11pub enum AsyncStreamWrapper {
12    Plain(TcpStream),
13    Encrypted(crate::auth::AsyncEncryptedStream),
14}
15
16/// Read-half wrapper that can be either plain or encrypted
17#[cfg(feature = "async")]
18pub enum AsyncReadWrapper {
19    Plain(tokio::net::tcp::OwnedReadHalf),
20    Encrypted(crate::auth::AsyncEncryptedRead),
21}
22
23/// Write-half wrapper that can be either plain or encrypted
24#[cfg(feature = "async")]
25pub enum AsyncWriteWrapper {
26    Plain(tokio::net::tcp::OwnedWriteHalf),
27    Encrypted(crate::auth::AsyncEncryptedWrite),
28}
29
30#[cfg(feature = "async")]
31impl AsyncRead for AsyncStreamWrapper {
32    fn poll_read(
33        mut self: std::pin::Pin<&mut Self>,
34        cx: &mut std::task::Context<'_>,
35        buf: &mut tokio::io::ReadBuf<'_>,
36    ) -> std::task::Poll<std::io::Result<()>> {
37        match &mut *self {
38            AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
39            AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
40        }
41    }
42}
43
44#[cfg(feature = "async")]
45impl AsyncRead for AsyncReadWrapper {
46    fn poll_read(
47        mut self: std::pin::Pin<&mut Self>,
48        cx: &mut std::task::Context<'_>,
49        buf: &mut tokio::io::ReadBuf<'_>,
50    ) -> std::task::Poll<std::io::Result<()>> {
51        match &mut *self {
52            AsyncReadWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
53            AsyncReadWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
54        }
55    }
56}
57
58#[cfg(feature = "async")]
59impl AsyncWrite for AsyncStreamWrapper {
60    fn poll_write(
61        mut self: std::pin::Pin<&mut Self>,
62        cx: &mut std::task::Context<'_>,
63        buf: &[u8],
64    ) -> std::task::Poll<Result<usize, std::io::Error>> {
65        match &mut *self {
66            AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
67            AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
68        }
69    }
70    
71    fn poll_flush(
72        mut self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74    ) -> std::task::Poll<Result<(), std::io::Error>> {
75        match &mut *self {
76            AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
77            AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
78        }
79    }
80    
81    fn poll_shutdown(
82        mut self: std::pin::Pin<&mut Self>,
83        cx: &mut std::task::Context<'_>,
84    ) -> std::task::Poll<Result<(), std::io::Error>> {
85        match &mut *self {
86            AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
87            AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
88        }
89    }
90}
91
92#[cfg(feature = "async")]
93impl AsyncWrite for AsyncWriteWrapper {
94    fn poll_write(
95        mut self: std::pin::Pin<&mut Self>,
96        cx: &mut std::task::Context<'_>,
97        buf: &[u8],
98    ) -> std::task::Poll<Result<usize, std::io::Error>> {
99        match &mut *self {
100            AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
101            AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
102        }
103    }
104    
105    fn poll_flush(
106        mut self: std::pin::Pin<&mut Self>,
107        cx: &mut std::task::Context<'_>,
108    ) -> std::task::Poll<Result<(), std::io::Error>> {
109        match &mut *self {
110            AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
111            AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
112        }
113    }
114    
115    fn poll_shutdown(
116        mut self: std::pin::Pin<&mut Self>,
117        cx: &mut std::task::Context<'_>,
118    ) -> std::task::Poll<Result<(), std::io::Error>> {
119        match &mut *self {
120            AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
121            AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
122        }
123    }
124}
125
126/// VIIPER management API client (asynchronous).
127#[cfg(feature = "async")]
128pub struct AsyncViiperClient {
129    addr: SocketAddr,
130    password: Option<String>,
131}
132
133#[cfg(feature = "async")]
134impl AsyncViiperClient {
135    /// Create a new async VIIPER client connecting to the specified address.
136    pub fn new(addr: SocketAddr) -> Self {
137        Self { addr, password: None }
138    }
139
140    /// Create a new async VIIPER client with password authentication.
141    /// Empty password string explicitly means no authentication.
142    pub fn new_with_password(addr: SocketAddr, password: String) -> Self {
143        let password = if password.is_empty() { None } else { Some(password) };
144        Self { addr, password }
145    }
146
147    async fn do_request<T: for<'de> serde::Deserialize<'de>>(
148        &self,
149        path: &str,
150        payload: Option<&str>,
151    ) -> Result<T, ViiperError> {
152        let tcp_stream = TcpStream::connect(self.addr).await?;
153        tcp_stream.set_nodelay(true)?;
154
155        let mut stream = if let Some(ref pwd) = self.password {
156            AsyncStreamWrapper::Encrypted(crate::auth::perform_handshake_async(tcp_stream, pwd).await?)
157        } else {
158            AsyncStreamWrapper::Plain(tcp_stream)
159        };
160
161        stream.write_all(path.as_bytes()).await?;
162        if let Some(p) = payload {
163            stream.write_all(b" ").await?;
164            stream.write_all(p.as_bytes()).await?;
165        }
166        stream.write_all(b"\0").await?;
167
168        let mut buf = Vec::new();
169        stream.read_to_end(&mut buf).await?;
170
171        let response = String::from_utf8(buf)
172            .map_err(|_| ViiperError::UnexpectedResponse("invalid UTF-8".into()))?
173            .trim_end_matches('\n')
174            .to_string();
175
176        if response.starts_with("{\"status\":") {
177            let problem: ProblemJson = serde_json::from_str(&response)?;
178            return Err(ViiperError::Protocol(problem));
179        }
180
181        serde_json::from_str(&response).map_err(Into::into)
182    }
183
184    /// Ping: ping -> PingResponse
185    pub async fn ping(&self) -> Result<PingResponse, ViiperError> {
186        let path = "ping".to_string();
187        let payload: Option<String> = None;
188        self.do_request(&path, payload.as_deref()).await
189    }
190
191    /// BusList: bus/list -> BusListResponse
192    pub async fn bus_list(&self) -> Result<BusListResponse, ViiperError> {
193        let path = "bus/list".to_string();
194        let payload: Option<String> = None;
195        self.do_request(&path, payload.as_deref()).await
196    }
197
198    /// BusCreate: bus/create -> BusCreateResponse
199    pub async fn bus_create(&self, uint32: Option<u32>) -> Result<BusCreateResponse, ViiperError> {
200        let path = "bus/create".to_string();
201        let payload = uint32.map(|v| v.to_string());
202        self.do_request(&path, payload.as_deref()).await
203    }
204
205    /// BusRemove: bus/remove -> BusRemoveResponse
206    pub async fn bus_remove(&self, uint32: Option<u32>) -> Result<BusRemoveResponse, ViiperError> {
207        let path = "bus/remove".to_string();
208        let payload = uint32.map(|v| v.to_string());
209        self.do_request(&path, payload.as_deref()).await
210    }
211
212    /// BusDevicesList: bus/{id}/list -> DevicesListResponse
213    pub async fn bus_devices_list(&self, id: u32) -> Result<DevicesListResponse, ViiperError> {
214        let path = format!("bus/{}/list", id);
215        let payload: Option<String> = None;
216        self.do_request(&path, payload.as_deref()).await
217    }
218
219    /// BusDeviceAdd: bus/{id}/add -> Device
220    pub async fn bus_device_add(&self, id: u32, device_create_request: &DeviceCreateRequest) -> Result<Device, ViiperError> {
221        let path = format!("bus/{}/add", id);
222        let payload = Some(serde_json::to_string(&device_create_request)?);
223        self.do_request(&path, payload.as_deref()).await
224    }
225
226    /// BusDeviceRemove: bus/{id}/remove -> DeviceRemoveResponse
227    pub async fn bus_device_remove(&self, id: u32, string: Option<&str>) -> Result<DeviceRemoveResponse, ViiperError> {
228        let path = format!("bus/{}/remove", id);
229        let payload = string.map(|s| s.to_string());
230        self.do_request(&path, payload.as_deref()).await
231    }
232
233    /// Connect to a device stream for sending input and receiving output.
234    pub async fn connect_device(&self, bus_id: u32, dev_id: &str) -> Result<AsyncDeviceStream, ViiperError> {
235        AsyncDeviceStream::connect(self.addr, bus_id, dev_id, self.password.as_deref()).await
236    }
237}
238
239/// An async connected device stream for bidirectional communication.
240#[cfg(feature = "async")]
241pub struct AsyncDeviceStream {
242    read_stream: std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>,
243    write_stream: std::sync::Arc<tokio::sync::Mutex<AsyncWriteWrapper>>,
244    cancel_token: Option<tokio_util::sync::CancellationToken>,
245    disconnect_callback: std::sync::Mutex<Option<Box<dyn FnOnce() + Send + 'static>>>,
246}
247
248#[cfg(feature = "async")]
249impl AsyncDeviceStream {
250    pub async fn connect(addr: SocketAddr, bus_id: u32, dev_id: &str, password: Option<&str>) -> Result<Self, ViiperError> {
251        let tcp_stream = TcpStream::connect(addr).await?;
252		tcp_stream.set_nodelay(true)?;
253		
254        let (read_stream, mut write_stream) = if let Some(pwd) = password {
255            let encrypted = crate::auth::perform_handshake_async(tcp_stream, pwd).await?;
256            let (read_half, write_half) = encrypted.into_split();
257            (AsyncReadWrapper::Encrypted(read_half), AsyncWriteWrapper::Encrypted(write_half))
258        } else {
259            let (read_half, write_half) = tcp_stream.into_split();
260            (AsyncReadWrapper::Plain(read_half), AsyncWriteWrapper::Plain(write_half))
261        };
262        
263        let handshake = format!("bus/{}/{}\0", bus_id, dev_id);
264        write_stream.write_all(handshake.as_bytes()).await?;
265        
266        Ok(Self { 
267            read_stream: std::sync::Arc::new(tokio::sync::Mutex::new(read_stream)),
268            write_stream: std::sync::Arc::new(tokio::sync::Mutex::new(write_stream)),
269            cancel_token: None,
270            disconnect_callback: std::sync::Mutex::new(None),
271        })
272    }
273
274    /// Send a device input to the device.
275    pub async fn send<T: crate::wire::DeviceInput>(
276        &self,
277        input: &T,
278    ) -> Result<(), ViiperError> {
279        let bytes = input.to_bytes();
280        let mut stream = self.write_stream.lock().await;
281        stream.write_all(&bytes).await?;
282        Ok(())
283    }
284
285    /// Send a device input to the device with a timeout.
286    ///
287    /// # Arguments
288    /// * `input` - The device input to send
289    /// * `timeout` - Timeout duration for the operation
290    pub async fn send_timeout<T: crate::wire::DeviceInput>(
291        &self,
292        input: &T,
293        timeout: std::time::Duration,
294    ) -> Result<(), ViiperError> {
295        let bytes = input.to_bytes();
296        let mut stream = self.write_stream.lock().await;
297        tokio::time::timeout(timeout, stream.write_all(&bytes))
298            .await
299            .map_err(|_| ViiperError::Timeout)?
300            .map_err(Into::into)
301    }
302
303    /// Register a callback to receive device output asynchronously.
304    /// The callback receives a shared reference to the read half and must read the exact number of bytes expected.
305    /// The callback will be invoked repeatedly on a tokio task until it returns an error.
306    /// Only one callback can be registered at a time.
307    pub fn on_output<F, Fut>(&mut self, callback: F) -> Result<(), ViiperError>
308    where
309        F: Fn(std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>) -> Fut + Send + 'static,
310        Fut: std::future::Future<Output = std::io::Result<()>> + Send + 'static,
311    {
312        if self.cancel_token.is_some() {
313            return Err(ViiperError::UnexpectedResponse("Output callback already registered".into()));
314        }
315
316        let stream = self.read_stream.clone();
317        let cancel_token = tokio_util::sync::CancellationToken::new();
318        let cancel_clone = cancel_token.clone();
319		let Ok(mut guard) = self.disconnect_callback.lock() else {
320			return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
321		};
322		let disconnect = guard.take();
323
324        tokio::spawn(async move {
325            loop {
326                tokio::select! {
327                    _ = cancel_clone.cancelled() => break,
328                    result = callback(stream.clone()) => {
329                        match result {
330                            Ok(()) => continue,
331                            Err(_) => break,
332                        }
333                    }
334                }
335            }
336            if let Some(cb) = disconnect {
337                cb();
338            }
339        });
340
341        self.cancel_token = Some(cancel_token);
342        Ok(())
343    }
344
345    pub fn on_disconnect<F>(&mut self, callback: F) -> Result<(), ViiperError>
346    where
347        F: FnOnce() + Send + 'static,
348    {
349		let Ok(mut guard) = self.disconnect_callback.lock() else {
350			return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
351		};
352		*guard = Some(Box::new(callback));
353		Ok(())
354    }
355
356    /// Send raw bytes to the device.
357    pub async fn send_raw(&self, data: &[u8]) -> Result<(), ViiperError> {
358        let mut stream = self.write_stream.lock().await;
359        stream.write_all(data).await?;
360        Ok(())
361    }
362
363    /// Read raw bytes from the device.
364    pub async fn read_raw(&self, buf: &mut [u8]) -> Result<usize, ViiperError> {
365        let mut stream = self.read_stream.lock().await;
366        stream.read(buf).await.map_err(Into::into)
367    }
368
369    /// Read exact number of bytes from the device.
370    pub async fn read_exact(&self, buf: &mut [u8]) -> Result<(), ViiperError> {
371        let mut stream = self.read_stream.lock().await;
372        stream.read_exact(buf).await?;
373        Ok(())
374    }
375}
376
377#[cfg(feature = "async")]
378impl Drop for AsyncDeviceStream {
379    fn drop(&mut self) {
380        if let Some(token) = &self.cancel_token {
381            token.cancel();
382        }
383    }
384}