varlink 13.0.0

Client and server support for the varlink protocol.
Documentation
//! Async client support for varlink using Tokio
//!
//! This module provides async versions of the varlink client functionality,
//! using the sans-io state machines for protocol handling and Tokio for I/O.
//!
//! # Example
//!
//! ```no_run
//! # #[cfg(feature = "tokio")]
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! use varlink::AsyncConnection;
//! use std::sync::Arc;
//!
//! // Connect to a varlink service
//! let connection = AsyncConnection::with_address("tcp:127.0.0.1:9999").await?;
//!
//! // Create a client wrapper (usually generated by varlink_generator)
//! // With AsyncMethodCall:
//! // let reply = varlink::AsyncMethodCall::new(
//! //     connection.clone(),
//! //     "org.example.ping.Ping",
//! //     PingArgs { ping: "Hello".to_string() }
//! // ).call_async().await?;
//!
//! # Ok(())
//! # }
//! ```

use crate::error::*;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::sync::RwLock;

/// Async stream trait for varlink connections
pub trait AsyncVarlinkStream: AsyncReadExt + AsyncWriteExt + Send + Sync + Unpin {}

impl AsyncVarlinkStream for TcpStream {}
#[cfg(unix)]
impl AsyncVarlinkStream for UnixStream {}

/// Async connection wrapper for TCP and Unix domain sockets
pub enum AsyncStream {
    /// TCP stream
    TCP(TcpStream),
    /// Unix domain socket stream
    #[cfg(unix)]
    UNIX(UnixStream),
}

impl AsyncStream {
    /// Shutdown the stream
    #[allow(dead_code)]
    async fn shutdown(&mut self) -> std::io::Result<()> {
        match self {
            AsyncStream::TCP(stream) => stream.shutdown().await,
            #[cfg(unix)]
            AsyncStream::UNIX(_) => Ok(()), // Unix streams don't need explicit shutdown
        }
    }
}

/// Connect to a varlink service asynchronously
///
/// Supported address formats:
/// - `tcp:host:port` - TCP connection
/// - `unix:/path/to/socket` - Unix domain socket
/// - `unix:@abstract` - Abstract Unix socket (Linux only)
pub async fn async_varlink_connect<S: AsRef<str>>(address: S) -> Result<(AsyncStream, String)> {
    let address = address.as_ref();
    let new_address: String = address.into();

    if let Some(addr) = new_address.strip_prefix("tcp:") {
        let stream = TcpStream::connect(addr)
            .await
            .map_err(|e| context!(ErrorKind::Io(e.kind())))?;
        Ok((AsyncStream::TCP(stream), new_address))
    } else if let Some(addr) = new_address.strip_prefix("unix:") {
        #[cfg(unix)]
        {
            if let Some(abstract_addr) = addr.strip_prefix('@') {
                // Abstract socket (Linux only)
                #[cfg(any(target_os = "linux", target_os = "android"))]
                {
                    let addr = abstract_addr.split(';').next().unwrap_or(abstract_addr);
                    let socket_path = format!("\0{}", addr);
                    let stream = UnixStream::connect(socket_path)
                        .await
                        .map_err(|e| context!(ErrorKind::Io(e.kind())))?;
                    Ok((AsyncStream::UNIX(stream), new_address))
                }
                #[cfg(not(any(target_os = "linux", target_os = "android")))]
                {
                    let _ = abstract_addr;
                    Err(context!(ErrorKind::InvalidAddress))
                }
            } else {
                // File-based Unix socket
                let addr = addr.split(';').next().unwrap_or(addr);
                let stream = UnixStream::connect(addr)
                    .await
                    .map_err(|e| context!(ErrorKind::Io(e.kind())))?;
                Ok((AsyncStream::UNIX(stream), new_address))
            }
        }
        #[cfg(not(unix))]
        {
            let _ = addr;
            Err(context!(ErrorKind::InvalidAddress))
        }
    } else {
        Err(context!(ErrorKind::InvalidAddress))
    }
}

/// An async client connection to a varlink service
pub struct AsyncConnection {
    address: String,
    stream: Arc<RwLock<Option<AsyncStream>>>,
}

impl AsyncConnection {
    /// Create a connection with a varlink URI
    ///
    /// Supported address formats:
    /// - `tcp:host:port` - TCP connection
    /// - `unix:/path/to/socket` - Unix domain socket
    /// - `unix:@abstract` - Abstract Unix socket (Linux only)
    ///
    /// # Examples
    ///
    /// ```no_run
    /// # #[cfg(feature = "tokio")]
    /// # use varlink::AsyncConnection;
    /// # #[tokio::main]
    /// # async fn main() -> varlink::Result<()> {
    /// let connection = AsyncConnection::with_address("tcp:127.0.0.1:9999").await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn with_address<S: AsRef<str>>(address: S) -> Result<Arc<Self>> {
        let (stream, address) = async_varlink_connect(address).await?;
        Ok(Arc::new(AsyncConnection {
            address,
            stream: Arc::new(RwLock::new(Some(stream))),
        }))
    }

    /// Return the address used by the connection
    pub fn address(&self) -> String {
        self.address.clone()
    }
}

impl Drop for AsyncConnection {
    fn drop(&mut self) {
        // Stream cleanup happens automatically
    }
}

/// Async method call builder
///
/// This struct is returned by the generated client methods and provides
/// async methods to execute the call.
pub struct AsyncMethodCall<MRequest, MReply, MError>
where
    MRequest: Serialize,
    MReply: DeserializeOwned,
    MError: From<Error>,
{
    connection: Arc<AsyncConnection>,
    request: Option<MRequest>,
    method: Option<String>,
    continues: bool,
    phantom_reply: PhantomData<MReply>,
    phantom_error: PhantomData<MError>,
}

