Skip to main content

dagger_sdk/core/
cli_session.rs

1use std::{fs::canonicalize, path::Path, process::Stdio, sync::Arc};
2
3use eyre::Context;
4use tokio::{io::AsyncBufReadExt, sync::broadcast};
5
6use crate::core::{config::Config, connect_params::ConnectParams};
7
8#[derive(Clone, Debug)]
9pub struct CliSession {
10    inner: Arc<InnerCliSession>,
11}
12
13impl Default for CliSession {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl CliSession {
20    pub fn new() -> Self {
21        Self {
22            inner: Arc::new(InnerCliSession {}),
23        }
24    }
25
26    pub async fn connect(
27        &self,
28        config: &Config,
29        cli_path: &Path,
30    ) -> eyre::Result<(ConnectParams, DaggerSessionProc)> {
31        self.inner.connect(config, cli_path).await
32    }
33}
34
35pub struct DaggerSessionProc {
36    shutdown: broadcast::Sender<()>,
37
38    inner: tokio::sync::Mutex<tokio::process::Child>,
39}
40
41impl DaggerSessionProc {
42    pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> {
43        self.shutdown.subscribe()
44    }
45
46    pub async fn shutdown(&self) -> eyre::Result<()> {
47        let mut proc = self.inner.lock().await;
48
49        tracing::trace!("waiting for dagger subprocess to shutdown");
50
51        tracing::trace!("sending shutdown signal");
52        if let Err(e) = self.shutdown.send(()) {
53            tracing::warn!("failed to send shutdown signal: {}", e);
54        }
55
56        tracing::trace!("closing stdin");
57        proc.wait().await.context("failed to shutdown session")?;
58
59        tracing::trace!("dagger subprocess shutdown");
60
61        Ok(())
62    }
63}
64
65impl From<tokio::process::Child> for DaggerSessionProc {
66    fn from(value: tokio::process::Child) -> Self {
67        let (tx, _) = broadcast::channel::<()>(1);
68
69        Self {
70            inner: tokio::sync::Mutex::new(value),
71            shutdown: tx,
72        }
73    }
74}
75
76#[derive(Debug)]
77struct InnerCliSession {}
78
79impl InnerCliSession {
80    pub async fn connect(
81        &self,
82        config: &Config,
83        cli_path: &Path,
84    ) -> eyre::Result<(ConnectParams, DaggerSessionProc)> {
85        let proc = self.start(config, cli_path)?;
86        let params = self.get_conn(proc, config).await?;
87
88        Ok(params)
89    }
90
91    fn start(&self, config: &Config, cli_path: &Path) -> eyre::Result<tokio::process::Child> {
92        let mut args: Vec<String> = vec!["session".into()];
93        if let Some(workspace) = &config.workdir_path {
94            let abs_path = canonicalize(workspace)?;
95            args.extend(["--workdir".into(), abs_path.to_string_lossy().to_string()])
96        }
97        if let Some(config_path) = &config.config_path {
98            let abs_path = canonicalize(config_path)?;
99            args.extend(["--project".into(), abs_path.to_string_lossy().to_string()])
100        }
101        if config.load_workspace_modules {
102            args.push("--load-workspace-modules".into());
103        }
104
105        args.extend(["--label".into(), "dagger.io/sdk.name:rust".into()]);
106        args.extend([
107            "--label".into(),
108            format!("dagger.io/sdk.version:{}", env!("CARGO_PKG_VERSION")),
109        ]);
110
111        let proc = tokio::process::Command::new(
112            cli_path
113                .to_str()
114                .ok_or(eyre::anyhow!("could not get string from path"))?,
115        )
116        .args(args.as_slice())
117        .stdin(Stdio::piped())
118        .stdout(Stdio::piped())
119        .stderr(Stdio::piped())
120        .spawn()?;
121
122        //TODO: Add retry mechanism
123
124        Ok(proc)
125    }
126
127    async fn get_conn(
128        &self,
129        mut proc: tokio::process::Child,
130        config: &Config,
131    ) -> eyre::Result<(ConnectParams, DaggerSessionProc)> {
132        let stdout = proc
133            .stdout
134            .take()
135            .ok_or(eyre::anyhow!("could not acquire stdout from child process"))?;
136
137        let stderr = proc
138            .stderr
139            .take()
140            .ok_or(eyre::anyhow!("could not acquire stderr from child process"))?;
141
142        let session: DaggerSessionProc = proc.into();
143
144        let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
145        let logger = config.logger.as_ref().map(|p| p.clone());
146        let mut rx = session.subscribe_shutdown();
147
148        tokio::spawn(async move {
149            let mut stdout_bufr = tokio::io::BufReader::new(stdout).lines();
150            loop {
151                tokio::select! {
152                    line = stdout_bufr.next_line() => {
153                        if let Ok(Some(line)) = line {
154                            if let Ok(conn) = serde_json::from_str::<ConnectParams>(&line) {
155                                sender.send(conn).await.unwrap();
156                                continue;
157                            }
158
159                            if let Some(logger) = &logger {
160                                logger.stdout(&line).unwrap();
161                            }
162                        }
163                    },
164                    _ = rx.recv() => {
165                        drop(stdout_bufr);
166                        tracing::trace!("shutting down stdout");
167                        break;
168                    },
169                };
170            }
171
172            tracing::trace!("closing stdout for dagger session");
173        });
174
175        let mut rx = session.subscribe_shutdown();
176        let logger = config.logger.as_ref().map(|p| p.clone());
177        tokio::spawn(async move {
178            let mut stderr_bufr = tokio::io::BufReader::new(stderr).lines();
179            loop {
180                tokio::select! {
181                    line = stderr_bufr.next_line() => {
182                        if let Ok(Some(line)) = line {
183                            if let Some(logger) = &logger {
184                                logger.stderr(&line).unwrap();
185                            }
186                        }
187                    },
188                    _ = rx.recv() => {
189                        drop(stderr_bufr);
190                        tracing::trace!("shutting down stderr");
191                        break;
192                    },
193                };
194            }
195
196            tracing::trace!("closing stderr for dagger session");
197        });
198
199        let conn = receiver.recv().await.ok_or(eyre::anyhow!(
200            "could not receive ok signal from dagger-engine"
201        ))?;
202
203        Ok((conn, session))
204    }
205}