kratart/
channel.rs

1use std::{
2    collections::HashMap,
3    sync::atomic::{fence, Ordering},
4    time::Duration,
5};
6
7use anyhow::{anyhow, Result};
8use log::{debug, error};
9use tokio::{
10    select,
11    sync::mpsc::{channel, Receiver, Sender},
12    task::JoinHandle,
13    time::sleep,
14};
15use xenevtchn::EventChannelService;
16use xengnt::{sys::GrantRef, GrantTab, MappedMemory};
17use xenstore::{XsdClient, XsdInterface};
18
19const SINGLE_CHANNEL_QUEUE_LEN: usize = 100;
20const GROUPED_CHANNEL_QUEUE_LEN: usize = 1000;
21
22#[repr(C)]
23struct XenConsoleInterface {
24    input: [u8; XenConsoleInterface::INPUT_SIZE],
25    output: [u8; XenConsoleInterface::OUTPUT_SIZE],
26    in_cons: u32,
27    in_prod: u32,
28    out_cons: u32,
29    out_prod: u32,
30}
31
32unsafe impl Send for XenConsoleInterface {}
33
34impl XenConsoleInterface {
35    const INPUT_SIZE: usize = 1024;
36    const OUTPUT_SIZE: usize = 2048;
37}
38
39pub struct ChannelService {
40    typ: String,
41    use_reserved_ref: Option<u64>,
42    backends: HashMap<u32, ChannelBackend>,
43    evtchn: EventChannelService,
44    store: XsdClient,
45    gnttab: GrantTab,
46    input_receiver: Receiver<(u32, Vec<u8>)>,
47    pub input_sender: Sender<(u32, Vec<u8>)>,
48    output_sender: Sender<(u32, Option<Vec<u8>>)>,
49}
50
51impl ChannelService {
52    pub async fn new(
53        typ: String,
54        use_reserved_ref: Option<u64>,
55    ) -> Result<(
56        ChannelService,
57        Sender<(u32, Vec<u8>)>,
58        Receiver<(u32, Option<Vec<u8>>)>,
59    )> {
60        let (input_sender, input_receiver) = channel(GROUPED_CHANNEL_QUEUE_LEN);
61        let (output_sender, output_receiver) = channel(GROUPED_CHANNEL_QUEUE_LEN);
62
63        debug!("opening xenevtchn");
64        let evtchn = EventChannelService::open().await?;
65        debug!("opening xenstore");
66        let store = XsdClient::open().await?;
67        debug!("opening xengnt");
68        let gnttab = GrantTab::open()?;
69
70        Ok((
71            ChannelService {
72                typ,
73                use_reserved_ref,
74                backends: HashMap::new(),
75                evtchn,
76                store,
77                gnttab,
78                input_sender: input_sender.clone(),
79                input_receiver,
80                output_sender,
81            },
82            input_sender,
83            output_receiver,
84        ))
85    }
86
87    pub async fn launch(mut self) -> Result<JoinHandle<()>> {
88        Ok(tokio::task::spawn(async move {
89            if let Err(error) = self.process().await {
90                error!("channel processor failed: {}", error);
91            }
92        }))
93    }
94
95    async fn process(&mut self) -> Result<()> {
96        self.scan_all_backends().await?;
97        let mut watch_handle = self
98            .store
99            .create_watch("/local/domain/0/backend/console")
100            .await?;
101        self.store.bind_watch(&watch_handle).await?;
102        loop {
103            select! {
104                x = watch_handle.receiver.recv() => match x {
105                    Some(_) => {
106                        self.scan_all_backends().await?;
107                    }
108
109                    None => {
110                        break;
111                    }
112                },
113
114                x = self.input_receiver.recv() => match x {
115                    Some((domid, data)) => {
116                        if let Some(backend) = self.backends.get_mut(&domid) {
117                            let _ = backend.sender.try_send(data);
118                        }
119                    },
120
121                    None => {
122                        break;
123                    }
124                }
125            }
126        }
127        Ok(())
128    }
129
130    pub async fn send(&mut self, domid: u32, message: Vec<u8>) -> Result<()> {
131        if let Some(backend) = self.backends.get(&domid) {
132            backend.sender.send(message).await?;
133        }
134        Ok(())
135    }
136
137    async fn ensure_backend_exists(&mut self, domid: u32, id: u32, path: String) -> Result<()> {
138        if self.backends.contains_key(&domid) {
139            return Ok(());
140        }
141        let Some(frontend_path) = self.store.read_string(format!("{}/frontend", path)).await?
142        else {
143            return Ok(());
144        };
145        let Some(typ) = self
146            .store
147            .read_string(format!("{}/type", frontend_path))
148            .await?
149        else {
150            return Ok(());
151        };
152
153        if typ != self.typ {
154            return Ok(());
155        }
156
157        let backend = ChannelBackend::new(
158            path.clone(),
159            frontend_path.clone(),
160            domid,
161            id,
162            self.store.clone(),
163            self.evtchn.clone(),
164            self.gnttab.clone(),
165            self.output_sender.clone(),
166            self.use_reserved_ref,
167        )
168        .await?;
169        self.backends.insert(domid, backend);
170        Ok(())
171    }
172
173    async fn scan_all_backends(&mut self) -> Result<()> {
174        let domains = self.store.list("/local/domain/0/backend/console").await?;
175        let mut seen: Vec<u32> = Vec::new();
176        for domid_string in &domains {
177            let domid = domid_string.parse::<u32>()?;
178            let domid_path = format!("/local/domain/0/backend/console/{}", domid);
179            for id_string in self.store.list(&domid_path).await? {
180                let id = id_string.parse::<u32>()?;
181                let console_path = format!(
182                    "/local/domain/0/backend/console/{}/{}",
183                    domid_string, id_string
184                );
185                self.ensure_backend_exists(domid, id, console_path).await?;
186                seen.push(domid);
187            }
188        }
189
190        let mut gone: Vec<u32> = Vec::new();
191        for backend in self.backends.keys() {
192            if !seen.contains(backend) {
193                gone.push(*backend);
194            }
195        }
196
197        for item in gone {
198            if let Some(backend) = self.backends.remove(&item) {
199                drop(backend);
200            }
201        }
202
203        Ok(())
204    }
205}
206
207pub struct ChannelBackend {
208    pub domid: u32,
209    pub id: u32,
210    pub sender: Sender<Vec<u8>>,
211    raw_sender: Sender<(u32, Option<Vec<u8>>)>,
212    task: JoinHandle<()>,
213}
214
215impl Drop for ChannelBackend {
216    fn drop(&mut self) {
217        self.task.abort();
218        let _ = self.raw_sender.try_send((self.domid, None));
219        debug!(
220            "destroyed channel backend for domain {} channel {}",
221            self.domid, self.id
222        );
223    }
224}
225
226impl ChannelBackend {
227    #[allow(clippy::too_many_arguments)]
228    pub async fn new(
229        backend: String,
230        frontend: String,
231        domid: u32,
232        id: u32,
233        store: XsdClient,
234        evtchn: EventChannelService,
235        gnttab: GrantTab,
236        output_sender: Sender<(u32, Option<Vec<u8>>)>,
237        use_reserved_ref: Option<u64>,
238    ) -> Result<ChannelBackend> {
239        let processor = KrataChannelBackendProcessor {
240            backend,
241            frontend,
242            domid,
243            id,
244            store,
245            evtchn,
246            gnttab,
247            use_reserved_ref,
248        };
249
250        let (input_sender, input_receiver) = channel(SINGLE_CHANNEL_QUEUE_LEN);
251
252        let task = processor
253            .launch(output_sender.clone(), input_receiver)
254            .await?;
255        Ok(ChannelBackend {
256            domid,
257            id,
258            task,
259            raw_sender: output_sender,
260            sender: input_sender,
261        })
262    }
263}
264
265#[derive(Clone)]
266pub struct KrataChannelBackendProcessor {
267    use_reserved_ref: Option<u64>,
268    backend: String,
269    frontend: String,
270    id: u32,
271    domid: u32,
272    store: XsdClient,
273    evtchn: EventChannelService,
274    gnttab: GrantTab,
275}
276
277impl KrataChannelBackendProcessor {
278    async fn init(&self) -> Result<()> {
279        self.store
280            .write_string(format!("{}/state", self.backend), "3")
281            .await?;
282        debug!(
283            "created channel backend for domain {} channel {}",
284            self.domid, self.id
285        );
286        Ok(())
287    }
288
289    async fn on_frontend_state_change(&self) -> Result<bool> {
290        let state = self
291            .store
292            .read_string(format!("{}/state", self.backend))
293            .await?
294            .unwrap_or("0".to_string())
295            .parse::<u32>()?;
296        if state == 3 {
297            return Ok(true);
298        }
299        Ok(false)
300    }
301
302    async fn on_self_state_change(&self) -> Result<bool> {
303        let state = self
304            .store
305            .read_string(format!("{}/state", self.backend))
306            .await?
307            .unwrap_or("0".to_string())
308            .parse::<u32>()?;
309        if state == 5 {
310            return Ok(true);
311        }
312        Ok(false)
313    }
314
315    async fn launch(
316        &self,
317        output_sender: Sender<(u32, Option<Vec<u8>>)>,
318        input_receiver: Receiver<Vec<u8>>,
319    ) -> Result<JoinHandle<()>> {
320        let owned = self.clone();
321        Ok(tokio::task::spawn(async move {
322            if let Err(error) = owned.processor(output_sender, input_receiver).await {
323                error!("failed to process krata channel: {}", error);
324            }
325            let _ = owned
326                .store
327                .write_string(format!("{}/state", owned.backend), "6")
328                .await;
329        }))
330    }
331
332    async fn processor(
333        &self,
334        sender: Sender<(u32, Option<Vec<u8>>)>,
335        mut receiver: Receiver<Vec<u8>>,
336    ) -> Result<()> {
337        self.init().await?;
338        let mut frontend_state_change = self
339            .store
340            .create_watch(format!("{}/state", self.frontend))
341            .await?;
342        self.store.bind_watch(&frontend_state_change).await?;
343
344        let (ring_ref, port) = loop {
345            match frontend_state_change.receiver.recv().await {
346                Some(_) => {
347                    if self.on_frontend_state_change().await? {
348                        let mut tries = 0;
349                        let (ring_ref, port) = loop {
350                            let ring_ref = self
351                                .store
352                                .read_string(format!("{}/ring-ref", self.frontend))
353                                .await?;
354                            let port = self
355                                .store
356                                .read_string(format!("{}/port", self.frontend))
357                                .await?;
358
359                            if (ring_ref.is_none() || port.is_none()) && tries < 40 {
360                                tries += 1;
361                                self.store
362                                    .write_string(format!("{}/state", self.backend), "4")
363                                    .await?;
364                                sleep(Duration::from_millis(250)).await;
365                                continue;
366                            }
367                            break (ring_ref, port);
368                        };
369
370                        if ring_ref.is_none() || port.is_none() {
371                            return Err(anyhow!("frontend did not give ring-ref and port"));
372                        }
373
374                        let Ok(mut ring_ref) = ring_ref.unwrap().parse::<u64>() else {
375                            return Err(anyhow!("frontend gave invalid ring-ref"));
376                        };
377
378                        let Ok(port) = port.unwrap().parse::<u32>() else {
379                            return Err(anyhow!("frontend gave invalid port"));
380                        };
381
382                        ring_ref = self.use_reserved_ref.unwrap_or(ring_ref);
383                        debug!(
384                            "channel backend for domain {} channel {}: ring-ref={} port={}",
385                            self.domid, self.id, ring_ref, port,
386                        );
387                        break (ring_ref, port);
388                    }
389                }
390
391                None => {
392                    return Ok(());
393                }
394            }
395        };
396
397        self.store
398            .write_string(format!("{}/state", self.backend), "4")
399            .await?;
400        let memory = self
401            .gnttab
402            .map_grant_refs(
403                vec![GrantRef {
404                    domid: self.domid,
405                    reference: ring_ref as u32,
406                }],
407                true,
408                true,
409            )
410            .map_err(|e| {
411                anyhow!(
412                    "failed to map grant ref {} for domid {}: {}",
413                    ring_ref,
414                    self.domid,
415                    e
416                )
417            })?;
418        let mut channel = self.evtchn.bind(self.domid, port).await?;
419        unsafe {
420            let buffer = self.read_output_buffer(channel.local_port, &memory).await?;
421            if !buffer.is_empty() {
422                sender.send((self.domid, Some(buffer))).await?;
423            }
424        };
425
426        let mut self_state_change = self
427            .store
428            .create_watch(format!("{}/state", self.backend))
429            .await?;
430        self.store.bind_watch(&self_state_change).await?;
431        loop {
432            select! {
433                x = self_state_change.receiver.recv() => match x {
434                    Some(_) => {
435                        match self.on_self_state_change().await {
436                            Err(error) => {
437                                error!("failed to process state change for domain {} channel {}: {}", self.domid, self.id, error);
438                            },
439
440                            Ok(stop) => {
441                                if stop {
442                                    break;
443                                }
444                            }
445                        }
446                    },
447
448                    None => {
449                        break;
450                    }
451                },
452
453                x = receiver.recv() => match x {
454                    Some(data) => {
455                        let mut index = 0;
456                        loop {
457                            if index >= data.len() {
458                                break;
459                            }
460                            let interface = memory.ptr() as *mut XenConsoleInterface;
461                            let cons = unsafe { (*interface).in_cons };
462                            let mut prod = unsafe { (*interface).in_prod };
463                            fence(Ordering::Release);
464                            let space = (prod - cons) as usize;
465                            if space > XenConsoleInterface::INPUT_SIZE {
466                                error!("channel for domid {} has an invalid input space of {}", self.domid, space);
467                            }
468                            let free = XenConsoleInterface::INPUT_SIZE.wrapping_sub(space);
469                            if free == 0 {
470                                sleep(Duration::from_micros(100)).await;
471                                continue;
472                            }
473                            let want = data.len().min(free);
474                            let buffer = &data[index..want];
475                            for b in buffer {
476                                unsafe { (*interface).input[prod as usize & (XenConsoleInterface::INPUT_SIZE - 1)] = *b; };
477                                prod = prod.wrapping_add(1);
478                            }
479                            fence(Ordering::Release);
480                            unsafe { (*interface).in_prod = prod; };
481                            self.evtchn.notify(channel.local_port).await?;
482                            index += want;
483                        }
484                    },
485
486                    None => {
487                        break;
488                    }
489                },
490
491                x = channel.receiver.recv() => match x {
492                    Some(_) => {
493                        unsafe {
494                            let buffer = self.read_output_buffer(channel.local_port, &memory).await?;
495                            if !buffer.is_empty() {
496                                sender.send((self.domid, Some(buffer))).await?;
497                            }
498                        };
499                        channel.unmask().await?;
500                    },
501
502                    None => {
503                        break;
504                    }
505                }
506            }
507        }
508        Ok(())
509    }
510
511    async unsafe fn read_output_buffer<'a>(
512        &self,
513        local_port: u32,
514        memory: &MappedMemory<'a>,
515    ) -> Result<Vec<u8>> {
516        let interface = memory.ptr() as *mut XenConsoleInterface;
517        let mut cons = (*interface).out_cons;
518        let prod = (*interface).out_prod;
519        fence(Ordering::Release);
520        let size = prod.wrapping_sub(cons);
521        let mut data: Vec<u8> = Vec::new();
522        if size == 0 || size as usize > XenConsoleInterface::OUTPUT_SIZE {
523            return Ok(data);
524        }
525        loop {
526            if cons == prod {
527                break;
528            }
529            data.push((*interface).output[cons as usize & (XenConsoleInterface::OUTPUT_SIZE - 1)]);
530            cons = cons.wrapping_add(1);
531        }
532        fence(Ordering::AcqRel);
533        (*interface).out_cons = cons;
534        self.evtchn.notify(local_port).await?;
535        Ok(data)
536    }
537}