1#![forbid(unsafe_code)]
2
3use hashes::get_expected_driver_hash;
30use http_req::request;
31use paths::get_cache_dir;
32use pico_common::Driver;
33use pico_driver::Resolution;
34use ring::digest::{Context, Digest, SHA256};
35use std::{fs, io::Read, path::Path};
36use thiserror::Error;
37
38mod hashes;
39mod paths;
40
41#[derive(Error, Debug)]
42pub enum DriverDownloadError {
43 #[error("IO Error: {0}")]
44 IoError(#[from] std::io::Error),
45
46 #[error("HTTP request Error: {0}")]
47 HttpRequestError(#[from] http_req::error::Error),
48
49 #[error("HTTP response Error: {0}")]
50 HttpResponseError(http_req::response::StatusCode),
51
52 #[error("Invalid driver hash")]
53 HashMismatch,
54}
55
56pub fn cache_resolution() -> Resolution {
58 Resolution::Custom(get_cache_dir())
59}
60
61pub fn download_drivers<P: AsRef<Path>>(
63 drivers: &[Driver],
64 path: P,
65) -> Result<(), DriverDownloadError> {
66 let driver_dir = path.as_ref().to_path_buf();
67
68 let required_files = [drivers, &Driver::get_dependencies_for_platform()].concat();
69
70 fs::create_dir_all(&driver_dir)?;
71
72 for driver in required_files {
73 let file_path = driver_dir.join(&driver.get_binary_name());
74
75 if file_path.exists() {
76 match driver {
77 Driver::PicoIPP | Driver::IOMP5 => {
78 continue;
79 }
80 _ => {
81 fs::remove_file(&file_path)?;
82 }
83 }
84 }
85
86 download_driver(driver, &driver_dir)?;
87 }
88
89 Ok(())
90}
91
92pub fn download_drivers_to_cache(drivers: &[Driver]) -> Result<(), DriverDownloadError> {
99 download_drivers(drivers, get_cache_dir())
100}
101
102fn sha256_digest_for_file<P: AsRef<Path>>(path: P) -> Result<Digest, DriverDownloadError> {
103 let mut src_file = fs::File::open(&path)?;
104
105 let mut context = Context::new(&SHA256);
106 let mut buffer = [0; 1024];
107
108 loop {
109 let count = src_file.read(&mut buffer)?;
110 if count == 0 {
111 break;
112 }
113 context.update(&buffer[..count]);
114 }
115
116 Ok(context.finish())
117}
118
119fn download_driver(driver: Driver, dst_dir: &Path) -> Result<(), DriverDownloadError> {
120 let name = driver.get_binary_name();
121
122 let url = format!(
123 "https://pico-drivers.s3.eu-west-2.amazonaws.com/{}/{}/{}",
124 std::env::consts::OS,
125 std::env::consts::ARCH,
126 name
127 );
128
129 let dst_temp_path = dst_dir.join(name.to_string() + ".temp");
130
131 let mut dst_file = fs::File::create(&dst_temp_path)?;
132 let response = request::get(url, &mut dst_file)?;
133
134 if response.status_code().is_success() {
135 let computed_hash = format!("{:?}", sha256_digest_for_file(&dst_temp_path)?);
136
137 let expected_hash = get_expected_driver_hash(driver);
138
139 if computed_hash == expected_hash {
140 let dst_path = dst_dir.join(name);
141 fs::copy(&dst_temp_path, dst_path)?;
142 fs::remove_file(dst_temp_path)?;
143
144 Ok(())
145 } else {
146 Err(DriverDownloadError::HashMismatch)
147 }
148 } else {
149 Err(DriverDownloadError::HttpResponseError(
150 response.status_code(),
151 ))
152 }
153}