pipa/runtime/
io_reactor.rs1use std::collections::HashMap;
2use std::os::unix::io::RawFd;
3
4#[cfg(feature = "fetch")]
5use crate::http::request::{HttpRequest, RequestEvent};
6#[cfg(feature = "fetch")]
7use crate::http::ws::conn::{WsConnection, WsEvent};
8use crate::util::iomux::Poller;
9
10use super::context::JSContext;
11use super::extension::MacroTaskExtension;
12
13pub enum ReactorTask {
14 #[cfg(feature = "fetch")]
15 Http(HttpRequest),
16 #[cfg(feature = "fetch")]
17 Ws(WsConnection),
18}
19
20pub enum AdvanceResult {
21 None,
22 #[cfg(feature = "fetch")]
23 HttpDone(HttpRequest, RequestEvent),
24 #[cfg(feature = "fetch")]
25 WsEvent(WsEvent),
26}
27
28impl ReactorTask {
29 pub fn fd(&self) -> Option<RawFd> {
30 match self {
31 #[cfg(feature = "fetch")]
32 ReactorTask::Http(r) => r.fd(),
33 #[cfg(feature = "fetch")]
34 ReactorTask::Ws(w) => w.fd(),
35 }
36 }
37
38 pub fn wants_read(&self) -> bool {
39 match self {
40 #[cfg(feature = "fetch")]
41 ReactorTask::Http(r) => r.wants_read(),
42 _ => true,
43 }
44 }
45
46 pub fn wants_write(&self) -> bool {
47 match self {
48 #[cfg(feature = "fetch")]
49 ReactorTask::Http(r) => r.wants_write(),
50 _ => false,
51 }
52 }
53
54 pub fn task_type(&self) -> &str {
55 match self {
56 #[cfg(feature = "fetch")]
57 ReactorTask::Http(_) => "http",
58 #[cfg(feature = "fetch")]
59 ReactorTask::Ws(_) => "ws",
60 }
61 }
62}
63
64pub struct IoReactor {
65 poller: Poller,
66 tasks: HashMap<u64, ReactorTask>,
67 next_id: u64,
68}
69
70impl IoReactor {
71 pub fn new() -> Result<Self, String> {
72 let poller = Poller::new()?;
73 Ok(IoReactor {
74 poller,
75 tasks: HashMap::new(),
76 next_id: 1,
77 })
78 }
79
80 pub fn register(&mut self, fd: RawFd, task: ReactorTask) -> Result<u64, String> {
81 let id = self.next_id;
82 self.next_id += 1;
83
84 self.poller
85 .register(fd, task.wants_read(), task.wants_write())?;
86
87 self.tasks.insert(id, task);
88 Ok(id)
89 }
90
91 fn update_registration(&mut self, id: u64) -> Result<(), String> {
92 let task = self
93 .tasks
94 .get(&id)
95 .ok_or_else(|| format!("task {id} not found"))?;
96
97 if let Some(fd) = task.fd() {
98 self.poller
99 .modify(fd, task.wants_read(), task.wants_write())?;
100 }
101 Ok(())
102 }
103
104 fn unregister(&mut self, id: u64) -> Option<ReactorTask> {
105 let task = self.tasks.remove(&id);
106 if let Some(ref t) = task {
107 if let Some(fd) = t.fd() {
108 let _ = self.poller.unregister(fd);
109 }
110 }
111 task
112 }
113
114 #[cfg(feature = "fetch")]
115 pub fn get_http_mut(&mut self, id: u64) -> Option<&mut HttpRequest> {
116 match self.tasks.get_mut(&id) {
117 Some(ReactorTask::Http(r)) => Some(r),
118 _ => None,
119 }
120 }
121
122 #[cfg(feature = "fetch")]
123 pub fn get_ws_mut(&mut self, id: u64) -> Option<&mut WsConnection> {
124 match self.tasks.get_mut(&id) {
125 Some(ReactorTask::Ws(w)) => Some(w),
126 _ => None,
127 }
128 }
129
130 fn poll(&mut self, timeout_ms: i32) -> Result<Vec<(u64, String, PollEvent)>, String> {
131 if self.tasks.is_empty() {
132 return Ok(Vec::new());
133 }
134
135 let events = self.poller.wait(timeout_ms)?;
136 let mut results = Vec::new();
137
138 for event in &events {
139 let ids_to_check: Vec<u64> = self
140 .tasks
141 .iter()
142 .filter(|(_, t)| t.fd() == Some(event.fd))
143 .map(|(id, _)| *id)
144 .collect();
145
146 for id in ids_to_check {
147 let task_type = self.tasks.get(&id).map(|t| t.task_type().to_string());
148 if event.error {
149 results.push((id, task_type.unwrap_or_default(), PollEvent::Error));
150 } else if event.readable {
151 results.push((id, task_type.unwrap_or_default(), PollEvent::Readable));
152 } else if event.writable {
153 results.push((id, task_type.unwrap_or_default(), PollEvent::Writable));
154 }
155 }
156 }
157
158 Ok(results)
159 }
160
161 #[cfg(feature = "fetch")]
162 fn advance_http(&mut self, id: u64) -> Result<AdvanceResult, String> {
163 let task = self
164 .tasks
165 .get_mut(&id)
166 .ok_or_else(|| format!("http task {id} not found"))?;
167 match task {
168 ReactorTask::Http(req) => {
169 let event = req.try_advance()?;
170 match &event {
171 RequestEvent::NeedRead | RequestEvent::NeedWrite => {
172 self.update_registration(id)?;
173 }
174 RequestEvent::Complete(_) | RequestEvent::Error(_) => {
175 let mut taken = HttpRequest::dummy();
176 std::mem::swap(req, &mut taken);
177 self.unregister(id);
178 return Ok(AdvanceResult::HttpDone(taken, event));
179 }
180 }
181 Ok(AdvanceResult::None)
182 }
183 ReactorTask::Ws(_) => Err(format!("task {id} is not http")),
184 }
185 }
186
187 #[cfg(feature = "fetch")]
188 fn advance_ws(&mut self, id: u64) -> Result<AdvanceResult, String> {
189 let task = self
190 .tasks
191 .get_mut(&id)
192 .ok_or_else(|| format!("ws task {id} not found"))?;
193 match task {
194 ReactorTask::Ws(ws) => {
195 let event = ws.try_advance()?;
196 match &event {
197 Some(ev) => {
198 if matches!(ev, WsEvent::Close(_, _) | WsEvent::Error(_)) {
199 self.unregister(id);
200 }
201 Ok(AdvanceResult::WsEvent(ev.clone()))
202 }
203 None => {
204 self.update_registration(id)?;
205 Ok(AdvanceResult::None)
206 }
207 }
208 }
209 ReactorTask::Http(_) => Err(format!("task {id} is not ws")),
210 }
211 }
212
213 pub fn is_empty(&self) -> bool {
214 self.tasks.is_empty()
215 }
216}
217
218impl MacroTaskExtension for IoReactor {
219 fn tick(&mut self, _ctx: &mut JSContext) -> Result<bool, String> {
220 if self.is_empty() {
221 return Ok(false);
222 }
223
224 let events = self.poll(0)?;
225 for (id, task_type, _) in &events {
226 match task_type.as_str() {
227 #[cfg(feature = "fetch")]
228 "http" => {
229 let _ = self.advance_http(*id);
230 let _ = self.update_registration(*id);
231 }
232 #[cfg(feature = "fetch")]
233 "ws" => {
234 let _ = self.advance_ws(*id);
235 let _ = self.update_registration(*id);
236 }
237 _ => {}
238 }
239 }
240
241 Ok(!events.is_empty())
242 }
243
244 fn has_pending(&self) -> bool {
245 !self.is_empty()
246 }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
250pub enum PollEvent {
251 Readable,
252 Writable,
253 Error,
254}