use std::io::{self, ErrorKind};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::mpsc::{Receiver, Sender};
use lutra_bin::{Encode, rr};
use crate::binary::messages;
#[derive(Clone)]
pub struct ClientSender {
tx: Sender<messages::ClientMessage>,
next_program_id: Arc<AtomicU32>,
next_request_id: Arc<AtomicU32>,
}
pub struct ClientReceiver {
rx: Receiver<messages::ServerMessage>,
}
pub struct Client {
sender: ClientSender,
receiver: ClientReceiver,
}
impl ClientSender {
pub fn new(tx: Sender<messages::ClientMessage>) -> Self {
Self {
tx,
next_program_id: Arc::new(AtomicU32::new(0)),
next_request_id: Arc::new(AtomicU32::new(0)),
}
}
pub fn prepare(&self, program: &rr::Program) -> io::Result<u32> {
let program_id = self.next_program_id.fetch_add(1, Ordering::Relaxed);
let msg = messages::ClientMessage::Prepare(messages::Prepare {
program_id,
program: program.encode(),
});
self.tx
.send(msg)
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "runner disconnected"))?;
Ok(program_id)
}
pub fn execute(&self, program_id: u32, input: &[u8]) -> io::Result<u32> {
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let msg = messages::ClientMessage::Execute(messages::Execute {
program_id,
request_id,
input: input.to_vec(),
});
self.tx
.send(msg)
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "runner disconnected"))?;
Ok(request_id)
}
pub fn release(&self, program_id: u32) -> io::Result<()> {
let msg = messages::ClientMessage::Release(messages::Release { program_id });
self.tx
.send(msg)
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "runner disconnected"))?;
Ok(())
}
}
impl ClientReceiver {
pub fn new(rx: Receiver<messages::ServerMessage>) -> Self {
Self { rx }
}
pub fn recv(&self) -> io::Result<messages::ServerMessage> {
self.rx
.recv()
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "runner disconnected"))
}
pub fn try_recv(&self) -> Result<messages::ServerMessage, std::sync::mpsc::TryRecvError> {
self.rx.try_recv()
}
pub fn recv_response(&self, request_id: u32) -> io::Result<messages::Result> {
loop {
let msg = self.recv()?;
match msg {
messages::ServerMessage::Response(resp) => {
if resp.request_id == request_id {
return Ok(resp.result);
}
}
}
}
}
}
impl Client {
pub fn new(tx: Sender<messages::ClientMessage>, rx: Receiver<messages::ServerMessage>) -> Self {
Self {
sender: ClientSender::new(tx),
receiver: ClientReceiver::new(rx),
}
}
pub fn split(self) -> (ClientSender, ClientReceiver) {
(self.sender, self.receiver)
}
pub fn prepare(&self, program: &rr::Program) -> io::Result<u32> {
self.sender.prepare(program)
}
pub fn execute(&self, program_id: u32, input: &[u8]) -> io::Result<u32> {
self.sender.execute(program_id, input)
}
pub fn release(&self, program_id: u32) -> io::Result<()> {
self.sender.release(program_id)
}
pub fn recv_response(&self, request_id: u32) -> io::Result<messages::Result> {
self.receiver.recv_response(request_id)
}
pub fn run_once(&self, program: &rr::Program, input: &[u8]) -> io::Result<Vec<u8>> {
let program_id = self.prepare(program)?;
let request_id = self.execute(program_id, input)?;
let result = self.recv_response(request_id)?;
self.release(program_id)?;
match result {
messages::Result::Ok(output) => Ok(output),
messages::Result::Err(err) => {
let msg = match &err.code {
Some(code) => format!("{}: {}", code, err.message),
None => err.message.clone(),
};
Err(io::Error::other(msg))
}
}
}
}