impl<MRequest, MReply, MError> AsyncMethodCall<MRequest, MReply, MError>
where
    MRequest: Serialize,
    MReply: DeserializeOwned,
    MError: From<Error>,
{
    /// Create a new async method call
    pub fn new<S: Into<String>>(
        connection: Arc<AsyncConnection>,
        method: S,
        parameters: MRequest,
    ) -> Self {
        AsyncMethodCall {
            connection,
            request: Some(parameters),
            method: Some(method.into()),
            continues: false,
            phantom_reply: PhantomData,
            phantom_error: PhantomData,
        }
    }

    /// Send a request to the server
    async fn send(
        &mut self,
        oneway: bool,
        more: bool,
        upgrade: bool,
    ) -> std::result::Result<(), MError> {
        use crate::Request;

        let mut req = match (self.method.take(), self.request.take()) {
            (Some(method), Some(request)) => Request::create(
                method,
                Some(serde_json::to_value(request).map_err(map_context!())?),
            ),
            _ => {
                return Err(MError::from(context!(ErrorKind::MethodCalledAlready)));
            }
        };

        if oneway {
            req.oneway = Some(true);
        }

        if more {
            req.more = Some(true);
            self.continues = true;
        }

        if upgrade {
            req.upgrade = Some(true);
        }

        // Serialize request using sans-io
        let data = crate::sansio::protocol::serialize_request(&req)?;

        // Get stream and write
        let stream_lock = self.connection.stream.clone();
        let mut stream_guard = stream_lock.write().await;
        let stream = stream_guard
            .as_mut()
            .ok_or_else(|| MError::from(context!(ErrorKind::ConnectionClosed)))?;

        match stream {
            AsyncStream::TCP(s) => {
                s.write_all(&data)
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
                s.flush()
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
            }
            #[cfg(unix)]
            AsyncStream::UNIX(s) => {
                s.write_all(&data)
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
                s.flush()
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?;
            }
        }

        Ok(())
    }

    /// Receive a reply from the server
    pub async fn recv(&mut self) -> std::result::Result<MReply, MError> {
        let mut buf = Vec::new();

        // Get stream and read
        let stream_lock = self.connection.stream.clone();
        let mut stream_guard = stream_lock.write().await;
        let stream = stream_guard
            .as_mut()
            .ok_or_else(|| MError::from(context!(ErrorKind::ConnectionClosed)))?;

        // Read until null terminator
        let n = match stream {
            AsyncStream::TCP(s) => {
                let mut reader = BufReader::new(s);
                reader
                    .read_until(0, &mut buf)
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?
            }
            #[cfg(unix)]
            AsyncStream::UNIX(s) => {
                let mut reader = BufReader::new(s);
                reader
                    .read_until(0, &mut buf)
                    .await
                    .map_err(|_| MError::from(context!(ErrorKind::ConnectionClosed)))?
            }
        };

        if n == 0 || buf.is_empty() {
            return Err(MError::from(context!(ErrorKind::ConnectionClosed)));
        }

        // Parse reply using sans-io
        use crate::sansio::types::ParseResult;
        let reply: crate::Reply = match crate::sansio::protocol::parse_message(&buf) {
            ParseResult::Complete { message, .. } => {
                crate::sansio::protocol::parse_reply(&message)?
            }
            ParseResult::Incomplete { .. } => {
                return Err(MError::from(context!(ErrorKind::ConnectionClosed)));
            }
            ParseResult::Invalid { error } => {
                return Err(MError::from(context!(ErrorKind::InvalidParameter(error))));
            }
        };

        // Check for continues
        match reply.continues {
            Some(true) => self.continues = true,
            _ => self.continues = false,
        }

        // Check for error
        if reply.error.is_some() {
            return Err(MError::from(context!(ErrorKind::from(reply))));
        }

        // Parse reply parameters
        match reply {
            crate::Reply {
                parameters: Some(p),
                ..
            } => {
                let mreply: MReply = serde_json::from_value(p).map_err(map_context!())?;
                Ok(mreply)
            }
            crate::Reply {
                parameters: None, ..
            } => {
                let mreply: MReply =
                    serde_json::from_value(serde_json::Value::Object(serde_json::Map::new()))
                        .map_err(map_context!())?;
                Ok(mreply)
            }
        }
    }

    /// Call the method and wait for a single reply
    ///
    /// # Examples
    ///
    /// ```no_run
    /// # #[cfg(feature = "tokio")]
    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
    /// # use varlink::AsyncConnection;
    /// let connection = AsyncConnection::with_address("tcp:127.0.0.1:9999").await?;
    /// // With generated client:
    /// // let mut client = org_example_ping::VarlinkClient::new(connection);
    /// // let reply = client.ping("Hello").call().await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn call(&mut self) -> std::result::Result<MReply, MError> {
        self.send(false, false, false).await?;
        self.recv().await
    }

    /// Call the method with `more` flag to receive multiple replies
    ///
    /// This returns self for use with recv() in a loop
    pub async fn more(&mut self) -> std::result::Result<&mut Self, MError> {
        self.send(false, true, false).await?;
        Ok(self)
    }

    /// Call the method without waiting for a reply (fire-and-forget)
    pub async fn oneway(&mut self) -> std::result::Result<(), MError> {
        self.send(true, false, false).await
    }

    /// Call the method with upgrade flag
    pub async fn upgrade(&mut self) -> std::result::Result<MReply, MError> {
        self.send(false, false, true).await?;
        self.recv().await
    }
}