Skip to main content

pipa/runtime/
io_reactor.rs

1use std::collections::HashMap;
2use std::os::unix::io::RawFd;
3
4#[cfg(feature = "fetch")]
5use crate::http::request::{HttpRequest, RequestEvent};
6#[cfg(feature = "fetch")]
7use crate::http::ws::conn::{WsConnection, WsEvent};
8use crate::util::iomux::Poller;
9
10use super::context::JSContext;
11use super::extension::MacroTaskExtension;
12
13pub enum ReactorTask {
14    #[cfg(feature = "fetch")]
15    Http(HttpRequest),
16    #[cfg(feature = "fetch")]
17    Ws(WsConnection),
18}
19
20pub enum AdvanceResult {
21    None,
22    #[cfg(feature = "fetch")]
23    HttpDone(HttpRequest, RequestEvent),
24    #[cfg(feature = "fetch")]
25    WsEvent(WsEvent),
26}
27
28impl ReactorTask {
29    pub fn fd(&self) -> Option<RawFd> {
30        match self {
31            #[cfg(feature = "fetch")]
32            ReactorTask::Http(r) => r.fd(),
33            #[cfg(feature = "fetch")]
34            ReactorTask::Ws(w) => w.fd(),
35        }
36    }
37
38    pub fn wants_read(&self) -> bool {
39        match self {
40            #[cfg(feature = "fetch")]
41            ReactorTask::Http(r) => r.wants_read(),
42            _ => true,
43        }
44    }
45
46    pub fn wants_write(&self) -> bool {
47        match self {
48            #[cfg(feature = "fetch")]
49            ReactorTask::Http(r) => r.wants_write(),
50            _ => false,
51        }
52    }
53
54    pub fn task_type(&self) -> &str {
55        match self {
56            #[cfg(feature = "fetch")]
57            ReactorTask::Http(_) => "http",
58            #[cfg(feature = "fetch")]
59            ReactorTask::Ws(_) => "ws",
60        }
61    }
62}
63
64pub struct IoReactor {
65    poller: Poller,
66    tasks: HashMap<u64, ReactorTask>,
67    next_id: u64,
68}
69
70impl IoReactor {
71    pub fn new() -> Result<Self, String> {
72        let poller = Poller::new()?;
73        Ok(IoReactor {
74            poller,
75            tasks: HashMap::new(),
76            next_id: 1,
77        })
78    }
79
80    pub fn register(&mut self, fd: RawFd, task: ReactorTask) -> Result<u64, String> {
81        let id = self.next_id;
82        self.next_id += 1;
83
84        self.poller
85            .register(fd, task.wants_read(), task.wants_write())?;
86
87        self.tasks.insert(id, task);
88        Ok(id)
89    }
90
91    fn update_registration(&mut self, id: u64) -> Result<(), String> {
92        let task = self
93            .tasks
94            .get(&id)
95            .ok_or_else(|| format!("task {id} not found"))?;
96
97        if let Some(fd) = task.fd() {
98            self.poller
99                .modify(fd, task.wants_read(), task.wants_write())?;
100        }
101        Ok(())
102    }
103
104    fn unregister(&mut self, id: u64) -> Option<ReactorTask> {
105        let task = self.tasks.remove(&id);
106        if let Some(ref t) = task {
107            if let Some(fd) = t.fd() {
108                let _ = self.poller.unregister(fd);
109            }
110        }
111        task
112    }
113
114    #[cfg(feature = "fetch")]
115    pub fn get_http_mut(&mut self, id: u64) -> Option<&mut HttpRequest> {
116        match self.tasks.get_mut(&id) {
117            Some(ReactorTask::Http(r)) => Some(r),
118            _ => None,
119        }
120    }
121
122    #[cfg(feature = "fetch")]
123    pub fn get_ws_mut(&mut self, id: u64) -> Option<&mut WsConnection> {
124        match self.tasks.get_mut(&id) {
125            Some(ReactorTask::Ws(w)) => Some(w),
126            _ => None,
127        }
128    }
129
130    fn poll(&mut self, timeout_ms: i32) -> Result<Vec<(u64, String, PollEvent)>, String> {
131        if self.tasks.is_empty() {
132            return Ok(Vec::new());
133        }
134
135        let events = self.poller.wait(timeout_ms)?;
136        let mut results = Vec::new();
137
138        for event in &events {
139            let ids_to_check: Vec<u64> = self
140                .tasks
141                .iter()
142                .filter(|(_, t)| t.fd() == Some(event.fd))
143                .map(|(id, _)| *id)
144                .collect();
145
146            for id in ids_to_check {
147                let task_type = self.tasks.get(&id).map(|t| t.task_type().to_string());
148                if event.error {
149                    results.push((id, task_type.unwrap_or_default(), PollEvent::Error));
150                } else if event.readable {
151                    results.push((id, task_type.unwrap_or_default(), PollEvent::Readable));
152                } else if event.writable {
153                    results.push((id, task_type.unwrap_or_default(), PollEvent::Writable));
154                }
155            }
156        }
157
158        Ok(results)
159    }
160
161    #[cfg(feature = "fetch")]
162    fn advance_http(&mut self, id: u64) -> Result<AdvanceResult, String> {
163        let task = self
164            .tasks
165            .get_mut(&id)
166            .ok_or_else(|| format!("http task {id} not found"))?;
167        match task {
168            ReactorTask::Http(req) => {
169                let event = req.try_advance()?;
170                match &event {
171                    RequestEvent::NeedRead | RequestEvent::NeedWrite => {
172                        self.update_registration(id)?;
173                    }
174                    RequestEvent::Complete(_) | RequestEvent::Error(_) => {
175                        let mut taken = HttpRequest::dummy();
176                        std::mem::swap(req, &mut taken);
177                        self.unregister(id);
178                        return Ok(AdvanceResult::HttpDone(taken, event));
179                    }
180                }
181                Ok(AdvanceResult::None)
182            }
183            ReactorTask::Ws(_) => Err(format!("task {id} is not http")),
184        }
185    }
186
187    #[cfg(feature = "fetch")]
188    fn advance_ws(&mut self, id: u64) -> Result<AdvanceResult, String> {
189        let task = self
190            .tasks
191            .get_mut(&id)
192            .ok_or_else(|| format!("ws task {id} not found"))?;
193        match task {
194            ReactorTask::Ws(ws) => {
195                let event = ws.try_advance()?;
196                match &event {
197                    Some(ev) => {
198                        if matches!(ev, WsEvent::Close(_, _) | WsEvent::Error(_)) {
199                            self.unregister(id);
200                        }
201                        Ok(AdvanceResult::WsEvent(ev.clone()))
202                    }
203                    None => {
204                        self.update_registration(id)?;
205                        Ok(AdvanceResult::None)
206                    }
207                }
208            }
209            ReactorTask::Http(_) => Err(format!("task {id} is not ws")),
210        }
211    }
212
213    pub fn is_empty(&self) -> bool {
214        self.tasks.is_empty()
215    }
216}
217
218impl MacroTaskExtension for IoReactor {
219    fn tick(&mut self, _ctx: &mut JSContext) -> Result<bool, String> {
220        if self.is_empty() {
221            return Ok(false);
222        }
223
224        let events = self.poll(0)?;
225        for (id, task_type, _) in &events {
226            match task_type.as_str() {
227                #[cfg(feature = "fetch")]
228                "http" => {
229                    let _ = self.advance_http(*id);
230                    let _ = self.update_registration(*id);
231                }
232                #[cfg(feature = "fetch")]
233                "ws" => {
234                    let _ = self.advance_ws(*id);
235                    let _ = self.update_registration(*id);
236                }
237                _ => {}
238            }
239        }
240
241        Ok(!events.is_empty())
242    }
243
244    fn has_pending(&self) -> bool {
245        !self.is_empty()
246    }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
250pub enum PollEvent {
251    Readable,
252    Writable,
253    Error,
254}