cf_core_alpha/
session.rs

1use std::{collections::HashMap, time::Duration};
2
3use anyhow::Context;
4use cf_pty_process_alpha::{unix::UnixPtySystem, Child, PtySystem};
5use serde::{Deserialize, Serialize};
6use tokio::{
7    io::{AsyncReadExt, AsyncWriteExt},
8    sync::mpsc::{self, Sender},
9    time,
10};
11
12static CHUNK_LEN: usize = 4096;
13
14pub struct Capabilities {
15    pub shell: bool,
16    pub actions: Option<Vec<Action>>,
17}
18
19/// Outgoing messages sent to the Relay service.
20#[derive(Serialize, Deserialize, Debug)]
21#[serde(tag = "t", content = "c")]
22pub enum TxMessage {
23    #[serde(rename = "tty_output")]
24    Output { session_id: String, data: String },
25    #[serde(rename = "session_created")]
26    SessionCreated { session_id: String },
27    #[serde(rename = "capabilities")]
28    Capabilities {
29        shell: bool,
30        actions: Option<Vec<Action>>,
31    },
32}
33
34#[derive(Serialize, Deserialize, Debug)]
35pub struct Action {
36    pub name: String,
37}
38
39#[derive(Debug, Clone)]
40pub enum RecvMessage {
41    Command { data: String },
42    End,
43}
44
45#[derive(Debug, Clone)]
46pub struct CommandName {
47    pub name: String,
48    pub args: Vec<String>,
49}
50
51pub struct Manager {
52    tx: Sender<TxMessage>,
53    command: CommandName,
54    sessions: HashMap<String, Sender<RecvMessage>>,
55}
56
57impl Manager {
58    pub fn new(tx: Sender<TxMessage>, command: CommandName) -> Self {
59        Self {
60            tx,
61            command,
62            sessions: HashMap::new(),
63        }
64    }
65
66    pub async fn send_command(
67        &mut self,
68        session_id: String,
69        command: String,
70    ) -> anyhow::Result<()> {
71        let proc_tx = self
72            .sessions
73            .get(&session_id)
74            .context("session not found")?;
75
76        proc_tx.send(RecvMessage::Command { data: command }).await?;
77        Ok(())
78    }
79
80    pub async fn create_session(&mut self, id: String) -> anyhow::Result<()> {
81        let proc_tx = spawn_process(id.clone(), self.command.clone(), self.tx.clone()).await?;
82        self.sessions.insert(id.clone(), proc_tx);
83
84        Ok(())
85    }
86
87    pub async fn end_session(&mut self, id: String) -> anyhow::Result<()> {
88        let session = self.sessions.get(&id).context("session not found")?;
89        session.send(RecvMessage::End).await?;
90        Ok(())
91    }
92}
93
94async fn spawn_process(
95    session_id: String,
96    command: CommandName,
97    tx: Sender<TxMessage>,
98) -> anyhow::Result<Sender<RecvMessage>> {
99    let mut cmd = tokio::process::Command::new(command.name.clone());
100    cmd.args(command.args);
101    let (proc_tx, mut proc_rx) = mpsc::channel(100);
102
103    let mut instance = UnixPtySystem::spawn(
104        cmd,
105        cf_pty_process_alpha::PtySystemOptions { raw_mode: false },
106    )?;
107
108    tx.send(TxMessage::SessionCreated {
109        session_id: session_id.clone(),
110    })
111    .await?;
112
113    let mut write = instance.write;
114    let mut read = instance.read;
115
116    tokio::spawn(async move {
117        loop {
118            match proc_rx.recv().await {
119                Some(RecvMessage::Command { data }) => {
120                    let with_newline = format!("{data}\n");
121                    write.write(with_newline.as_bytes()).await.unwrap();
122                }
123                Some(RecvMessage::End) => {
124                    // try and kill the child process with best effort
125                    if let Err(e) = instance.child.kill().await {
126                        println!("error ending session: {}", e)
127                    };
128                }
129                None => break,
130            };
131        }
132    });
133
134    tokio::spawn(async move {
135        let mut buffer = vec![0u8; CHUNK_LEN];
136        while let Ok(read) = read.read(buffer.as_mut_slice()).await {
137            if read == 0 {
138                println!("Received {} bytes", read);
139                break;
140            }
141
142            println!("Received {} bytes", read);
143
144            let mut buf = vec![0; read];
145            buf.copy_from_slice(&buffer[0..read]);
146
147            let data_str = std::str::from_utf8(&buf).unwrap();
148
149            println!("data: {data_str}");
150            let msg = TxMessage::Output {
151                data: data_str.to_owned(),
152                session_id: session_id.clone(),
153            };
154            if let Err(e) = tx.send(msg).await {
155                println!("mpsc send error: {e}")
156            }
157
158            time::sleep(Duration::from_micros(150)).await;
159        }
160    });
161
162    return Ok(proc_tx);
163}