1use super::types::{IOInput, IOOutput};
59use orcs_types::ChannelId;
60use tokio::sync::mpsc;
61
62pub const DEFAULT_BUFFER_SIZE: usize = 64;
64
65pub struct IOPort {
70 channel_id: ChannelId,
72 input_rx: mpsc::Receiver<IOInput>,
74 output_tx: mpsc::Sender<IOOutput>,
76}
77
78impl IOPort {
79 #[must_use]
91 pub fn new(channel_id: ChannelId, buffer_size: usize) -> (Self, IOInputHandle, IOOutputHandle) {
92 let (input_tx, input_rx) = mpsc::channel(buffer_size);
93 let (output_tx, output_rx) = mpsc::channel(buffer_size);
94
95 let port = Self {
96 channel_id,
97 input_rx,
98 output_tx,
99 };
100
101 let input_handle = IOInputHandle {
102 tx: input_tx,
103 channel_id,
104 };
105
106 let output_handle = IOOutputHandle {
107 rx: output_rx,
108 channel_id,
109 };
110
111 (port, input_handle, output_handle)
112 }
113
114 #[must_use]
116 pub fn with_defaults(channel_id: ChannelId) -> (Self, IOInputHandle, IOOutputHandle) {
117 Self::new(channel_id, DEFAULT_BUFFER_SIZE)
118 }
119
120 #[must_use]
122 pub fn channel_id(&self) -> ChannelId {
123 self.channel_id
124 }
125
126 pub async fn recv(&mut self) -> Option<IOInput> {
130 self.input_rx.recv().await
131 }
132
133 #[must_use]
137 pub fn try_recv(&mut self) -> Option<IOInput> {
138 self.input_rx.try_recv().ok()
139 }
140
141 pub async fn send(&self, output: IOOutput) -> Result<(), mpsc::error::SendError<IOOutput>> {
147 self.output_tx.send(output).await
148 }
149
150 pub fn try_send(&self, output: IOOutput) -> Result<(), mpsc::error::TrySendError<IOOutput>> {
156 self.output_tx.try_send(output)
157 }
158
159 pub fn drain_input(&mut self) -> Vec<IOInput> {
163 let mut inputs = Vec::new();
164 while let Some(input) = self.try_recv() {
165 inputs.push(input);
166 }
167 inputs
168 }
169
170 #[must_use]
172 pub fn is_output_closed(&self) -> bool {
173 self.output_tx.is_closed()
174 }
175}
176
177impl std::fmt::Debug for IOPort {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("IOPort")
180 .field("channel_id", &self.channel_id)
181 .finish_non_exhaustive()
182 }
183}
184
185#[derive(Clone)]
189pub struct IOInputHandle {
190 tx: mpsc::Sender<IOInput>,
191 channel_id: ChannelId,
192}
193
194impl IOInputHandle {
195 #[must_use]
197 pub fn channel_id(&self) -> ChannelId {
198 self.channel_id
199 }
200
201 pub async fn send(&self, input: IOInput) -> Result<(), mpsc::error::SendError<IOInput>> {
207 self.tx.send(input).await
208 }
209
210 pub fn try_send(&self, input: IOInput) -> Result<(), mpsc::error::TrySendError<IOInput>> {
216 self.tx.try_send(input)
217 }
218
219 pub async fn send_line(
227 &self,
228 text: impl Into<String>,
229 ) -> Result<(), mpsc::error::SendError<IOInput>> {
230 self.send(IOInput::line(text)).await
231 }
232
233 #[must_use]
235 pub fn is_closed(&self) -> bool {
236 self.tx.is_closed()
237 }
238}
239
240impl std::fmt::Debug for IOInputHandle {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("IOInputHandle")
243 .field("channel_id", &self.channel_id)
244 .finish_non_exhaustive()
245 }
246}
247
248pub struct IOOutputHandle {
250 rx: mpsc::Receiver<IOOutput>,
251 channel_id: ChannelId,
252}
253
254impl IOOutputHandle {
255 #[must_use]
257 pub fn channel_id(&self) -> ChannelId {
258 self.channel_id
259 }
260
261 pub async fn recv(&mut self) -> Option<IOOutput> {
265 self.rx.recv().await
266 }
267
268 #[must_use]
272 pub fn try_recv(&mut self) -> Option<IOOutput> {
273 self.rx.try_recv().ok()
274 }
275
276 pub fn drain(&mut self) -> Vec<IOOutput> {
278 let mut outputs = Vec::new();
279 while let Some(output) = self.try_recv() {
280 outputs.push(output);
281 }
282 outputs
283 }
284}
285
286impl std::fmt::Debug for IOOutputHandle {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("IOOutputHandle")
289 .field("channel_id", &self.channel_id)
290 .finish_non_exhaustive()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn port_basic_io() {
300 let channel_id = ChannelId::new();
301 let (mut port, input_handle, mut output_handle) = IOPort::new(channel_id, 8);
302
303 input_handle
305 .send(IOInput::line("hello"))
306 .await
307 .expect("send input should succeed");
308
309 let input = port.recv().await.expect("should receive input");
311 assert_eq!(input.as_line(), Some("hello"));
312
313 port.send(IOOutput::info("received"))
315 .await
316 .expect("send output should succeed");
317
318 let output = output_handle.recv().await.expect("should receive output");
320 assert!(matches!(output, IOOutput::Print { .. }));
321 }
322
323 #[tokio::test]
324 async fn port_try_recv_empty() {
325 let channel_id = ChannelId::new();
326 let (mut port, _input_handle, _output_handle) = IOPort::new(channel_id, 8);
327
328 assert!(port.try_recv().is_none());
330 }
331
332 #[tokio::test]
333 async fn port_drain_input() {
334 let channel_id = ChannelId::new();
335 let (mut port, input_handle, _output_handle) = IOPort::new(channel_id, 8);
336
337 input_handle
339 .send(IOInput::line("one"))
340 .await
341 .expect("send first input should succeed");
342 input_handle
343 .send(IOInput::line("two"))
344 .await
345 .expect("send second input should succeed");
346 input_handle
347 .send(IOInput::line("three"))
348 .await
349 .expect("send third input should succeed");
350
351 let inputs = port.drain_input();
353 assert_eq!(inputs.len(), 3);
354 assert_eq!(inputs[0].as_line(), Some("one"));
355 assert_eq!(inputs[1].as_line(), Some("two"));
356 assert_eq!(inputs[2].as_line(), Some("three"));
357
358 assert!(port.try_recv().is_none());
360 }
361
362 #[tokio::test]
363 async fn input_handle_send_line() {
364 let channel_id = ChannelId::new();
365 let (mut port, input_handle, _output_handle) = IOPort::new(channel_id, 8);
366
367 input_handle
368 .send_line("test")
369 .await
370 .expect("send_line should succeed");
371
372 let input = port.recv().await.expect("should receive input line");
373 assert_eq!(input.as_line(), Some("test"));
374 }
375
376 #[tokio::test]
377 async fn output_handle_drain() {
378 let channel_id = ChannelId::new();
379 let (port, _input_handle, mut output_handle) = IOPort::new(channel_id, 8);
380
381 port.send(IOOutput::info("one"))
383 .await
384 .expect("send info output should succeed");
385 port.send(IOOutput::warn("two"))
386 .await
387 .expect("send warn output should succeed");
388
389 let outputs = output_handle.drain();
391 assert_eq!(outputs.len(), 2);
392 }
393
394 #[tokio::test]
395 async fn port_channel_id() {
396 let channel_id = ChannelId::new();
397 let (port, input_handle, output_handle) = IOPort::new(channel_id, 8);
398
399 assert_eq!(port.channel_id(), channel_id);
400 assert_eq!(input_handle.channel_id(), channel_id);
401 assert_eq!(output_handle.channel_id(), channel_id);
402 }
403
404 #[tokio::test]
405 async fn port_closed_detection() {
406 let channel_id = ChannelId::new();
407 let (port, input_handle, _output_handle) = IOPort::new(channel_id, 8);
408
409 assert!(!input_handle.is_closed());
410 assert!(!port.is_output_closed());
411
412 drop(port);
414
415 assert!(input_handle.is_closed());
417 }
418
419 #[tokio::test]
420 async fn input_handle_clone() {
421 let channel_id = ChannelId::new();
422 let (mut port, input_handle, _output_handle) = IOPort::new(channel_id, 8);
423
424 let input_handle2 = input_handle.clone();
425
426 input_handle
428 .send_line("from handle 1")
429 .await
430 .expect("send from handle 1 should succeed");
431 input_handle2
432 .send_line("from handle 2")
433 .await
434 .expect("send from handle 2 should succeed");
435
436 let inputs = port.drain_input();
437 assert_eq!(inputs.len(), 2);
438 }
439
440 #[test]
441 fn port_debug() {
442 let channel_id = ChannelId::new();
443 let (port, input_handle, output_handle) = IOPort::new(channel_id, 8);
444
445 let _ = format!("{:?}", port);
447 let _ = format!("{:?}", input_handle);
448 let _ = format!("{:?}", output_handle);
449 }
450
451 #[tokio::test]
452 async fn port_with_defaults() {
453 let channel_id = ChannelId::new();
454 let (port, _input_handle, _output_handle) = IOPort::with_defaults(channel_id);
455
456 assert_eq!(port.channel_id(), channel_id);
457 }
458}