1pub mod buffer;
7pub mod ipc_trait;
8pub mod protocol;
9pub mod pty_manager;
10pub mod state;
11
12#[cfg(test)]
13mod tests;
14
15pub use buffer::{ScrollbackBuffer, ScrollbackLine};
17pub use ipc_trait::{DirectChannel, IpcChannel};
18pub use protocol::{ControlMessage, TerminalInput, TerminalMetadata, TerminalOutput};
19pub use pty_manager::{PtyHandle, PtyManager};
20pub use state::TerminalState;
21
22use std::collections::HashMap;
23use std::sync::Arc;
24use tokio::sync::RwLock;
25
26#[allow(async_fn_in_trait)]
28pub trait PtyHandleExt {
29 async fn write(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>>;
30 async fn kill(&self) -> Result<(), Box<dyn std::error::Error>>;
31}
32
33impl PtyHandleExt for PtyHandle {
34 async fn write(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
35 self.send_input(bytes::Bytes::from(data.to_vec()))
36 .await
37 .map_err(|e| e.into())
38 }
39
40 async fn kill(&self) -> Result<(), Box<dyn std::error::Error>> {
41 self.shutdown().await.map_err(|e| e.into())
42 }
43}
44
45pub struct TerminalStreamManager {
47 pty_manager: Arc<PtyManager>,
48 ipc_channel: Arc<dyn IpcChannel>,
49 terminals: Arc<RwLock<HashMap<String, TerminalState>>>,
50 pty_handles: Arc<RwLock<HashMap<String, PtyHandle>>>,
51 scrollback_buffers: Arc<RwLock<HashMap<String, ScrollbackBuffer>>>,
52}
53
54impl TerminalStreamManager {
55 pub fn with_ipc_channel(ipc_channel: Arc<dyn IpcChannel>) -> Self {
57 let pty_manager = Arc::new(PtyManager::new());
58
59 Self {
60 pty_manager,
61 ipc_channel,
62 terminals: Arc::new(RwLock::new(HashMap::new())),
63 pty_handles: Arc::new(RwLock::new(HashMap::new())),
64 scrollback_buffers: Arc::new(RwLock::new(HashMap::new())),
65 }
66 }
67
68 pub async fn create_terminal(
70 &self,
71 terminal_id: String,
72 shell: Option<String>,
73 rows: u16,
74 cols: u16,
75 ) -> Result<PtyHandle, Box<dyn std::error::Error>> {
76 let pty_handle = self
78 .pty_manager
79 .create_pty(terminal_id.clone(), shell, rows, cols)
80 .await?;
81
82 let terminal_state = TerminalState::new(terminal_id.clone(), rows, cols);
84 self.terminals
85 .write()
86 .await
87 .insert(terminal_id.clone(), terminal_state);
88
89 self.pty_handles
91 .write()
92 .await
93 .insert(terminal_id.clone(), pty_handle.clone());
94
95 let scrollback_buffer = ScrollbackBuffer::with_default_size();
97 self.scrollback_buffers
98 .write()
99 .await
100 .insert(terminal_id.clone(), scrollback_buffer);
101
102 self.ipc_channel
104 .start_streaming(terminal_id, pty_handle.clone())
105 .await?;
106
107 Ok(pty_handle)
108 }
109
110 pub async fn send_input(
112 &self,
113 terminal_id: &str,
114 input: TerminalInput,
115 ) -> Result<(), Box<dyn std::error::Error>> {
116 let handles = self.pty_handles.read().await;
117 let pty_handle = handles
118 .get(terminal_id)
119 .ok_or_else(|| format!("Terminal {terminal_id} not found"))?;
120
121 match input {
122 TerminalInput::Text(text) => {
123 pty_handle.write(text.as_bytes()).await?;
124 }
125 TerminalInput::Binary(data) => {
126 pty_handle.write(&data).await?;
127 }
128 TerminalInput::SpecialKey(key) => {
129 pty_handle.write(key.as_bytes()).await?;
132 }
133 }
134
135 Ok(())
136 }
137
138 pub async fn send_control(
140 &self,
141 terminal_id: &str,
142 control: ControlMessage,
143 ) -> Result<(), Box<dyn std::error::Error>> {
144 match control {
145 ControlMessage::Resize { rows, cols } => {
146 let handles = self.pty_handles.read().await;
147 if let Some(pty_handle) = handles.get(terminal_id) {
148 pty_handle.resize(rows, cols).await?;
149
150 let mut terminals = self.terminals.write().await;
152 if let Some(state) = terminals.get_mut(terminal_id) {
153 state.resize(rows, cols);
154 }
155 }
156 }
157 _ => {
158 }
160 }
161
162 Ok(())
163 }
164
165 pub async fn kill_terminal(&self, terminal_id: &str) -> Result<(), Box<dyn std::error::Error>> {
167 self.ipc_channel.stop_streaming(terminal_id).await?;
169
170 if let Some(handle) = self.pty_handles.write().await.remove(terminal_id) {
172 handle.kill().await?;
173 }
174
175 self.terminals.write().await.remove(terminal_id);
177 self.scrollback_buffers.write().await.remove(terminal_id);
178
179 Ok(())
180 }
181
182 pub async fn list_terminals(&self) -> Vec<String> {
184 self.terminals.read().await.keys().cloned().collect()
185 }
186}