Skip to main content

geckodriver_librs/
lib.rs

1#![forbid(unsafe_code)]
2
3extern crate chrono;
4#[macro_use]
5extern crate clap;
6#[macro_use]
7extern crate lazy_static;
8extern crate hyper;
9extern crate marionette as marionette_rs;
10extern crate mozdevice;
11extern crate mozprofile;
12extern crate mozrunner;
13extern crate mozversion;
14extern crate regex;
15extern crate serde;
16#[macro_use]
17extern crate serde_derive;
18extern crate serde_json;
19extern crate tempfile;
20pub extern crate url;
21extern crate uuid;
22extern crate webdriver;
23extern crate yaml_rust;
24extern crate zip;
25
26#[macro_use]
27pub extern crate log;
28
29use std::env;
30use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
31use std::path::PathBuf;
32use std::process::ExitCode;
33
34use std::str::FromStr;
35
36use clap::{Arg, ArgAction, Command};
37
38macro_rules! try_opt {
39    ($expr:expr, $err_type:expr, $err_msg:expr) => {{
40        match $expr {
41            Some(x) => x,
42            None => return Err(WebDriverError::new($err_type, $err_msg)),
43        }
44    }};
45}
46
47mod android;
48mod browser;
49mod build;
50mod capabilities;
51mod command;
52mod logging;
53pub mod marionette;
54mod prefs;
55
56#[cfg(test)]
57pub mod test;
58
59use crate::command::extension_routes;
60use crate::logging::Level;
61use crate::marionette::{MarionetteHandler, MarionetteSettings};
62use anyhow::{bail, Result as ProgramResult};
63use clap::ArgMatches;
64use mozdevice::AndroidStorageInput;
65use url::{Host, Url};
66
67const EXIT_USAGE: u8 = 64;
68const EXIT_UNAVAILABLE: u8 = 69;
69
70#[allow(clippy::large_enum_variant)]
71enum Operation {
72    Help,
73    Version,
74    Server {
75        log_level: Option<Level>,
76        log_truncate: bool,
77        address: SocketAddr,
78        allow_hosts: Vec<Host>,
79        allow_origins: Vec<Url>,
80        settings: MarionetteSettings,
81        deprecated_enable_crash_reporter: bool,
82        deprecated_storage_arg: bool,
83    },
84}
85
86pub struct GeckodriverSettings {
87    pub marionette: marionette::MarionetteSettings,
88    pub address: SocketAddr,
89    pub allowed_hosts: Vec<Host>,
90    pub allowed_origins: Vec<Url>,
91    pub log_level: Option<Level>,
92    pub log_truncate: bool,
93}
94
95impl Default for GeckodriverSettings {
96    fn default() -> Self {
97        Self {
98            marionette: marionette::MarionetteSettings::default(),
99            address: std::net::SocketAddr::new("127.0.0.1".parse().unwrap(), 4444),
100            allowed_hosts: vec![Host::Ipv4("127.0.0.1".parse().unwrap())],
101            allowed_origins: vec![],
102            log_level: Some(Level::Info),
103            log_truncate: false,
104        }
105    }
106}
107
108pub fn start(settings: GeckodriverSettings) -> std::io::Result<()> {
109    if let Some(ref level) = settings.log_level {
110        logging::init_with_level(*level, false).unwrap();
111    } else {
112        logging::init(settings.log_truncate).unwrap();
113    }
114
115    let result = webdriver::server::start(
116        settings.address,
117        settings.allowed_hosts,
118        settings.allowed_origins,
119        marionette::MarionetteHandler::new(settings.marionette),
120        extension_routes(),
121    );
122
123    match result {
124        Ok(listener) => {
125            info!("Listening on {}", listener.socket);
126            Ok(())
127        }
128        Err(error) => Err(error),
129    }
130}
131
132/// Get a socket address from the provided host and port
133///
134/// # Arguments
135/// * `webdriver_host` - The hostname on which the server will listen
136/// * `webdriver_port` - The port on which the server will listen
137///
138/// When the host and port resolve to multiple addresses, prefer
139/// IPv4 addresses vs IPv6.
140fn server_address(webdriver_host: &str, webdriver_port: u16) -> ProgramResult<SocketAddr> {
141    let mut socket_addrs = match format!("{}:{}", webdriver_host, webdriver_port).to_socket_addrs()
142    {
143        Ok(addrs) => addrs.collect::<Vec<_>>(),
144        Err(e) => bail!("{}: {}:{}", e, webdriver_host, webdriver_port),
145    };
146    if socket_addrs.is_empty() {
147        bail!(
148            "Unable to resolve host: {}:{}",
149            webdriver_host,
150            webdriver_port
151        )
152    }
153    // Prefer ipv4 address
154    socket_addrs.sort_by(|a, b| {
155        let a_val = i32::from(!a.ip().is_ipv4());
156        let b_val = i32::from(!b.ip().is_ipv4());
157        a_val.partial_cmp(&b_val).expect("Comparison failed")
158    });
159    Ok(socket_addrs.remove(0))
160}
161
162/// Parse a given string into a Host
163fn parse_hostname(webdriver_host: &str) -> Result<Host, url::ParseError> {
164    let host_str = if let Ok(ip_addr) = IpAddr::from_str(webdriver_host) {
165        // In this case we have an IP address as the host
166        if ip_addr.is_ipv6() {
167            // Convert to quoted form
168            format!("[{}]", &webdriver_host)
169        } else {
170            webdriver_host.into()
171        }
172    } else {
173        webdriver_host.into()
174    };
175
176    Host::parse(&host_str)
177}
178
179/// Get a list of default hostnames to allow
180///
181/// This only covers domain names, not IP addresses, since IP adresses
182/// are always accepted.
183fn get_default_allowed_hosts(ip: IpAddr) -> Vec<Host> {
184    let localhost_is_loopback = ("localhost".to_string(), 80)
185        .to_socket_addrs()
186        .map(|addr_iter| {
187            addr_iter
188                .map(|addr| addr.ip())
189                .filter(|ip| ip.is_loopback())
190        })
191        .iter()
192        .len()
193        > 0;
194    if ip.is_loopback() && localhost_is_loopback {
195        vec![Host::parse("localhost").unwrap()]
196    } else {
197        vec![]
198    }
199}
200
201fn get_allowed_hosts(host: Host, allow_hosts: Option<clap::parser::ValuesRef<Host>>) -> Vec<Host> {
202    allow_hosts
203        .map(|hosts| hosts.cloned().collect())
204        .unwrap_or_else(|| match host {
205            Host::Domain(_) => {
206                vec![host.clone()]
207            }
208            Host::Ipv4(ip) => get_default_allowed_hosts(IpAddr::V4(ip)),
209            Host::Ipv6(ip) => get_default_allowed_hosts(IpAddr::V6(ip)),
210        })
211}
212
213fn get_allowed_origins(allow_origins: Option<clap::parser::ValuesRef<Url>>) -> Vec<Url> {
214    allow_origins.into_iter().flatten().cloned().collect()
215}
216
217fn parse_args(args: &ArgMatches) -> ProgramResult<Operation> {
218    if args.get_flag("help") {
219        return Ok(Operation::Help);
220    } else if args.get_flag("version") {
221        return Ok(Operation::Version);
222    }
223
224    let log_level = if let Some(log_level) = args.get_one::<String>("log_level") {
225        Level::from_str(log_level).ok()
226    } else {
227        Some(match args.get_count("verbosity") {
228            0 => Level::Info,
229            1 => Level::Debug,
230            _ => Level::Trace,
231        })
232    };
233
234    let webdriver_host = args.get_one::<String>("webdriver_host").unwrap();
235    let webdriver_port = {
236        let s = args.get_one::<String>("webdriver_port").unwrap();
237        match u16::from_str(s) {
238            Ok(n) => n,
239            Err(e) => bail!("invalid --port: {}: {}", e, s),
240        }
241    };
242
243    let android_storage = args
244        .get_one::<String>("android_storage")
245        .and_then(|arg| AndroidStorageInput::from_str(arg).ok())
246        .unwrap_or(AndroidStorageInput::Auto);
247
248    let binary = args.get_one::<String>("binary").map(PathBuf::from);
249
250    let profile_root = args.get_one::<String>("profile_root").map(PathBuf::from);
251
252    // Try to create a temporary directory on startup to check that the directory exists and is writable
253    {
254        let tmp_dir = if let Some(ref tmp_root) = profile_root {
255            tempfile::tempdir_in(tmp_root)
256        } else {
257            tempfile::tempdir()
258        };
259        if tmp_dir.is_err() {
260            bail!("Unable to write to temporary directory; consider --profile-root with a writeable directory")
261        }
262    }
263
264    let marionette_host = args.get_one::<String>("marionette_host").unwrap();
265    let marionette_port = match args.get_one::<String>("marionette_port") {
266        Some(s) => match u16::from_str(s) {
267            Ok(n) => Some(n),
268            Err(e) => bail!("invalid --marionette-port: {}", e),
269        },
270        None => None,
271    };
272
273    // For Android the port on the device must be the same as the one on the
274    // host. For now default to 9222, which is the default for --remote-debugging-port.
275    let websocket_port = match args.get_one::<String>("websocket_port") {
276        Some(s) => match u16::from_str(s) {
277            Ok(n) => n,
278            Err(e) => bail!("invalid --websocket-port: {}", e),
279        },
280        None => 9222,
281    };
282
283    let host = match parse_hostname(webdriver_host) {
284        Ok(name) => name,
285        Err(e) => bail!("invalid --host {}: {}", webdriver_host, e),
286    };
287
288    let allow_hosts = get_allowed_hosts(host, args.get_many("allow_hosts"));
289
290    let allow_origins = get_allowed_origins(args.get_many("allow_origins"));
291
292    let address = server_address(webdriver_host, webdriver_port)?;
293
294    let settings = MarionetteSettings {
295        binary,
296        profile_root,
297        connect_existing: args.get_flag("connect_existing"),
298        host: marionette_host.into(),
299        port: marionette_port,
300        websocket_port,
301        allow_hosts: allow_hosts.clone(),
302        allow_origins: allow_origins.clone(),
303        jsdebugger: args.get_flag("jsdebugger"),
304        android_storage,
305        system_access: args.get_flag("allow_system_access"),
306    };
307    Ok(Operation::Server {
308        log_level,
309        log_truncate: !args.get_flag("log_no_truncate"),
310        allow_hosts,
311        allow_origins,
312        address,
313        settings,
314        deprecated_enable_crash_reporter: args.get_flag("enable_crash_reporter"),
315        deprecated_storage_arg: args.contains_id("android_storage"),
316    })
317}
318
319fn inner_main(operation: Operation, cmd: &mut Command) -> ProgramResult<()> {
320    match operation {
321        Operation::Help => print_help(cmd),
322        Operation::Version => print_version(),
323
324        Operation::Server {
325            log_level,
326            log_truncate,
327            address,
328            allow_hosts,
329            allow_origins,
330            settings,
331            deprecated_enable_crash_reporter,
332            deprecated_storage_arg,
333        } => {
334            if let Some(ref level) = log_level {
335                logging::init_with_level(*level, log_truncate).unwrap();
336            } else {
337                logging::init(log_truncate).unwrap();
338            }
339
340            if deprecated_storage_arg {
341                warn!("--android-storage argument is deprecated and will be removed soon.");
342            };
343
344            if deprecated_enable_crash_reporter {
345                warn!("--enable-crash-reporter argument is deprecated and will be removed in the next version.");
346            }
347
348            let handler = MarionetteHandler::new(settings);
349            let listening = webdriver::server::start(
350                address,
351                allow_hosts,
352                allow_origins,
353                handler,
354                extension_routes(),
355            )?;
356            info!("Listening on {}", listening.socket);
357        }
358    }
359
360    Ok(())
361}
362
363pub fn bin_main() -> ExitCode {
364    let mut cmd = make_command();
365
366    let args = match cmd.try_get_matches_from_mut(env::args()) {
367        Ok(args) => args,
368        Err(e) => {
369            // Clap already says "error:" and don't repeat help.
370            eprintln!("{}: {}", get_program_name(), e);
371            return ExitCode::from(EXIT_USAGE);
372        }
373    };
374
375    let operation = match parse_args(&args) {
376        Ok(op) => op,
377        Err(e) => {
378            eprintln!("{}: error: {}", get_program_name(), e);
379            print_help(&mut cmd);
380            return ExitCode::from(EXIT_USAGE);
381        }
382    };
383
384    if let Err(e) = inner_main(operation, &mut cmd) {
385        eprintln!("{}: error: {}", get_program_name(), e);
386        print_help(&mut cmd);
387        return ExitCode::from(EXIT_UNAVAILABLE);
388    }
389
390    ExitCode::SUCCESS
391}
392
393fn make_command() -> Command {
394    Command::new(format!("geckodriver {}", build::build_info()))
395        .disable_help_flag(true)
396        .disable_version_flag(true)
397        .about("WebDriver implementation for Firefox")
398        .arg(
399            Arg::new("allow_hosts")
400                .long("allow-hosts")
401                .num_args(1..)
402                .value_parser(clap::builder::ValueParser::new(Host::parse))
403                .value_name("ALLOW_HOSTS")
404                .help("List of hostnames to allow. By default the value of --host is allowed, and in addition if that's a well known local address, other variations on well known local addresses are allowed. If --allow-hosts is provided only exactly those hosts are allowed."),
405        )
406        .arg(
407            Arg::new("allow_origins")
408                .long("allow-origins")
409                .num_args(1..)
410                .value_parser(clap::builder::ValueParser::new(Url::parse))
411                .value_name("ALLOW_ORIGINS")
412                .help("List of request origins to allow. These must be formatted as scheme://host:port. By default any request with an origin header is rejected. If --allow-origins is provided then only exactly those origins are allowed."),
413        )
414        .arg(
415            Arg::new("allow_system_access")
416                .long("allow-system-access")
417                .action(ArgAction::SetTrue)
418                .help("Enable privileged access to the application's parent process"),
419        )
420        .arg(
421            Arg::new("android_storage")
422                .long("android-storage")
423                .value_parser(["auto", "app", "internal", "sdcard"])
424                .value_name("ANDROID_STORAGE")
425                .help("Selects storage location to be used for test data (deprecated)."),
426        )
427        .arg(
428            Arg::new("binary")
429                .short('b')
430                .long("binary")
431                .num_args(1)
432                .value_name("BINARY")
433                .help("Path to the Firefox binary"),
434        )
435        .arg(
436            Arg::new("connect_existing")
437                .long("connect-existing")
438                .requires("marionette_port")
439                .action(ArgAction::SetTrue)
440                .help("Connect to an existing Firefox instance"),
441        )
442        .arg(
443            Arg::new("enable_crash_reporter")
444                .long("enable-crash-reporter")
445                .action(ArgAction::SetTrue)
446                .help("Enable the Firefox crash reporter for diagnostic purposes (deprecated)"),
447        )
448        .arg(
449            Arg::new("help")
450                .short('h')
451                .long("help")
452                .action(ArgAction::SetTrue)
453                .help("Prints this message"),
454        )
455        .arg(
456            Arg::new("webdriver_host")
457                .long("host")
458                .num_args(1)
459                .value_name("HOST")
460                .default_value("127.0.0.1")
461                .help("Host IP to use for WebDriver server"),
462        )
463        .arg(
464            Arg::new("jsdebugger")
465                .long("jsdebugger")
466                .action(ArgAction::SetTrue)
467                .help("Attach browser toolbox debugger for Firefox"),
468        )
469        .arg(
470            Arg::new("log_level")
471                .long("log")
472                .num_args(1)
473                .value_name("LEVEL")
474                .value_parser(["fatal", "error", "warn", "info", "config", "debug", "trace"])
475                .help("Set Gecko log level"),
476        )
477        .arg(
478            Arg::new("log_no_truncate")
479                .long("log-no-truncate")
480                .action(ArgAction::SetTrue)
481                .help("Disable truncation of long log lines"),
482        )
483        .arg(
484            Arg::new("marionette_host")
485                .long("marionette-host")
486                .num_args(1)
487                .value_name("HOST")
488                .default_value("127.0.0.1")
489                .help("Host to use to connect to Gecko"),
490        )
491        .arg(
492            Arg::new("marionette_port")
493                .long("marionette-port")
494                .num_args(1)
495                .value_name("PORT")
496                .help("Port to use to connect to Gecko [default: system-allocated port]"),
497        )
498        .arg(
499            Arg::new("webdriver_port")
500                .short('p')
501                .long("port")
502                .num_args(1)
503                .value_name("PORT")
504                .default_value("4444")
505                .help("Port to use for WebDriver server"),
506        )
507        .arg(
508            Arg::new("profile_root")
509                .long("profile-root")
510                .num_args(1)
511                .value_name("PROFILE_ROOT")
512                .help("Directory in which to create profiles. Defaults to the system temporary directory."),
513        )
514        .arg(
515            Arg::new("verbosity")
516                .conflicts_with("log_level")
517                .short('v')
518                .action(ArgAction::Count)
519                .help("Log level verbosity (-v for debug and -vv for trace level)"),
520        )
521        .arg(
522            Arg::new("version")
523                .short('V')
524                .long("version")
525                .action(ArgAction::SetTrue)
526                .help("Prints version and copying information"),
527        )
528        .arg(
529            Arg::new("websocket_port")
530                .long("websocket-port")
531                .num_args(1)
532                .value_name("PORT")
533                .conflicts_with("connect_existing")
534                .help("Port to use to connect to WebDriver BiDi [default: 9222]"),
535        )
536}
537
538fn get_program_name() -> String {
539    env::args().next().unwrap()
540}
541
542fn print_help(cmd: &mut Command) {
543    cmd.print_help().ok();
544    println!();
545}
546
547fn print_version() {
548    println!("geckodriver {}", build::build_info());
549    println!();
550    println!("The source code of this program is available from");
551    println!("testing/geckodriver in https://hg.mozilla.org/mozilla-central.");
552    println!();
553    println!("This program is subject to the terms of the Mozilla Public License 2.0.");
554    println!("You can obtain a copy of the license at https://mozilla.org/MPL/2.0/.");
555}