use alloc::sync::Arc;
use ax_kspin::SpinNoIrq;
use axpoll::{IoEvents, PollSet};
use ringbuf::{
Cons, HeapRb, Prod,
traits::{Consumer, Producer},
};
use super::{
Tty,
terminal::{
Terminal,
ldisc::{ProcessMode, TtyConfig, TtyRead, TtyWrite},
},
};
const PTY_BUF_SIZE: usize = 4096;
pub type PtyDriver = Tty<PtyReader, PtyWriter>;
type Buffer = Arc<HeapRb<u8>>;
pub struct PtyReader(Cons<Buffer>);
impl PtyReader {
pub fn new(buffer: Buffer) -> Self {
Self(Cons::new(buffer))
}
}
impl TtyRead for PtyReader {
fn read(&mut self, buf: &mut [u8]) -> usize {
self.0.pop_slice(buf)
}
}
#[derive(Clone)]
pub struct PtyWriter(Arc<SpinNoIrq<Prod<Buffer>>>, Arc<PollSet>);
impl PtyWriter {
pub fn new(buffer: Buffer, poll_rx: Arc<PollSet>) -> Self {
Self(Arc::new(SpinNoIrq::new(Prod::new(buffer))), poll_rx)
}
}
impl TtyWrite for PtyWriter {
fn write(&self, buf: &[u8]) {
let read = self.0.lock().push_slice(buf);
unsafe { self.1.wake(IoEvents::IN) };
if read < buf.len() {
warn!("Discarding {} bytes written to pty", buf.len() - read);
}
}
}
pub(crate) fn create_pty_pair() -> (Arc<PtyDriver>, Arc<PtyDriver>) {
let master_to_slave = Arc::new(HeapRb::new(PTY_BUF_SIZE));
let slave_to_master = Arc::new(HeapRb::new(PTY_BUF_SIZE));
let poll_rx_slave = Arc::new(PollSet::new());
let poll_rx_master = Arc::new(PollSet::new());
let terminal = Arc::new(Terminal::default());
let master = Tty::new(
terminal.clone(),
TtyConfig {
reader: PtyReader::new(slave_to_master.clone()),
writer: PtyWriter::new(master_to_slave.clone(), poll_rx_slave.clone()),
process_mode: ProcessMode::Passive(poll_rx_master.clone()),
},
);
let slave = Tty::new(
terminal,
TtyConfig {
reader: PtyReader::new(master_to_slave),
writer: PtyWriter::new(slave_to_master, poll_rx_master),
process_mode: ProcessMode::InterruptDriven(poll_rx_slave),
},
);
(master, slave)
}
#[cfg(test)]
mod tests {
use axpoll::{IoEvents, Pollable};
use crate::pseudofs::DeviceOps;
#[test]
fn pty_preserves_mouse_escape_reports() {
let (master, slave) = super::create_pty_pair();
let report = b"\x1b[<0;1;1M";
assert_eq!(slave.write_at(report, 0), Ok(report.len()));
assert!(master.poll().contains(IoEvents::IN));
let mut buf = [0; 16];
let read = master.read_at(&mut buf, 0).unwrap();
assert_eq!(&buf[..read], report);
}
}