use core::{fmt::Debug, str::from_utf8_unchecked};
use crate::{Result, varlink_service};
use super::{
BUFFER_SIZE, Call, MAX_BUFFER_SIZE,
reply::{self, Reply},
socket::ReadHalf,
};
#[cfg(feature = "std")]
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use serde::Deserialize;
use serde_json::Deserializer;
#[cfg(feature = "std")]
use std::os::fd::OwnedFd;
#[cfg(feature = "std")]
type RecvResult<T> = (T, Vec<OwnedFd>);
#[cfg(not(feature = "std"))]
type RecvResult<T> = T;
#[derive(Debug)]
pub struct ReadConnection<Read: ReadHalf> {
socket: Read,
read_pos: usize,
msg_pos: usize,
buffer: Vec<u8>,
id: usize,
#[cfg(feature = "std")]
pending_fds: VecDeque<Vec<OwnedFd>>,
}
impl<Read: ReadHalf> ReadConnection<Read> {
pub(super) fn new(socket: Read, id: usize) -> Self {
Self {
socket,
read_pos: 0,
msg_pos: 0,
id,
buffer: alloc::vec![0; BUFFER_SIZE],
#[cfg(feature = "std")]
pending_fds: VecDeque::new(),
}
}
#[inline]
pub fn id(&self) -> usize {
self.id
}
pub async fn receive_reply<'r, ReplyParams, ReplyError>(
&'r mut self,
) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
where
ReplyParams: Deserialize<'r> + Debug,
ReplyError: Deserialize<'r> + Debug,
{
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ReplyMsg<'m, ReplyParams, ReplyError> {
#[serde(borrow)]
Varlink(varlink_service::Error<'m>),
Error(ReplyError),
Reply(Reply<ReplyParams>),
}
let recv_result = self
.read_message::<ReplyMsg<'_, ReplyParams, ReplyError>>()
.await?;
#[cfg(feature = "std")]
let (msg, fds) = recv_result;
#[cfg(not(feature = "std"))]
let msg = recv_result;
let result = match msg {
ReplyMsg::Varlink(e) => Err(crate::Error::VarlinkService(e.into())),
ReplyMsg::Error(e) => Ok(Err(e)),
ReplyMsg::Reply(reply) => Ok(Ok(reply)),
};
#[cfg(feature = "std")]
return result.map(|r| (r, fds));
#[cfg(not(feature = "std"))]
return result;
}
pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
where
Method: Deserialize<'m> + Debug,
{
self.read_message::<Call<Method>>().await
}
async fn read_message<'m, M>(&'m mut self) -> Result<RecvResult<M>>
where
M: Deserialize<'m> + Debug,
{
self.read_from_socket().await?;
let mut stream = Deserializer::from_slice(&self.buffer[self.msg_pos..]).into_iter::<M>();
let msg = stream.next();
let null_index = self.msg_pos + stream.byte_offset();
let buffer = &self.buffer[self.msg_pos..null_index];
if self.buffer[null_index + 1] == b'\0' {
self.read_pos = 0;
self.msg_pos = 0;
} else {
self.msg_pos = null_index + 1;
}
match msg {
Some(Ok(msg)) => {
trace!("connection {}: received a message: {}", self.id, unsafe {
from_utf8_unchecked(buffer)
});
#[cfg(feature = "std")]
{
let fds = self.pending_fds.pop_front().unwrap_or_default();
Ok((msg, fds))
}
#[cfg(not(feature = "std"))]
Ok(msg)
}
Some(Err(e)) => Err(e.into()),
None => Err(crate::Error::UnexpectedEof),
}
}
async fn read_from_socket(&mut self) -> Result<()> {
if self.msg_pos > 0 {
return Ok(());
}
loop {
#[cfg(feature = "std")]
let (bytes_read, fds) = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
#[cfg(not(feature = "std"))]
let bytes_read = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
if bytes_read == 0 {
return Err(crate::Error::UnexpectedEof);
}
self.read_pos += bytes_read;
#[cfg(feature = "std")]
if !fds.is_empty() {
self.pending_fds.push_back(fds);
}
if self.read_pos == self.buffer.len() {
if self.read_pos >= MAX_BUFFER_SIZE {
return Err(crate::Error::BufferOverflow);
}
self.buffer.extend(core::iter::repeat_n(0, BUFFER_SIZE));
}
self.buffer[self.read_pos] = b'\0';
if self.buffer[self.read_pos - 1] == b'\0' {
break;
}
}
Ok(())
}
pub fn read_half(&self) -> &Read {
&self.socket
}
}