Skip to main content

selium_messaging/
driver.rs

1use std::sync::Arc;
2
3use selium_abi::{ChannelBackpressure, IoFrame};
4use selium_kernel::{
5    drivers::{channel::ChannelCapability, io::IoCapability},
6    guest_data::{GuestError, GuestUint},
7};
8use tokio::io::AsyncWriteExt;
9
10use crate::{Channel, ChannelError, StrongReader, StrongWriter, WeakReader, WeakWriter};
11
12/// Runtime driver for channel hostcalls
13#[derive(Clone)]
14pub struct ChannelDriver;
15
16/// Runtime driver for strong read/write hostcalls
17pub struct ChannelStrongIoDriver;
18
19/// Runtime driver for weak read/write hostcalls
20pub struct ChannelWeakIoDriver;
21
22impl ChannelDriver {
23    /// Create a new channel driver instance
24    pub fn new() -> Arc<Self> {
25        Arc::new(Self)
26    }
27}
28
29impl ChannelCapability for ChannelDriver {
30    type Channel = Arc<Channel>;
31    type StrongReader = StrongReader;
32    type WeakReader = WeakReader;
33    type StrongWriter = StrongWriter;
34    type WeakWriter = WeakWriter;
35    type Error = ChannelError;
36
37    fn create(
38        &self,
39        size: GuestUint,
40        backpressure: ChannelBackpressure,
41    ) -> Result<Self::Channel, Self::Error> {
42        let backpressure = match backpressure {
43            ChannelBackpressure::Park => crate::Backpressure::Park,
44            ChannelBackpressure::Drop => crate::Backpressure::Drop,
45        };
46        Ok(Channel::with_parameters(size as usize, backpressure))
47    }
48
49    fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error> {
50        channel.terminate()
51    }
52
53    fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error> {
54        channel.drain()
55    }
56
57    fn downgrade_writer(
58        &self,
59        writer: Self::StrongWriter,
60    ) -> Result<Self::WeakWriter, Self::Error> {
61        Ok(writer.downgrade())
62    }
63
64    fn downgrade_reader(
65        &self,
66        reader: Self::StrongReader,
67    ) -> Result<Self::WeakReader, Self::Error> {
68        Ok(reader.downgrade())
69    }
70
71    fn ptr(&self, channel: &Self::Channel) -> String {
72        format!("{:p}", Arc::as_ptr(channel))
73    }
74}
75
76impl ChannelStrongIoDriver {
77    /// Create a new channel strong I/O driver instance
78    pub fn new() -> Arc<Self> {
79        Arc::new(Self)
80    }
81}
82
83impl IoCapability for ChannelStrongIoDriver {
84    type Handle = Arc<Channel>;
85    type Reader = StrongReader;
86    type Writer = StrongWriter;
87    type Error = ChannelError;
88
89    fn new_reader(&self, handle: &Self::Handle) -> Result<Self::Reader, Self::Error> {
90        Ok(handle.new_strong_reader())
91    }
92
93    fn new_writer(&self, handle: &Self::Handle) -> Result<Self::Writer, Self::Error> {
94        Ok(handle.new_strong_writer())
95    }
96
97    async fn read(&self, reader: &mut Self::Reader, len: usize) -> Result<IoFrame, Self::Error> {
98        let (id, buf) = reader.read_frame(len).await?;
99        Ok(IoFrame {
100            writer_id: id,
101            payload: buf,
102        })
103    }
104
105    async fn write(&self, writer: &mut Self::Writer, bytes: &[u8]) -> Result<(), Self::Error> {
106        let mut offset = 0;
107        while offset < bytes.len() {
108            let written = writer.write(&bytes[offset..]).await?;
109            if written == 0 {
110                if offset == 0 {
111                    return Ok(());
112                }
113                return Err(ChannelError::Io("write stalled mid-frame".to_string()));
114            }
115            offset += written;
116        }
117        Ok(())
118    }
119}
120
121impl ChannelWeakIoDriver {
122    /// Create a new channel weak I/O driver instance
123    pub fn new() -> Arc<Self> {
124        Arc::new(Self)
125    }
126}
127
128impl IoCapability for ChannelWeakIoDriver {
129    type Handle = Arc<Channel>;
130    type Reader = WeakReader;
131    type Writer = WeakWriter;
132    type Error = ChannelError;
133
134    fn new_reader(&self, handle: &Self::Handle) -> Result<Self::Reader, Self::Error> {
135        Ok(handle.new_weak_reader())
136    }
137
138    fn new_writer(&self, handle: &Self::Handle) -> Result<Self::Writer, Self::Error> {
139        Ok(handle.new_weak_writer())
140    }
141
142    async fn read(&self, reader: &mut Self::Reader, len: usize) -> Result<IoFrame, Self::Error> {
143        let (id, buf) = reader.read_frame(len).await?;
144        Ok(IoFrame {
145            writer_id: id,
146            payload: buf,
147        })
148    }
149
150    async fn write(&self, writer: &mut Self::Writer, bytes: &[u8]) -> Result<(), Self::Error> {
151        let mut offset = 0;
152        while offset < bytes.len() {
153            let written = writer.write(&bytes[offset..]).await?;
154            if written == 0 {
155                if offset == 0 {
156                    return Ok(());
157                }
158                return Err(ChannelError::Io("write stalled mid-frame".to_string()));
159            }
160            offset += written;
161        }
162        Ok(())
163    }
164}
165
166impl From<ChannelError> for GuestError {
167    fn from(value: ChannelError) -> Self {
168        GuestError::Subsystem(value.to_string())
169    }
170}