use core::fmt::Debug;
#[cfg(feature = "std")]
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use serde::Serialize;
use super::{BUFFER_SIZE, Call, Reply, socket::WriteHalf};
#[cfg(feature = "std")]
use std::os::fd::OwnedFd;
#[derive(Debug)]
pub struct WriteConnection<Write: WriteHalf> {
pub(super) socket: Write,
pub(super) buffer: Vec<u8>,
pub(super) pos: usize,
id: usize,
#[cfg(feature = "std")]
pending_fds: VecDeque<MessageFds>,
}
impl<Write: WriteHalf> WriteConnection<Write> {
pub(super) fn new(socket: Write, id: usize) -> Self {
Self {
socket,
id,
buffer: alloc::vec![0; BUFFER_SIZE],
pos: 0,
#[cfg(feature = "std")]
pending_fds: VecDeque::new(),
}
}
#[inline]
pub fn id(&self) -> usize {
self.id
}
pub async fn send_call<Method>(
&mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Method: Serialize + Debug,
{
trace!("connection {}: sending call: {:?}", self.id, call);
#[cfg(feature = "std")]
{
self.write(call, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(call).await
}
}
pub async fn send_reply<Params>(
&mut self,
reply: &Reply<Params>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Params: Serialize + Debug,
{
trace!("connection {}: sending reply: {:?}", self.id, reply);
#[cfg(feature = "std")]
{
self.write(reply, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(reply).await
}
}
pub async fn send_error<ReplyError>(
&mut self,
error: &ReplyError,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
ReplyError: Serialize + Debug,
{
trace!("connection {}: sending error: {:?}", self.id, error);
#[cfg(feature = "std")]
{
self.write(error, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(error).await
}
}
pub fn enqueue_call<Method>(
&mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Method: Serialize + Debug,
{
trace!("connection {}: enqueuing call: {:?}", self.id, call);
#[cfg(feature = "std")]
{
self.enqueue(call, fds)
}
#[cfg(not(feature = "std"))]
{
self.enqueue(call)
}
}
pub async fn flush(&mut self) -> crate::Result<()> {
if self.pos == 0 {
return Ok(());
}
#[allow(unused_mut)]
let mut sent_pos = 0;
#[cfg(feature = "std")]
{
while !self.pending_fds.is_empty() {
let pending = self.pending_fds.front().unwrap();
let fd_offset = pending.offset;
let msg_len = pending.len;
if sent_pos < fd_offset {
trace!(
"connection {}: flushing {} bytes before FD message",
self.id,
fd_offset - sent_pos
);
self.socket
.write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
.await?;
}
let msg_end = fd_offset + msg_len;
let pending = self.pending_fds.pop_front().unwrap();
let fds = &pending.fds;
trace!(
"connection {}: flushing {} bytes with {} FDs",
self.id,
msg_len,
fds.len()
);
self.socket
.write(&self.buffer[fd_offset..msg_end], fds)
.await?;
sent_pos = msg_end;
}
}
if sent_pos < self.pos {
trace!(
"connection {}: flushing {} bytes",
self.id,
self.pos - sent_pos
);
#[cfg(feature = "std")]
{
self.socket
.write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
.await?;
}
#[cfg(not(feature = "std"))]
{
self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
}
}
self.pos = 0;
Ok(())
}
pub fn write_half(&self) -> &Write {
&self.socket
}
pub(super) async fn write<T>(
&mut self,
value: &T,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
T: Serialize + ?Sized + Debug,
{
#[cfg(feature = "std")]
{
self.enqueue(value, fds)?;
}
#[cfg(not(feature = "std"))]
{
self.enqueue(value)?;
}
self.flush().await
}
pub(super) fn enqueue<T>(
&mut self,
value: &T,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
T: Serialize + ?Sized + Debug,
{
#[cfg(feature = "std")]
let start_pos = self.pos;
let len = loop {
match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
Ok(len) => break len,
Err(crate::json_ser::Error::BufferTooSmall) => {
self.grow_buffer()?;
}
Err(crate::json_ser::Error::KeyMustBeAString) => {
return Err(crate::Error::Json(serde::ser::Error::custom(
"key must be a string",
)));
}
}
};
if self.pos + len == self.buffer.len() {
self.grow_buffer()?;
}
self.buffer[self.pos + len] = b'\0';
self.pos += len + 1;
#[cfg(feature = "std")]
if !fds.is_empty() {
self.pending_fds.push_back(MessageFds {
offset: start_pos,
len: len + 1, fds,
});
}
Ok(())
}
fn grow_buffer(&mut self) -> crate::Result<()> {
if self.buffer.len() >= super::MAX_BUFFER_SIZE {
return Err(crate::Error::BufferOverflow);
}
self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
Ok(())
}
}
#[cfg(feature = "std")]
#[derive(Debug)]
struct MessageFds {
fds: Vec<OwnedFd>,
offset: usize,
len: usize,
}