noir_compute/
config.rs

1//! Configuration types used to initialize the [`StreamContext`](crate::StreamContext).
2//!
3//! See the documentation of [`RuntimeConfig`] for more details.
4
5use std::fmt::{Display, Formatter};
6use std::path::Path;
7use std::path::PathBuf;
8use std::str::FromStr;
9
10use anyhow::{bail, Result};
11#[cfg(feature = "clap")]
12use clap::Parser;
13use serde::{Deserialize, Serialize};
14
15use crate::runner::spawn_remote_workers;
16use crate::scheduler::HostId;
17use crate::CoordUInt;
18
19/// Environment variable set by the runner with the host id of the process. If it's missing the
20/// process will have to spawn the processes by itself.
21pub const HOST_ID_ENV_VAR: &str = "NOIR_HOST_ID";
22/// Environment variable set by the runner with the content of the config file so that it's not
23/// required to have it on all the hosts.
24pub const CONFIG_ENV_VAR: &str = "NOIR_CONFIG";
25
26/// The runtime configuration of the environment,
27///
28/// This configuration selects which runtime to use for this execution. The runtime is either local
29/// (i.e. parallelism is achieved only using threads), or remote (i.e. using both threads locally
30/// and remote workers).
31///
32/// In a remote execution the current binary is copied using scp to a remote host and then executed
33/// using ssh. The configuration of the remote environment should be specified via a YAML
34/// configuration file.
35///
36/// ## Local environment
37///
38/// ```
39/// # use noir_compute::{StreamContext, RuntimeConfig};
40/// let config = RuntimeConfig::local(1);
41/// let env = StreamContext::new(config);
42/// ```
43///
44/// ## Remote environment
45///
46/// ```
47/// # use noir_compute::{StreamContext, RuntimeConfig};
48/// # use std::fs::File;
49/// # use std::io::Write;
50/// let config = r"
51/// hosts:
52///   - address: host1
53///     base_port: 9500
54///     num_cores: 16
55///   - address: host2
56///     base_port: 9500
57///     num_cores: 24
58/// ";
59/// let mut file = File::create("config.yaml").unwrap();
60/// file.write_all(config.as_bytes());
61///
62/// let config = RuntimeConfig::remote("config.yaml").expect("cannot read config file");
63/// let env = StreamContext::new(config);
64/// ```
65///
66/// ## From command line arguments
67/// This reads from `std::env::args()` and reads the most common options (`--local`, `--remote`,
68/// `--verbose`). All the unparsed options will be returned into `args`. You can use `--help` to see
69/// their docs.
70///
71/// ```no_run
72/// # use noir_compute::{RuntimeConfig, StreamContext};
73/// let (config, args) = RuntimeConfig::from_args();
74/// let env = StreamContext::new(config);
75/// ```
76#[derive(Debug, Clone, Eq, PartialEq)]
77pub enum RuntimeConfig {
78    /// Use only local threads.
79    Local(LocalConfig),
80    /// Use both local threads and remote workers.
81    Remote(RemoteConfig),
82}
83
84// #[derive(Debug, Clone, Eq, PartialEq)]
85// pub struct RuntimeConfig {
86//     /// Which runtime to use for the environment.
87//     pub runtime: RuntimeConfig,
88//     /// In a remote execution this field represents the identifier of the host, i.e. the index
89//     /// inside the host list in the config.
90//     pub host_id: Option<HostId>,
91// }
92
93/// This environment uses only local threads.
94#[derive(Debug, Clone, Eq, PartialEq)]
95pub struct LocalConfig {
96    /// The number of CPU cores of this host.
97    ///
98    /// A thread will be spawned for each core, for each block in the job graph.
99    pub num_cores: CoordUInt,
100}
101
102/// This environment uses local threads and remote hosts.
103#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
104pub struct RemoteConfig {
105    /// The identifier for this host.
106    #[serde(skip)]
107    pub host_id: Option<HostId>,
108    /// The set of remote hosts to use.
109    pub hosts: Vec<HostConfig>,
110    /// If specified some debug information will be stored inside this directory.
111    pub tracing_dir: Option<PathBuf>,
112    /// Remove remote binaries after execution
113    #[serde(default)]
114    pub cleanup_executable: bool,
115}
116
117/// The configuration of a single remote host.
118#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
119pub struct HostConfig {
120    /// The IP address or domain name to use for connecting to this remote host.
121    ///
122    /// This must be reachable from all the hosts in the cluster.
123    pub address: String,
124    /// The first port to use for inter-host communication.
125    ///
126    /// This port and the following ones will be bound by the host, one for each connection between
127    /// blocks of the job graph..
128    pub base_port: u16,
129    /// The number of cores of the remote host.
130    ///
131    /// This is the same as `LocalRuntimeConfig::num_cores`.
132    pub num_cores: CoordUInt,
133    /// The configuration to use to connect via SSH to the remote host.
134    #[serde(default)]
135    pub ssh: SSHConfig,
136    /// If specified the remote worker will be spawned under `perf`, and its output will be stored
137    /// at this location.
138    pub perf_path: Option<PathBuf>,
139}
140
141/// The information used to connect to a remote host via SSH.
142#[derive(Clone, Serialize, Deserialize, Derivative, Eq, PartialEq)]
143#[derivative(Default)]
144#[allow(clippy::upper_case_acronyms)]
145pub struct SSHConfig {
146    /// The SSH port this host listens to.
147    #[derivative(Default(value = "22"))]
148    #[serde(default = "ssh_default_port")]
149    pub ssh_port: u16,
150    /// The username of the remote host. Defaulted to the local username.
151    pub username: Option<String>,
152    /// The password of the remote host. If not specified ssh-agent will be used for the connection.
153    pub password: Option<String>,
154    /// The path to the private key to use for authenticating to the remote host.
155    pub key_file: Option<PathBuf>,
156    /// The passphrase for decrypting the private SSH key.
157    pub key_passphrase: Option<String>,
158}
159
160impl std::fmt::Debug for SSHConfig {
161    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("RemoteHostSSHConfig")
163            .field("ssh_port", &self.ssh_port)
164            .field("username", &self.username)
165            .field("password", &self.password.as_ref().map(|_| "REDACTED"))
166            .field("key_file", &self.key_file)
167            .field(
168                "key_passphrase",
169                &self.key_passphrase.as_ref().map(|_| "REDACTED"),
170            )
171            .finish()
172    }
173}
174
175#[cfg(feature = "clap")]
176#[derive(Debug, Parser)]
177#[clap(
178    name = "noir",
179    about = "Network of Operators In Rust",
180    trailing_var_arg = true
181)]
182pub struct CommandLineOptions {
183    /// Path to the configuration file for remote execution.
184    ///
185    /// When this is specified the execution will be remote. This conflicts with `--local`.
186    #[clap(short, long)]
187    remote: Option<PathBuf>,
188
189    /// Number of cores in the local execution.
190    ///
191    /// When this is specified the execution will be local. This conflicts with `--remote`.
192    #[clap(short, long)]
193    local: Option<CoordUInt>,
194
195    /// The rest of the arguments.
196    args: Vec<String>,
197}
198
199impl RuntimeConfig {
200    /// Build the configuration from the specified args list.
201    #[cfg(feature = "clap")]
202    pub fn from_args() -> (RuntimeConfig, Vec<String>) {
203        let opt: CommandLineOptions = CommandLineOptions::parse();
204        opt.validate();
205
206        let mut args = opt.args;
207        args.insert(0, std::env::args().next().unwrap());
208
209        if let Some(num_cores) = opt.local {
210            (Self::local(num_cores), args)
211        } else if let Some(remote) = opt.remote {
212            (Self::remote(remote).unwrap(), args)
213        } else {
214            unreachable!("Invalid configuration")
215        }
216    }
217
218    /// Local environment that avoid using the network and runs concurrently using only threads.
219    pub fn local(num_cores: CoordUInt) -> RuntimeConfig {
220        RuntimeConfig::Local(LocalConfig { num_cores })
221    }
222
223    /// Remote environment based on the provided configuration file.
224    ///
225    /// The behaviour of this changes if this process is the "runner" process (ie the one that will
226    /// execute via ssh the other workers) or a worker process.
227    /// If it's the runner, the configuration file is read. If it's a worker, the configuration is
228    /// read directly from the environment variable and not from the file (remote hosts may not have
229    /// the configuration file).
230    pub fn remote<P: AsRef<Path>>(config: P) -> Result<RuntimeConfig> {
231        let mut config = if let Some(config) = RuntimeConfig::config_from_env() {
232            config
233        } else {
234            log::info!("reading config from: {}", config.as_ref().display());
235            let content = std::fs::read_to_string(config)?;
236            serde_yaml::from_str(&content)?
237        };
238
239        // validate the configuration
240        for (host_id, host) in config.hosts.iter().enumerate() {
241            if host.ssh.password.is_some() && host.ssh.key_file.is_some() {
242                bail!("Malformed configuration: cannot specify both password and key file on host {}: {}", host_id, host.address);
243            }
244        }
245
246        config.host_id = RuntimeConfig::host_id_from_env(config.hosts.len().try_into().unwrap());
247        log::debug!("runtime configuration: {config:#?}");
248        Ok(RuntimeConfig::Remote(config))
249    }
250
251    /// Extract the host id from the environment variable, if present.
252    fn host_id_from_env(num_hosts: CoordUInt) -> Option<HostId> {
253        let host_id = match std::env::var(HOST_ID_ENV_VAR) {
254            Ok(host_id) => host_id,
255            Err(_) => return None,
256        };
257        let host_id = match HostId::from_str(&host_id) {
258            Ok(host_id) => host_id,
259            Err(e) => panic!("Invalid value for environment {HOST_ID_ENV_VAR}: {e:?}"),
260        };
261        if host_id >= num_hosts {
262            panic!(
263                "Invalid value for environment {}: value too large, max possible is {}",
264                HOST_ID_ENV_VAR,
265                num_hosts - 1
266            );
267        }
268        Some(host_id)
269    }
270
271    /// Extract the configuration from the environment, if it's present.
272    fn config_from_env() -> Option<RemoteConfig> {
273        match std::env::var(CONFIG_ENV_VAR) {
274            Ok(config) => {
275                info!("reading remote config from env {}", CONFIG_ENV_VAR);
276                let config: RemoteConfig =
277                    serde_yaml::from_str(&config).expect("Invalid configuration from environment");
278                Some(config)
279            }
280            Err(_) => None,
281        }
282    }
283
284    /// Spawn the remote workers via SSH and exit if this is the process that should spawn. If this
285    /// is already a spawned process nothing is done.
286    pub fn spawn_remote_workers(&self) {
287        match &self {
288            RuntimeConfig::Local(_) => {}
289            #[cfg(feature = "ssh")]
290            RuntimeConfig::Remote(remote) => {
291                spawn_remote_workers(remote.clone());
292            }
293            #[cfg(not(feature = "ssh"))]
294            RuntimeConfig::Remote(_) => {
295                panic!("spawn_remote_workers() requires the `ssh` feature for remote configs.");
296            }
297        }
298    }
299
300    pub fn host_id(&self) -> Option<HostId> {
301        match self {
302            RuntimeConfig::Local(_) => Some(0),
303            RuntimeConfig::Remote(remote) => remote.host_id,
304        }
305    }
306}
307
308impl Display for HostConfig {
309    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
310        write!(f, "[{}:{}-]", self.address, self.base_port)
311    }
312}
313
314#[cfg(feature = "clap")]
315impl CommandLineOptions {
316    /// Check that the configuration provided is valid.
317    fn validate(&self) {
318        if !(self.remote.is_some() ^ self.local.is_some()) {
319            panic!("Use one of --remote or --local");
320        }
321        if let Some(threads) = self.local {
322            if threads == 0 {
323                panic!("The number of cores should be positive");
324            }
325        }
326    }
327}
328
329/// Default port for ssh, used by the serde default value.
330fn ssh_default_port() -> u16 {
331    22
332}